Repository: YaoApp/yao Branch: main Commit: e9892784f510 Files: 1900 Total size: 25.3 MB Directory structure: gitextract_ug5_rm40/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── config.yml │ │ └── issue_report.md │ ├── actions/ │ │ ├── setup-db/ │ │ │ ├── Dockerfile │ │ │ └── action.yml │ │ └── setup-yao/ │ │ └── action.yml │ ├── codesign/ │ │ └── entitlements.plist │ ├── env/ │ │ └── sandbox-v2.env │ └── workflows/ │ ├── build-docker.yml │ ├── build-linux.yml │ ├── build-macos.yml │ ├── notarize-macos.yml │ ├── pr-receive.yml │ ├── pr-test.yml │ ├── release-linux.yml │ ├── release-macos.yml │ ├── release.yml │ ├── unit-test-v1.yml │ └── unit-test.yml ├── .gitignore ├── COMMERCIAL_LICENSE.md ├── COMMERCIAL_LICENSE.zh-CN.md ├── LICENSE ├── Makefile ├── README.md ├── README.zh-CN.md ├── agent/ │ ├── README.md │ ├── agent_test.go │ ├── assistant/ │ │ ├── agent.go │ │ ├── agent_interrupt_test.go │ │ ├── agent_next_test.go │ │ ├── assistant.go │ │ ├── build.go │ │ ├── build_content.go │ │ ├── build_mcp_test.go │ │ ├── build_prompts_test.go │ │ ├── build_test.go │ │ ├── cache.go │ │ ├── cache_test.go │ │ ├── chat.go │ │ ├── chat_test.go │ │ ├── handlers/ │ │ │ └── stream.go │ │ ├── history.go │ │ ├── history_test.go │ │ ├── hook/ │ │ │ ├── REALWORLD_PERFORMANCE_REPORT.md │ │ │ ├── create.go │ │ │ ├── create_bench_test.go │ │ │ ├── create_mem_test.go │ │ │ ├── create_nested_test.go │ │ │ ├── create_test.go │ │ │ ├── goroutine_leak_test.go │ │ │ ├── hook.go │ │ │ ├── next.go │ │ │ ├── next_test.go │ │ │ ├── realworld_next_test.go │ │ │ ├── realworld_stress_test.go │ │ │ ├── script.go │ │ │ └── types.go │ │ ├── llm.go │ │ ├── load.go │ │ ├── load_merge_test.go │ │ ├── load_process_test.go │ │ ├── load_store_test.go │ │ ├── load_system.go │ │ ├── load_test.go │ │ ├── mcp.go │ │ ├── mcp_test.go │ │ ├── next.go │ │ ├── permission.go │ │ ├── sandbox.go │ │ ├── sandbox_debug_test.go │ │ ├── sandbox_e2e_test.go │ │ ├── sandbox_integration_test.go │ │ ├── sandbox_test.go │ │ ├── sandbox_v2.go │ │ ├── scripts.go │ │ ├── scripts_process_test.go │ │ ├── scripts_test.go │ │ ├── search.go │ │ ├── search_auth_db.go │ │ ├── search_auth_integration_test.go │ │ ├── search_auth_kb.go │ │ ├── search_auto_disabled_test.go │ │ ├── search_auto_full_test.go │ │ ├── search_auto_hook_disable_test.go │ │ ├── search_auto_keyword_test.go │ │ ├── search_auto_web_test.go │ │ ├── source.go │ │ ├── trace.go │ │ ├── types.go │ │ └── utils.go │ ├── caller/ │ │ ├── caller.go │ │ ├── context.go │ │ ├── integration_test.go │ │ ├── jsapi.go │ │ ├── jsapi_test.go │ │ ├── orchestrator.go │ │ ├── orchestrator_test.go │ │ ├── process.go │ │ ├── process_e2e_test.go │ │ ├── process_test.go │ │ ├── sandbox_integration_test.go │ │ ├── types.go │ │ └── types_test.go │ ├── content/ │ │ ├── content.go │ │ ├── docx/ │ │ │ ├── docx.go │ │ │ └── docx_test.go │ │ ├── image/ │ │ │ ├── image.go │ │ │ └── image_test.go │ │ ├── link/ │ │ │ └── link.go │ │ ├── pdf/ │ │ │ ├── pdf.go │ │ │ └── pdf_test.go │ │ ├── pptx/ │ │ │ ├── pptx.go │ │ │ └── pptx_test.go │ │ ├── text/ │ │ │ ├── text.go │ │ │ └── text_test.go │ │ ├── tools/ │ │ │ └── tools.go │ │ └── types/ │ │ └── types.go │ ├── context/ │ │ ├── JSAPI.md │ │ ├── RESOURCE_MANAGEMENT.md │ │ ├── authorized_test.go │ │ ├── buffer.go │ │ ├── buffer_test.go │ │ ├── chat.go │ │ ├── chat_test.go │ │ ├── context.go │ │ ├── context_test.go │ │ ├── grpc.go │ │ ├── interfaces.go │ │ ├── interrupt.go │ │ ├── interrupt_test.go │ │ ├── jsapi.go │ │ ├── jsapi_agent.go │ │ ├── jsapi_agent_test.go │ │ ├── jsapi_agent_v8_test.go │ │ ├── jsapi_computer.go │ │ ├── jsapi_helpers.go │ │ ├── jsapi_llm.go │ │ ├── jsapi_llm_v8_test.go │ │ ├── jsapi_mcp.go │ │ ├── jsapi_mcp_test.go │ │ ├── jsapi_mcp_v8_test.go │ │ ├── jsapi_memory_test.go │ │ ├── jsapi_output_test.go │ │ ├── jsapi_release_test.go │ │ ├── jsapi_sandbox.go │ │ ├── jsapi_sandbox_test.go │ │ ├── jsapi_search.go │ │ ├── jsapi_search_test.go │ │ ├── jsapi_stress_test.go │ │ ├── jsapi_test.go │ │ ├── jsapi_workspace.go │ │ ├── log.go │ │ ├── mcp.go │ │ ├── mcp_test.go │ │ ├── message.go │ │ ├── message_events_test.go │ │ ├── message_test.go │ │ ├── openapi.go │ │ ├── openapi_test.go │ │ ├── options.go │ │ ├── output.go │ │ ├── stack.go │ │ ├── stack_test.go │ │ ├── types.go │ │ ├── types_llm.go │ │ └── utils.go │ ├── docs/ │ │ ├── configuration.md │ │ ├── context-api.md │ │ ├── hooks.md │ │ ├── i18n.md │ │ ├── iframe.md │ │ ├── mcp.md │ │ ├── models.md │ │ ├── pages.md │ │ ├── prompts.md │ │ ├── search.md │ │ └── testing.md │ ├── i18n/ │ │ ├── builtin.go │ │ ├── i18n.go │ │ └── i18n_test.go │ ├── llm/ │ │ ├── adapters/ │ │ │ ├── adapter.go │ │ │ ├── audio.go │ │ │ ├── reasoning.go │ │ │ ├── toolcall.go │ │ │ └── vision.go │ │ ├── capabilities.go │ │ ├── interfaces.go │ │ ├── jsapi.go │ │ ├── jsapi_types.go │ │ ├── llm.go │ │ ├── process.go │ │ └── providers/ │ │ ├── ANTHROPIC_PROVIDER_PROPOSAL.md │ │ ├── README.md │ │ ├── anthropic/ │ │ │ ├── anthropic.go │ │ │ ├── anthropic_test.go │ │ │ └── types.go │ │ ├── base/ │ │ │ └── base.go │ │ ├── factory.go │ │ └── openai/ │ │ ├── claude_test.go │ │ ├── deepseek_r1_test.go │ │ ├── deepseek_v3_test.go │ │ ├── gpt5_test.go │ │ ├── openai.go │ │ ├── openai_test.go │ │ ├── temperature_test.go │ │ └── types.go │ ├── load.go │ ├── load_test.go │ ├── memory/ │ │ ├── interfaces.go │ │ ├── manager.go │ │ ├── memory.go │ │ ├── memory_test.go │ │ ├── namespace.go │ │ └── types.go │ ├── output/ │ │ ├── BUILTIN_TYPES.md │ │ ├── README.md │ │ ├── adapters/ │ │ │ ├── cui/ │ │ │ │ ├── adapter.go │ │ │ │ └── writer.go │ │ │ └── openai/ │ │ │ ├── README.md │ │ │ ├── adapter.go │ │ │ ├── converter.go │ │ │ ├── types.go │ │ │ └── writer.go │ │ ├── builtin.go │ │ ├── jsapi/ │ │ │ ├── README.md │ │ │ ├── output.go │ │ │ └── output_test.go │ │ ├── message/ │ │ │ ├── STREAMING.md │ │ │ ├── interfaces.go │ │ │ ├── types.go │ │ │ ├── utils.go │ │ │ └── utils_test.go │ │ ├── output.go │ │ ├── safe_writer.go │ │ └── safe_writer_test.go │ ├── robot/ │ │ ├── DESIGN-V2-REVIEW-FINDINGS.md │ │ ├── DESIGN.md │ │ ├── TECHNICAL.md │ │ ├── TODO.md │ │ ├── V2-IMPROVEMENTS.md │ │ ├── api/ │ │ │ ├── README.md │ │ │ ├── activities.go │ │ │ ├── api_test.go │ │ │ ├── e2e_clock_test.go │ │ │ ├── e2e_concurrent_test.go │ │ │ ├── e2e_control_test.go │ │ │ ├── e2e_event_test.go │ │ │ ├── e2e_human_test.go │ │ │ ├── e2e_interact_test.go │ │ │ ├── e2e_suspend_test.go │ │ │ ├── execution.go │ │ │ ├── execution_test.go │ │ │ ├── interact.go │ │ │ ├── interact_test.go │ │ │ ├── lifecycle.go │ │ │ ├── lifecycle_test.go │ │ │ ├── results.go │ │ │ ├── robot.go │ │ │ ├── robot_test.go │ │ │ ├── trigger.go │ │ │ ├── trigger_test.go │ │ │ └── types.go │ │ ├── cache/ │ │ │ ├── cache.go │ │ │ ├── cache_test.go │ │ │ ├── load.go │ │ │ └── refresh.go │ │ ├── dedup/ │ │ │ └── dedup.go │ │ ├── events/ │ │ │ ├── delivery.go │ │ │ ├── event_push_test.go │ │ │ ├── events.go │ │ │ ├── events_test.go │ │ │ ├── handlers.go │ │ │ ├── handlers_test.go │ │ │ ├── integrations/ │ │ │ │ ├── dingtalk/ │ │ │ │ │ ├── dedup.go │ │ │ │ │ ├── dingtalk.go │ │ │ │ │ ├── e2e_test.go │ │ │ │ │ ├── message.go │ │ │ │ │ ├── reply.go │ │ │ │ │ └── stream.go │ │ │ │ ├── discord/ │ │ │ │ │ ├── dedup.go │ │ │ │ │ ├── discord.go │ │ │ │ │ ├── e2e_test.go │ │ │ │ │ ├── gateway.go │ │ │ │ │ ├── message.go │ │ │ │ │ └── reply.go │ │ │ │ ├── dispatcher.go │ │ │ │ ├── dispatcher_test.go │ │ │ │ ├── feishu/ │ │ │ │ │ ├── dedup.go │ │ │ │ │ ├── e2e_test.go │ │ │ │ │ ├── feishu.go │ │ │ │ │ ├── message.go │ │ │ │ │ ├── reply.go │ │ │ │ │ └── stream.go │ │ │ │ └── telegram/ │ │ │ │ ├── dedup.go │ │ │ │ ├── e2e_test.go │ │ │ │ ├── message.go │ │ │ │ ├── polling.go │ │ │ │ ├── reply.go │ │ │ │ ├── telegram.go │ │ │ │ └── webhook.go │ │ │ ├── log.go │ │ │ └── message.go │ │ ├── executor/ │ │ │ ├── README.md │ │ │ ├── dryrun/ │ │ │ │ └── executor.go │ │ │ ├── executor.go │ │ │ ├── executor_test.go │ │ │ ├── sandbox/ │ │ │ │ └── executor.go │ │ │ ├── standard/ │ │ │ │ ├── agent.go │ │ │ │ ├── agent_stream_test.go │ │ │ │ ├── agent_test.go │ │ │ │ ├── delivery.go │ │ │ │ ├── delivery_test.go │ │ │ │ ├── executor.go │ │ │ │ ├── executor_test.go │ │ │ │ ├── goals.go │ │ │ │ ├── goals_test.go │ │ │ │ ├── host.go │ │ │ │ ├── host_test.go │ │ │ │ ├── input.go │ │ │ │ ├── input_integration_test.go │ │ │ │ ├── input_test.go │ │ │ │ ├── inspiration.go │ │ │ │ ├── inspiration_test.go │ │ │ │ ├── learning.go │ │ │ │ ├── log.go │ │ │ │ ├── resume_test.go │ │ │ │ ├── run.go │ │ │ │ ├── run_test.go │ │ │ │ ├── runner.go │ │ │ │ ├── runner_test.go │ │ │ │ ├── suspend_resume_test.go │ │ │ │ ├── suspend_test.go │ │ │ │ ├── tasks.go │ │ │ │ ├── tasks_test.go │ │ │ │ ├── ui_fields_test.go │ │ │ │ ├── validator.go │ │ │ │ └── validator_test.go │ │ │ └── types/ │ │ │ ├── helpers.go │ │ │ └── types.go │ │ ├── logger/ │ │ │ └── logger.go │ │ ├── manager/ │ │ │ ├── integration_clock_test.go │ │ │ ├── integration_concurrent_test.go │ │ │ ├── integration_control_test.go │ │ │ ├── integration_event_test.go │ │ │ ├── integration_human_test.go │ │ │ ├── integration_test.go │ │ │ ├── interact.go │ │ │ ├── interact_helpers_test.go │ │ │ ├── interact_test.go │ │ │ ├── manager.go │ │ │ └── manager_test.go │ │ ├── plan/ │ │ │ └── plan.go │ │ ├── pool/ │ │ │ ├── goroutine_test.go │ │ │ ├── pool.go │ │ │ ├── pool_test.go │ │ │ ├── queue.go │ │ │ ├── queue_test.go │ │ │ ├── worker.go │ │ │ └── worker_test.go │ │ ├── process.go │ │ ├── process_test.go │ │ ├── robot.go │ │ ├── store/ │ │ │ ├── execution.go │ │ │ ├── execution_test.go │ │ │ ├── robot.go │ │ │ ├── robot_test.go │ │ │ └── store.go │ │ ├── trigger/ │ │ │ ├── clock.go │ │ │ ├── clock_test.go │ │ │ ├── control.go │ │ │ ├── control_test.go │ │ │ ├── trigger.go │ │ │ └── trigger_test.go │ │ ├── types/ │ │ │ ├── clock.go │ │ │ ├── clock_test.go │ │ │ ├── config.go │ │ │ ├── config_global.go │ │ │ ├── config_test.go │ │ │ ├── context.go │ │ │ ├── enums.go │ │ │ ├── enums_test.go │ │ │ ├── errors.go │ │ │ ├── host.go │ │ │ ├── host_test.go │ │ │ ├── inspiration.go │ │ │ ├── interfaces.go │ │ │ ├── request.go │ │ │ ├── robot.go │ │ │ └── robot_test.go │ │ └── utils/ │ │ ├── convert.go │ │ ├── convert_test.go │ │ ├── id.go │ │ ├── time.go │ │ ├── utils_test.go │ │ └── validate.go │ ├── sandbox/ │ │ ├── DESIGN.md │ │ ├── PLAN.md │ │ ├── claude/ │ │ │ ├── attachments_test.go │ │ │ ├── command.go │ │ │ ├── command_test.go │ │ │ ├── e2e_test.go │ │ │ ├── executor.go │ │ │ ├── executor_test.go │ │ │ ├── real_e2e_test.go │ │ │ └── types.go │ │ ├── cursor/ │ │ │ └── README.md │ │ ├── executor.go │ │ ├── executor_test.go │ │ ├── integration_test.go │ │ ├── types.go │ │ ├── types_test.go │ │ └── v2/ │ │ ├── claude/ │ │ │ ├── attachments.go │ │ │ ├── oscompat.go │ │ │ ├── parse.go │ │ │ ├── runner.go │ │ │ ├── runner_test.go │ │ │ └── testdata/ │ │ │ └── code.ts │ │ ├── init.go │ │ ├── lifecycle.go │ │ ├── lifecycle_test.go │ │ ├── options.go │ │ ├── prepare.go │ │ ├── prepare_test.go │ │ ├── runner.go │ │ ├── shell.go │ │ ├── stream.go │ │ ├── testutils/ │ │ │ └── testutils.go │ │ ├── testutils_remote_test.go │ │ ├── testutils_test.go │ │ ├── testutils_wintest_test.go │ │ ├── token.go │ │ ├── types/ │ │ │ ├── config.go │ │ │ ├── runner.go │ │ │ └── token.go │ │ └── yao/ │ │ ├── runner.go │ │ └── runner_test.go │ ├── search/ │ │ ├── DESIGN.md │ │ ├── citation.go │ │ ├── citation_test.go │ │ ├── defaults/ │ │ │ └── defaults.go │ │ ├── handlers/ │ │ │ ├── db/ │ │ │ │ ├── handler.go │ │ │ │ ├── handler_integration_test.go │ │ │ │ └── handler_test.go │ │ │ ├── kb/ │ │ │ │ ├── handler.go │ │ │ │ └── handler_test.go │ │ │ └── web/ │ │ │ ├── agent.go │ │ │ ├── agent_test.go │ │ │ ├── handler.go │ │ │ ├── mcp.go │ │ │ ├── mcp_test.go │ │ │ ├── serpapi.go │ │ │ ├── serpapi_test.go │ │ │ ├── serper.go │ │ │ ├── serper_test.go │ │ │ ├── tavily.go │ │ │ └── tavily_test.go │ │ ├── interfaces/ │ │ │ ├── handler.go │ │ │ ├── nlp.go │ │ │ ├── reranker.go │ │ │ └── searcher.go │ │ ├── jsapi.go │ │ ├── jsapi_db_test.go │ │ ├── jsapi_test.go │ │ ├── nlp/ │ │ │ ├── keyword/ │ │ │ │ ├── agent.go │ │ │ │ ├── agent_test.go │ │ │ │ ├── extractor.go │ │ │ │ ├── extractor_test.go │ │ │ │ ├── mcp.go │ │ │ │ └── mcp_test.go │ │ │ └── querydsl/ │ │ │ ├── agent.go │ │ │ ├── agent_test.go │ │ │ ├── generator.go │ │ │ ├── generator_test.go │ │ │ ├── mcp.go │ │ │ ├── mcp_test.go │ │ │ └── types.go │ │ ├── reference.go │ │ ├── reference_test.go │ │ ├── registry.go │ │ ├── registry_test.go │ │ ├── rerank/ │ │ │ ├── agent.go │ │ │ ├── agent_test.go │ │ │ ├── builtin.go │ │ │ ├── builtin_test.go │ │ │ ├── mcp.go │ │ │ ├── mcp_test.go │ │ │ ├── reranker.go │ │ │ └── reranker_test.go │ │ ├── search.go │ │ ├── search_test.go │ │ ├── search_web_test.go │ │ └── types/ │ │ ├── config.go │ │ ├── graph.go │ │ ├── reference.go │ │ └── types.go │ ├── store/ │ │ ├── CHAT_STORAGE_DESIGN.md │ │ ├── README.md │ │ ├── mongo/ │ │ │ └── mongo.go │ │ ├── redis/ │ │ │ └── redis.go │ │ ├── types/ │ │ │ ├── convert.go │ │ │ ├── convert_test.go │ │ │ ├── fields.go │ │ │ ├── fields_test.go │ │ │ ├── mcp_test.go │ │ │ ├── prompt.go │ │ │ ├── prompt_test.go │ │ │ ├── sandbox_v2.go │ │ │ ├── store.go │ │ │ └── types.go │ │ └── xun/ │ │ ├── assistant.go │ │ ├── assistant_test.go │ │ ├── chat.go │ │ ├── chat_test.go │ │ ├── message.go │ │ ├── message_test.go │ │ ├── resume.go │ │ ├── resume_test.go │ │ ├── search.go │ │ ├── search_test.go │ │ ├── utils.go │ │ ├── utils_test.go │ │ └── xun.go │ ├── test/ │ │ ├── DESIGN.md │ │ ├── DESIGN_V2.md │ │ ├── README.md │ │ ├── assert.go │ │ ├── assert_agent_test.go │ │ ├── assert_test.go │ │ ├── context.go │ │ ├── dynamic_integration_test.go │ │ ├── dynamic_runner.go │ │ ├── dynamic_runner_test.go │ │ ├── dynamic_types.go │ │ ├── extract.go │ │ ├── input.go │ │ ├── input_source.go │ │ ├── input_source_test.go │ │ ├── input_test.go │ │ ├── interfaces.go │ │ ├── loader.go │ │ ├── output.go │ │ ├── reporter.go │ │ ├── resolver.go │ │ ├── runner.go │ │ ├── runner_integration_test.go │ │ ├── script.go │ │ ├── script_assert.go │ │ ├── script_hooks.go │ │ ├── script_hooks_test.go │ │ ├── script_types.go │ │ └── types.go │ ├── testutils/ │ │ └── testutils.go │ └── types/ │ ├── dsl.go │ └── types.go ├── aigc/ │ ├── aigc.go │ ├── aigc_test.go │ ├── load.go │ ├── load_test.go │ ├── process.go │ ├── process_test.go │ └── types.go ├── api/ │ ├── README.md │ ├── api.go │ └── api_test.go ├── assert/ │ ├── asserter.go │ ├── asserter_test.go │ ├── helpers.go │ └── types.go ├── attachment/ │ ├── README.md │ ├── compresses.go │ ├── convert.go │ ├── example_usage.go │ ├── fileheader.go │ ├── gzip.go │ ├── load.go │ ├── load_test.go │ ├── local/ │ │ ├── storage.go │ │ └── storage_test.go │ ├── manager.go │ ├── manager_test.go │ ├── process.go │ ├── process_test.go │ ├── s3/ │ │ ├── storage.go │ │ └── storage_test.go │ └── types.go ├── audit/ │ └── README.md ├── bin/ │ └── yao-dev ├── cert/ │ ├── cert.go │ └── cert_test.go ├── cmd/ │ ├── README.md │ ├── agent/ │ │ ├── add.go │ │ ├── agent.go │ │ ├── extract.go │ │ ├── fork.go │ │ ├── push.go │ │ ├── test.go │ │ └── update.go │ ├── ci-token/ │ │ └── main.go │ ├── credential.go │ ├── dump.go │ ├── get/ │ │ ├── get.go │ │ └── get_test.go │ ├── get.go │ ├── help.go │ ├── init.go │ ├── inspect.go │ ├── login.go │ ├── logout.go │ ├── mcp/ │ │ ├── add.go │ │ ├── fork.go │ │ ├── mcp.go │ │ ├── push.go │ │ └── update.go │ ├── migrate.go │ ├── pack.go │ ├── restore.go │ ├── robot/ │ │ ├── add.go │ │ └── robot.go │ ├── root.go │ ├── run.go │ ├── socket.go │ ├── start.go │ ├── sui/ │ │ ├── build.go │ │ ├── sui.go │ │ ├── trans.go │ │ ├── utils.go │ │ └── watch.go │ ├── tea.go │ ├── upgrade.go │ ├── version.go │ └── websocket.go ├── config/ │ ├── config.go │ ├── config_test.go │ └── types.go ├── connector/ │ ├── connector.go │ └── connector_test.go ├── crypto/ │ ├── aes.go │ ├── aes_test.go │ ├── crypto.go │ ├── crypto_test.go │ └── process.go ├── cui/ │ ├── setup/ │ │ └── index.html │ ├── v0.9/ │ │ └── index.html │ └── v1.0/ │ ├── index.html │ ├── layouts__index.async.js │ └── umi.js ├── data/ │ ├── bindata.go │ ├── data.go │ └── data_test.go ├── docker/ │ ├── build/ │ │ └── Dockerfile │ ├── development/ │ │ └── Dockerfile │ └── production/ │ └── Dockerfile ├── docs/ │ └── README.md ├── dsl/ │ ├── api/ │ │ └── api.go │ ├── connector/ │ │ ├── cases_test.go │ │ ├── connector.go │ │ └── connector_test.go │ ├── dsl.go │ ├── dsl_test.go │ ├── io/ │ │ ├── cases_test.go │ │ ├── db.go │ │ ├── db_test.go │ │ ├── fs.go │ │ ├── fs_test.go │ │ └── utils.go │ ├── mcp/ │ │ ├── cases_test.go │ │ ├── client.go │ │ ├── client_test.go │ │ └── server.go │ ├── model/ │ │ ├── cases_test.go │ │ ├── model.go │ │ └── model_test.go │ └── types/ │ ├── interfaces.go │ ├── types.go │ ├── utils.go │ └── utils_test.go ├── engine/ │ ├── load.go │ ├── load_test.go │ ├── machine.go │ ├── machine_darwin.go │ ├── machine_linux.go │ ├── machine_test.go │ ├── machine_windows.go │ ├── process.go │ └── process_test.go ├── event/ │ ├── README.md │ ├── bench_test.go │ ├── bus.go │ ├── bus_test.go │ ├── leak_test.go │ ├── listener.go │ ├── listener_test.go │ ├── option.go │ ├── queue.go │ ├── queue_test.go │ ├── service.go │ ├── service_test.go │ ├── sub.go │ ├── sub_test.go │ ├── types/ │ │ ├── interfaces.go │ │ ├── types.go │ │ └── types_test.go │ ├── worker.go │ └── worker_test.go ├── excel/ │ ├── README.md │ ├── each.go │ ├── each_test.go │ ├── excel.go │ ├── excel_test.go │ ├── process.go │ ├── process_test.go │ ├── sheet.go │ ├── sheet_test.go │ ├── write.go │ └── write_test.go ├── flow/ │ ├── README.md │ ├── flow.go │ └── flow_test.go ├── fs/ │ ├── fs.go │ └── fs_test.go ├── go.mod ├── go.sum ├── grpc/ │ ├── DESIGN.md │ ├── IMPL.md │ ├── TEST.md │ ├── agent/ │ │ ├── agent.go │ │ └── agent_test.go │ ├── api/ │ │ ├── api.go │ │ └── api_test.go │ ├── auth/ │ │ ├── endpoint.go │ │ ├── endpoint_test.go │ │ ├── guard.go │ │ ├── guard_test.go │ │ └── scope.go │ ├── client/ │ │ ├── client.go │ │ └── token.go │ ├── grpc.go │ ├── health/ │ │ ├── health.go │ │ └── health_test.go │ ├── llm/ │ │ ├── llm.go │ │ └── llm_test.go │ ├── mcp/ │ │ ├── mcp.go │ │ └── mcp_test.go │ ├── pb/ │ │ ├── yao.pb.go │ │ ├── yao.proto │ │ └── yao_grpc.pb.go │ ├── run/ │ │ ├── run.go │ │ └── run_test.go │ ├── sandbox/ │ │ ├── heartbeat.go │ │ └── heartbeat_test.go │ ├── shell/ │ │ ├── shell.go │ │ └── shell_test.go │ └── tests/ │ └── testutils/ │ └── testutils.go ├── helper/ │ ├── array.go │ ├── array.process.go │ ├── array_test.go │ ├── captcha.go │ ├── captcha_test.go │ ├── case.go │ ├── case_test.go │ ├── condition.go │ ├── condition_test.go │ ├── control.process.go │ ├── control_test.go │ ├── env.process.go │ ├── env_test.go │ ├── hex.process.go │ ├── hex_test.go │ ├── if.go │ ├── if_test.go │ ├── jwt.go │ ├── jwt_test.go │ ├── map.go │ ├── map.process.go │ ├── map_test.go │ ├── password.go │ ├── password_test.go │ ├── process.go │ ├── range.go │ ├── string.process.go │ └── string_test.go ├── i18n/ │ ├── i18n.go │ └── i18n_test.go ├── importer/ │ ├── column.go │ ├── column_test.go │ ├── csv/ │ │ └── csv.go │ ├── from/ │ │ └── source.go │ ├── importer.go │ ├── importer_test.go │ ├── option.go │ ├── option_test.go │ ├── process.go │ ├── process_test.go │ ├── types.go │ └── xlsx/ │ └── xlsx.go ├── integrations/ │ ├── dingtalk/ │ │ ├── bot.go │ │ ├── bot_test.go │ │ ├── convert.go │ │ ├── convert_test.go │ │ ├── dingtalk_test.go │ │ ├── e2e_test.go │ │ ├── file.go │ │ ├── format.go │ │ └── message.go │ ├── discord/ │ │ ├── bot.go │ │ ├── bot_test.go │ │ ├── convert.go │ │ ├── convert_test.go │ │ ├── discord_test.go │ │ ├── e2e_test.go │ │ ├── file.go │ │ ├── format.go │ │ ├── message.go │ │ └── message_test.go │ ├── feishu/ │ │ ├── bot.go │ │ ├── bot_test.go │ │ ├── convert.go │ │ ├── convert_test.go │ │ ├── e2e_test.go │ │ ├── feishu_test.go │ │ ├── file.go │ │ ├── format.go │ │ └── message.go │ ├── telegram/ │ │ ├── bot.go │ │ ├── bot_test.go │ │ ├── convert.go │ │ ├── e2e_test.go │ │ ├── file.go │ │ ├── file_test.go │ │ ├── format.go │ │ ├── media_e2e_test.go │ │ ├── message.go │ │ ├── message_test.go │ │ ├── polling.go │ │ ├── telegram_test.go │ │ ├── types.go │ │ ├── verify.go │ │ ├── verify_test.go │ │ ├── webhook.go │ │ └── webhook_e2e_test.go │ └── testdata/ │ ├── test.docx │ ├── test.ogg │ └── test.pptx ├── job/ │ ├── README.md │ ├── data.go │ ├── data_test.go │ ├── execution.go │ ├── goroutine.go │ ├── health.go │ ├── health_test.go │ ├── interfaces.go │ ├── job.go │ ├── job_test.go │ ├── jsapi/ │ │ ├── jsapi.go │ │ └── jsapi_test.go │ ├── process.go │ ├── progress.go │ ├── types.go │ └── worker.go ├── kb/ │ ├── README.md │ ├── api/ │ │ ├── README.md │ │ ├── addfile.go │ │ ├── addfile_test.go │ │ ├── addtext.go │ │ ├── addtext_test.go │ │ ├── addurl.go │ │ ├── addurl_test.go │ │ ├── api.go │ │ ├── collection.go │ │ ├── collection_test.go │ │ ├── consts.go │ │ ├── document.go │ │ ├── document_test.go │ │ ├── interfaces.go │ │ ├── search.go │ │ ├── search_setup_test.go │ │ ├── search_test.go │ │ ├── types.go │ │ └── utils.go │ ├── kb.go │ ├── kb_test.go │ ├── providers/ │ │ ├── README.md │ │ ├── chunking.go │ │ ├── chunking_test.go │ │ ├── converter.go │ │ ├── converters/ │ │ │ ├── mcp.go │ │ │ ├── mcp_test.go │ │ │ ├── ocr.go │ │ │ ├── ocr_test.go │ │ │ ├── office.go │ │ │ ├── office_test.go │ │ │ ├── utf8.go │ │ │ ├── utf8_test.go │ │ │ ├── utils.go │ │ │ ├── utils_test.go │ │ │ ├── video.go │ │ │ ├── video_test.go │ │ │ ├── vision.go │ │ │ ├── vision_test.go │ │ │ ├── whisper.go │ │ │ └── whisper_test.go │ │ ├── embedding.go │ │ ├── embedding_test.go │ │ ├── extraction.go │ │ ├── extraction_test.go │ │ ├── factory/ │ │ │ ├── factory.go │ │ │ ├── interfaces.go │ │ │ └── utils.go │ │ ├── fetcher.go │ │ └── fetcher_test.go │ └── types/ │ ├── collection.go │ ├── config.go │ ├── config_test.go │ ├── document.go │ ├── provider.go │ └── types.go ├── main.go ├── mcp/ │ ├── README.md │ ├── mcp.go │ └── mcp_test.go ├── messenger/ │ ├── messenger.go │ ├── messenger_onreceive_test.go │ ├── messenger_receiver_test.go │ ├── messenger_sendt_test.go │ ├── messenger_test.go │ ├── providers/ │ │ ├── mailer/ │ │ │ ├── mailer.go │ │ │ ├── mailer_batch_test.go │ │ │ ├── mailer_receive.go │ │ │ ├── mailer_template_test.go │ │ │ └── mailer_test.go │ │ ├── mailgun/ │ │ │ ├── mailgun.go │ │ │ ├── mailgun_batch_test.go │ │ │ ├── mailgun_receive.go │ │ │ ├── mailgun_receive_test.go │ │ │ ├── mailgun_template_test.go │ │ │ └── mailgun_test.go │ │ └── twilio/ │ │ ├── twilio.go │ │ ├── twilio_batch_test.go │ │ ├── twilio_receive.go │ │ ├── twilio_receive_test.go │ │ ├── twilio_sms_test.go │ │ ├── twilio_template_test.go │ │ ├── twilio_test.go │ │ └── twilio_whatsapp_test.go │ ├── template/ │ │ ├── debug_test.go │ │ ├── load_test.go │ │ ├── render_test.go │ │ ├── template.go │ │ ├── template_test.go │ │ └── walk_test.go │ └── types/ │ ├── interfaces.go │ ├── template.go │ └── types.go ├── model/ │ ├── migrate.go │ ├── migrate_test.go │ ├── model.go │ └── model_test.go ├── monitor/ │ ├── README.md │ ├── logger.go │ ├── service.go │ ├── service_test.go │ └── types.go ├── openai/ │ ├── openai.go │ ├── openai_test.go │ ├── process.go │ ├── process_test.go │ └── types.go ├── openapi/ │ ├── agent/ │ │ ├── agent.go │ │ ├── assistant.go │ │ ├── filter.go │ │ ├── models.go │ │ ├── robot/ │ │ │ ├── DESIGN.md │ │ │ ├── GAPS.md │ │ │ ├── TODO.md │ │ │ ├── activities.go │ │ │ ├── completions.go │ │ │ ├── detail.go │ │ │ ├── execute.go │ │ │ ├── execution.go │ │ │ ├── interact.go │ │ │ ├── interact_test.go │ │ │ ├── list.go │ │ │ ├── permission.go │ │ │ ├── results.go │ │ │ ├── robot.go │ │ │ ├── trigger.go │ │ │ ├── types.go │ │ │ ├── utils.go │ │ │ └── verify.go │ │ └── types.go │ ├── app/ │ │ └── app.go │ ├── audit/ │ │ └── audit.go │ ├── captcha/ │ │ └── captcha.go │ ├── chat/ │ │ ├── chat.go │ │ ├── completions.go │ │ ├── reference.go │ │ ├── session.go │ │ └── types.go │ ├── computer/ │ │ └── computer.go │ ├── config.go │ ├── docs/ │ │ ├── migration-guide.md │ │ └── oauth.md │ ├── dsl/ │ │ ├── README.md │ │ └── dsl.go │ ├── file/ │ │ ├── README.md │ │ ├── file.go │ │ └── filter.go │ ├── hello/ │ │ ├── README.md │ │ └── hello.go │ ├── integrations/ │ │ └── integrations.go │ ├── job/ │ │ ├── categories.go │ │ ├── executions.go │ │ ├── filter.go │ │ ├── job.go │ │ ├── jobs.go │ │ └── logs.go │ ├── kb/ │ │ ├── addfile.go │ │ ├── addtext.go │ │ ├── addurl.go │ │ ├── backup.go │ │ ├── collection.go │ │ ├── collection_process.go │ │ ├── document.go │ │ ├── document_process.go │ │ ├── filter.go │ │ ├── graph.go │ │ ├── hit.go │ │ ├── kb.go │ │ ├── provider.go │ │ ├── score.go │ │ ├── search.go │ │ ├── segment.go │ │ ├── types.go │ │ ├── utils.go │ │ ├── vote.go │ │ └── weight.go │ ├── llm/ │ │ └── llm.go │ ├── mcp/ │ │ └── mcp.go │ ├── messenger/ │ │ └── messenger.go │ ├── nodes/ │ │ └── nodes.go │ ├── oauth/ │ │ ├── ERRORS.md │ │ ├── TESTING_GUIDE.md │ │ ├── acl/ │ │ │ ├── DESIGN.md │ │ │ ├── FEATURES_CONFIGURATION.md │ │ │ ├── README.md │ │ │ ├── SCOPES_CONFIGURATION.md │ │ │ ├── acl.go │ │ │ ├── enforce.go │ │ │ ├── errors.go │ │ │ ├── feature.go │ │ │ ├── feature_integration_test.go │ │ │ ├── feature_test.go │ │ │ ├── interfaces.go │ │ │ ├── role/ │ │ │ │ ├── cache.go │ │ │ │ ├── role.go │ │ │ │ ├── types.go │ │ │ │ └── utils.go │ │ │ ├── scope.go │ │ │ ├── scope_test.go │ │ │ └── types.go │ │ ├── apikey.go │ │ ├── authenticate.go │ │ ├── authorized/ │ │ │ ├── utils.go │ │ │ └── utils_test.go │ │ ├── client.go │ │ ├── client_test.go │ │ ├── core.go │ │ ├── core_test.go │ │ ├── device.go │ │ ├── discovery.go │ │ ├── guard.go │ │ ├── mcp.go │ │ ├── oauth.go │ │ ├── oauth_test.go │ │ ├── providers/ │ │ │ ├── client/ │ │ │ │ ├── default.go │ │ │ │ └── default_test.go │ │ │ └── user/ │ │ │ ├── default.go │ │ │ ├── exists_test.go │ │ │ ├── invitation.go │ │ │ ├── invitation_test.go │ │ │ ├── member.go │ │ │ ├── member_test.go │ │ │ ├── oauth_account.go │ │ │ ├── oauth_account_test.go │ │ │ ├── role.go │ │ │ ├── role_test.go │ │ │ ├── team.go │ │ │ ├── team_test.go │ │ │ ├── type.go │ │ │ ├── type_test.go │ │ │ ├── user_basic.go │ │ │ ├── user_basic_test.go │ │ │ ├── user_list.go │ │ │ ├── user_list_test.go │ │ │ ├── user_mfa.go │ │ │ ├── user_mfa_test.go │ │ │ ├── user_role_type.go │ │ │ ├── user_role_type_test.go │ │ │ ├── user_test.go │ │ │ ├── utils.go │ │ │ └── utils_test.go │ │ ├── security.go │ │ ├── security_test.go │ │ ├── signing.go │ │ ├── token.go │ │ ├── token_test.go │ │ ├── types/ │ │ │ ├── authorized.go │ │ │ ├── authorized_test.go │ │ │ ├── errors.go │ │ │ ├── interfaces.go │ │ │ ├── oidc.go │ │ │ ├── types.go │ │ │ └── utils.go │ │ ├── user.go │ │ └── user_test.go │ ├── oauth.go │ ├── openapi.go │ ├── otp/ │ │ ├── DESIGN.md │ │ ├── README.md │ │ ├── generate.go │ │ ├── handler.go │ │ ├── login.go │ │ ├── otp.go │ │ ├── process.go │ │ ├── revoke.go │ │ └── verify.go │ ├── request/ │ │ └── REQUEST_DESIGN.md │ ├── response/ │ │ └── response.go │ ├── sandbox/ │ │ ├── manage.go │ │ └── sandbox.go │ ├── tai/ │ │ ├── proxy.go │ │ ├── tai.go │ │ ├── util.go │ │ └── vnc.go │ ├── team/ │ │ └── team.go │ ├── tests/ │ │ ├── agent/ │ │ │ ├── assistant_create_test.go │ │ │ ├── assistant_test.go │ │ │ ├── assistant_update_test.go │ │ │ ├── models_test.go │ │ │ ├── robot_execution_test.go │ │ │ ├── robot_host_test.go │ │ │ ├── robot_interact_test.go │ │ │ ├── robot_results_activities_test.go │ │ │ ├── robot_test.go │ │ │ └── robot_trigger_test.go │ │ ├── chat/ │ │ │ ├── reference_test.go │ │ │ └── session_test.go │ │ ├── config_test.go │ │ ├── dsl/ │ │ │ └── dsl_test.go │ │ ├── file/ │ │ │ └── file_test.go │ │ ├── hello/ │ │ │ └── hello_test.go │ │ ├── integrations_webhook_test.go │ │ ├── kb/ │ │ │ ├── addfile_test.go │ │ │ ├── addtext_test.go │ │ │ ├── addurl_test.go │ │ │ ├── collection_process_test.go │ │ │ ├── collection_test.go │ │ │ ├── document_test.go │ │ │ └── utils_test.go │ │ ├── nodes/ │ │ │ └── nodes_test.go │ │ ├── oauth/ │ │ │ ├── acl/ │ │ │ │ ├── acl_test.go │ │ │ │ ├── enforce_test.go │ │ │ │ ├── role/ │ │ │ │ │ └── role_test.go │ │ │ │ ├── scope_atomic_test.go │ │ │ │ └── scope_test.go │ │ │ ├── authorized_test.go │ │ │ ├── device_test.go │ │ │ ├── guard_test.go │ │ │ ├── oauth_test.go │ │ │ └── token_test.go │ │ ├── openapi_test.go │ │ ├── otp/ │ │ │ └── otp_test.go │ │ ├── sandbox/ │ │ │ └── sandbox_test.go │ │ ├── testutils/ │ │ │ └── testutils.go │ │ ├── trace/ │ │ │ ├── common_test.go │ │ │ ├── events_test.go │ │ │ ├── info_test.go │ │ │ ├── logs_test.go │ │ │ ├── nodes_test.go │ │ │ └── spaces_test.go │ │ ├── user/ │ │ │ ├── config_functions_test.go │ │ │ ├── config_loading_test.go │ │ │ ├── config_validation_test.go │ │ │ ├── entry_test.go │ │ │ ├── env_test.go │ │ │ ├── env_var_extraction_test.go │ │ │ ├── invitation_test.go │ │ │ ├── login_config_test.go │ │ │ ├── login_test.go │ │ │ ├── member_test.go │ │ │ ├── oauth_authorize_test.go │ │ │ ├── oauth_callback_test.go │ │ │ ├── profile_test.go │ │ │ ├── team_config_robot_test.go │ │ │ ├── team_config_test.go │ │ │ ├── team_test.go │ │ │ └── utils_test.go │ │ └── workspace/ │ │ └── workspace_test.go │ ├── trace/ │ │ ├── README.md │ │ ├── events.go │ │ ├── helpers.go │ │ ├── info.go │ │ ├── logs.go │ │ ├── nodes.go │ │ ├── spaces.go │ │ └── trace.go │ ├── types.go │ ├── user/ │ │ ├── README.md │ │ ├── TODO.md │ │ ├── account.go │ │ ├── config.go │ │ ├── entry.go │ │ ├── features.go │ │ ├── login.go │ │ ├── member.go │ │ ├── oauth.go │ │ ├── profile.go │ │ ├── provider.go │ │ ├── team.go │ │ ├── team_invitation.go │ │ ├── types.go │ │ ├── user.go │ │ └── utils.go │ ├── utils/ │ │ ├── convert.go │ │ └── session.go │ ├── well-known.go │ └── workspace/ │ └── workspace.go ├── pack/ │ └── pack.go ├── pipe/ │ ├── README.md │ ├── context.go │ ├── expression.go │ ├── json.go │ ├── node.go │ ├── pipe.go │ ├── pipe_test.go │ ├── process.go │ ├── process_test.go │ ├── types.go │ ├── ui/ │ │ └── cli/ │ │ └── cli.go │ └── utils.go ├── plugin/ │ ├── README.md │ ├── plugin.go │ └── plugin_test.go ├── query/ │ ├── README.md │ ├── query.go │ └── query_test.go ├── registry/ │ ├── README.md │ ├── client.go │ ├── client_test.go │ ├── manager/ │ │ ├── agent/ │ │ │ ├── add.go │ │ │ ├── agent.go │ │ │ ├── agent_test.go │ │ │ ├── fork.go │ │ │ ├── push.go │ │ │ ├── scan.go │ │ │ └── update.go │ │ ├── agent_e2e_test.go │ │ ├── common/ │ │ │ ├── deps.go │ │ │ ├── deps_test.go │ │ │ ├── hash.go │ │ │ ├── hash_test.go │ │ │ ├── lockfile.go │ │ │ ├── lockfile_test.go │ │ │ ├── packer.go │ │ │ ├── packer_test.go │ │ │ ├── path.go │ │ │ ├── path_test.go │ │ │ ├── prompt.go │ │ │ ├── prompt_test.go │ │ │ └── types.go │ │ ├── e2e_helpers_test.go │ │ ├── mcp/ │ │ │ ├── add.go │ │ │ ├── fork.go │ │ │ ├── mcp.go │ │ │ ├── mcp_test.go │ │ │ ├── push.go │ │ │ ├── script.go │ │ │ └── update.go │ │ ├── mcp_e2e_test.go │ │ ├── robot/ │ │ │ ├── add.go │ │ │ ├── deps.go │ │ │ ├── robot.go │ │ │ └── robot_test.go │ │ └── robot_e2e_test.go │ └── testdata/ │ └── build.go ├── rss/ │ ├── README.md │ ├── atom.go │ ├── atom_test.go │ ├── build.go │ ├── build_atom.go │ ├── build_rss.go │ ├── build_test.go │ ├── convert.go │ ├── discover.go │ ├── discover_test.go │ ├── fetch.go │ ├── fetch_test.go │ ├── parse.go │ ├── parse_test.go │ ├── process.go │ ├── rss.go │ ├── rss_test.go │ └── types.go ├── runtime/ │ ├── runtime.go │ └── runtime_test.go ├── sandbox/ │ ├── DESIGN-PLAYWRIGHT-VNC.md │ ├── DESIGN.md │ ├── PLAN.md │ ├── README.md │ ├── SPEC.md │ ├── bridge/ │ │ └── main.go │ ├── config.go │ ├── config_test.go │ ├── docker/ │ │ ├── base/ │ │ │ └── Dockerfile.base │ │ ├── browser/ │ │ │ └── Dockerfile │ │ ├── build.sh │ │ ├── chrome/ │ │ │ ├── Dockerfile │ │ │ ├── config/ │ │ │ │ ├── chrome-preferences.json │ │ │ │ └── stealth-init.js │ │ │ └── tests/ │ │ │ ├── README.md │ │ │ ├── demo-baidu.py │ │ │ ├── demo-duckduckgo.py │ │ │ └── demo-llm-vision.py │ │ ├── claude/ │ │ │ ├── Dockerfile │ │ │ └── Dockerfile.full │ │ ├── desktop/ │ │ │ ├── Dockerfile │ │ │ └── config/ │ │ │ ├── panel-launcher-chromium.desktop │ │ │ ├── setup-xfce.sh │ │ │ └── workspace.desktop │ │ └── vnc/ │ │ ├── entrypoint-vnc.sh │ │ └── start-vnc.sh │ ├── errors.go │ ├── helpers.go │ ├── helpers_test.go │ ├── ipc/ │ │ ├── jsonrpc_test.go │ │ ├── manager.go │ │ ├── manager_test.go │ │ ├── session.go │ │ ├── session_test.go │ │ └── types.go │ ├── manager.go │ ├── manager_test.go │ ├── proxy/ │ │ ├── README.md │ │ ├── cmd/ │ │ │ └── claude-proxy/ │ │ │ └── main.go │ │ ├── convert.go │ │ ├── main.go │ │ └── types.go │ ├── types.go │ ├── v2/ │ │ ├── DESIGN.md │ │ ├── IMPL.md │ │ ├── Makefile │ │ ├── TEST.md │ │ ├── bench_test.go │ │ ├── box.go │ │ ├── box_attach_test.go │ │ ├── box_image_test.go │ │ ├── box_test.go │ │ ├── box_workspace_test.go │ │ ├── docs/ │ │ │ └── API.md │ │ ├── errors.go │ │ ├── export_test.go │ │ ├── grpc.go │ │ ├── grpc_test.go │ │ ├── host.go │ │ ├── host_test.go │ │ ├── jsapi/ │ │ │ ├── API.md │ │ │ ├── computer.go │ │ │ ├── jsapi.go │ │ │ ├── jsapi_test.go │ │ │ └── node.go │ │ ├── manager.go │ │ ├── manager_lifecycle_test.go │ │ ├── manager_test.go │ │ ├── sandbox.go │ │ ├── sandbox_test.go │ │ ├── testutils_containerized_test.go │ │ ├── testutils_k8s_test.go │ │ ├── testutils_remote_test.go │ │ ├── testutils_test.go │ │ ├── testutils_wintest_test.go │ │ ├── types.go │ │ └── watcher.go │ └── vncproxy/ │ ├── config.go │ ├── proxy.go │ └── proxy_test.go ├── schedule/ │ ├── schedule.go │ └── schedule_test.go ├── script/ │ ├── script.go │ └── script_test.go ├── seed/ │ ├── process.go │ ├── process_test.go │ ├── seed.go │ ├── seed_reset_test.go │ ├── seed_test.go │ └── types.go ├── service/ │ ├── dynamic.go │ ├── dynamic_test.go │ ├── fs/ │ │ ├── default.go │ │ └── utils.go │ ├── guards.go │ ├── gzip.go │ ├── log/ │ │ ├── access.go │ │ └── access_test.go │ ├── middleware.go │ ├── service.go │ ├── service_test.go │ ├── static.go │ ├── watch.go │ └── watch_test.go ├── setup/ │ ├── check.go │ ├── check_test.go │ ├── install.go │ ├── install_test.go │ └── setup.go ├── share/ │ ├── api.go │ ├── api_test.go │ ├── app.go │ ├── columns.go │ ├── const.go │ ├── db.go │ ├── filters.go │ ├── importable.go │ ├── importable_test.go │ ├── session.go │ ├── types.go │ ├── utils.go │ ├── watch.go │ └── watch_test.go ├── sitemap/ │ ├── README.md │ ├── build.go │ ├── build_test.go │ ├── convert.go │ ├── convert_test.go │ ├── discover.go │ ├── fetch.go │ ├── fetch_test.go │ ├── parse.go │ ├── parse_test.go │ ├── process.go │ ├── robots.go │ ├── robots_test.go │ └── types.go ├── store/ │ ├── store.go │ └── store_test.go ├── sui/ │ ├── README.md │ ├── api/ │ │ ├── api.go │ │ ├── build_test.go │ │ ├── guards.go │ │ ├── process.go │ │ ├── process_test.go │ │ ├── render.go │ │ ├── render_test.go │ │ ├── request.go │ │ ├── request_test.go │ │ ├── run.go │ │ ├── sui.go │ │ └── sui_test.go │ ├── core/ │ │ ├── block.go │ │ ├── build.go │ │ ├── cache.go │ │ ├── compile.go │ │ ├── component.go │ │ ├── context.go │ │ ├── core.go │ │ ├── data.go │ │ ├── editor.go │ │ ├── event.go │ │ ├── fs.go │ │ ├── injections.go │ │ ├── interfaces.go │ │ ├── jit.go │ │ ├── json.go │ │ ├── json_test.go │ │ ├── locale.go │ │ ├── locale_test.go │ │ ├── matcher.go │ │ ├── page.go │ │ ├── page_test.go │ │ ├── parser.go │ │ ├── parser_test.go │ │ ├── preview.go │ │ ├── request.go │ │ ├── script.go │ │ ├── sui.go │ │ ├── sui_test.go │ │ ├── token.go │ │ ├── token_test.go │ │ ├── translate.go │ │ ├── types.go │ │ └── utils.go │ ├── docs/ │ │ ├── agent-sui.md │ │ ├── backend-scripts.md │ │ ├── components.md │ │ ├── data-binding.md │ │ ├── event-handling.md │ │ ├── frontend-api.md │ │ ├── i18n.md │ │ ├── page-config.md │ │ ├── routing.md │ │ └── template-syntax.md │ ├── libsui/ │ │ ├── index.ts │ │ ├── openapi.ts │ │ ├── utils.ts │ │ └── yao.ts │ └── storages/ │ ├── agent/ │ │ ├── agent.go │ │ ├── agent_test.go │ │ ├── page.go │ │ ├── template.go │ │ └── types.go │ ├── azure/ │ │ └── azure.go │ └── local/ │ ├── block.go │ ├── block_test.go │ ├── build.go │ ├── build_test.go │ ├── component.go │ ├── component_test.go │ ├── copy.go │ ├── local.go │ ├── local_test.go │ ├── page.go │ ├── page_render_test.go │ ├── page_test.go │ ├── template.go │ ├── template_test.go │ └── types.go ├── tai/ │ ├── api/ │ │ ├── register.go │ │ └── register_test.go │ ├── conn.go │ ├── dial.go │ ├── docs/ │ │ ├── README.md │ │ ├── api.md │ │ ├── proxy.md │ │ ├── registry.md │ │ ├── sandbox.md │ │ ├── tunnel.md │ │ ├── vnc.md │ │ ├── volume.md │ │ └── workspace.md │ ├── heartbeat.go │ ├── hostexec/ │ │ ├── local.go │ │ └── pb/ │ │ ├── hostexec.pb.go │ │ ├── hostexec.proto │ │ └── hostexec_grpc.pb.go │ ├── proxy/ │ │ ├── connect.go │ │ ├── proxy.go │ │ └── proxy_test.go │ ├── registry/ │ │ ├── registry.go │ │ ├── registry_test.go │ │ └── testing.go │ ├── runtime/ │ │ ├── client_accessor.go │ │ ├── docker.go │ │ ├── docker_core.go │ │ ├── image.go │ │ ├── image_docker.go │ │ ├── image_k8s.go │ │ ├── k8s.go │ │ ├── local.go │ │ ├── runtime_test.go │ │ └── sandbox.go │ ├── serverinfo/ │ │ └── pb/ │ │ ├── serverinfo.pb.go │ │ ├── serverinfo.proto │ │ └── serverinfo_grpc.pb.go │ ├── sysinfo.go │ ├── tai.go │ ├── tai_test.go │ ├── taiid/ │ │ ├── taiid.go │ │ └── taiid_test.go │ ├── token.go │ ├── tunnel/ │ │ ├── forward.go │ │ ├── forward_test.go │ │ ├── grpc_handler.go │ │ ├── grpc_handler_test.go │ │ ├── proto/ │ │ │ └── tunnel.proto │ │ ├── server.go │ │ ├── server_test.go │ │ └── taipb/ │ │ ├── tunnel.pb.go │ │ └── tunnel_grpc.pb.go │ ├── types/ │ │ └── types.go │ ├── vnc/ │ │ ├── vnc.go │ │ └── vnc_test.go │ ├── volume/ │ │ ├── local.go │ │ ├── mock_test.go │ │ ├── pb/ │ │ │ ├── volume.pb.go │ │ │ ├── volume.proto │ │ │ └── volume_grpc.pb.go │ │ ├── remote.go │ │ ├── volume.go │ │ └── volume_test.go │ ├── workspace/ │ │ ├── copy.go │ │ ├── uri.go │ │ ├── workspace.go │ │ └── workspace_test.go │ └── yao.go ├── task/ │ ├── task.go │ └── task_test.go ├── test/ │ ├── request.go │ └── utils.go ├── trace/ │ ├── BUGFIX.md │ ├── KNOWN_ISSUES.md │ ├── README.md │ ├── event_listener.go │ ├── handler.go │ ├── jsapi/ │ │ ├── jsapi.go │ │ ├── jsapi_test.go │ │ ├── node.go │ │ ├── space.go │ │ └── trace.go │ ├── local/ │ │ └── driver.go │ ├── manager.go │ ├── node.go │ ├── space.go │ ├── state.go │ ├── store/ │ │ └── driver.go │ ├── subscription.go │ ├── test_helpers.go │ ├── trace.go │ ├── trace_archive_test.go │ ├── trace_autocomplete_test.go │ ├── trace_basic_test.go │ ├── trace_bench_test.go │ ├── trace_concurrent_test.go │ ├── trace_lifecycle_test.go │ ├── trace_mem_test.go │ ├── trace_node_test.go │ ├── trace_resource_test.go │ ├── trace_space_test.go │ ├── trace_subscription_leak_test.go │ ├── trace_subscription_test.go │ └── types/ │ ├── driver.go │ ├── events.go │ ├── interfaces.go │ └── types.go ├── utils/ │ ├── README.md │ ├── captcha/ │ │ ├── captcha.go │ │ ├── captcha_test.go │ │ └── process.go │ ├── datetime/ │ │ └── now.go │ ├── datetime_test.go │ ├── fmt/ │ │ └── fmt.go │ ├── json/ │ │ └── json.go │ ├── jsonschema/ │ │ ├── jsonschema.go │ │ └── jsonschema_test.go │ ├── otp/ │ │ ├── otp.go │ │ ├── otp_test.go │ │ └── process.go │ ├── process.go │ ├── str/ │ │ └── str.go │ ├── str_test.go │ ├── throw/ │ │ └── throw.go │ ├── throw_test.go │ ├── tree/ │ │ └── tree.go │ ├── tree_test.go │ ├── url/ │ │ └── url.go │ └── url_test.go ├── wework/ │ ├── process.go │ ├── wework.go │ ├── wework_test.go │ └── xml.go ├── widget/ │ ├── driver/ │ │ ├── connector.go │ │ └── source.go │ ├── instance.go │ ├── load.go │ ├── load_test.go │ ├── process.go │ ├── process_test.go │ ├── types.go │ ├── widget.go │ └── widget_test.go ├── widgets/ │ ├── action/ │ │ ├── action.go │ │ ├── action_test.go │ │ ├── guard.go │ │ ├── process.go │ │ ├── process_test.go │ │ └── types.go │ ├── action.go │ ├── api.go │ ├── app/ │ │ ├── app.go │ │ ├── app_test.go │ │ └── types.go │ ├── chart/ │ │ ├── action.go │ │ ├── api.go │ │ ├── chart.go │ │ ├── chart_test.go │ │ ├── export.go │ │ ├── fields.go │ │ ├── handler.go │ │ ├── layout.go │ │ ├── mapping.go │ │ ├── process.go │ │ ├── process_test.go │ │ ├── types.go │ │ └── vaildate.go │ ├── component/ │ │ ├── action.go │ │ ├── action_test.go │ │ ├── component.go │ │ ├── compute.go │ │ ├── compute_test.go │ │ ├── handlers.go │ │ ├── process.go │ │ ├── process_test.go │ │ ├── props.go │ │ └── types.go │ ├── compute/ │ │ ├── compute.go │ │ └── types.go │ ├── dashboard/ │ │ ├── action.go │ │ ├── api.go │ │ ├── dashboard.go │ │ ├── dashboard_test.go │ │ ├── export.go │ │ ├── fields.go │ │ ├── handler.go │ │ ├── layout.go │ │ ├── mapping.go │ │ ├── process.go │ │ ├── process_test.go │ │ ├── types.go │ │ └── vaildate.go │ ├── expression/ │ │ ├── expression.go │ │ ├── expression_test.go │ │ └── process.go │ ├── field/ │ │ ├── column.go │ │ ├── column_test.go │ │ ├── field.go │ │ ├── field_test.go │ │ ├── filter.go │ │ ├── filter_test.go │ │ ├── transform.go │ │ ├── transform_test.go │ │ └── types.go │ ├── field.go │ ├── form/ │ │ ├── action.go │ │ ├── api.go │ │ ├── bind.go │ │ ├── export.go │ │ ├── fields.go │ │ ├── form.go │ │ ├── form_test.go │ │ ├── handler.go │ │ ├── layout.go │ │ ├── mapping.go │ │ ├── process.go │ │ ├── process_test.go │ │ ├── types.go │ │ └── vaildate.go │ ├── hook/ │ │ ├── hook.go │ │ └── types.go │ ├── item.go │ ├── list/ │ │ ├── action.go │ │ ├── api.go │ │ ├── bind.go │ │ ├── export.go │ │ ├── fields.go │ │ ├── handler.go │ │ ├── layout.go │ │ ├── list.go │ │ ├── list_test.go │ │ ├── mapping.go │ │ ├── process.go │ │ ├── process_test.go │ │ ├── types.go │ │ └── vaildate.go │ ├── login/ │ │ ├── login.go │ │ ├── login_test.go │ │ ├── process.go │ │ └── types.go │ ├── mapping/ │ │ └── mapping.go │ ├── models.go │ ├── process.go │ ├── process_test.go │ ├── table/ │ │ ├── README.md │ │ ├── action.go │ │ ├── api.go │ │ ├── api_test.go │ │ ├── bind.go │ │ ├── excel.go │ │ ├── export.go │ │ ├── fields.go │ │ ├── fields_test.go │ │ ├── handler.go │ │ ├── layout.go │ │ ├── mapping.go │ │ ├── mapping_test.go │ │ ├── process.go │ │ ├── process_test.go │ │ ├── table.go │ │ ├── table_test.go │ │ ├── types.go │ │ └── validate.go │ ├── widgets.go │ └── widgets_test.go ├── workspace/ │ ├── DESIGN.md │ ├── Makefile │ ├── TEST.md │ ├── bench_test.go │ ├── errors.go │ ├── fileio_test.go │ ├── jsapi/ │ │ ├── API.md │ │ ├── fs.go │ │ ├── jsapi.go │ │ └── jsapi_test.go │ ├── manager.go │ ├── testutils_test.go │ ├── workspace.go │ └── workspace_test.go └── yao/ ├── assistants/ │ ├── entity/ │ │ ├── package.yao │ │ └── prompts.yml │ ├── keyword/ │ │ ├── package.yao │ │ ├── prompts.yml │ │ └── src/ │ │ └── index.ts │ ├── needsearch/ │ │ ├── package.yao │ │ ├── prompts.yml │ │ └── src/ │ │ └── index.ts │ ├── prompt/ │ │ ├── package.yao │ │ └── prompts.yml │ ├── querydsl/ │ │ ├── package.yao │ │ ├── prompts/ │ │ │ ├── aggregation.yml │ │ │ ├── complex.yml │ │ │ ├── filter.yml │ │ │ └── join.yml │ │ ├── prompts.yml │ │ └── src/ │ │ └── index.ts │ ├── robot_prompt/ │ │ ├── package.yao │ │ └── prompts.yml │ └── title/ │ ├── package.yao │ └── prompts.yml ├── data/ │ ├── icons/ │ │ └── icon.icns │ ├── index.html │ └── kb/ │ └── providers/ │ ├── chunking/ │ │ ├── semantic/ │ │ │ ├── en.json │ │ │ └── zh-cn.json │ │ └── structured/ │ │ ├── en.json │ │ └── zh-cn.json │ ├── converter/ │ │ ├── mcp/ │ │ │ ├── en.json │ │ │ └── zh-cn.json │ │ ├── ocr/ │ │ │ ├── en.json │ │ │ └── zh-cn.json │ │ ├── office/ │ │ │ ├── en.json │ │ │ └── zh-cn.json │ │ ├── utf8/ │ │ │ ├── en.json │ │ │ └── zh-cn.json │ │ ├── video/ │ │ │ ├── en.json │ │ │ └── zh-cn.json │ │ ├── vision/ │ │ │ ├── en.json │ │ │ └── zh-cn.json │ │ └── whisper/ │ │ ├── en.json │ │ └── zh-cn.json │ ├── embedding/ │ │ ├── fastembed/ │ │ │ ├── en.json │ │ │ └── zh-cn.json │ │ └── openai/ │ │ ├── en.json │ │ └── zh-cn.json │ ├── extraction/ │ │ └── openai/ │ │ ├── en.json │ │ └── zh-cn.json │ └── fetcher/ │ ├── http/ │ │ ├── en.json │ │ └── zh-cn.json │ └── mcp/ │ ├── en.json │ └── zh-cn.json ├── fields/ │ └── model.trans.json ├── langs/ │ ├── en-US.json │ ├── zh-cn/ │ │ ├── global.yml │ │ └── logins/ │ │ ├── admin.login.yml │ │ └── user.login.yml │ └── zh-hk/ │ ├── global.yml │ └── logins/ │ ├── admin.login.yml │ └── user.login.yml ├── models/ │ ├── agent/ │ │ ├── assistant.mod.yao │ │ ├── chat.mod.yao │ │ ├── execution.mod.yao │ │ ├── message.mod.yao │ │ ├── resume.mod.yao │ │ └── search.mod.yao │ ├── attachment.mod.yao │ ├── audit.mod.yao │ ├── config.mod.yao │ ├── dsl.mod.yao │ ├── invitation.mod.yao │ ├── job/ │ │ ├── category.mod.yao │ │ ├── execution.mod.yao │ │ ├── job.mod.yao │ │ └── log.mod.yao │ ├── kb/ │ │ ├── collection.mod.yao │ │ └── document.mod.yao │ ├── member.mod.yao │ ├── role.mod.yao │ ├── team.mod.yao │ ├── user/ │ │ ├── oauth_account.mod.yao │ │ └── type.mod.yao │ └── user.mod.yao ├── release/ │ └── app.yaz ├── stores/ │ ├── agent/ │ │ ├── cache.lru.yao │ │ └── memory/ │ │ ├── chat.xun.yao │ │ ├── context.xun.yao │ │ ├── team.xun.yao │ │ └── user.xun.yao │ ├── cache.lru.yao │ ├── kb/ │ │ ├── cache.lru.yao │ │ └── store.xun.yao │ ├── oauth/ │ │ ├── cache.lru.yao │ │ ├── client.xun.yao │ │ └── store.xun.yao │ └── store.xun.yao └── uploaders/ └── attachment.local.yao ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ blank_issues_enabled: false contact_links: - name: Ask a question or discuss a topic url: https://discord.com/invite/BkMR2NUsjU about: Ask questions and discuss with other community members - name: Join Yao community url: https://yaoapps.com/community about: Join the community to get updates and news ================================================ FILE: .github/ISSUE_TEMPLATE/issue_report.md ================================================ --- name: "Report an issue" about: "Report an issue to help us improve" labels: "" assignees: "" --- ## Description ## Context - **Yao Version( yao version --all )**: - **Platform**: ================================================ FILE: .github/actions/setup-db/Dockerfile ================================================ FROM docker:latest COPY entrypoint.sh /entrypoint.sh RUN chmod +x /entrypoint.sh ENTRYPOINT ["/entrypoint.sh"] ================================================ FILE: .github/actions/setup-db/action.yml ================================================ inputs: kind: description: "Chose the kind of database (MySQL8.0, MySQL5.7, Postgres9.6, Postgres14.0, SQLite3)" required: true db: description: "The name of database" required: false default: "github" port: description: "The port of database" required: false user: description: "The user of database" required: false default: "github" password: description: "The passowrd of database" required: false default: "123456" runs: using: "docker" image: "Dockerfile" ================================================ FILE: .github/actions/setup-yao/action.yml ================================================ name: "Setup Yao Build Environment" description: "Checkout dependency repos, setup Go toolchain, and install build tools (v1.0.0)" inputs: go-version: description: "Go version to install" default: "1.25" repo-kun: description: "Kun repository (owner/repo)" required: true repo-xun: description: "Xun repository (owner/repo)" required: true repo-gou: description: "Gou repository (owner/repo)" required: true checkout-app: description: "Checkout yao-dev-app (demo application for tests)" default: "true" checkout-init: description: "Checkout yao-init (for Yao server startup in CI)" default: "false" apple-private-key: description: "Apple private key content for OAuth certs (optional)" default: "" runs: using: "composite" steps: # -- Dependency repositories -- - name: Checkout Kun uses: actions/checkout@v4 with: repository: ${{ inputs.repo-kun }} path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: ${{ inputs.repo-xun }} path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: ${{ inputs.repo-gou }} path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 shell: bash run: | for file in $(find ./v8go -name "libv8*.zip"); do dir=$(dirname "$file") echo "Extracting $file to $dir" unzip -o -d "$dir" "$file" rm -rf "$dir/__MACOSX" done - name: Checkout Demo App if: ${{ inputs.checkout-app == 'true' }} uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Checkout yao-init if: ${{ inputs.checkout-init == 'true' }} uses: actions/checkout@v4 with: repository: yaoapp/yao-init path: yao-init # -- Move all dependencies to parent directory (Go workspace layout) -- - name: Move Dependencies shell: bash run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ [ -d app ] && mv app ../ mv extension ../ [ -d yao-init ] && mv yao-init ../ # -- Setup Apple Private Key (if provided) -- - name: Setup Apple Private Key if: ${{ inputs.apple-private-key != '' }} shell: bash run: | mkdir -p ../app/openapi/certs/apple echo "${{ inputs.apple-private-key }}" > ../app/openapi/certs/apple/signin_client_secret_key.p8 # -- Go toolchain -- - name: Setup Go ${{ inputs.go-version }} uses: actions/setup-go@v5 with: go-version: ${{ inputs.go-version }} - name: Setup Go Tools shell: bash run: make tools ================================================ FILE: .github/codesign/entitlements.plist ================================================ com.apple.security.cs.allow-jit com.apple.security.cs.allow-unsigned-executable-memory ================================================ FILE: .github/env/sandbox-v2.env ================================================ # ============================================================ # Yao CI Environment — sandbox-v2 (v1.0.0) # Loaded via: cat .github/env/sandbox-v2.env >> $GITHUB_ENV # ============================================================ # ======================================== # Yao Runtime (YAO_ prefix, read by Yao) # ======================================== YAO_HOST=0.0.0.0 YAO_PORT=5099 YAO_GRPC_HOST=0.0.0.0 YAO_GRPC_PORT=9099 YAO_DB_DRIVER=sqlite3 YAO_SESSION=memory YAO_ENV=development # ======================================== # CI Test Parameters (YAO_CI_ prefix) # ======================================== # -- Network -- YAO_CI_BRIDGE_IP=172.17.0.1 # -- Yao service ports (tests read these, not YAO_PORT/YAO_GRPC_PORT) -- YAO_CI_HTTP_PORT=5099 YAO_CI_GRPC_PORT=9099 YAO_CI_URL=http://127.0.0.1:5099 YAO_CI_GRPC=127.0.0.1:9099 # -- OAuth token generation (ci-token tool) -- YAO_CI_OAUTH_SUBJECT=ci-test-user YAO_CI_OAUTH_USER_ID=ci-test-user YAO_CI_OAUTH_TEAM_ID=ci-test-team YAO_CI_OAUTH_SCOPE=tai:tunnel YAO_CI_OAUTH_TTL=24h # ======================================== # Tai Instances # # Connection modes: # tai-local → DIRECT (--direct, Yao dials Tai gRPC directly) # tai-docker → TUNNEL (default, Tai dials Yao gRPC, reverse tunnel) # tai-k8s → TUNNEL (same as above, with K8s runtime) # tai-hostexec → TUNNEL (same as above, no container runtime) # ======================================== # -- tai-local (DIRECT mode, auto-detect Docker via /var/run/docker.sock) -- # Yao dials tai-local gRPC directly, so gRPC port must be reachable. # Docker proxy on 12376 (not default 12375) to avoid conflict with tai-docker. YAO_CI_TAI_LOCAL_HOST=127.0.0.1 YAO_CI_TAI_LOCAL_GRPC_PORT=19103 YAO_CI_TAI_LOCAL_HTTP_PORT=8102 YAO_CI_TAI_LOCAL_VNC_PORT=16083 YAO_CI_TAI_LOCAL_DOCKER_PORT=12376 YAO_CI_TAI_LOCAL_GRPC=127.0.0.1:19103 YAO_CI_TAI_LOCAL_DOCKER_API=tcp://127.0.0.1:12376 # -- tai-docker (TUNNEL mode, explicit Docker API proxy) -- # Tunnel: Tai connects to Yao gRPC. Sandbox connects to Yao, traffic forwarded via tunnel. # gRPC port used only for Tai's own listener; Yao accesses via tunnel, not direct dial. YAO_CI_TAI_DOCKER_HOST=127.0.0.1 YAO_CI_TAI_DOCKER_GRPC_PORT=19100 YAO_CI_TAI_DOCKER_HTTP_PORT=8099 YAO_CI_TAI_DOCKER_VNC_PORT=16080 YAO_CI_TAI_DOCKER_API_PORT=12375 YAO_CI_TAI_DOCKER_API=tcp://127.0.0.1:12375 # -- tai-k8s (TUNNEL mode, K8s API proxy via k3d) -- # K8s proxy on 16444 (not 16443) because k3d --api-port already binds 16443. # TAI_K8S_UPSTREAM points to k3d at 127.0.0.1:16443; proxy exposes on 16444. YAO_CI_TAI_K8S_HOST=127.0.0.1 YAO_CI_TAI_K8S_GRPC_PORT=19101 YAO_CI_TAI_K8S_HTTP_PORT=8100 YAO_CI_TAI_K8S_VNC_PORT=16081 YAO_CI_TAI_K8S_API_PORT=16444 YAO_CI_K3D_API_PORT=16443 # -- tai-hostexec (TUNNEL mode, no container runtime, HostExec only) -- YAO_CI_TAI_HOSTEXEC_HOST=127.0.0.1 YAO_CI_TAI_HOSTEXEC_GRPC_PORT=19102 YAO_CI_TAI_HOSTEXEC_HTTP_PORT=8101 YAO_CI_TAI_HOSTEXEC_VNC_PORT=16082 # ======================================== # Sandbox V2 addresses (used by test code) # ======================================== YAO_CI_SANDBOX_LOCAL_ADDR=tai://127.0.0.1:19103 YAO_CI_SANDBOX_DOCKER_ADDR=tai://127.0.0.1:19100 YAO_CI_SANDBOX_K8S_ADDR=tai://127.0.0.1:19101 YAO_CI_SANDBOX_IMAGE=yaoapp/tai-sandbox-test:latest # ======================================== # HostExec addresses # ======================================== YAO_CI_HOSTEXEC_LOCAL_ADDR=127.0.0.1:19103 YAO_CI_HOSTEXEC_DOCKER_ADDR=127.0.0.1:19100 YAO_CI_HOSTEXEC_K8S_ADDR=127.0.0.1:19101 YAO_CI_HOSTEXEC_ONLY_ADDR=127.0.0.1:19102 # -- Tunnel -- YAO_CI_TUNNEL=true # ======================================== # Database (MySQL / PostgreSQL / SQLite) # ======================================== MYSQL_TEST_HOST=127.0.0.1 MYSQL_TEST_PORT=3308 MYSQL_TEST_USER=test MYSQL_TEST_PASS=123456 PG_TEST_HOST=127.0.0.1 PG_TEST_PORT=5432 PG_TEST_USER=test PG_TEST_PASS=123456 SQLITE_DB=./app/db/yao.db # ======================================== # Legacy variable mapping (migrate later) # ======================================== TAI_TEST_HOST=127.0.0.1 TAI_TEST_DOCKER=tcp://127.0.0.1:12375 TAI_TEST_GRPC_PORT=19100 TAI_TEST_HTTP_PORT=8099 TAI_TEST_VNC_PORT=16080 TAI_TEST_DOCKER_PORT=12375 TAI_TEST_K8S_HOST=127.0.0.1 TAI_TEST_K8S_PORT=16444 TAI_TEST_K8S_GRPC_PORT=19101 TAI_TEST_K8S_HTTP_PORT=8100 TAI_TEST_K8S_VNC_PORT=16081 TAI_TEST_HOST_IP=172.17.0.1 TAI_TEST_TUNNEL=true TAI_TEST_YAO_URL=http://127.0.0.1:5099 TAI_TEST_YAO_GRPC=127.0.0.1:9099 SANDBOX_TEST_LOCAL_ADDR=tai://127.0.0.1:19103 SANDBOX_TEST_REMOTE_ADDR=tai://127.0.0.1:19100 SANDBOX_TEST_K8S_REMOTE_ADDR=tai://127.0.0.1:19101 SANDBOX_TEST_HOSTEXEC_ADDR=tai://127.0.0.1:19102 SANDBOX_TEST_IMAGE=yaoapp/tai-sandbox-test:latest DOCKER_BRIDGE_IP=172.17.0.1 ================================================ FILE: .github/workflows/build-docker.yml ================================================ name: Build and push docker images on: # push: # branches: [main] # paths: # - ".github/workflows/docker.yml" workflow_run: workflows: ["Build Linux Artifacts"] types: - completed env: VERSION: 0.10.5 jobs: build: if: ${{ github.event.workflow_run.conclusion == 'success' }} runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v4 - name: Get Version run: | echo VERSION=$(cat share/const.go |grep 'const VERSION' | awk '{print $4}' | sed "s/\"//g") >> $GITHUB_ENV - name: Check Version run: echo $VERSION - name: Set up QEMU uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Login to DockerHub uses: docker/login-action@v3 with: username: ${{ secrets.DOCKER_USER }} password: ${{ secrets.DOCKER_TOKEN }} - name: Build Development uses: docker/build-push-action@v6 env: DOCKER_CONTENT_TRUST: 1 with: context: ./docker/development platforms: linux/amd64 build-args: | VERSION=${{ env.VERSION }} ARCH=amd64 push: true tags: yaoapp/yao:${{ env.VERSION }}-amd64-dev - name: Build Development Arm64 uses: docker/build-push-action@v6 env: DOCKER_CONTENT_TRUST: 1 with: context: ./docker/development platforms: linux/arm64 build-args: | VERSION=${{ env.VERSION }} ARCH=arm64 push: true tags: yaoapp/yao:${{ env.VERSION }}-arm64-dev - name: Build Production uses: docker/build-push-action@v6 env: DOCKER_CONTENT_TRUST: 1 with: context: ./docker/production platforms: linux/amd64 build-args: | VERSION=${{ env.VERSION }} ARCH=amd64 push: true tags: yaoapp/yao:${{ env.VERSION }}-amd64 - name: Build Production Arm64 uses: docker/build-push-action@v6 env: DOCKER_CONTENT_TRUST: 1 with: context: ./docker/production platforms: linux/arm64 build-args: | VERSION=${{ env.VERSION }} ARCH=arm64 push: true tags: yaoapp/yao:${{ env.VERSION }}-arm64 - name: Build Production Slim uses: docker/build-push-action@v6 env: DOCKER_CONTENT_TRUST: 1 with: context: ./docker/production-slim platforms: linux/amd64 build-args: | VERSION=${{ env.VERSION }} ARCH=amd64 push: true tags: yaoapp/yao:${{ env.VERSION }}-amd64-slim - name: Build Production Slim Arm64 uses: docker/build-push-action@v6 env: DOCKER_CONTENT_TRUST: 1 with: context: ./docker/production-slim platforms: linux/arm64 build-args: | VERSION=${{ env.VERSION }} ARCH=arm64 push: true tags: yaoapp/yao:${{ env.VERSION }}-arm64-slim ================================================ FILE: .github/workflows/build-linux.yml ================================================ name: Build Linux Artifacts on: workflow_dispatch: inputs: tags: description: "Version tags" jobs: build: runs-on: "ubuntu-latest" container: image: yaoapp/yao-build:1.0.0 env: CF_ACCESS_KEY_ID: ${{ secrets.CF_ACCESS_KEY_ID }} CF_SECRET_ACCESS_KEY: ${{ secrets.CF_SECRET_ACCESS_KEY }} R2_BUCKET: ${{ secrets.R2_BUCKET }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} steps: - name: Configure R2 For Cloudflare run: | aws configure set aws_access_key_id $CF_ACCESS_KEY_ID aws configure set aws_secret_access_key $CF_SECRET_ACCESS_KEY aws configure set default.region us-east-1 # Update with your R2 region if different aws configure set default.s3.signature_version s3v4 aws configure set default.s3.endpoint_url https://$R2_ACCOUNT_ID.r2.cloudflarestorage.com aws --version - name: Build run: | export PATH=$PATH:/github/home/go/bin /app/build.sh ls -l /data - name: Archive production artifacts uses: actions/upload-artifact@v4 with: name: yao-linux path: | /data/* - name: Push To R2 Cloudflare run: | for file in /data/*; do aws s3 cp $file s3://$R2_BUCKET/archives/ --endpoint-url https://$R2_ACCOUNT_ID.r2.cloudflarestorage.com done ================================================ FILE: .github/workflows/build-macos.yml ================================================ name: Build MacOS Artifacts on: workflow_dispatch: inputs: tags: description: "Version tags" env: VERSION: 1.0.0 jobs: build: strategy: matrix: go: ["1.25"] runs-on: "macos-latest" steps: - name: Setup Node.js uses: actions/setup-node@v4 with: node-version: 18 - name: Install pnpm run: npm install -g pnpm - name: Setup Cache uses: actions/cache@v4 with: path: | ~/.cache/go-build ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - name: Checkout Kun uses: actions/checkout@v4 with: repository: yaoapp/kun path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: yaoapp/xun path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: yaoapp/gou path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") # Get the directory where the ZIP file is located echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout CUI v1.0 # ** XGEN will be renamed to CUI in the feature. and move to the new repository. ** # ** new repository: https://github.com/YaoApp/cui.git ** uses: actions/checkout@v4 with: repository: yaoapp/cui path: cui-v1.0 - name: Checkout Yao-Init uses: actions/checkout@v4 with: repository: yaoapp/yao-init path: yao-init - name: Move Kun, Xun, Gou, UI, V8Go run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv cui-v1.0 ../ mv yao-init ../ rm -f ../cui-v1.0/packages/setup/vite.config.ts.* ls -l . ls -l ../ ls -l ../cui-v1.0/packages/setup/ - name: Checkout Code uses: actions/checkout@v4 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Setup Go Tools run: | make tools - name: Get Version run: | echo VERSION=$(cat share/const.go |grep 'const VERSION' | awk '{print $4}' | sed "s/\"//g") >> $GITHUB_ENV - name: Make Artifacts MacOS run: | make artifacts-macos - name: Install Certificates env: KEYCHAIN_PASSWORD: ${{ secrets.KEYCHAIN_PASSWORD }} run: | mkdir -p certs echo "${{ secrets.APPLE_DEVELOPERIDG2CA }}" | base64 --decode > certs/DeveloperIDG2CA.cer echo "${{ secrets.APPLE_DISTRIBUTION }}" | base64 --decode > certs/distribution.cer echo "${{ secrets.APPLE_PRIVATE_KEY }}" | base64 --decode > certs/private_key.p12 security verify-cert -c certs/DeveloperIDG2CA.cer security verify-cert -c certs/distribution.cer - name: Import Certificates run: | KEYCHAIN_PATH=$RUNNER_TEMP/app-signing.keychain-db # create temporary keychain security create-keychain -p "$KEYCHAIN_PASSWORD" $KEYCHAIN_PATH security set-keychain-settings -lut 21600 $KEYCHAIN_PATH security unlock-keychain -p "$KEYCHAIN_PASSWORD" $KEYCHAIN_PATH # import certificate to keychain security import ./certs/DeveloperIDG2CA.cer -k $KEYCHAIN_PATH -T /usr/bin/codesign security import ./certs/distribution.cer -k $KEYCHAIN_PATH -T /usr/bin/codesign # import private key to keychain security import ./certs/private_key.p12 -k $KEYCHAIN_PATH -P "${{ secrets.APPLE_PRIVATE_KEY_PASSWORD }}" -T /usr/bin/codesign security list-keychain -d user -s $KEYCHAIN_PATH - name: Sign Artifacts run: | codesign --deep --force --verbose --timestamp --sign "Developer ID Application: ${{ secrets.APPLE_SIGN }}" dist/release/yao-$VERSION-darwin-arm64 codesign --deep --force --verbose --timestamp --sign "Developer ID Application: ${{ secrets.APPLE_SIGN }}" dist/release/yao-$VERSION-darwin-amd64 codesign --deep --force --verbose --timestamp --sign "Developer ID Application: ${{ secrets.APPLE_SIGN }}" dist/release/yao-$VERSION-darwin-arm64-prod codesign --deep --force --verbose --timestamp --sign "Developer ID Application: ${{ secrets.APPLE_SIGN }}" dist/release/yao-$VERSION-darwin-amd64-prod - name: Verify Signature run: | codesign --verify --deep --strict --verbose=2 dist/release/yao-$VERSION-darwin-arm64 codesign --verify --deep --strict --verbose=2 dist/release/yao-$VERSION-darwin-amd64 codesign --verify --deep --strict --verbose=2 dist/release/yao-$VERSION-darwin-arm64-prod codesign --verify --deep --strict --verbose=2 dist/release/yao-$VERSION-darwin-amd64-prod - name: Send to Apple Notary Service run: | zip -r dist/release/yao-$VERSION-darwin-arm64.zip dist/release/yao-$VERSION-darwin-arm64 zip -r dist/release/yao-$VERSION-darwin-amd64.zip dist/release/yao-$VERSION-darwin-amd64 zip -r dist/release/yao-$VERSION-darwin-arm64-prod.zip dist/release/yao-$VERSION-darwin-arm64-prod zip -r dist/release/yao-$VERSION-darwin-amd64-prod.zip dist/release/yao-$VERSION-darwin-amd64-prod xcrun notarytool submit dist/release/yao-$VERSION-darwin-arm64.zip --apple-id "${{ secrets.APPLE_ID }}" --team-id "${{ secrets.APPLE_TEAME_ID }}" --password "${{ secrets.APPLE_APP_SPEC_PASS }}" --output-format json xcrun notarytool submit dist/release/yao-$VERSION-darwin-amd64.zip --apple-id "${{ secrets.APPLE_ID }}" --team-id "${{ secrets.APPLE_TEAME_ID }}" --password "${{ secrets.APPLE_APP_SPEC_PASS }}" --output-format json xcrun notarytool submit dist/release/yao-$VERSION-darwin-arm64-prod.zip --apple-id "${{ secrets.APPLE_ID }}" --team-id "${{ secrets.APPLE_TEAME_ID }}" --password "${{ secrets.APPLE_APP_SPEC_PASS }}" --output-format json xcrun notarytool submit dist/release/yao-$VERSION-darwin-amd64-prod.zip --apple-id "${{ secrets.APPLE_ID }}" --team-id "${{ secrets.APPLE_TEAME_ID }}" --password "${{ secrets.APPLE_APP_SPEC_PASS }}" --output-format json rm -f dist/release/yao-$VERSION-darwin-arm64.zip rm -f dist/release/yao-$VERSION-darwin-amd64.zip rm -f dist/release/yao-$VERSION-darwin-arm64-prod.zip rm -f dist/release/yao-$VERSION-darwin-amd64-prod.zip - name: Archive production artifacts uses: actions/upload-artifact@v4 with: name: yao-macos path: | dist/release/* ================================================ FILE: .github/workflows/notarize-macos.yml ================================================ name: Notarize macOS on: workflow_dispatch: inputs: run_id: description: "Release macOS workflow run ID (to download artifacts from)" required: true version: description: "Version used in the release build (e.g. 1.0.0 or 1.0.0-alpha)" required: true permissions: contents: write jobs: # =================================================================== # Notarize Yao binaries (arm64 + amd64) # =================================================================== notarize: runs-on: macos-latest strategy: matrix: arch: [arm64, amd64] steps: - name: Download Yao Binary uses: actions/download-artifact@v4 with: name: yao-darwin-${{ matrix.arch }} path: bin run-id: ${{ github.event.inputs.run_id }} github-token: ${{ secrets.GITHUB_TOKEN }} - name: Install Certificates env: KEYCHAIN_PASSWORD: ${{ secrets.KEYCHAIN_PASSWORD }} run: | mkdir -p certs echo "${{ secrets.APPLE_DEVELOPERIDG2CA }}" | base64 --decode > certs/DeveloperIDG2CA.cer echo "${{ secrets.APPLE_DISTRIBUTION }}" | base64 --decode > certs/distribution.cer echo "${{ secrets.APPLE_PRIVATE_KEY }}" | base64 --decode > certs/private_key.p12 security verify-cert -c certs/DeveloperIDG2CA.cer security verify-cert -c certs/distribution.cer - name: Import Certificates env: KEYCHAIN_PASSWORD: ${{ secrets.KEYCHAIN_PASSWORD }} run: | KEYCHAIN_PATH=$RUNNER_TEMP/app-signing.keychain-db security create-keychain -p "$KEYCHAIN_PASSWORD" $KEYCHAIN_PATH security set-keychain-settings -lut 21600 $KEYCHAIN_PATH security unlock-keychain -p "$KEYCHAIN_PASSWORD" $KEYCHAIN_PATH security import ./certs/DeveloperIDG2CA.cer -k $KEYCHAIN_PATH -T /usr/bin/codesign security import ./certs/distribution.cer -k $KEYCHAIN_PATH -T /usr/bin/codesign security import ./certs/private_key.p12 -k $KEYCHAIN_PATH -P "${{ secrets.APPLE_PRIVATE_KEY_PASSWORD }}" -T /usr/bin/codesign security list-keychain -d user -s $KEYCHAIN_PATH - name: Verify Signature run: codesign --verify --deep --strict --verbose=2 bin/yao - name: Notarize Yao ${{ matrix.arch }} timeout-minutes: 15 env: APPLE_ID: ${{ secrets.APPLE_ID }} APPLE_TEAME_ID: ${{ secrets.APPLE_TEAME_ID }} APPLE_APP_SPEC_PASS: ${{ secrets.APPLE_APP_SPEC_PASS }} run: | zip -j bin/yao.zip bin/yao SUBMIT_OUT=$(xcrun notarytool submit bin/yao.zip \ --apple-id "$APPLE_ID" \ --team-id "$APPLE_TEAME_ID" \ --password "$APPLE_APP_SPEC_PASS" \ --wait --timeout 10m --output-format json 2>&1) || true echo "$SUBMIT_OUT" STATUS=$(echo "$SUBMIT_OUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('status',''))" 2>/dev/null || true) SUB_ID=$(echo "$SUBMIT_OUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('id',''))" 2>/dev/null || true) if [ "$STATUS" != "Accepted" ]; then echo "::error::Yao ${{ matrix.arch }} notarization failed (status: $STATUS)" [ -n "$SUB_ID" ] && xcrun notarytool log "$SUB_ID" \ --apple-id "$APPLE_ID" \ --team-id "$APPLE_TEAME_ID" \ --password "$APPLE_APP_SPEC_PASS" || true exit 1 fi echo "Yao ${{ matrix.arch }} notarization accepted." ================================================ FILE: .github/workflows/pr-receive.yml ================================================ name: Receive PR # read-only repo token # no access to secrets on: pull_request: jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Save PR number run: | mkdir -p ./pr echo ${{ github.event.number }} > ./pr/NR echo ${{ github.event.pull_request.head.sha }} > ./pr/SHA - uses: actions/upload-artifact@v4 with: name: pr path: pr/ ================================================ FILE: .github/workflows/pr-test.yml ================================================ name: PR Unit Test # read-write repo token # access to secrets on: workflow_run: workflows: ["Receive PR"] types: - completed env: YAO_DEV: ${{ github.WORKSPACE }} YAO_ENV: development YAO_ROOT: ${{ github.WORKSPACE }}/../app YAO_HOST: 0.0.0.0 YAO_PORT: 5099 YAO_SESSION: "memory" YAO_LOG: "./logs/application.log" YAO_LOG_MODE: "TEXT" YAO_JWT_SECRET: "bLp@bi!oqo-2U+hoTRUG" YAO_DB_AESKEY: "ZLX=T&f6refeCh-ro*r@" OSS_TEST_ID: ${{ secrets.OSS_TEST_ID}} OSS_TEST_SECRET: ${{ secrets.OSS_TEST_SECRET}} ROOT_PLUGIN: ${{ github.WORKSPACE }}/../../../data/gou-unit/plugins MYSQL_TEST_HOST: "127.0.0.1" MYSQL_TEST_PORT: "3308" MYSQL_TEST_USER: test MYSQL_TEST_PASS: "123456" SQLITE_DB: "./app/db/yao.db" REDIS_TEST_HOST: "127.0.0.1" REDIS_TEST_PORT: "6379" REDIS_TEST_DB: "2" MONGO_TEST_HOST: "127.0.0.1" MONGO_TEST_PORT: "27017" MONGO_TEST_USER: "root" MONGO_TEST_PASS: "123456" OPENAI_TEST_KEY: ${{ secrets.OPENAI_TEST_KEY }} TEST_MOAPI_SECRET: ${{ secrets.OPENAI_TEST_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_TEST_KEY }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} TEST_MOAPI_MIRROR: https://api.openai.com # DeepSeek API Configuration DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }} DEEPSEEK_API_PROXY: ${{ secrets.DEEPSEEK_API_PROXY }} DEEPSEEK_MODELS_R1: ${{ secrets.DEEPSEEK_MODELS_R1 }} DEEPSEEK_MODELS_V3: ${{ secrets.DEEPSEEK_MODELS_V3 }} DEEPSEEK_MODELS_V3_1: ${{ secrets.DEEPSEEK_MODELS_V3_1 }} # Search API Configuration TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} SERPAPI_API_KEY: ${{ secrets.SERPAPI_API_KEY }} SERPER_API_KEY: ${{ secrets.SERPER_API_KEY }} # Claude API Configuration CLAUDE_API_KEY: ${{ secrets.CLAUDE_API_KEY }} CLAUDE_PROXY: ${{ secrets.CLAUDE_PROXY }} CLAUDE_API_HOST: ${{ secrets.CLAUDE_API_HOST }} CLAUDE_SONNET_4: ${{ secrets.CLAUDE_SONNET_4 }} CLAUDE_SONNET_4_THINKING: ${{ secrets.CLAUDE_SONNET_4_THINKING }} # Moonshot / Kimi API Configuration MOONSHOT_API_KEY: ${{ secrets.MOONSHOT_API_KEY }} MOONSHOT_PROXY: "https://api.moonshot.cn" KIMI_CODE_API_KEY: ${{ secrets.KIMI_CODE_API_KEY }} KIMI_CODE_PROXY: "https://api.kimi.com/coding" TAB_NAME: "::PET ADMIN" PAGE_SIZE: "20" PAGE_LINK: "https://yaoapps.com" PAGE_ICON: "icon-trash" DEMO_APP: ${{ github.WORKSPACE }}/../app # Application Setting ## Path YAO_EXTENSION_ROOT: ${{ github.WORKSPACE }}/../extension YAO_TEST_APPLICATION: ${{ github.WORKSPACE }}/../app YAO_SUI_TEST_APPLICATION: ${{ github.WORKSPACE }}/../yao-startup-webapp ## Runtime YAO_RUNTIME_MIN: 3 YAO_RUNTIME_MAX: 6 YAO_RUNTIME_HEAP_LIMIT: 1500000000 YAO_RUNTIME_HEAP_RELEASE: 10000000 YAO_RUNTIME_HEAP_AVAILABLE: 550000000 YAO_RUNTIME_PRECOMPILE: true # Neo4j NEO4J_TEST_URL: "neo4j://localhost:7687" NEO4J_TEST_USER: "neo4j" NEO4J_TEST_PASS: "Yao2026Neo4j" # Qdrant QDRANT_TEST_HOST: "127.0.0.1" QDRANT_TEST_PORT: "6334" # S3 S3_API: ${{ secrets.S3_API }} S3_ACCESS_KEY: ${{ secrets.S3_ACCESS_KEY }} S3_SECRET_KEY: ${{ secrets.S3_SECRET_KEY }} S3_BUCKET: ${{ secrets.S3_BUCKET }} S3_PUBLIC_URL: ${{ secrets.S3_PUBLIC_URL }} # === Openapi Signin Configs === SIGNIN_CLIENT_ID: "kiCeR88kDwHBDuNHvN51cZgmpp3tmF6Z" ## Google GOOGLE_CLIENT_ID: ${{ secrets.GOOGLE_CLIENT_ID }} GOOGLE_CLIENT_SECRET: ${{ secrets.GOOGLE_CLIENT_SECRET }} ## Microsoft MICROSOFT_CLIENT_ID: ${{ secrets.MICROSOFT_CLIENT_ID }} MICROSOFT_CLIENT_SECRET: ${{ secrets.MICROSOFT_CLIENT_SECRET }} ## Apple APPLE_SERVICE_ID: ${{ secrets.APPLE_SERVICE_ID }} APPLE_PRIVATE_KEY_PATH: "apple/signin_client_secret_key.p8" APPLE_KEY_ID: ${{ secrets.APPLE_KEY_ID }} APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} ## Github GITHUBUSER_CLIENT_ID: ${{ secrets.GITHUBUSER_CLIENT_ID }} GITHUBUSER_CLIENT_SECRET: ${{ secrets.GITHUBUSER_CLIENT_SECRET }} ## Cloudflare Turnstile CLOUDFLARE_TURNSTILE_SITEKEY: ${{ secrets.CLOUDFLARE_TURNSTILE_SITEKEY }} CLOUDFLARE_TURNSTILE_SECRET: ${{ secrets.CLOUDFLARE_TURNSTILE_SECRET }} # === Messaging Services === ## Mailgun MAILGUN_DOMAIN: ${{ secrets.MAILGUN_DOMAIN }} MAILGUN_API_KEY: ${{ secrets.MAILGUN_API_KEY }} MAILGUN_FROM: "Yaobots Tests " ## SMTP Server( Mailgun ) SMTP_HOST: "smtp.mailgun.org" SMTP_PORT: "465" SMTP_USERNAME: ${{ secrets.SMTP_USERNAME }} SMTP_PASSWORD: ${{ secrets.SMTP_PASSWORD }} SMTP_FROM: "Yaobots SMTP Tests " ## SMTP Server( Gmail ) RELIABLE_SMTP_HOST: "smtp.gmail.com" RELIABLE_SMTP_PORT: "465" RELIABLE_SMTP_USERNAME: ${{ secrets.RELIABLE_SMTP_USERNAME }} RELIABLE_SMTP_PASSWORD: ${{ secrets.RELIABLE_SMTP_PASSWORD }} RELIABLE_SMTP_FROM: "Yaobots Gmail Tests " ## IMAP Server (Gmail) RELIABLE_IMAP_HOST: "imap.gmail.com" RELIABLE_IMAP_PORT: "993" RELIABLE_IMAP_USERNAME: ${{ secrets.RELIABLE_SMTP_USERNAME }} RELIABLE_IMAP_PASSWORD: ${{ secrets.RELIABLE_SMTP_PASSWORD }} RELIABLE_IMAP_MAILBOX: "INBOX" ## Twilio TWILIO_ACCOUNT_SID: ${{ secrets.TWILIO_ACCOUNT_SID }} TWILIO_AUTH_TOKEN: ${{ secrets.TWILIO_AUTH_TOKEN }} TWILIO_API_SID: ${{ secrets.TWILIO_API_SID }} TWILIO_API_KEY: ${{ secrets.TWILIO_API_KEY }} TWILIO_SENDGRID_API_SID: ${{ secrets.TWILIO_SENDGRID_API_SID }} TWILIO_SENDGRID_API_KEY: ${{ secrets.TWILIO_SENDGRID_API_KEY }} TWILIO_FROM_PHONE: "+17035701412" TWILIO_FROM_EMAIL: "unit-test@sendgrid.yaobots.com" TWILIO_TEST_PHONE: ${{ secrets.TWILIO_TEST_PHONE }} jobs: # ============================================================================= # KB Tests (kb) - Run once with SQLite (requires Qdrant, Neo4j, FastEmbed) # ============================================================================= KBTest: runs-on: ubuntu-latest services: qdrant: image: qdrant/qdrant:latest ports: - 6333:6333 - 6334:6334 fastembed: image: yaoapp/fastembed:latest-amd64 env: FASTEMBED_PASSWORD: Yao@2026 ports: - 6001:8000 neo4j: image: neo4j:latest ports: - "7687:7687" env: NEO4J_AUTH: neo4j/Yao2026Neo4j mongodb: image: mongo:6.0 ports: - 27017:27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: 123456 MONGO_INITDB_DATABASE: test strategy: matrix: go: ["1.25"] if: > ${{ github.event.workflow_run.event == 'pull_request' && github.event.workflow_run.conclusion == 'success' }} steps: - name: "Download artifact" uses: actions/github-script@v7 with: script: | var artifacts = await github.rest.actions.listWorkflowRunArtifacts({ owner: context.repo.owner, repo: context.repo.repo, run_id: ${{github.event.workflow_run.id }}, }); var matchArtifact = artifacts.data.artifacts.filter((artifact) => { return artifact.name == "pr" })[0]; var download = await github.rest.actions.downloadArtifact({ owner: context.repo.owner, repo: context.repo.repo, artifact_id: matchArtifact.id, archive_format: 'zip', }); var fs = require('fs'); fs.writeFileSync('${{github.workspace}}/pr.zip', Buffer.from(download.data)); - name: "Read NR & SHA" run: | unzip pr.zip cat NR cat SHA echo HEAD=$(cat SHA) >> $GITHUB_ENV echo NR=$(cat NR) >> $GITHUB_ENV - name: "Comment on PR" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: '🤖 KB Tests (kb) running with SQLite...' }); - name: Checkout Kun uses: actions/checkout@v4 with: repository: yaoapp/kun path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: yaoapp/xun path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: yaoapp/gou path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout pull request HEAD commit uses: actions/checkout@v4 with: ref: ${{ env.HEAD }} - name: Setup Apple Private Key run: | mkdir -p ../app/openapi/certs/apple echo "${{ secrets.APPLE_PRIVATE_KEY_USER }}" > ../app/openapi/certs/apple/signin_client_secret_key.p8 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:6 - name: Setup Go Tools run: make tools - name: Setup ENV (SQLite) run: | mkdir -p ${{ github.WORKSPACE }}/../app/db echo "YAO_DB_DRIVER=sqlite3" >> $GITHUB_ENV echo "YAO_DB_PRIMARY=${{ github.WORKSPACE }}/../app/db/yao.db" >> $GITHUB_ENV - name: Run KB Tests (kb) run: make unit-test-kb - name: Codecov Report uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} - name: "Comment on PR - KB Tests Done" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: '✅ KB Tests (kb) passed!' }); # ============================================================================= # Agent Tests (agent, aigc) - Run once with SQLite # ============================================================================= AgentTest: runs-on: ubuntu-latest services: qdrant: image: qdrant/qdrant:latest ports: - 6333:6333 - 6334:6334 fastembed: image: yaoapp/fastembed:latest-amd64 env: FASTEMBED_PASSWORD: Yao@2026 ports: - 6001:8000 neo4j: image: neo4j:latest ports: - "7687:7687" env: NEO4J_AUTH: neo4j/Yao2026Neo4j mcp-everything: image: yaoapp/mcp-everything:latest ports: - "3021:3021" - "3022:3022" mongodb: image: mongo:6.0 ports: - 27017:27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: 123456 MONGO_INITDB_DATABASE: test strategy: matrix: go: ["1.25"] if: > ${{ github.event.workflow_run.event == 'pull_request' && github.event.workflow_run.conclusion == 'success' }} steps: - name: "Download artifact" uses: actions/github-script@v7 with: script: | var artifacts = await github.rest.actions.listWorkflowRunArtifacts({ owner: context.repo.owner, repo: context.repo.repo, run_id: ${{github.event.workflow_run.id }}, }); var matchArtifact = artifacts.data.artifacts.filter((artifact) => { return artifact.name == "pr" })[0]; var download = await github.rest.actions.downloadArtifact({ owner: context.repo.owner, repo: context.repo.repo, artifact_id: matchArtifact.id, archive_format: 'zip', }); var fs = require('fs'); fs.writeFileSync('${{github.workspace}}/pr.zip', Buffer.from(download.data)); - name: "Read NR & SHA" run: | unzip pr.zip cat NR cat SHA echo HEAD=$(cat SHA) >> $GITHUB_ENV echo NR=$(cat NR) >> $GITHUB_ENV - name: "Comment on PR" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: '🤖 Agent Tests (agent, aigc) running with SQLite...' }); - name: Checkout Kun uses: actions/checkout@v4 with: repository: yaoapp/kun path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: yaoapp/xun path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: yaoapp/gou path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout pull request HEAD commit uses: actions/checkout@v4 with: ref: ${{ env.HEAD }} - name: Setup Apple Private Key run: | mkdir -p ../app/openapi/certs/apple echo "${{ secrets.APPLE_PRIVATE_KEY_USER }}" > ../app/openapi/certs/apple/signin_client_secret_key.p8 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:6 - name: Setup Go Tools run: make tools - name: Install pdftoppm, mutool, imagemagick run: | sudo apt update sudo apt install -y poppler-utils mupdf-tools imagemagick - name: Test pdftoppm, mutool, imagemagick run: | pdftoppm -v mutool -v convert -version - name: Setup ENV (SQLite) run: | mkdir -p ${{ github.WORKSPACE }}/../app/db echo "YAO_DB_DRIVER=sqlite3" >> $GITHUB_ENV echo "YAO_DB_PRIMARY=${{ github.WORKSPACE }}/../app/db/yao.db" >> $GITHUB_ENV - name: Pull Sandbox Test Images run: | docker pull alpine:latest docker pull yaoapp/sandbox-base:latest || true docker pull yaoapp/sandbox-claude:latest || true - name: Run Agent Tests (agent, aigc) env: YAO_SANDBOX_WORKSPACE: ${{ runner.temp }}/sandbox/workspace YAO_SANDBOX_IPC: ${{ runner.temp }}/sandbox/ipc run: | export YAO_SANDBOX_CONTAINER_USER="$(id -u):$(id -g)" make unit-test-agent - name: Codecov Report uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} - name: "Comment on PR - Agent Tests Done" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: '✅ Agent Tests (agent, aigc) passed!' }); # ============================================================================= # Robot Tests (all agent/robot/... packages) - Unit + E2E with real LLM calls # ============================================================================= RobotTest: runs-on: ubuntu-latest services: mcp-everything: image: yaoapp/mcp-everything:latest ports: - "3021:3021" - "3022:3022" mongodb: image: mongo:6.0 ports: - 27017:27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: 123456 MONGO_INITDB_DATABASE: test strategy: matrix: go: ["1.25"] if: > ${{ github.event.workflow_run.event == 'pull_request' && github.event.workflow_run.conclusion == 'success' }} steps: - name: "Download artifact" uses: actions/github-script@v7 with: script: | var artifacts = await github.rest.actions.listWorkflowRunArtifacts({ owner: context.repo.owner, repo: context.repo.repo, run_id: ${{github.event.workflow_run.id }}, }); var matchArtifact = artifacts.data.artifacts.filter((artifact) => { return artifact.name == "pr" })[0]; var download = await github.rest.actions.downloadArtifact({ owner: context.repo.owner, repo: context.repo.repo, artifact_id: matchArtifact.id, archive_format: 'zip', }); var fs = require('fs'); fs.writeFileSync('${{github.workspace}}/pr.zip', Buffer.from(download.data)); - name: "Read NR & SHA" run: | unzip pr.zip cat NR cat SHA echo HEAD=$(cat SHA) >> $GITHUB_ENV echo NR=$(cat NR) >> $GITHUB_ENV - name: "Comment on PR" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: '🤖 Robot Tests (Unit + E2E) running with SQLite...' }); - name: Checkout Kun uses: actions/checkout@v4 with: repository: yaoapp/kun path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: yaoapp/xun path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: yaoapp/gou path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout pull request HEAD commit uses: actions/checkout@v4 with: ref: ${{ env.HEAD }} - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:6 - name: Setup Go Tools run: make tools - name: Setup ENV (SQLite) run: | mkdir -p ${{ github.WORKSPACE }}/../app/db echo "YAO_DB_DRIVER=sqlite3" >> $GITHUB_ENV echo "YAO_DB_PRIMARY=${{ github.WORKSPACE }}/../app/db/yao.db" >> $GITHUB_ENV - name: Run Robot Tests (Unit + E2E) run: make unit-test-robot - name: Codecov Report uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} - name: "Comment on PR - Robot Tests Done" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: '✅ Robot Tests (Unit + E2E) passed!' }); # ============================================================================= # Sandbox Tests (requires Docker) - Run with Docker-in-Docker # ============================================================================= SandboxTest: runs-on: ubuntu-latest services: mongodb: image: mongo:6.0 ports: - 27017:27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: 123456 MONGO_INITDB_DATABASE: test strategy: matrix: go: ["1.25"] if: > ${{ github.event.workflow_run.event == 'pull_request' && github.event.workflow_run.conclusion == 'success' }} steps: - name: "Download artifact" uses: actions/github-script@v7 with: script: | var artifacts = await github.rest.actions.listWorkflowRunArtifacts({ owner: context.repo.owner, repo: context.repo.repo, run_id: ${{github.event.workflow_run.id }}, }); var matchArtifact = artifacts.data.artifacts.filter((artifact) => { return artifact.name == "pr" })[0]; var download = await github.rest.actions.downloadArtifact({ owner: context.repo.owner, repo: context.repo.repo, artifact_id: matchArtifact.id, archive_format: 'zip', }); var fs = require('fs'); fs.writeFileSync('${{github.workspace}}/pr.zip', Buffer.from(download.data)); - name: "Read NR & SHA" run: | unzip pr.zip cat NR cat SHA echo HEAD=$(cat SHA) >> $GITHUB_ENV echo NR=$(cat NR) >> $GITHUB_ENV - name: "Comment on PR" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: '🤖 Sandbox Tests running with Docker...' }); - name: Checkout Kun uses: actions/checkout@v4 with: repository: yaoapp/kun path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: yaoapp/xun path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: yaoapp/gou path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout pull request HEAD commit uses: actions/checkout@v4 with: ref: ${{ env.HEAD }} - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:6 - name: Setup Go Tools run: make tools - name: Setup ENV (SQLite) run: | mkdir -p ${{ github.WORKSPACE }}/../app/db echo "YAO_DB_DRIVER=sqlite3" >> $GITHUB_ENV echo "YAO_DB_PRIMARY=${{ github.WORKSPACE }}/../app/db/yao.db" >> $GITHUB_ENV - name: Pull Sandbox Test Images run: | docker pull alpine:latest docker pull yaoapp/sandbox-base:latest || true docker pull yaoapp/sandbox-claude:latest || true - name: Run Sandbox Tests env: YAO_SANDBOX_WORKSPACE: ${{ runner.temp }}/sandbox/workspace YAO_SANDBOX_IPC: ${{ runner.temp }}/sandbox/ipc run: | # Use runner's UID:GID to avoid permission issues export YAO_SANDBOX_CONTAINER_USER="$(id -u):$(id -g)" make unit-test-sandbox - name: Codecov Report uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} - name: "Comment on PR - Sandbox Tests Done" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: '✅ Sandbox Tests passed!' }); # ============================================================================= # Sandbox V2 Tests (tai SDK + workspace, Docker + K8s via k3d) # Full sandbox/v2 integration tests are run locally. # ============================================================================= SandboxV2Test: runs-on: ubuntu-latest services: mongodb: image: mongo:6.0 ports: - 27017:27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: 123456 MONGO_INITDB_DATABASE: test strategy: matrix: go: ["1.25"] steps: - name: "Download artifact" uses: actions/github-script@v7 with: script: | var artifacts = await github.rest.actions.listWorkflowRunArtifacts({ owner: context.repo.owner, repo: context.repo.repo, run_id: ${{github.event.workflow_run.id }}, }); var matchArtifact = artifacts.data.artifacts.filter((artifact) => { return artifact.name == "pr" })[0]; var download = await github.rest.actions.downloadArtifact({ owner: context.repo.owner, repo: context.repo.repo, artifact_id: matchArtifact.id, archive_format: 'zip', }); var fs = require('fs'); fs.writeFileSync('${{github.workspace}}/pr.zip', Buffer.from(download.data)); - name: "Read NR & SHA" run: | unzip pr.zip cat NR cat SHA echo HEAD=$(cat SHA) >> $GITHUB_ENV echo NR=$(cat NR) >> $GITHUB_ENV - name: "Comment on PR" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: '🤖 Sandbox V2 CI Tests running (tai + workspace)...' }); - name: Checkout Kun uses: actions/checkout@v4 with: repository: yaoapp/kun path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: yaoapp/xun path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: yaoapp/gou path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout pull request HEAD commit uses: actions/checkout@v4 with: ref: ${{ env.HEAD }} - name: Setup Apple Private Key run: | mkdir -p ../app/openapi/certs/apple echo "${{ secrets.APPLE_PRIVATE_KEY_USER }}" > ../app/openapi/certs/apple/signin_client_secret_key.p8 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:6 - name: Setup Go Tools run: make tools - name: Setup ENV (SQLite) run: | mkdir -p ${{ github.WORKSPACE }}/../app/db echo "YAO_DB_DRIVER=sqlite3" >> $GITHUB_ENV echo "YAO_DB_PRIMARY=${{ github.WORKSPACE }}/../app/db/yao.db" >> $GITHUB_ENV - name: Pull Test Images run: | docker pull yaoapp/tai-sandbox-test:latest || true docker pull yaoapp/tai:latest docker pull alpine:latest - name: Install k3d run: curl -s https://raw.githubusercontent.com/k3d-io/k3d/main/install.sh | bash - name: Create k3d cluster run: | k3d cluster create tai-test --no-lb --wait --api-port 16443 kubectl wait --for=condition=Ready node --all --timeout=60s k3d image import alpine:latest -c tai-test - name: Start Tai Docker instance run: | docker run -d --name tai-docker \ -v /var/run/docker.sock:/var/run/docker.sock \ -p 19100:19100 -p 8099:8099 -p 12375:12375 -p 16080:16080 \ yaoapp/tai:latest server -direct \ -grpc 0.0.0.0:19100 -http 0.0.0.0:8099 -vnc 0.0.0.0:16080 -docker 0.0.0.0:12375 for i in $(seq 1 30); do if curl -sf http://127.0.0.1:8099/healthz > /dev/null 2>&1; then echo "Tai Docker HTTP ready"; break fi echo "Waiting for Tai Docker HTTP... ($i)"; sleep 1 done curl -sf http://127.0.0.1:8099/healthz > /dev/null 2>&1 || { echo "::error::Tai Docker HTTP failed"; docker logs tai-docker 2>&1; exit 1 } for i in $(seq 1 15); do if nc -z 127.0.0.1 19100 2>/dev/null; then echo "Tai Docker gRPC ready"; break fi echo "Waiting for Tai Docker gRPC... ($i)"; sleep 1 done nc -z 127.0.0.1 19100 2>/dev/null || { echo "::error::Tai Docker gRPC failed"; docker logs tai-docker 2>&1; exit 1 } - name: Generate kubeconfig for Tai K8s run: | K3D_IP=$(docker inspect k3d-tai-test-server-0 | jq -r '.[0].NetworkSettings.Networks["k3d-tai-test"].IPAddress') echo "k3d server IP: ${K3D_IP}" k3d kubeconfig get tai-test > /tmp/kubeconfig-k3d.yml # Kubeconfig for tai-k8s container (uses k3d-internal IP) sed "s|server: .*|server: https://${K3D_IP}:6443|" /tmp/kubeconfig-k3d.yml \ > /tmp/kubeconfig-tai-k8s.yml echo "Container kubeconfig server:" grep server: /tmp/kubeconfig-tai-k8s.yml # Kubeconfig for test runner (uses localhost via port-mapped 6443) sed 's|server: .*|server: https://127.0.0.1:6443|' /tmp/kubeconfig-k3d.yml \ > ${{ runner.temp }}/kubeconfig-tai.yml echo "Test runner kubeconfig server:" grep server: ${{ runner.temp }}/kubeconfig-tai.yml - name: Start Tai K8s instance run: | K3D_IP=$(docker inspect k3d-tai-test-server-0 | jq -r '.[0].NetworkSettings.Networks["k3d-tai-test"].IPAddress') echo "k3d server IP: ${K3D_IP}" docker run -d --name tai-k8s \ --network k3d-tai-test \ -p 19101:19100 -p 8100:8099 -p 6443:16443 -p 16081:16080 \ -v /var/run/docker.sock:/var/run/docker.sock:ro \ -v /tmp/kubeconfig-tai-k8s.yml:/etc/tai/kubeconfig.yml:ro \ -e TAI_K8S_UPSTREAM="tcp://${K3D_IP}:6443" \ -e TAI_KUBECONFIG=/etc/tai/kubeconfig.yml \ yaoapp/tai:latest server -direct \ -grpc 0.0.0.0:19100 -http 0.0.0.0:8099 -vnc 0.0.0.0:16080 -k8s 0.0.0.0:16443 for i in $(seq 1 30); do if curl -sf http://127.0.0.1:8100/healthz > /dev/null 2>&1; then echo "Tai K8s HTTP ready"; break fi echo "Waiting for Tai K8s HTTP... ($i)"; sleep 1 done curl -sf http://127.0.0.1:8100/healthz > /dev/null 2>&1 || { echo "::error::Tai K8s HTTP failed"; docker logs tai-k8s 2>&1; exit 1 } for i in $(seq 1 15); do if nc -z 127.0.0.1 19101 2>/dev/null; then echo "Tai K8s gRPC ready"; break fi echo "Waiting for Tai K8s gRPC... ($i)"; sleep 1 done nc -z 127.0.0.1 19101 2>/dev/null || { echo "::error::Tai K8s gRPC failed"; docker logs tai-k8s 2>&1; exit 1 } - name: Run Sandbox V2 CI Tests (tai + workspace) env: TAI_TEST_HOST: "127.0.0.1" TAI_TEST_DOCKER: "tcp://127.0.0.1:12375" TAI_TEST_GRPC_PORT: "19100" TAI_TEST_HTTP_PORT: "8099" TAI_TEST_VNC_PORT: "16080" TAI_TEST_DOCKER_PORT: "12375" TAI_TEST_K8S_HOST: "127.0.0.1" TAI_TEST_K8S_PORT: "6443" TAI_TEST_K8S_GRPC_PORT: "19101" TAI_TEST_KUBECONFIG: "${{ runner.temp }}/kubeconfig-tai.yml" TAI_TEST_HOST_IP: "172.17.0.1" SANDBOX_TEST_REMOTE_ADDR: "tai://127.0.0.1:19100" SANDBOX_TEST_IMAGE: "yaoapp/tai-sandbox-test:latest" run: make unit-test-sandbox-v2 - name: Codecov Report if: always() uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false - name: "Comment on PR - Sandbox V2 Tests Done" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: '✅ Sandbox V2 CI Tests passed (tai + workspace)!' }); # ============================================================================= # Benchmark & Memory Leak Tests - Run with MySQL8.0 and SQLite3 # ============================================================================= PerfTest: runs-on: ubuntu-latest services: mcp-everything: image: yaoapp/mcp-everything:latest ports: - "3021:3021" - "3022:3022" mongodb: image: mongo:6.0 ports: - 27017:27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: 123456 MONGO_INITDB_DATABASE: test strategy: matrix: go: ["1.25"] db: [MySQL8.0, SQLite3] if: > ${{ github.event.workflow_run.event == 'pull_request' && github.event.workflow_run.conclusion == 'success' }} steps: - name: "Download artifact" uses: actions/github-script@v7 with: script: | var artifacts = await github.rest.actions.listWorkflowRunArtifacts({ owner: context.repo.owner, repo: context.repo.repo, run_id: ${{github.event.workflow_run.id }}, }); var matchArtifact = artifacts.data.artifacts.filter((artifact) => { return artifact.name == "pr" })[0]; var download = await github.rest.actions.downloadArtifact({ owner: context.repo.owner, repo: context.repo.repo, artifact_id: matchArtifact.id, archive_format: 'zip', }); var fs = require('fs'); fs.writeFileSync('${{github.workspace}}/pr.zip', Buffer.from(download.data)); - name: "Read NR & SHA" run: | unzip pr.zip cat NR cat SHA echo HEAD=$(cat SHA) >> $GITHUB_ENV echo NR=$(cat NR) >> $GITHUB_ENV - name: Checkout Kun uses: actions/checkout@v4 with: repository: yaoapp/kun path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: yaoapp/xun path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: yaoapp/gou path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout pull request HEAD commit uses: actions/checkout@v4 with: ref: ${{ env.HEAD }} - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:6 - name: Install FFmpeg 7.x run: | wget https://github.com/BtbN/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-linux64-gpl.tar.xz tar -xf ffmpeg-master-latest-linux64-gpl.tar.xz sudo cp ffmpeg-master-latest-linux64-gpl/bin/ffmpeg /usr/local/bin/ sudo cp ffmpeg-master-latest-linux64-gpl/bin/ffprobe /usr/local/bin/ sudo chmod +x /usr/local/bin/ffmpeg /usr/local/bin/ffprobe - name: Test FFmpeg run: ffmpeg -version - name: Setup Go Tools run: make tools - name: Setup ${{ matrix.db }} uses: ./.github/actions/setup-db with: kind: "${{ matrix.db }}" db: "xiang" user: "xiang" password: ${{ secrets.UNIT_PASS }} - name: Setup ENV env: PASSWORD: ${{ secrets.UNIT_PASS }} run: | echo "YAO_DB_DRIVER=$DB_DRIVER" >> $GITHUB_ENV if [ "$DB_DRIVER" = "mysql" ]; then echo "YAO_DB_PRIMARY=$DB_USER:$PASSWORD@$DB_HOST" >> $GITHUB_ENV else echo "YAO_DB_PRIMARY=${{ github.WORKSPACE }}/../app/db/yao.db" >> $GITHUB_ENV mkdir -p ${{ github.WORKSPACE }}/../app/db fi - name: Run Benchmark & Memory Leak Tests run: | make benchmark make memory-leak # ============================================================================= # Core Tests - Run with DB matrix (MySQL/SQLite/Redis/Mongo combinations) # ============================================================================= CoreTest: runs-on: ubuntu-latest services: qdrant: image: qdrant/qdrant:latest ports: - 6333:6333 # HTTP API - 6334:6334 # gRPC fastembed: image: yaoapp/fastembed:latest-amd64 env: FASTEMBED_PASSWORD: Yao@2026 ports: - 6001:8000 neo4j: image: neo4j:latest ports: - "7687:7687" env: NEO4J_AUTH: neo4j/Yao2026Neo4j mcp-everything: image: yaoapp/mcp-everything:latest ports: - "3021:3021" - "3022:3022" strategy: matrix: go: ["1.25"] db: [MySQL8.0, SQLite3] redis: [4, 5, 6] mongo: ["6.0"] if: > ${{ github.event.workflow_run.event == 'pull_request' && github.event.workflow_run.conclusion == 'success' }} steps: - name: "Download artifact" uses: actions/github-script@v7 with: script: | var artifacts = await github.rest.actions.listWorkflowRunArtifacts({ owner: context.repo.owner, repo: context.repo.repo, run_id: ${{github.event.workflow_run.id }}, }); var matchArtifact = artifacts.data.artifacts.filter((artifact) => { return artifact.name == "pr" })[0]; var download = await github.rest.actions.downloadArtifact({ owner: context.repo.owner, repo: context.repo.repo, artifact_id: matchArtifact.id, archive_format: 'zip', }); var fs = require('fs'); fs.writeFileSync('${{github.workspace}}/pr.zip', Buffer.from(download.data)); - name: "Read NR & SHA" run: | unzip pr.zip cat NR cat SHA echo HEAD=$(cat SHA) >> $GITHUB_ENV echo NR=$(cat NR) >> $GITHUB_ENV - name: "Comment on PR" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var fs = require('fs'); var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: 'Thank you for the PR! The db: ${{ matrix.db }} redis: ${{ matrix.redis }} mongo: ${{ matrix.mongo }} test workflow is running, the results of the run will be commented later.' }); - name: Checkout Kun uses: actions/checkout@v4 with: repository: yaoapp/kun path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: yaoapp/xun path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: yaoapp/gou path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") # Get the directory where the ZIP file is located echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Yao Startup Webapp uses: actions/checkout@v4 with: repository: yaoapp/yao-startup-webapp submodules: true token: ${{ secrets.YAO_TEST_TOKEN }} path: yao-startup-webapp - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Kun, Xun, Gou, V8Go run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ mv yao-startup-webapp ../ ls -l . ls -l ../ - name: Checkout pull request HEAD commit uses: actions/checkout@v4 with: ref: ${{ env.HEAD }} - name: Setup Apple Private Key run: | mkdir -p ../app/openapi/certs/apple echo "${{ secrets.APPLE_PRIVATE_KEY_USER }}" > ../app/openapi/certs/apple/signin_client_secret_key.p8 - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:${{ matrix.redis }} - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Install FFmpeg 7.x run: | wget https://github.com/BtbN/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-linux64-gpl.tar.xz tar -xf ffmpeg-master-latest-linux64-gpl.tar.xz sudo cp ffmpeg-master-latest-linux64-gpl/bin/ffmpeg /usr/local/bin/ sudo cp ffmpeg-master-latest-linux64-gpl/bin/ffprobe /usr/local/bin/ sudo chmod +x /usr/local/bin/ffmpeg /usr/local/bin/ffprobe - name: Test FFmpeg run: ffmpeg -version - name: Install pdftoppm, mutool, imagemagick run: | sudo apt update sudo apt install -y poppler-utils mupdf-tools imagemagick - name: Test pdftoppm, mutool, imagemagick run: | pdftoppm -v mutool -v convert -version - name: Start MongoDB run: | docker run --name mongodb --publish 27017:27017 \ -e MONGO_INITDB_DATABASE=test \ -e MONGO_INITDB_ROOT_USERNAME=root \ -e MONGO_INITDB_ROOT_PASSWORD=123456 \ --detach mongo:${{ matrix.mongo }} # Wait for MongoDB to be ready for i in $(seq 1 20); do if docker exec mongodb mongosh --quiet --port 27017 --username root --password 123456 --eval "db.serverStatus()" > /dev/null 2>&1; then echo "MongoDB is ready" break fi echo "Waiting for MongoDB... ($i)" sleep 1 done - name: Setup MySQL8.0 (connector) uses: ./.github/actions/setup-db with: kind: "MySQL8.0" db: "test" user: "test" password: "123456" port: "3308" - name: Setup ${{ matrix.db }} uses: ./.github/actions/setup-db with: kind: "${{ matrix.db }}" db: "xiang" user: "xiang" password: ${{ secrets.UNIT_PASS }} - name: Setup Go Tools run: | make tools - name: Setup ENV & Host env: PASSWORD: ${{ secrets.UNIT_PASS }} run: | sudo echo "127.0.0.1 local.iqka.com" | sudo tee -a /etc/hosts echo "YAO_DB_DRIVER=$DB_DRIVER" >> $GITHUB_ENV echo "GITHUB_WORKSPACE:\n" && ls -l $GITHUB_WORKSPACE if [ "$DB_DRIVER" = "mysql" ]; then echo "YAO_DB_PRIMARY=$DB_USER:$PASSWORD@$DB_HOST" >> $GITHUB_ENV elif [ "$DB_DRIVER" = "postgres" ]; then echo "YAO_DB_PRIMARY=postgres://$DB_USER:$PASSWORD@$DB_HOST" >> $GITHUB_ENV else echo "YAO_DB_PRIMARY=$YAO_ROOT/$DB_HOST" >> $GITHUB_ENV fi echo ".:\n" && ls -l . echo "..:\n" && ls -l .. ping -c 1 -t 1 local.iqka.com - name: Test Prepare run: | make vet make fmt-check make misspell-check - name: Run Core Tests (exclude AI) run: make unit-test-core - name: Codecov Report uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos - name: "Comment on PR" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var fs = require('fs'); var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: '✨DONE✨ db: ${{ matrix.db }} redis: ${{ matrix.redis }} mongo: ${{ matrix.mongo }} passed.' }); # ============================================================================= # Registry Client SDK Tests (requires Yao Registry Docker service) # ============================================================================= RegistryTest: runs-on: ubuntu-latest services: yao-registry: image: yaoapp/registry:latest ports: - "8080:8080" env: REGISTRY_INIT_USER: yaoagents REGISTRY_INIT_PASS: yaoagents strategy: matrix: go: ["1.25"] if: > ${{ github.event.workflow_run.event == 'pull_request' && github.event.workflow_run.conclusion == 'success' }} steps: - name: "Download artifact" uses: actions/github-script@v7 with: script: | var artifacts = await github.rest.actions.listWorkflowRunArtifacts({ owner: context.repo.owner, repo: context.repo.repo, run_id: ${{github.event.workflow_run.id }}, }); var matchArtifact = artifacts.data.artifacts.filter((artifact) => { return artifact.name == "pr" })[0]; var download = await github.rest.actions.downloadArtifact({ owner: context.repo.owner, repo: context.repo.repo, artifact_id: matchArtifact.id, archive_format: 'zip', }); var fs = require('fs'); fs.writeFileSync('${{github.workspace}}/pr.zip', Buffer.from(download.data)); - name: "Read NR & SHA" run: | unzip pr.zip cat NR cat SHA echo HEAD=$(cat SHA) >> $GITHUB_ENV echo NR=$(cat NR) >> $GITHUB_ENV - name: "Comment on PR" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: '🤖 Registry Client SDK Tests running...' }); - name: Wait for Registry run: | for i in $(seq 1 15); do if curl -sf http://localhost:8080/.well-known/yao-registry > /dev/null 2>&1; then echo "Registry is ready" break fi echo "Waiting for registry... ($i)" sleep 1 done - name: Checkout Kun uses: actions/checkout@v4 with: repository: yaoapp/kun path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: yaoapp/xun path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: yaoapp/gou path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout pull request HEAD commit uses: actions/checkout@v4 with: ref: ${{ env.HEAD }} - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Run Registry Client Tests env: YAO_REGISTRY_URL: http://localhost:8080 YAO_TEST_APPLICATION: ${{ github.workspace }}/../app run: make unit-test-registry - name: Codecov Report uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} - name: "Comment on PR - Registry Tests Done" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: '✅ Registry Client SDK Tests passed!' }); # ============================================================================= # gRPC Tests - Run once with SQLite (transport layer, no DB matrix needed) # ============================================================================= GRPCTest: runs-on: ubuntu-latest services: mongodb: image: mongo:6.0 ports: - 27017:27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: 123456 MONGO_INITDB_DATABASE: test strategy: matrix: go: ["1.25"] if: > ${{ github.event.workflow_run.event == 'pull_request' && github.event.workflow_run.conclusion == 'success' }} steps: - name: "Download artifact" uses: actions/github-script@v7 with: script: | var artifacts = await github.rest.actions.listWorkflowRunArtifacts({ owner: context.repo.owner, repo: context.repo.repo, run_id: ${{github.event.workflow_run.id }}, }); var matchArtifact = artifacts.data.artifacts.filter((artifact) => { return artifact.name == "pr" })[0]; var download = await github.rest.actions.downloadArtifact({ owner: context.repo.owner, repo: context.repo.repo, artifact_id: matchArtifact.id, archive_format: 'zip', }); var fs = require('fs'); fs.writeFileSync('${{github.workspace}}/pr.zip', Buffer.from(download.data)); - name: "Read NR & SHA" run: | unzip pr.zip cat NR cat SHA echo HEAD=$(cat SHA) >> $GITHUB_ENV echo NR=$(cat NR) >> $GITHUB_ENV - name: "Comment on PR" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: '🤖 gRPC Tests running with SQLite...' }); - name: Checkout Kun uses: actions/checkout@v4 with: repository: yaoapp/kun path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: yaoapp/xun path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: yaoapp/gou path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout pull request HEAD commit uses: actions/checkout@v4 with: ref: ${{ env.HEAD }} - name: Setup Apple Private Key run: | mkdir -p ../app/openapi/certs/apple echo "${{ secrets.APPLE_PRIVATE_KEY_USER }}" > ../app/openapi/certs/apple/signin_client_secret_key.p8 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:6 - name: Setup Go Tools run: make tools - name: Setup ENV (SQLite) run: | mkdir -p ${{ github.WORKSPACE }}/../app/db echo "YAO_DB_DRIVER=sqlite3" >> $GITHUB_ENV echo "YAO_DB_PRIMARY=${{ github.WORKSPACE }}/../app/db/yao.db" >> $GITHUB_ENV - name: Run gRPC Tests run: make unit-test-grpc - name: Codecov Report uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} - name: "Comment on PR - gRPC Tests Done" uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const { NR } = process.env var issue_number = NR; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issue_number, body: '✅ gRPC Tests passed!' }); ================================================ FILE: .github/workflows/release-linux.yml ================================================ name: Release Linux on: workflow_dispatch: push: tags: - "v*" permissions: contents: write env: IMAGE_NAME: yaoapp/yao jobs: # =================================================================== # Build Linux Binaries (amd64 + arm64) # =================================================================== build: runs-on: ubuntu-latest container: image: yaoapp/yao-build:1.0.0 env: CF_ACCESS_KEY_ID: ${{ secrets.CF_ACCESS_KEY_ID }} CF_SECRET_ACCESS_KEY: ${{ secrets.CF_SECRET_ACCESS_KEY }} R2_BUCKET: ${{ secrets.R2_BUCKET }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} steps: - name: Configure R2 For Cloudflare run: | aws configure set aws_access_key_id $CF_ACCESS_KEY_ID aws configure set aws_secret_access_key $CF_SECRET_ACCESS_KEY aws configure set default.region us-east-1 aws configure set default.s3.signature_version s3v4 aws configure set default.s3.endpoint_url https://$R2_ACCOUNT_ID.r2.cloudflarestorage.com - name: Build run: | export PATH=$PATH:/github/home/go/bin # Clone dependencies cd /app git clone https://github.com/yaoapp/kun.git /app/kun git clone https://github.com/yaoapp/xun.git /app/xun git clone https://github.com/yaoapp/gou.git /app/gou git clone https://github.com/yaoapp/v8go.git /app/v8go git clone https://github.com/yaoapp/cui.git /app/cui-v1.0 git clone https://github.com/yaoapp/yao-init.git /app/yao-init git clone https://github.com/yaoapp/yao.git /app/yao # Extract libv8 files=$(find /app/v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done # Set VERSION from git tag (required) cd /app/yao if [[ "$GITHUB_REF" != refs/tags/v* ]]; then echo "::error::This workflow requires a tag (refs/tags/v*). Got: $GITHUB_REF" exit 1 fi TAG_VERSION="${GITHUB_REF#refs/tags/v}" echo "Setting VERSION to $TAG_VERSION" sed -i "s/const VERSION = \".*\"/const VERSION = \"${TAG_VERSION}\"/g" share/const.go grep 'const VERSION' share/const.go make tools && make artifacts-linux mv /app/yao/dist/release/* /data/ ls -l /data - name: Push To R2 run: | for file in /data/*; do aws s3 cp "$file" s3://$R2_BUCKET/archives/ \ --endpoint-url https://$R2_ACCOUNT_ID.r2.cloudflarestorage.com done - name: Upload Artifact uses: actions/upload-artifact@v4 with: name: yao-linux path: /data/* # =================================================================== # Docker Images (multi-arch manifest: linux/amd64 + linux/arm64) # =================================================================== docker: needs: build runs-on: ubuntu-latest steps: - name: Checkout Code uses: actions/checkout@v4 - name: Get Version id: version run: | if [[ "$GITHUB_REF" != refs/tags/v* ]]; then echo "::error::This workflow requires a tag (refs/tags/v*). Got: $GITHUB_REF" exit 1 fi VERSION="${GITHUB_REF#refs/tags/v}" echo "version=${VERSION}" >> $GITHUB_OUTPUT echo "VERSION=${VERSION}" - name: Set up QEMU uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Login to DockerHub uses: docker/login-action@v3 with: username: ${{ secrets.DOCKER_USER }} password: ${{ secrets.DOCKER_TOKEN }} - name: Build & Push Development (multi-arch) uses: docker/build-push-action@v6 with: context: ./docker/development platforms: linux/amd64,linux/arm64 build-args: | VERSION=${{ steps.version.outputs.version }} push: true tags: | ${{ env.IMAGE_NAME }}:${{ steps.version.outputs.version }}-dev ${{ env.IMAGE_NAME }}:dev - name: Build & Push Production (multi-arch) uses: docker/build-push-action@v6 with: context: ./docker/production platforms: linux/amd64,linux/arm64 build-args: | VERSION=${{ steps.version.outputs.version }} push: true tags: | ${{ env.IMAGE_NAME }}:${{ steps.version.outputs.version }} ${{ env.IMAGE_NAME }}:latest ================================================ FILE: .github/workflows/release-macos.yml ================================================ name: Release macOS on: workflow_dispatch: push: tags: - "v*" permissions: contents: write jobs: # =================================================================== # Build Yao macOS binaries (arm64 + amd64) — one job, both arches # =================================================================== build: runs-on: macos-latest steps: - name: Setup Node.js uses: actions/setup-node@v4 with: node-version: 18 - name: Install pnpm run: npm install -g pnpm - name: Setup Cache uses: actions/cache@v4 with: path: | ~/.cache/go-build ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - name: Checkout Kun uses: actions/checkout@v4 with: repository: yaoapp/kun path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: yaoapp/xun path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: yaoapp/gou path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout CUI v1.0 uses: actions/checkout@v4 with: repository: yaoapp/cui path: cui-v1.0 - name: Checkout Yao-Init uses: actions/checkout@v4 with: repository: yaoapp/yao-init path: yao-init - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv cui-v1.0 ../ mv yao-init ../ rm -f ../cui-v1.0/packages/setup/vite.config.ts.* - name: Checkout Yao uses: actions/checkout@v4 - name: Set Version from Tag run: | if [[ "$GITHUB_REF" != refs/tags/v* ]]; then echo "::error::This workflow requires a tag (refs/tags/v*). Got: $GITHUB_REF" exit 1 fi TAG="${GITHUB_REF#refs/tags/v}" echo "Setting VERSION to $TAG" sed -i.bak "s/const VERSION = \".*\"/const VERSION = \"${TAG}\"/g" share/const.go rm -f share/const.go.bak grep 'const VERSION' share/const.go - name: Setup Go uses: actions/setup-go@v5 with: go-version: "1.25" - name: Setup Go Tools run: make tools - name: Make Artifacts macOS run: make artifacts-macos - name: Get Version id: version run: | VERSION=$(grep 'const VERSION =' share/const.go | awk '{print $4}' | sed 's/"//g') echo "version=${VERSION}" >> $GITHUB_OUTPUT - name: List Build Output run: ls -lh dist/release/ - name: Install Certificates env: KEYCHAIN_PASSWORD: ${{ secrets.KEYCHAIN_PASSWORD }} run: | mkdir -p certs echo "${{ secrets.APPLE_DEVELOPERIDG2CA }}" | base64 --decode > certs/DeveloperIDG2CA.cer echo "${{ secrets.APPLE_DISTRIBUTION }}" | base64 --decode > certs/distribution.cer echo "${{ secrets.APPLE_PRIVATE_KEY }}" | base64 --decode > certs/private_key.p12 security verify-cert -c certs/DeveloperIDG2CA.cer security verify-cert -c certs/distribution.cer - name: Import Certificates env: KEYCHAIN_PASSWORD: ${{ secrets.KEYCHAIN_PASSWORD }} run: | KEYCHAIN_PATH=$RUNNER_TEMP/app-signing.keychain-db security create-keychain -p "$KEYCHAIN_PASSWORD" $KEYCHAIN_PATH security set-keychain-settings -lut 21600 $KEYCHAIN_PATH security unlock-keychain -p "$KEYCHAIN_PASSWORD" $KEYCHAIN_PATH security import ./certs/DeveloperIDG2CA.cer -k $KEYCHAIN_PATH -T /usr/bin/codesign security import ./certs/distribution.cer -k $KEYCHAIN_PATH -T /usr/bin/codesign security import ./certs/private_key.p12 -k $KEYCHAIN_PATH -P "${{ secrets.APPLE_PRIVATE_KEY_PASSWORD }}" -T /usr/bin/codesign security list-keychain -d user -s $KEYCHAIN_PATH - name: Sign Yao Binaries run: | VERSION="${{ steps.version.outputs.version }}" IDENTITY="Developer ID Application: ${{ secrets.APPLE_SIGN }}" for ARCH in arm64 amd64; do for SUFFIX in "" "-prod"; do BIN="dist/release/yao-${VERSION}-darwin-${ARCH}${SUFFIX}" codesign --force --verbose --timestamp --options runtime \ --entitlements .github/codesign/entitlements.plist \ --sign "$IDENTITY" "$BIN" codesign --verify --deep --strict --verbose=2 "$BIN" done done - name: Prepare Output and Checksums run: | VERSION="${{ steps.version.outputs.version }}" for ARCH in arm64 amd64; do for VARIANT in dev prod; do if [ "$VARIANT" = "dev" ]; then SRC="dist/release/yao-${VERSION}-darwin-${ARCH}" else SRC="dist/release/yao-${VERSION}-darwin-${ARCH}-prod" fi DIR="/tmp/yao-output-${ARCH}-${VARIANT}" mkdir -p "$DIR" cp "$SRC" "$DIR/yao" chmod +x "$DIR/yao" done done mkdir -p /tmp/checksums for ARCH in arm64 amd64; do for VARIANT in dev prod; do shasum -a 256 "/tmp/yao-output-${ARCH}-${VARIANT}/yao" | awk '{print $1" yao"}' > "/tmp/checksums/yao-darwin-${ARCH}-${VARIANT}.sha256" done done cat /tmp/checksums/*.sha256 - name: Upload Artifacts uses: actions/upload-artifact@v4 with: name: yao-darwin-arm64 path: /tmp/yao-output-arm64-prod/yao - name: Upload arm64 Dev Binary uses: actions/upload-artifact@v4 with: name: yao-darwin-arm64-dev path: /tmp/yao-output-arm64-dev/yao - name: Upload amd64 Binary uses: actions/upload-artifact@v4 with: name: yao-darwin-amd64 path: /tmp/yao-output-amd64-prod/yao - name: Upload amd64 Dev Binary uses: actions/upload-artifact@v4 with: name: yao-darwin-amd64-dev path: /tmp/yao-output-amd64-dev/yao - name: Upload Checksums uses: actions/upload-artifact@v4 with: name: yao-darwin-checksums path: /tmp/checksums/*.sha256 ================================================ FILE: .github/workflows/release.yml ================================================ name: Release on: workflow_run: workflows: ["Release Linux", "Release macOS"] types: - completed permissions: contents: write jobs: # =================================================================== # Wait for both workflows to succeed, then create a unified release # =================================================================== release: runs-on: ubuntu-latest if: > github.event.workflow_run.conclusion == 'success' && startsWith(github.event.workflow_run.head_branch, 'v') steps: - name: Checkout Code uses: actions/checkout@v4 - name: Get Version id: version run: | TAG="${{ github.event.workflow_run.head_branch }}" VERSION="${TAG#v}" echo "version=${VERSION}" >> $GITHUB_OUTPUT echo "tag=${TAG}" >> $GITHUB_OUTPUT echo "TAG=${TAG} VERSION=${VERSION}" - name: Wait for Both Workflows env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | TAG="${{ steps.version.outputs.tag }}" echo "Waiting for both Release Linux and Release macOS to complete for $TAG..." for i in $(seq 1 60); do LINUX_STATUS=$(gh run list --workflow="Release Linux" --branch="$TAG" --limit=1 --json conclusion --jq '.[0].conclusion // "pending"') MACOS_STATUS=$(gh run list --workflow="Release macOS" --branch="$TAG" --limit=1 --json conclusion --jq '.[0].conclusion // "pending"') echo "Attempt $i: Linux=$LINUX_STATUS macOS=$MACOS_STATUS" if [ "$LINUX_STATUS" = "success" ] && [ "$MACOS_STATUS" = "success" ]; then echo "Both workflows completed successfully." exit 0 fi if [ "$LINUX_STATUS" = "failure" ] || [ "$MACOS_STATUS" = "failure" ]; then echo "::error::One or both workflows failed (Linux=$LINUX_STATUS macOS=$MACOS_STATUS)" exit 1 fi sleep 60 done echo "::error::Timed out waiting for workflows" exit 1 - name: Download Linux Artifacts env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | TAG="${{ steps.version.outputs.tag }}" LINUX_RUN_ID=$(gh run list --workflow="Release Linux" --branch="$TAG" --limit=1 --json databaseId --jq '.[0].databaseId') mkdir -p dist/linux gh run download "$LINUX_RUN_ID" --name yao-linux --dir dist/linux - name: Download macOS Artifacts env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | TAG="${{ steps.version.outputs.tag }}" MACOS_RUN_ID=$(gh run list --workflow="Release macOS" --branch="$TAG" --limit=1 --json databaseId --jq '.[0].databaseId') mkdir -p dist/macos gh run download "$MACOS_RUN_ID" --name yao-darwin-arm64 --dir dist/macos/arm64-prod gh run download "$MACOS_RUN_ID" --name yao-darwin-arm64-dev --dir dist/macos/arm64-dev gh run download "$MACOS_RUN_ID" --name yao-darwin-amd64 --dir dist/macos/amd64-prod gh run download "$MACOS_RUN_ID" --name yao-darwin-amd64-dev --dir dist/macos/amd64-dev gh run download "$MACOS_RUN_ID" --name yao-darwin-checksums --dir dist/macos/checksums - name: Prepare Release Files run: | VERSION="${{ steps.version.outputs.version }}" mkdir -p release # Linux artifacts (already named correctly from build.sh) cp dist/linux/* release/ 2>/dev/null || true # macOS prod binaries cp dist/macos/arm64-prod/yao "release/yao-${VERSION}-darwin-arm64" cp dist/macos/amd64-prod/yao "release/yao-${VERSION}-darwin-amd64" # macOS dev binaries cp dist/macos/arm64-dev/yao "release/yao-${VERSION}-darwin-arm64-dev" cp dist/macos/amd64-dev/yao "release/yao-${VERSION}-darwin-amd64-dev" # Checksums cp dist/macos/checksums/*.sha256 release/ 2>/dev/null || true chmod +x release/yao-* 2>/dev/null || true echo "=== Release files ===" ls -lh release/ - name: Create GitHub Release uses: softprops/action-gh-release@v2 with: tag_name: ${{ steps.version.outputs.tag }} name: Yao v${{ steps.version.outputs.version }} files: release/* generate_release_notes: true ================================================ FILE: .github/workflows/unit-test-v1.yml ================================================ name: Unit Test V1 on: workflow_dispatch: inputs: tags: description: "Version" env: CI_VERSION: "1.0.0" REPO_KUN: ${{ github.repository_owner }}/kun REPO_XUN: ${{ github.repository_owner }}/xun REPO_GOU: ${{ github.repository_owner }}/gou YAO_DEV: ${{ github.WORKSPACE }} YAO_ENV: development YAO_ROOT: ${{ github.WORKSPACE }}/../app YAO_HOST: 0.0.0.0 YAO_PORT: 5099 YAO_SESSION: "memory" YAO_LOG: "./logs/application.log" YAO_LOG_MODE: "TEXT" YAO_JWT_SECRET: "bLp@bi!oqo-2U+hoTRUG" YAO_DB_AESKEY: "ZLX=T&f6refeCh-ro*r@" YAO_EXTENSION_ROOT: ${{ github.WORKSPACE }}/../extension YAO_TEST_APPLICATION: ${{ github.WORKSPACE }}/../app YAO_RUNTIME_MIN: 3 YAO_RUNTIME_MAX: 6 YAO_RUNTIME_HEAP_LIMIT: 1500000000 YAO_RUNTIME_HEAP_RELEASE: 10000000 YAO_RUNTIME_HEAP_AVAILABLE: 550000000 YAO_RUNTIME_PRECOMPILE: true MYSQL_TEST_HOST: "127.0.0.1" MYSQL_TEST_PORT: "3308" MYSQL_TEST_USER: "test" MYSQL_TEST_PASS: "123456" REDIS_TEST_HOST: "127.0.0.1" REDIS_TEST_PORT: "6379" REDIS_TEST_DB: "2" MONGO_TEST_HOST: "127.0.0.1" MONGO_TEST_PORT: "27017" MONGO_TEST_USER: "root" MONGO_TEST_PASS: "123456" PG_TEST_HOST: "127.0.0.1" PG_TEST_PORT: "5432" PG_TEST_USER: "test" PG_TEST_PASS: "123456" jobs: # ============================================================================= # Environment Setup & Verification # Build Yao, start services, connect Tai via gRPC tunnel, verify everything. # No tests are run — this job validates the CI environment is healthy. # ============================================================================= setup-and-verify: runs-on: ubuntu-latest services: mongodb: image: mongo:6.0 ports: - 27017:27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: 123456 MONGO_INITDB_DATABASE: test postgres: image: postgres:14 ports: - 5432:5432 env: POSTGRES_USER: test POSTGRES_PASSWORD: 123456 POSTGRES_DB: test options: >- --health-cmd="pg_isready -U test" --health-interval=10s --health-timeout=5s --health-retries=5 strategy: matrix: go: ["1.25"] steps: # ==== Phase 1: Checkout & Setup ==== - name: Checkout Yao uses: actions/checkout@v4 - name: Setup Build Environment uses: ./.github/actions/setup-yao with: repo-kun: ${{ env.REPO_KUN }} repo-xun: ${{ env.REPO_XUN }} repo-gou: ${{ env.REPO_GOU }} checkout-init: "true" apple-private-key: ${{ secrets.APPLE_PRIVATE_KEY_USER }} - name: Load sandbox-v2 env run: grep -vE '^\s*#|^\s*$' .github/env/sandbox-v2.env >> $GITHUB_ENV - name: Setup SQLite run: | mkdir -p ${{ github.WORKSPACE }}/../app/db echo "YAO_DB_PRIMARY=${{ github.WORKSPACE }}/../app/db/yao.db" >> $GITHUB_ENV - name: Start MySQL 8.0 run: | docker run -d --name mysql \ -e MYSQL_RANDOM_ROOT_PASSWORD=true \ -e MYSQL_USER=${MYSQL_TEST_USER} \ -e MYSQL_PASSWORD=${MYSQL_TEST_PASS} \ -e MYSQL_DATABASE=test \ -p ${MYSQL_TEST_PORT}:3306 \ mysql:8.0 --port=3306 --sql-mode='' \ --character-set-server=utf8mb4 --collation-server=utf8mb4_general_ci for i in $(seq 1 30); do if docker exec mysql mysqladmin ping -h 127.0.0.1 -u ${MYSQL_TEST_USER} -p${MYSQL_TEST_PASS} > /dev/null 2>&1; then echo "MySQL ready" break fi echo "Waiting for MySQL... ($i/30)" sleep 2 done - name: Start Redis run: docker run --name redis -d -p 6379:6379 redis:6 # ==== Phase 2: Code Quality & Build ==== - name: Code Quality Check run: | make vet make fmt-check - name: Build Yao run: go build -v -o $RUNNER_TEMP/yao . - name: Build ci-token run: go build -tags ci -v -o $RUNNER_TEMP/ci-token ./cmd/ci-token - name: Extract tai binary from image run: | CID=$(docker create yaoapp/tai:latest) docker cp "$CID":/usr/local/bin/tai $RUNNER_TEMP/tai docker rm "$CID" chmod +x $RUNNER_TEMP/tai $RUNNER_TEMP/tai version || echo "tai binary extracted" # ==== Phase 3: Prepare & Start Yao ==== - name: Prepare test app directory run: | cp -r ${{ github.WORKSPACE }}/../yao-init $RUNNER_TEMP/yao-test-app mkdir -p $RUNNER_TEMP/yao-test-app/db - name: Start Yao server run: | cd $RUNNER_TEMP/yao-test-app YAO_ROOT=$(pwd) \ YAO_HOST=0.0.0.0 \ YAO_PORT=5099 \ YAO_GRPC_HOST=0.0.0.0 \ YAO_GRPC_PORT=9099 \ YAO_DB_DRIVER=sqlite3 \ YAO_DB_PRIMARY=$(pwd)/db/yao.db \ YAO_SESSION=memory \ YAO_ENV=development \ YAO_JWT_SECRET="${{ env.YAO_JWT_SECRET }}" \ YAO_DB_AESKEY="${{ env.YAO_DB_AESKEY }}" \ $RUNNER_TEMP/yao start & # Wait for Yao HTTP to be ready (up to 120s) for i in $(seq 1 60); do if curl -sf http://127.0.0.1:5099/.well-known/yao > /dev/null 2>&1; then echo "Yao HTTP ready" curl -s http://127.0.0.1:5099/.well-known/yao | jq . break fi echo "Waiting for Yao... ($i/60)" sleep 2 done curl -sf http://127.0.0.1:5099/.well-known/yao > /dev/null 2>&1 || { echo "::error::Yao HTTP failed to start" exit 1 } # ==== Phase 4: Generate Tai credentials ==== - name: Generate Tai credentials run: | gen_cred() { local CID=$1 TID=$2 OUT=$3 local TOKEN TOKEN=$($RUNNER_TEMP/ci-token \ --app $RUNNER_TEMP/yao-test-app \ --client-id "$CID" \ --subject "${YAO_CI_OAUTH_SUBJECT:-ci-tai}" \ --user-id "${YAO_CI_OAUTH_USER_ID}" \ --team-id "${YAO_CI_OAUTH_TEAM_ID}" \ --scope "${YAO_CI_OAUTH_SCOPE:-tai:tunnel}" \ --ttl "${YAO_CI_OAUTH_TTL:-24h}" 2>/dev/null | tail -1 | tr -d '[:space:]') local JSON="{\"client_id\":\"$CID\",\"tai_id\":\"$TID\",\"machine_id\":\"ci-runner\",\"server\":\"http://127.0.0.1:${YAO_CI_HTTP_PORT}\",\"yao_grpc_addr\":\"127.0.0.1:${YAO_CI_GRPC_PORT}\",\"access_token\":\"$TOKEN\",\"scope\":\"${YAO_CI_OAUTH_SCOPE}\",\"expires_at\":\"2099-01-01T00:00:00Z\",\"registered\":true}" echo -n "$JSON" | base64 -w0 > "$OUT" echo "" echo "Credentials JSON (debug): $JSON" | head -c 200 echo "..." echo "Generated credentials for $CID → $OUT" } gen_cred tai-ci-local tai-local-001 $RUNNER_TEMP/tai-local-credentials gen_cred tai-ci-docker tai-docker-001 $RUNNER_TEMP/tai-docker-credentials gen_cred tai-ci-k8s tai-k8s-001 $RUNNER_TEMP/tai-k8s-credentials gen_cred tai-ci-hostexec tai-hostexec-001 $RUNNER_TEMP/tai-hostexec-credentials # ==== Phase 5: Pull images & Setup K8s ==== - name: Pull test images run: | docker pull yaoapp/tai-sandbox-test:latest || true docker pull yaoapp/tai:latest docker pull alpine:latest - name: Install k3d & create cluster run: | curl -s https://raw.githubusercontent.com/k3d-io/k3d/main/install.sh | bash k3d cluster create tai-test --no-lb --wait --api-port ${YAO_CI_K3D_API_PORT} kubectl wait --for=condition=Ready node --all --timeout=60s k3d image import alpine:latest -c tai-test - name: Generate kubeconfig run: | K3D_IP=$(docker inspect k3d-tai-test-server-0 | jq -r '.[0].NetworkSettings.Networks["k3d-tai-test"].IPAddress') echo "k3d server IP: ${K3D_IP}" k3d kubeconfig get tai-test > /tmp/kubeconfig-k3d.yml # For tai-k8s container (uses k3d internal IP) sed "s|server: .*|server: https://${K3D_IP}:6443|" /tmp/kubeconfig-k3d.yml \ > /tmp/kubeconfig-tai-k8s.yml echo "Container kubeconfig server:" grep server: /tmp/kubeconfig-tai-k8s.yml # For test runner (uses localhost via k3d port-mapped API) sed "s|server: .*|server: https://127.0.0.1:${YAO_CI_K3D_API_PORT}|" /tmp/kubeconfig-k3d.yml \ > $RUNNER_TEMP/kubeconfig-tai.yml echo "Test runner kubeconfig server:" grep server: $RUNNER_TEMP/kubeconfig-tai.yml # Export for later steps echo "TAI_TEST_KUBECONFIG=$RUNNER_TEMP/kubeconfig-tai.yml" >> $GITHUB_ENV echo "YAO_CI_TAI_KUBECONFIG=$RUNNER_TEMP/kubeconfig-tai.yml" >> $GITHUB_ENV # ==== Phase 6: Start Tai instances (host processes) ==== - name: Start tai-local (DIRECT mode, auto-detect Docker) run: | mkdir -p $RUNNER_TEMP/tai-local-data TAI_CREDENTIALS=$RUNNER_TEMP/tai-local-credentials \ TAI_YAO_SERVER=http://127.0.0.1:${YAO_CI_HTTP_PORT} \ TAI_DATA_DIR=$RUNNER_TEMP/tai-local-data \ $RUNNER_TEMP/tai server \ --grpc 127.0.0.1:${YAO_CI_TAI_LOCAL_GRPC_PORT} \ --http 127.0.0.1:${YAO_CI_TAI_LOCAL_HTTP_PORT} \ --vnc 127.0.0.1:${YAO_CI_TAI_LOCAL_VNC_PORT} \ --docker 127.0.0.1:${YAO_CI_TAI_LOCAL_DOCKER_PORT} \ --direct \ --host-exec --host-exec-full-access \ --log-level debug & echo $! > $RUNNER_TEMP/tai-local.pid echo "tai-local PID: $(cat $RUNNER_TEMP/tai-local.pid)" for i in $(seq 1 30); do if curl -sf http://127.0.0.1:${YAO_CI_TAI_LOCAL_HTTP_PORT}/healthz > /dev/null 2>&1; then echo "tai-local HTTP ready" break fi echo "Waiting for tai-local HTTP... ($i/30)" sleep 1 done - name: Start tai-docker (TUNNEL mode, Docker API proxy) run: | mkdir -p $RUNNER_TEMP/tai-docker-data TAI_CREDENTIALS=$RUNNER_TEMP/tai-docker-credentials \ TAI_YAO_SERVER=http://127.0.0.1:${YAO_CI_HTTP_PORT} \ TAI_DATA_DIR=$RUNNER_TEMP/tai-docker-data \ $RUNNER_TEMP/tai server \ --grpc 127.0.0.1:${YAO_CI_TAI_DOCKER_GRPC_PORT} \ --http 127.0.0.1:${YAO_CI_TAI_DOCKER_HTTP_PORT} \ --vnc 127.0.0.1:${YAO_CI_TAI_DOCKER_VNC_PORT} \ --docker 127.0.0.1:${YAO_CI_TAI_DOCKER_API_PORT} \ --host-exec --host-exec-full-access \ --log-level debug & echo $! > $RUNNER_TEMP/tai-docker.pid echo "tai-docker PID: $(cat $RUNNER_TEMP/tai-docker.pid)" for i in $(seq 1 30); do if curl -sf http://127.0.0.1:${YAO_CI_TAI_DOCKER_HTTP_PORT}/healthz > /dev/null 2>&1; then echo "tai-docker HTTP ready" break fi echo "Waiting for tai-docker HTTP... ($i/30)" sleep 1 done - name: Start tai-k8s (TUNNEL mode, K8s API proxy) run: | mkdir -p $RUNNER_TEMP/tai-k8s-data TAI_CREDENTIALS=$RUNNER_TEMP/tai-k8s-credentials \ TAI_YAO_SERVER=http://127.0.0.1:${YAO_CI_HTTP_PORT} \ TAI_DATA_DIR=$RUNNER_TEMP/tai-k8s-data \ TAI_K8S_UPSTREAM="tcp://127.0.0.1:${YAO_CI_K3D_API_PORT}" \ TAI_KUBECONFIG=$RUNNER_TEMP/kubeconfig-tai.yml \ $RUNNER_TEMP/tai server \ --grpc 127.0.0.1:${YAO_CI_TAI_K8S_GRPC_PORT} \ --http 127.0.0.1:${YAO_CI_TAI_K8S_HTTP_PORT} \ --vnc 127.0.0.1:${YAO_CI_TAI_K8S_VNC_PORT} \ --k8s 127.0.0.1:${YAO_CI_TAI_K8S_API_PORT} \ --docker="" \ --host-exec --host-exec-full-access \ --log-level debug & echo $! > $RUNNER_TEMP/tai-k8s.pid echo "tai-k8s PID: $(cat $RUNNER_TEMP/tai-k8s.pid)" for i in $(seq 1 30); do if curl -sf http://127.0.0.1:${YAO_CI_TAI_K8S_HTTP_PORT}/healthz > /dev/null 2>&1; then echo "tai-k8s HTTP ready" break fi echo "Waiting for tai-k8s HTTP... ($i/30)" sleep 1 done - name: Start tai-hostexec (TUNNEL mode, HostExec only, no runtime) run: | mkdir -p $RUNNER_TEMP/tai-hostexec-data TAI_CREDENTIALS=$RUNNER_TEMP/tai-hostexec-credentials \ TAI_YAO_SERVER=http://127.0.0.1:${YAO_CI_HTTP_PORT} \ TAI_DATA_DIR=$RUNNER_TEMP/tai-hostexec-data \ $RUNNER_TEMP/tai server \ --grpc 127.0.0.1:${YAO_CI_TAI_HOSTEXEC_GRPC_PORT} \ --http 127.0.0.1:${YAO_CI_TAI_HOSTEXEC_HTTP_PORT} \ --vnc 127.0.0.1:${YAO_CI_TAI_HOSTEXEC_VNC_PORT} \ --docker="" \ --host-exec --host-exec-full-access \ --log-level debug & echo $! > $RUNNER_TEMP/tai-hostexec.pid echo "tai-hostexec PID: $(cat $RUNNER_TEMP/tai-hostexec.pid)" for i in $(seq 1 30); do if curl -sf http://127.0.0.1:${YAO_CI_TAI_HOSTEXEC_HTTP_PORT}/healthz > /dev/null 2>&1; then echo "tai-hostexec HTTP ready" break fi echo "Waiting for tai-hostexec HTTP... ($i/30)" sleep 1 done # ==== Phase 7: Environment Verification (fail fast) ==== - name: Verify Environment run: | echo "CI Environment v${CI_VERSION}" echo "" FAILED=0 check() { local name=$1; shift if "$@" > /dev/null 2>&1; then echo " [PASS] $name" else echo " [FAIL] $name" FAILED=$((FAILED + 1)) fi } echo "=== Environment Verification ===" # ── 1. Service Health ── echo "" echo "--- 1. Service Health ---" echo "[Yao]" check "Yao HTTP (/.well-known/yao)" curl -sf http://127.0.0.1:5099/.well-known/yao check "Yao gRPC port" nc -z 127.0.0.1 9099 echo "[tai-local (DIRECT)]" check "tai-local process alive" kill -0 $(cat $RUNNER_TEMP/tai-local.pid 2>/dev/null || echo 0) check "tai-local HTTP (/healthz)" curl -sf http://127.0.0.1:${YAO_CI_TAI_LOCAL_HTTP_PORT}/healthz check "tai-local gRPC reachable (direct)" nc -z 127.0.0.1 ${YAO_CI_TAI_LOCAL_GRPC_PORT} echo "[tai-docker (TUNNEL)]" check "tai-docker process alive" kill -0 $(cat $RUNNER_TEMP/tai-docker.pid 2>/dev/null || echo 0) check "tai-docker HTTP (/healthz)" curl -sf http://127.0.0.1:${YAO_CI_TAI_DOCKER_HTTP_PORT}/healthz check "tai-docker gRPC listener" nc -z 127.0.0.1 ${YAO_CI_TAI_DOCKER_GRPC_PORT} echo "[tai-k8s (TUNNEL)]" check "tai-k8s process alive" kill -0 $(cat $RUNNER_TEMP/tai-k8s.pid 2>/dev/null || echo 0) check "tai-k8s HTTP (/healthz)" curl -sf http://127.0.0.1:${YAO_CI_TAI_K8S_HTTP_PORT}/healthz check "tai-k8s gRPC listener" nc -z 127.0.0.1 ${YAO_CI_TAI_K8S_GRPC_PORT} echo "[tai-hostexec (TUNNEL)]" check "tai-hostexec process alive" kill -0 $(cat $RUNNER_TEMP/tai-hostexec.pid 2>/dev/null || echo 0) check "tai-hostexec HTTP (/healthz)" curl -sf http://127.0.0.1:${YAO_CI_TAI_HOSTEXEC_HTTP_PORT}/healthz check "tai-hostexec gRPC listener" nc -z 127.0.0.1 ${YAO_CI_TAI_HOSTEXEC_GRPC_PORT} echo "[K8s (k3d via direct API)]" check "kubectl get nodes (k3d direct)" kubectl --kubeconfig=$RUNNER_TEMP/kubeconfig-tai.yml get nodes echo "[Data Stores]" MONGO_CID=$(docker ps -qf "ancestor=mongo:6.0" | head -1) check "MongoDB ping" docker exec "$MONGO_CID" mongosh --quiet \ -u ${MONGO_TEST_USER} -p ${MONGO_TEST_PASS} --authenticationDatabase admin \ --eval "db.runCommand({ping:1})" check "MySQL ping" docker exec mysql mysqladmin ping -h 127.0.0.1 \ -u ${MYSQL_TEST_USER} -p${MYSQL_TEST_PASS} PG_CID=$(docker ps -qf "ancestor=postgres:14" | head -1) check "PostgreSQL ping" docker exec "$PG_CID" pg_isready -U ${PG_TEST_USER} check "Redis ping" docker exec redis redis-cli ping # ── 2. Network Topology ── echo "" echo "--- 2. Network Topology ---" BRIDGE_IP=${YAO_CI_BRIDGE_IP} echo "[Host network basics]" check "docker0 bridge exists" ip addr show docker0 check "Bridge IP reachable (${BRIDGE_IP})" ping -c1 -W2 ${BRIDGE_IP} check "Docker socket accessible" test -S /var/run/docker.sock check "tai binary on runner" test -x $RUNNER_TEMP/tai echo "[Yao endpoints (all Tai instances need these)]" check "Yao HTTP :${YAO_CI_HTTP_PORT}" curl -sf http://127.0.0.1:${YAO_CI_HTTP_PORT}/.well-known/yao check "Yao gRPC :${YAO_CI_GRPC_PORT}" nc -z 127.0.0.1 ${YAO_CI_GRPC_PORT} echo "[DIRECT path: Yao → tai-local]" check "Yao→tai-local gRPC :${YAO_CI_TAI_LOCAL_GRPC_PORT}" nc -z 127.0.0.1 ${YAO_CI_TAI_LOCAL_GRPC_PORT} check "Yao→tai-local HTTP :${YAO_CI_TAI_LOCAL_HTTP_PORT}" curl -sf http://127.0.0.1:${YAO_CI_TAI_LOCAL_HTTP_PORT}/healthz check "tai-local Docker API proxy :${YAO_CI_TAI_LOCAL_DOCKER_PORT}" nc -z 127.0.0.1 ${YAO_CI_TAI_LOCAL_DOCKER_PORT} echo "[TUNNEL path: tai-docker → Yao gRPC (reverse tunnel)]" check "tai-docker→Yao gRPC :${YAO_CI_GRPC_PORT}" nc -z 127.0.0.1 ${YAO_CI_GRPC_PORT} check "tai-docker Docker API proxy :${YAO_CI_TAI_DOCKER_API_PORT}" nc -z 127.0.0.1 ${YAO_CI_TAI_DOCKER_API_PORT} check "Docker API via proxy" bash -c "curl -sf http://127.0.0.1:${YAO_CI_TAI_DOCKER_API_PORT}/version | jq -r .ApiVersion" echo "[TUNNEL path: tai-k8s → Yao gRPC (reverse tunnel)]" check "tai-k8s→Yao gRPC :${YAO_CI_GRPC_PORT}" nc -z 127.0.0.1 ${YAO_CI_GRPC_PORT} check "tai-k8s K8s API proxy :${YAO_CI_TAI_K8S_API_PORT}" nc -z 127.0.0.1 ${YAO_CI_TAI_K8S_API_PORT} sed "s|server: .*|server: https://127.0.0.1:${YAO_CI_TAI_K8S_API_PORT}|" $RUNNER_TEMP/kubeconfig-tai.yml \ > $RUNNER_TEMP/kubeconfig-tai-proxy.yml check "K8s API via tai-k8s proxy (kubectl)" kubectl --kubeconfig=$RUNNER_TEMP/kubeconfig-tai-proxy.yml --insecure-skip-tls-verify get nodes echo "[TUNNEL path: tai-hostexec → Yao gRPC (reverse tunnel)]" check "tai-hostexec→Yao gRPC :${YAO_CI_GRPC_PORT}" nc -z 127.0.0.1 ${YAO_CI_GRPC_PORT} echo "[Data Store connectivity from runner]" check "MongoDB :27017" nc -z 127.0.0.1 27017 check "Redis :6379" nc -z 127.0.0.1 6379 check "MySQL :${MYSQL_TEST_PORT}" nc -z 127.0.0.1 ${MYSQL_TEST_PORT} check "PostgreSQL :${PG_TEST_PORT}" nc -z 127.0.0.1 ${PG_TEST_PORT} # ── 3. Connection Mode Verification ── echo "" echo "--- 3. Connection Modes ---" sleep 5 echo "[Credentials (4 tokens)]" check "tai-local credentials exist" test -f $RUNNER_TEMP/tai-local-credentials check "tai-docker credentials exist" test -f $RUNNER_TEMP/tai-docker-credentials check "tai-k8s credentials exist" test -f $RUNNER_TEMP/tai-k8s-credentials check "tai-hostexec credentials exist" test -f $RUNNER_TEMP/tai-hostexec-credentials echo "[DIRECT: tai-local → Yao HTTP register → Yao dials tai-local gRPC]" echo " tai-local registers via POST /tai-nodes/register" echo " Yao dials back tai-local gRPC at 127.0.0.1:${YAO_CI_TAI_LOCAL_GRPC_PORT}" echo "[TUNNEL: tai-docker → Yao gRPC :${YAO_CI_GRPC_PORT} (Register + Forward)]" echo " Sandbox connects Yao gRPC → Forward stream → tai-docker :${YAO_CI_TAI_DOCKER_GRPC_PORT}" echo "[TUNNEL: tai-k8s → Yao gRPC :${YAO_CI_GRPC_PORT} (Register + Forward)]" echo " Sandbox connects Yao gRPC → Forward stream → tai-k8s :${YAO_CI_TAI_K8S_GRPC_PORT}" echo "[TUNNEL: tai-hostexec → Yao gRPC :${YAO_CI_GRPC_PORT} (Register + Forward)]" echo " HostExec only, no container runtime" WELL_KNOWN=$(curl -sf http://127.0.0.1:5099/.well-known/yao 2>/dev/null || echo "{}") echo "" echo " Yao .well-known/yao:" echo "$WELL_KNOWN" | jq . 2>/dev/null || echo " $WELL_KNOWN" # ── 4. HostExec Readiness ── echo "" echo "--- 4. HostExec ---" echo "[HostExec gRPC ports (all 4 instances)]" check "HostExec tai-local (direct, auto-Docker)" nc -z 127.0.0.1 ${YAO_CI_TAI_LOCAL_GRPC_PORT} check "HostExec tai-docker (tunnel, Docker proxy)" nc -z 127.0.0.1 ${YAO_CI_TAI_DOCKER_GRPC_PORT} check "HostExec tai-k8s (tunnel, K8s proxy)" nc -z 127.0.0.1 ${YAO_CI_TAI_K8S_GRPC_PORT} check "HostExec tai-hostexec (tunnel, no runtime)" nc -z 127.0.0.1 ${YAO_CI_TAI_HOSTEXEC_GRPC_PORT} echo "" echo "==========================================" if [ $FAILED -gt 0 ]; then echo "::error::$FAILED verification check(s) FAILED" echo "" echo "=== Diagnostic Info ===" echo "--- Docker containers ---" docker ps -a echo "" echo "--- Processes (yao + tai) ---" ps aux | grep -E "yao|tai" | grep -v grep || true echo "" echo "--- Listening ports ---" ss -tlnp | grep -E "5099|9099|19100|19101|19102|19103|8099|8100|8101|8102|12375|12376|16443|16444" || true exit 1 else echo "All verification checks PASSED" fi ================================================ FILE: .github/workflows/unit-test.yml ================================================ name: Unit Test on: workflow_dispatch: inputs: tags: description: "Version" push: branches: [main] env: YAO_DEV: ${{ github.WORKSPACE }} YAO_ENV: development YAO_ROOT: ${{ github.WORKSPACE }}/../app YAO_HOST: 0.0.0.0 YAO_PORT: 5099 YAO_SESSION: "memory" YAO_LOG: "./logs/application.log" YAO_LOG_MODE: "TEXT" YAO_JWT_SECRET: "bLp@bi!oqo-2U+hoTRUG" YAO_DB_AESKEY: "ZLX=T&f6refeCh-ro*r@" OSS_TEST_ID: ${{ secrets.OSS_TEST_ID}} OSS_TEST_SECRET: ${{ secrets.OSS_TEST_SECRET}} ROOT_PLUGIN: ${{ github.WORKSPACE }}/../../../data/gou-unit/plugins REPO_KUN: ${{ github.repository_owner }}/kun REPO_XUN: ${{ github.repository_owner }}/xun REPO_GOU: ${{ github.repository_owner }}/gou MYSQL_TEST_HOST: "127.0.0.1" MYSQL_TEST_PORT: "3308" MYSQL_TEST_USER: test MYSQL_TEST_PASS: "123456" SQLITE_DB: "./app/db/yao.db" REDIS_TEST_HOST: "127.0.0.1" REDIS_TEST_PORT: "6379" REDIS_TEST_DB: "2" MONGO_TEST_HOST: "127.0.0.1" MONGO_TEST_PORT: "27017" MONGO_TEST_USER: "root" MONGO_TEST_PASS: "123456" OPENAI_TEST_KEY: ${{ secrets.OPENAI_TEST_KEY }} TEST_MOAPI_SECRET: ${{ secrets.OPENAI_TEST_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_TEST_KEY }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} TEST_MOAPI_MIRROR: https://api.openai.com # DeepSeek API Configuration DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }} DEEPSEEK_API_PROXY: ${{ secrets.DEEPSEEK_API_PROXY }} DEEPSEEK_MODELS_R1: ${{ secrets.DEEPSEEK_MODELS_R1 }} DEEPSEEK_MODELS_V3: ${{ secrets.DEEPSEEK_MODELS_V3 }} DEEPSEEK_MODELS_V3_1: ${{ secrets.DEEPSEEK_MODELS_V3_1 }} # Search API Configuration TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} SERPAPI_API_KEY: ${{ secrets.SERPAPI_API_KEY }} SERPER_API_KEY: ${{ secrets.SERPER_API_KEY }} # Claude API Configuration CLAUDE_API_KEY: ${{ secrets.CLAUDE_API_KEY }} CLAUDE_PROXY: ${{ secrets.CLAUDE_PROXY }} CLAUDE_API_HOST: ${{ secrets.CLAUDE_API_HOST }} CLAUDE_SONNET_4: ${{ secrets.CLAUDE_SONNET_4 }} CLAUDE_SONNET_4_THINKING: ${{ secrets.CLAUDE_SONNET_4_THINKING }} # Moonshot / Kimi API Configuration MOONSHOT_API_KEY: ${{ secrets.MOONSHOT_API_KEY }} MOONSHOT_PROXY: "https://api.moonshot.cn" KIMI_CODE_API_KEY: ${{ secrets.KIMI_CODE_API_KEY }} KIMI_CODE_PROXY: "https://api.kimi.com/coding" TAB_NAME: "::PET ADMIN" PAGE_SIZE: "20" PAGE_LINK: "https://yaoapps.com" PAGE_ICON: "icon-trash" DEMO_APP: ${{ github.WORKSPACE }}/../app # Application Setting ## Path YAO_EXTENSION_ROOT: ${{ github.WORKSPACE }}/../extension YAO_TEST_APPLICATION: ${{ github.WORKSPACE }}/../app YAO_SUI_TEST_APPLICATION: ${{ github.WORKSPACE }}/../yao-startup-webapp ## Runtime YAO_RUNTIME_MIN: 3 YAO_RUNTIME_MAX: 6 YAO_RUNTIME_HEAP_LIMIT: 1500000000 YAO_RUNTIME_HEAP_RELEASE: 10000000 YAO_RUNTIME_HEAP_AVAILABLE: 550000000 YAO_RUNTIME_PRECOMPILE: true # Neo4j NEO4J_TEST_URL: "neo4j://localhost:7687" NEO4J_TEST_USER: "neo4j" NEO4J_TEST_PASS: "Yao2026Neo4j" # Qdrant QDRANT_TEST_HOST: "127.0.0.1" QDRANT_TEST_PORT: "6334" # S3 S3_API: ${{ secrets.S3_API }} S3_ACCESS_KEY: ${{ secrets.S3_ACCESS_KEY }} S3_SECRET_KEY: ${{ secrets.S3_SECRET_KEY }} S3_BUCKET: ${{ secrets.S3_BUCKET }} S3_PUBLIC_URL: ${{ secrets.S3_PUBLIC_URL }} # === Openapi Signin Configs === SIGNIN_CLIENT_ID: "kiCeR88kDwHBDuNHvN51cZgmpp3tmF6Z" ## Google GOOGLE_CLIENT_ID: ${{ secrets.GOOGLE_CLIENT_ID }} GOOGLE_CLIENT_SECRET: ${{ secrets.GOOGLE_CLIENT_SECRET }} ## Microsoft MICROSOFT_CLIENT_ID: ${{ secrets.MICROSOFT_CLIENT_ID }} MICROSOFT_CLIENT_SECRET: ${{ secrets.MICROSOFT_CLIENT_SECRET }} ## Apple APPLE_SERVICE_ID: ${{ secrets.APPLE_SERVICE_ID }} APPLE_PRIVATE_KEY_PATH: "apple/signin_client_secret_key.p8" APPLE_KEY_ID: ${{ secrets.APPLE_KEY_ID }} APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} ## Github GITHUBUSER_CLIENT_ID: ${{ secrets.GITHUBUSER_CLIENT_ID }} GITHUBUSER_CLIENT_SECRET: ${{ secrets.GITHUBUSER_CLIENT_SECRET }} ## Cloudflare Turnstile CLOUDFLARE_TURNSTILE_SITEKEY: ${{ secrets.CLOUDFLARE_TURNSTILE_SITEKEY }} CLOUDFLARE_TURNSTILE_SECRET: ${{ secrets.CLOUDFLARE_TURNSTILE_SECRET }} # === Messaging Services === ## Mailgun MAILGUN_DOMAIN: ${{ secrets.MAILGUN_DOMAIN }} MAILGUN_API_KEY: ${{ secrets.MAILGUN_API_KEY }} MAILGUN_FROM: "Yaobots Tests " ## SMTP Server( Mailgun ) SMTP_HOST: "smtp.mailgun.org" SMTP_PORT: "465" SMTP_USERNAME: ${{ secrets.SMTP_USERNAME }} SMTP_PASSWORD: ${{ secrets.SMTP_PASSWORD }} SMTP_FROM: "Yaobots SMTP Tests " ## SMTP Server( Gmail ) RELIABLE_SMTP_HOST: "smtp.gmail.com" RELIABLE_SMTP_PORT: "465" RELIABLE_SMTP_USERNAME: ${{ secrets.RELIABLE_SMTP_USERNAME }} RELIABLE_SMTP_PASSWORD: ${{ secrets.RELIABLE_SMTP_PASSWORD }} RELIABLE_SMTP_FROM: "Yaobots Gmail Tests " ## IMAP Server (Gmail) RELIABLE_IMAP_HOST: "imap.gmail.com" RELIABLE_IMAP_PORT: "993" RELIABLE_IMAP_USERNAME: ${{ secrets.RELIABLE_SMTP_USERNAME }} RELIABLE_IMAP_PASSWORD: ${{ secrets.RELIABLE_SMTP_PASSWORD }} RELIABLE_IMAP_MAILBOX: "INBOX" ## Twilio TWILIO_ACCOUNT_SID: ${{ secrets.TWILIO_ACCOUNT_SID }} TWILIO_AUTH_TOKEN: ${{ secrets.TWILIO_AUTH_TOKEN }} TWILIO_API_SID: ${{ secrets.TWILIO_API_SID }} TWILIO_API_KEY: ${{ secrets.TWILIO_API_KEY }} TWILIO_SENDGRID_API_SID: ${{ secrets.TWILIO_SENDGRID_API_SID }} TWILIO_SENDGRID_API_KEY: ${{ secrets.TWILIO_SENDGRID_API_KEY }} TWILIO_FROM_PHONE: "+17035701412" TWILIO_FROM_EMAIL: "unit-test@sendgrid.yaobots.com" TWILIO_TEST_PHONE: ${{ secrets.TWILIO_TEST_PHONE }} jobs: # ============================================================================= # KB Tests (kb) - Run once with SQLite (requires Qdrant, Neo4j, FastEmbed) # ============================================================================= kb-test: runs-on: ubuntu-latest services: qdrant: image: qdrant/qdrant:latest ports: - 6333:6333 - 6334:6334 fastembed: image: yaoapp/fastembed:latest-amd64 env: FASTEMBED_PASSWORD: Yao@2026 ports: - 6001:8000 neo4j: image: neo4j:latest ports: - "7687:7687" env: NEO4J_AUTH: neo4j/Yao2026Neo4j mongodb: image: mongo:6.0 ports: - 27017:27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: 123456 MONGO_INITDB_DATABASE: test strategy: matrix: go: ["1.25"] steps: - name: Checkout Kun uses: actions/checkout@v4 with: repository: ${{ env.REPO_KUN }} path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: ${{ env.REPO_XUN }} path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: ${{ env.REPO_GOU }} path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout Code uses: actions/checkout@v4 - name: Setup Apple Private Key run: | mkdir -p ../app/openapi/certs/apple echo "${{ secrets.APPLE_PRIVATE_KEY_USER }}" > ../app/openapi/certs/apple/signin_client_secret_key.p8 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:6 - name: Setup Go Tools run: make tools - name: Setup ENV (SQLite) run: | mkdir -p ${{ github.WORKSPACE }}/../app/db echo "YAO_DB_DRIVER=sqlite3" >> $GITHUB_ENV echo "YAO_DB_PRIMARY=${{ github.WORKSPACE }}/../app/db/yao.db" >> $GITHUB_ENV - name: Run KB Tests (kb) run: make unit-test-kb - name: Codecov Report uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} # ============================================================================= # Agent Tests (agent, aigc) - Run once with SQLite # ============================================================================= agent-test: runs-on: ubuntu-latest services: qdrant: image: qdrant/qdrant:latest ports: - 6333:6333 - 6334:6334 fastembed: image: yaoapp/fastembed:latest-amd64 env: FASTEMBED_PASSWORD: Yao@2026 ports: - 6001:8000 neo4j: image: neo4j:latest ports: - "7687:7687" env: NEO4J_AUTH: neo4j/Yao2026Neo4j mcp-everything: image: yaoapp/mcp-everything:latest ports: - "3021:3021" - "3022:3022" mongodb: image: mongo:6.0 ports: - 27017:27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: 123456 MONGO_INITDB_DATABASE: test strategy: matrix: go: ["1.25"] steps: - name: Checkout Kun uses: actions/checkout@v4 with: repository: ${{ env.REPO_KUN }} path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: ${{ env.REPO_XUN }} path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: ${{ env.REPO_GOU }} path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout Code uses: actions/checkout@v4 - name: Setup Apple Private Key run: | mkdir -p ../app/openapi/certs/apple echo "${{ secrets.APPLE_PRIVATE_KEY_USER }}" > ../app/openapi/certs/apple/signin_client_secret_key.p8 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:6 - name: Setup Go Tools run: make tools - name: Install pdftoppm, mutool, imagemagick run: | sudo apt update sudo apt install -y poppler-utils mupdf-tools imagemagick - name: Test pdftoppm, mutool, imagemagick run: | pdftoppm -v mutool -v convert -version - name: Setup ENV (SQLite) run: | mkdir -p ${{ github.WORKSPACE }}/../app/db echo "YAO_DB_DRIVER=sqlite3" >> $GITHUB_ENV echo "YAO_DB_PRIMARY=${{ github.WORKSPACE }}/../app/db/yao.db" >> $GITHUB_ENV - name: Pull Sandbox Test Images run: | docker pull alpine:latest docker pull yaoapp/sandbox-base:latest || true docker pull yaoapp/sandbox-claude:latest || true - name: Run Agent Tests (agent, aigc) env: YAO_SANDBOX_WORKSPACE: ${{ runner.temp }}/sandbox/workspace YAO_SANDBOX_IPC: ${{ runner.temp }}/sandbox/ipc run: | export YAO_SANDBOX_CONTAINER_USER="$(id -u):$(id -g)" make unit-test-agent - name: Codecov Report uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} # ============================================================================= # Robot Tests (all agent/robot/... packages) - Unit + E2E with real LLM calls # ============================================================================= robot-test: runs-on: ubuntu-latest services: mcp-everything: image: yaoapp/mcp-everything:latest ports: - "3021:3021" - "3022:3022" mongodb: image: mongo:6.0 ports: - 27017:27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: 123456 MONGO_INITDB_DATABASE: test strategy: matrix: go: ["1.25"] steps: - name: Checkout Kun uses: actions/checkout@v4 with: repository: ${{ env.REPO_KUN }} path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: ${{ env.REPO_XUN }} path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: ${{ env.REPO_GOU }} path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout Code uses: actions/checkout@v4 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:6 - name: Setup Go Tools run: make tools - name: Setup ENV (SQLite) run: | mkdir -p ${{ github.WORKSPACE }}/../app/db echo "YAO_DB_DRIVER=sqlite3" >> $GITHUB_ENV echo "YAO_DB_PRIMARY=${{ github.WORKSPACE }}/../app/db/yao.db" >> $GITHUB_ENV - name: Run Robot Tests (Unit + E2E) run: make unit-test-robot - name: Codecov Report uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} # ============================================================================= # Sandbox Tests (requires Docker) - Run with Docker-in-Docker # ============================================================================= sandbox-test: runs-on: ubuntu-latest services: mongodb: image: mongo:6.0 ports: - 27017:27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: 123456 MONGO_INITDB_DATABASE: test strategy: matrix: go: ["1.25"] steps: - name: Checkout Kun uses: actions/checkout@v4 with: repository: ${{ env.REPO_KUN }} path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: ${{ env.REPO_XUN }} path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: ${{ env.REPO_GOU }} path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout Code uses: actions/checkout@v4 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:6 - name: Setup Go Tools run: make tools - name: Setup ENV (SQLite) run: | mkdir -p ${{ github.WORKSPACE }}/../app/db echo "YAO_DB_DRIVER=sqlite3" >> $GITHUB_ENV echo "YAO_DB_PRIMARY=${{ github.WORKSPACE }}/../app/db/yao.db" >> $GITHUB_ENV - name: Pull Sandbox Test Images run: | docker pull alpine:latest docker pull yaoapp/sandbox-base:latest || true docker pull yaoapp/sandbox-claude:latest || true - name: Run Sandbox Tests env: YAO_SANDBOX_WORKSPACE: ${{ runner.temp }}/sandbox/workspace YAO_SANDBOX_IPC: ${{ runner.temp }}/sandbox/ipc run: | # Use runner's UID:GID to match host permissions export YAO_SANDBOX_CONTAINER_USER="$(id -u):$(id -g)" make unit-test-sandbox - name: Codecov Report uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} # ============================================================================= # Sandbox V2 Tests (tai SDK + workspace, Docker + K8s via k3d) # Full sandbox/v2 integration tests are run locally. # ============================================================================= sandbox-v2-test: runs-on: ubuntu-latest services: mongodb: image: mongo:6.0 ports: - 27017:27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: 123456 MONGO_INITDB_DATABASE: test strategy: matrix: go: ["1.25"] steps: - name: Checkout Kun uses: actions/checkout@v4 with: repository: ${{ env.REPO_KUN }} path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: ${{ env.REPO_XUN }} path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: ${{ env.REPO_GOU }} path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout Code uses: actions/checkout@v4 - name: Setup Apple Private Key run: | mkdir -p ../app/openapi/certs/apple echo "${{ secrets.APPLE_PRIVATE_KEY_USER }}" > ../app/openapi/certs/apple/signin_client_secret_key.p8 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:6 - name: Setup Go Tools run: make tools - name: Setup ENV (SQLite) run: | mkdir -p ${{ github.WORKSPACE }}/../app/db echo "YAO_DB_DRIVER=sqlite3" >> $GITHUB_ENV echo "YAO_DB_PRIMARY=${{ github.WORKSPACE }}/../app/db/yao.db" >> $GITHUB_ENV - name: Pull Test Images run: | docker pull yaoapp/tai-sandbox-test:latest || true docker pull yaoapp/tai:latest docker pull alpine:latest - name: Install k3d run: curl -s https://raw.githubusercontent.com/k3d-io/k3d/main/install.sh | bash - name: Create k3d cluster run: | k3d cluster create tai-test --no-lb --wait --api-port 16443 kubectl wait --for=condition=Ready node --all --timeout=60s k3d image import alpine:latest -c tai-test - name: Start Tai Docker instance run: | docker run -d --name tai-docker \ -v /var/run/docker.sock:/var/run/docker.sock \ -p 19100:19100 -p 8099:8099 -p 12375:12375 -p 16080:16080 \ yaoapp/tai:latest server -direct \ -grpc 0.0.0.0:19100 -http 0.0.0.0:8099 -vnc 0.0.0.0:16080 -docker 0.0.0.0:12375 for i in $(seq 1 30); do if curl -sf http://127.0.0.1:8099/healthz > /dev/null 2>&1; then echo "Tai Docker HTTP ready"; break fi echo "Waiting for Tai Docker HTTP... ($i)"; sleep 1 done curl -sf http://127.0.0.1:8099/healthz > /dev/null 2>&1 || { echo "::error::Tai Docker HTTP failed"; docker logs tai-docker 2>&1; exit 1 } for i in $(seq 1 15); do if nc -z 127.0.0.1 19100 2>/dev/null; then echo "Tai Docker gRPC ready"; break fi echo "Waiting for Tai Docker gRPC... ($i)"; sleep 1 done nc -z 127.0.0.1 19100 2>/dev/null || { echo "::error::Tai Docker gRPC failed"; docker logs tai-docker 2>&1; exit 1 } - name: Generate kubeconfig for Tai K8s run: | K3D_IP=$(docker inspect k3d-tai-test-server-0 | jq -r '.[0].NetworkSettings.Networks["k3d-tai-test"].IPAddress') echo "k3d server IP: ${K3D_IP}" k3d kubeconfig get tai-test > /tmp/kubeconfig-k3d.yml # Kubeconfig for tai-k8s container (uses k3d-internal IP) sed "s|server: .*|server: https://${K3D_IP}:6443|" /tmp/kubeconfig-k3d.yml \ > /tmp/kubeconfig-tai-k8s.yml echo "Container kubeconfig server:" grep server: /tmp/kubeconfig-tai-k8s.yml # Kubeconfig for test runner (uses localhost via port-mapped 6443) sed 's|server: .*|server: https://127.0.0.1:6443|' /tmp/kubeconfig-k3d.yml \ > ${{ runner.temp }}/kubeconfig-tai.yml echo "Test runner kubeconfig server:" grep server: ${{ runner.temp }}/kubeconfig-tai.yml - name: Start Tai K8s instance run: | K3D_IP=$(docker inspect k3d-tai-test-server-0 | jq -r '.[0].NetworkSettings.Networks["k3d-tai-test"].IPAddress') echo "k3d server IP: ${K3D_IP}" docker run -d --name tai-k8s \ --network k3d-tai-test \ -p 19101:19100 -p 8100:8099 -p 6443:16443 -p 16081:16080 \ -v /var/run/docker.sock:/var/run/docker.sock:ro \ -v /tmp/kubeconfig-tai-k8s.yml:/etc/tai/kubeconfig.yml:ro \ -e TAI_K8S_UPSTREAM="tcp://${K3D_IP}:6443" \ -e TAI_KUBECONFIG=/etc/tai/kubeconfig.yml \ yaoapp/tai:latest server -direct \ -grpc 0.0.0.0:19100 -http 0.0.0.0:8099 -vnc 0.0.0.0:16080 -k8s 0.0.0.0:16443 for i in $(seq 1 30); do if curl -sf http://127.0.0.1:8100/healthz > /dev/null 2>&1; then echo "Tai K8s HTTP ready"; break fi echo "Waiting for Tai K8s HTTP... ($i)"; sleep 1 done curl -sf http://127.0.0.1:8100/healthz > /dev/null 2>&1 || { echo "::error::Tai K8s HTTP failed"; docker logs tai-k8s 2>&1; exit 1 } for i in $(seq 1 15); do if nc -z 127.0.0.1 19101 2>/dev/null; then echo "Tai K8s gRPC ready"; break fi echo "Waiting for Tai K8s gRPC... ($i)"; sleep 1 done nc -z 127.0.0.1 19101 2>/dev/null || { echo "::error::Tai K8s gRPC failed"; docker logs tai-k8s 2>&1; exit 1 } - name: Run Sandbox V2 CI Tests (tai + workspace) env: TAI_TEST_HOST: "127.0.0.1" TAI_TEST_DOCKER: "tcp://127.0.0.1:12375" TAI_TEST_GRPC_PORT: "19100" TAI_TEST_HTTP_PORT: "8099" TAI_TEST_VNC_PORT: "16080" TAI_TEST_DOCKER_PORT: "12375" TAI_TEST_K8S_HOST: "127.0.0.1" TAI_TEST_K8S_PORT: "6443" TAI_TEST_K8S_GRPC_PORT: "19101" TAI_TEST_KUBECONFIG: "${{ runner.temp }}/kubeconfig-tai.yml" TAI_TEST_HOST_IP: "172.17.0.1" SANDBOX_TEST_REMOTE_ADDR: "tai://127.0.0.1:19100" SANDBOX_TEST_IMAGE: "yaoapp/tai-sandbox-test:latest" run: make unit-test-sandbox-v2 - name: Codecov Report if: always() uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false # ============================================================================= # Benchmark & Memory Leak Tests - Run with MySQL8.0 and SQLite3 # ============================================================================= perf-test: runs-on: ubuntu-latest services: mcp-everything: image: yaoapp/mcp-everything:latest ports: - "3021:3021" - "3022:3022" mongodb: image: mongo:6.0 ports: - 27017:27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: 123456 MONGO_INITDB_DATABASE: test strategy: matrix: go: ["1.25"] db: [MySQL8.0, SQLite3] steps: - name: Checkout Kun uses: actions/checkout@v4 with: repository: ${{ env.REPO_KUN }} path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: ${{ env.REPO_XUN }} path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: ${{ env.REPO_GOU }} path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout Code uses: actions/checkout@v4 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:6 - name: Install FFmpeg 7.x run: | wget https://github.com/BtbN/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-linux64-gpl.tar.xz tar -xf ffmpeg-master-latest-linux64-gpl.tar.xz sudo cp ffmpeg-master-latest-linux64-gpl/bin/ffmpeg /usr/local/bin/ sudo cp ffmpeg-master-latest-linux64-gpl/bin/ffprobe /usr/local/bin/ sudo chmod +x /usr/local/bin/ffmpeg /usr/local/bin/ffprobe - name: Test FFmpeg run: ffmpeg -version - name: Setup Go Tools run: make tools - name: Setup ${{ matrix.db }} uses: ./.github/actions/setup-db with: kind: "${{ matrix.db }}" db: "xiang" user: "xiang" password: ${{ secrets.UNIT_PASS }} - name: Setup ENV env: PASSWORD: ${{ secrets.UNIT_PASS }} run: | echo "YAO_DB_DRIVER=$DB_DRIVER" >> $GITHUB_ENV if [ "$DB_DRIVER" = "mysql" ]; then echo "YAO_DB_PRIMARY=$DB_USER:$PASSWORD@$DB_HOST" >> $GITHUB_ENV else echo "YAO_DB_PRIMARY=${{ github.WORKSPACE }}/../app/db/yao.db" >> $GITHUB_ENV mkdir -p ${{ github.WORKSPACE }}/../app/db fi - name: Run Benchmark & Memory Leak Tests run: | make benchmark make memory-leak # ============================================================================= # Core Tests - Run with DB matrix (MySQL/SQLite/Redis/Mongo combinations) # ============================================================================= core-test: runs-on: ubuntu-latest services: qdrant: image: qdrant/qdrant:latest ports: - 6333:6333 # HTTP API - 6334:6334 # gRPC fastembed: image: yaoapp/fastembed:latest-amd64 env: FASTEMBED_PASSWORD: Yao@2026 ports: - 6001:8000 neo4j: image: neo4j:latest ports: - "7687:7687" env: NEO4J_AUTH: neo4j/Yao2026Neo4j mcp-everything: image: yaoapp/mcp-everything:latest ports: - "3021:3021" - "3022:3022" strategy: matrix: go: ["1.25"] db: [MySQL8.0, SQLite3] redis: [4, 5, 6] mongo: ["6.0"] steps: - name: Checkout Kun uses: actions/checkout@v4 with: repository: ${{ env.REPO_KUN }} path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: ${{ env.REPO_XUN }} path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: ${{ env.REPO_GOU }} path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") # Get the directory where the ZIP file is located echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Yao Startup Webapp uses: actions/checkout@v4 with: repository: yaoapp/yao-startup-webapp submodules: true token: ${{ secrets.YAO_TEST_TOKEN }} path: yao-startup-webapp - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Kun, Xun, Gou, V8Go, Extension run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ mv yao-startup-webapp ../ ls -l . ls -l ../ - name: Checkout Code uses: actions/checkout@v4 - name: Setup Apple Private Key run: | mkdir -p ../app/openapi/certs/apple echo "${{ secrets.APPLE_PRIVATE_KEY_USER }}" > ../app/openapi/certs/apple/signin_client_secret_key.p8 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Install FFmpeg 7.x run: | wget https://github.com/BtbN/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-linux64-gpl.tar.xz tar -xf ffmpeg-master-latest-linux64-gpl.tar.xz sudo cp ffmpeg-master-latest-linux64-gpl/bin/ffmpeg /usr/local/bin/ sudo cp ffmpeg-master-latest-linux64-gpl/bin/ffprobe /usr/local/bin/ sudo chmod +x /usr/local/bin/ffmpeg /usr/local/bin/ffprobe - name: Test FFmpeg run: ffmpeg -version - name: Install pdftoppm, mutool, imagemagick run: | sudo apt update sudo apt install -y poppler-utils mupdf-tools imagemagick - name: Test pdftoppm, mutool, imagemagick run: | pdftoppm -v mutool -v convert -version - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:${{ matrix.redis }} - name: Start MongoDB run: | docker run --name mongodb --publish 27017:27017 \ -e MONGO_INITDB_DATABASE=test \ -e MONGO_INITDB_ROOT_USERNAME=root \ -e MONGO_INITDB_ROOT_PASSWORD=123456 \ --detach mongo:${{ matrix.mongo }} # Wait for MongoDB to be ready for i in $(seq 1 20); do if docker exec mongodb mongosh --quiet --port 27017 --username root --password 123456 --eval "db.serverStatus()" > /dev/null 2>&1; then echo "MongoDB is ready" break fi echo "Waiting for MongoDB... ($i)" sleep 1 done - name: Setup MySQL8.0 (connector) uses: ./.github/actions/setup-db with: kind: "MySQL8.0" db: "test" user: "test" password: "123456" port: "3308" - name: Setup ${{ matrix.db }} uses: ./.github/actions/setup-db with: kind: "${{ matrix.db }}" db: "xiang" user: "xiang" password: ${{ secrets.UNIT_PASS }} - name: Setup Go Tools run: | make tools - name: Setup ENV & Host env: PASSWORD: ${{ secrets.UNIT_PASS }} run: | sudo echo "127.0.0.1 local.iqka.com" | sudo tee -a /etc/hosts echo "YAO_DB_DRIVER=$DB_DRIVER" >> $GITHUB_ENV echo "GITHUB_WORKSPACE:\n" && ls -l $GITHUB_WORKSPACE if [ "$DB_DRIVER" = "mysql" ]; then echo "YAO_DB_PRIMARY=$DB_USER:$PASSWORD@$DB_HOST" >> $GITHUB_ENV elif [ "$DB_DRIVER" = "postgres" ]; then echo "YAO_DB_PRIMARY=postgres://$DB_USER:$PASSWORD@$DB_HOST" >> $GITHUB_ENV else echo "YAO_DB_PRIMARY=$YAO_ROOT/$DB_HOST" >> $GITHUB_ENV fi echo ".:\n" && ls -l . echo "..:\n" && ls -l .. echo "../app:\n" && ls -l ../app ping -c 1 -t 1 local.iqka.com - name: Test Prepare run: | make vet make fmt-check make misspell-check - name: Inspect run: | go run . run utils.env.Get MONGO_TEST_HOST go run . run utils.env.Get REDIS_TEST_HOST go run . inspect - name: Run Core Tests (exclude AI) run: make unit-test-core - name: Codecov Report uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos # ============================================================================= # Registry Client SDK Tests (requires Yao Registry Docker service) # ============================================================================= registry-test: runs-on: ubuntu-latest services: yao-registry: image: yaoapp/registry:latest ports: - "8080:8080" env: REGISTRY_INIT_USER: yaoagents REGISTRY_INIT_PASS: yaoagents strategy: matrix: go: ["1.25"] steps: - name: Wait for Registry run: | for i in $(seq 1 15); do if curl -sf http://localhost:8080/.well-known/yao-registry > /dev/null 2>&1; then echo "Registry is ready" break fi echo "Waiting for registry... ($i)" sleep 1 done - name: Checkout Kun uses: actions/checkout@v4 with: repository: ${{ env.REPO_KUN }} path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: ${{ env.REPO_XUN }} path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: ${{ env.REPO_GOU }} path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout Code uses: actions/checkout@v4 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Run Registry Client Tests env: YAO_REGISTRY_URL: http://localhost:8080 YAO_TEST_APPLICATION: ${{ github.workspace }}/../app run: make unit-test-registry - name: Codecov Report uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} # ============================================================================= # gRPC Tests - Run once with SQLite (transport layer, no DB matrix needed) # ============================================================================= grpc-test: runs-on: ubuntu-latest services: mongodb: image: mongo:6.0 ports: - 27017:27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: 123456 MONGO_INITDB_DATABASE: test strategy: matrix: go: ["1.25"] steps: - name: Checkout Kun uses: actions/checkout@v4 with: repository: ${{ env.REPO_KUN }} path: kun - name: Checkout Xun uses: actions/checkout@v4 with: repository: ${{ env.REPO_XUN }} path: xun - name: Checkout Gou uses: actions/checkout@v4 with: repository: ${{ env.REPO_GOU }} path: gou - name: Checkout V8Go uses: actions/checkout@v4 with: repository: yaoapp/v8go path: v8go - name: Unzip libv8 run: | files=$(find ./v8go -name "libv8*.zip") for file in $files; do dir=$(dirname "$file") echo "Extracting $file to directory $dir" unzip -o -d $dir $file rm -rf $dir/__MACOSX done - name: Checkout Demo App uses: actions/checkout@v4 with: repository: yaoapp/yao-dev-app path: app - name: Checkout Extension uses: actions/checkout@v4 with: repository: yaoapp/yao-extensions-dev path: extension - name: Move Dependencies run: | mv kun ../ mv xun ../ mv gou ../ mv v8go ../ mv app ../ mv extension ../ - name: Checkout Code uses: actions/checkout@v4 - name: Setup Apple Private Key run: | mkdir -p ../app/openapi/certs/apple echo "${{ secrets.APPLE_PRIVATE_KEY_USER }}" > ../app/openapi/certs/apple/signin_client_secret_key.p8 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Start Redis run: docker run --name redis --publish 6379:6379 --detach redis:6 - name: Setup Go Tools run: make tools - name: Setup ENV (SQLite) run: | mkdir -p ${{ github.WORKSPACE }}/../app/db echo "YAO_DB_DRIVER=sqlite3" >> $GITHUB_ENV echo "YAO_DB_PRIMARY=${{ github.WORKSPACE }}/../app/db/yao.db" >> $GITHUB_ENV - name: Run gRPC Tests run: make unit-test-grpc - name: Codecov Report uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} ================================================ FILE: .gitignore ================================================ # Binaries for programs and plugins *.exe *.exe~ *.dll *.so *.dylib # Test binary, built with `go test -c` *.test # Output of the go coverage tool, specifically when used with LiteIDE *.out *.log # Dependency directories (remove the comment below to include it) # vendor/ .DS_Store .env* .tmp dist ui ^ui/index.html tests/data/* !tests/data/assets debug.log yao-test dev.sh logs logs/application.log yao-arm64 tests/db/yao.db env.*.sh xgen/v0.9/* xgen/v1.0/* !xgen/v0.9/index.html !xgen/v1.0/index.html !xgen/v1.0/umi.js !xgen/v1.0/layouts__index.async.js !pipe/ui *-unit-test docker/build/test db !agent/search/handlers/db *.sh data/bindata.go.bak share/const.go.bak share/const.goe .cursor openapi/*.md coverage.html agent/assistant/hook/*.test.md agent/search/TODO.md agent/search/job-logs.txt agent/test/MULTI_TURN_DESIGN.md agent/test/UPGRADE_PLAN.md introduction/* !sandbox/docker/build.sh !sandbox/docker/vnc/*.sh !sandbox/docker/desktop/config/*.sh sandbox/docker/yao-bridge-* sandbox/docker/claude-proxy-* sandbox/docker/claude/claude-proxy-* sandbox/proxy/claude-proxy-linux-* release/* sandbox/TODO-VNC.md sandbox/docker/chrome/PLAN.md sandbox/DESIGN-REMOTE.md event/DESIGN.md event/TODO.md agent/robot/DESIGN-V2.md tg-session.json tg-login tg-send registry/data/ registry/manager/DESIGN*.md tai/testdata/ agent/sandbox/docs/*.md tai/docs/refactor-registration.md ================================================ FILE: COMMERCIAL_LICENSE.md ================================================ > **DEPRECATED**: This license is no longer in effect. Please refer to the [LICENSE](LICENSE) file for current licensing terms. # Commercial License for Yao This document outlines the terms for the commercial license of the **Yao** project. While the Yao project is primarily licensed under the **Apache License, Version 2.0**, certain commercial use cases require a separate commercial license. ## 1. Commercial License Requirements The following use cases require a commercial license: 1. **Application Hosting Services** If you use Yao, or any derivative product (such as a forked or modified version of Yao), to provide Yao-based application hosting services (e.g., Software-as-a-Service (SaaS) or Platform-as-a-Service (PaaS)) to users, you must obtain a commercial license. This restriction applies regardless of whether the original Yao code or a modified version is used to host and manage applications on behalf of third-party users for commercial purposes. **In addition**, if you provide hosting services for applications that are built using Yao (even if they are customized or modified versions of Yao), a commercial license is required. ### Definition: Application Hosting Services "Application Hosting Services" refers to any service that involves hosting Yao-based applications or web applications created with Yao (including modified versions of Yao) for third-party users. This includes, but is not limited to: - **Hosting platforms** providing software or services built on top of Yao for third-party users. - **SaaS or PaaS offerings** where you manage and host applications that are based on or utilize Yao, either in their original or modified form. - **Managed hosting services** where Yao is used as the underlying technology for applications deployed for external clients. In these cases, a commercial license is required, whether you are using the original Yao code or a fork/modified version. 2. **AI Web Application Generation Services** If you provide services that generate AI-driven web applications using Yao, or any derivative product (such as a fork or modified version of Yao), to third-party users, you are required to purchase a commercial license. ### Definition: AI Web Application Generation Services "AI Web Application Generation Services" refers to any service or functionality that utilizes Yao (or any forked or modified version of Yao) to automate the creation of web applications with AI capabilities. This includes, but is not limited to, providing third-party users with: - **Automated web application development** driven by AI, where the service generates complete or partial web applications. - **Customizable web solutions** that are powered by AI and built using Yao as the core technology. - **On-demand application generation** for specific client needs, using Yao to dynamically build, configure, or deploy applications for users. In these cases, whether Yao is directly used, forked, or modified, a commercial license is required to operate legally. ## 2. Use Under Apache License 2.0 For all other uses, the **Apache License, Version 2.0** applies. You are free to use, modify, and distribute the Yao project under the terms of Apache 2.0 as long as your usage does not fall within the restricted scenarios outlined above. ## 3. Obtaining a Commercial License To inquire about or obtain a commercial license, please contact us at: - **Email**: [friends@iqka.com] - **Website**: [https://moapi.ai/contact] Pricing and terms for commercial licenses vary based on usage scenarios, user scale, and other factors. ## 4. Compliance and Auditing If you have any questions about whether your use case requires a commercial license, please contact us for clarification. We reserve the right to audit usage for compliance and enforce commercial licensing terms where necessary. ## 5. Disclaimer Failure to comply with these licensing terms may result in a violation of the Yao licensing agreement and could lead to legal action. --- **Note:** This commercial license is supplementary to the Apache 2.0 license and only applies in specific commercial scenarios outlined above. ================================================ FILE: COMMERCIAL_LICENSE.zh-CN.md ================================================ > **已废弃**: 本许可证已不再生效。请参考 [LICENSE](LICENSE) 文件获取当前的许可条款。 # Yao 商业许可证 本文件概述了 **Yao** 项目的商业许可证条款。虽然 Yao 项目主要使用 **Apache 许可证 2.0 版** 授权,但某些商业使用场景需要单独的商业许可证。 ## 1. 商业许可证要求 以下使用场景需要商业许可证: 1. **应用托管服务** 如果您使用 Yao 或其衍生产品(如 Yao 的分支版本或修改版本)为用户提供基于 Yao 的应用托管服务(例如,软件即服务(SaaS)或平台即服务(PaaS)),您必须获得商业许可证。此限制适用于无论是否使用原始 Yao 代码或修改版 Yao 代码,托管和管理应用程序的行为只要是为第三方用户提供的商业目的。 **此外**,如果您提供的托管服务是为使用 Yao 构建的应用程序提供托管服务(即使它们是定制或修改版的 Yao),也需要获得商业许可证。 ### 定义:应用托管服务 "应用托管服务"指任何涉及托管基于 Yao 的应用程序或使用 Yao 创建的 WEB 应用程序(包括 Yao 的修改版本)的服务,服务对象为第三方用户。包括但不限于: - **托管平台** 提供基于 Yao 的软件或服务给第三方用户。 - **SaaS 或 PaaS 服务**,在这些服务中,您管理并托管基于或利用 Yao 的应用程序,可能是原版或修改版。 - **托管服务**,其中 Yao 被用作为客户外部部署应用程序的基础技术。 在这些情况下,无论是使用原始 Yao 代码还是修改版 Yao,都需要获得商业许可证。 2. **AI WEB 应用生成服务** 如果您提供利用 Yao 或其衍生产品(如 Yao 的分支版本或修改版本)为第三方用户生成 AI 驱动的 WEB 应用程序的服务,您需要购买商业许可证。 ### 定义:AI WEB 应用生成服务 "AI WEB 应用生成服务"指任何利用 Yao(或任何分支版本或修改版本的 Yao)自动化创建具有 AI 功能的 WEB 应用程序的服务或功能。包括但不限于,为第三方用户提供以下服务: - **AI 驱动的自动化 WEB 应用开发**,该服务生成完整或部分 WEB 应用程序。 - **可定制的 WEB 解决方案**,这些解决方案由 AI 提供支持,并以 Yao 作为核心技术构建。 - **按需应用生成**,根据特定客户需求,使用 Yao 动态构建、配置或部署应用程序。 在这些情况下,无论是直接使用 Yao,还是使用其分支或修改版,均需要获得商业许可证。 ## 2. 使用 Apache 许可证 2.0 对于所有其他用途,**Apache 许可证 2.0 版** 适用。只要您的使用不属于上述限制的商业场景,您可以自由地根据 Apache 2.0 许可证使用、修改和分发 Yao 项目。 ## 3. 获取商业许可证 如需咨询或获取商业许可证,请通过以下方式联系我们: - **电子邮件**:[friends@iqka.com] - **网站**:[https://moapi.ai/contact](https://moapi.ai/contact) 商业许可证的定价和条款会根据使用场景、用户规模及其他因素有所不同。 ## 4. 合规与审计 如果您对您的使用场景是否需要商业许可证有任何疑问,请联系我们以获取澄清。我们保留审核使用情况以确保合规,并在必要时执行商业许可条款的权利。 ## 5. 免责声明 未遵守这些许可条款可能会导致违反 Yao 许可证协议,并可能导致法律诉讼。 --- **注意:** 此商业许可证是 Apache 2.0 许可证的补充,仅适用于上述特定的商业场景。 ================================================ FILE: LICENSE ================================================ # Open Source License Yao App Engine is licensed under a modified version of the Apache License 2.0, with the following additional conditions: 1. Commercial Usage Terms: Yao App Engine may be utilized commercially, A commercial license from the producer is required if: a. Trademark and Branding Requirements - The Yao App Engine console/application logo and copyright information must not be removed or modified - Logo and copyright information can only be changed with an authorization certificate issued through Yao Developer Certificate b. Authorization Verification Requirements - The Yao certificate verification logic, processes, and related pages (marked in code comments) must be preserved - The complete Yao certificate verification system must be maintained regardless of usage purpose 2. Contributor Agreement: - The producer reserves the right to modify the open-source agreement terms - Contributed code may be used for commercial purposes, including cloud business operations All other rights and restrictions follow the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0). © 2025 Infinite Wisdom Software. ================================================ FILE: Makefile ================================================ GO ?= go GIT ?= git GOFMT ?= gofmt "-s" PACKAGES ?= $(shell $(GO) list ./...) VETPACKAGES ?= $(shell $(GO) list ./... | grep -v /examples/) GOFILES := $(shell find . -name "*.go") VERSION := $(shell grep 'const VERSION =' share/const.go |awk '{print $$4}' |sed 's/\"//g') COMMIT := $(shell git log | head -n 1 | awk '{print substr($$2, 0, 12)}') NOW := $(shell date +"%FT%T%z") OS := $(shell uname) # ROOT_DIR := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) TESTFOLDER := $(shell $(GO) list ./... | grep -vE 'examples|openai|aigc|neo|twilio|share*|registry|agent/sandbox/v2' | awk '!/\/tests\// || /openapi\/tests/' | grep -vE 'openapi/tests/(nodes|sandbox|workspace)') # Core tests (exclude AI-related: agent, aigc, openai, KB, sandbox, registry, grpc, and integrations which require external services) TESTFOLDER_CORE := $(shell $(GO) list ./... | grep -vE 'examples|openai|aigc|neo|twilio|share*|agent|kb|sandbox|integrations|registry|tai|grpc' | awk '!/\/tests\// || /openapi\/tests/' | grep -vE 'openapi/tests/(nodes|sandbox|workspace)') # Agent tests (agent, aigc) - exclude agent/search/handlers/web (requires external API keys), robot packages (tested in robot job), and agent/sandbox/v2 (WIP, has its own job) TESTFOLDER_AGENT := $(shell $(GO) list ./agent/... ./aigc/... | grep -vE 'agent/search/handlers/web|agent/robot/|agent/sandbox/v2') # KB tests (kb) TESTFOLDER_KB := $(shell $(GO) list ./kb/...) # Robot tests (agent/robot/... packages, excluding events/integrations which require Telegram etc.) TESTFOLDER_ROBOT := $(shell $(GO) list ./agent/robot/... | grep -vE 'agent/robot/events') # Sandbox tests (requires Docker) — excludes sandbox/v2 (has its own job) TESTFOLDER_SANDBOX := $(shell $(GO) list ./sandbox/... | grep -v 'sandbox/v2') # Tai SDK tests (requires Tai container with Docker socket) TESTFOLDER_TAI := $(shell $(GO) list ./tai/...) # Workspace tests (requires Tai for remote mode) TESTFOLDER_WORKSPACE := $(shell $(GO) list ./workspace/...) # gRPC tests TESTFOLDER_GRPC := $(shell $(GO) list ./grpc/...) TESTTAGS ?= "" # TESTWIDGETS := $(shell $(GO) list ./widgets/...) # Unit Test (all tests) .PHONY: unit-test unit-test: echo "mode: count" > coverage.out for d in $(TESTFOLDER); do \ $(GO) test -tags $(TESTTAGS) -v -covermode=count -coverprofile=profile.out -coverpkg=$$(echo $$d | sed "s/\/test$$//g") -skip='TestMemoryLeak|TestIsolateDisposal|TestLeak_|TestScenario_' $$d > tmp.out; \ cat tmp.out; \ if grep -q "^--- FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "build failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "setup failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "runtime error" tmp.out; then \ rm tmp.out; \ exit 1; \ fi; \ if [ -f profile.out ]; then \ cat profile.out | grep -v "mode:" >> coverage.out; \ rm profile.out; \ fi; \ done # Core Unit Test (exclude AI-related tests) .PHONY: unit-test-core unit-test-core: echo "mode: count" > coverage.out for d in $(TESTFOLDER_CORE); do \ $(GO) test -tags $(TESTTAGS) -v -covermode=count -coverprofile=profile.out -coverpkg=$$(echo $$d | sed "s/\/test$$//g") -skip='TestMemoryLeak|TestIsolateDisposal|TestLeak_|TestScenario_' $$d > tmp.out; \ cat tmp.out; \ if grep -q "^--- FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "build failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "setup failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "runtime error" tmp.out; then \ rm tmp.out; \ exit 1; \ fi; \ if [ -f profile.out ]; then \ cat profile.out | grep -v "mode:" >> coverage.out; \ rm profile.out; \ fi; \ done # Agent Unit Test (agent, aigc) - excludes robot packages (tested in unit-test-robot) and TestE2E* .PHONY: unit-test-agent unit-test-agent: echo "mode: count" > coverage.out for d in $(TESTFOLDER_AGENT); do \ $(GO) test -tags $(TESTTAGS) -v -timeout=50m -covermode=count -coverprofile=profile.out -coverpkg=$$(echo $$d | sed "s/\/test$$//g") -skip='TestMemoryLeak|TestIsolateDisposal|TestE2E' $$d > tmp.out; \ cat tmp.out; \ if grep -q "^--- FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "^FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "^panic:" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "build failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "setup failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "runtime error" tmp.out; then \ rm tmp.out; \ exit 1; \ fi; \ if [ -f profile.out ]; then \ cat profile.out | grep -v "mode:" >> coverage.out; \ rm profile.out; \ fi; \ done # KB Unit Test (kb) .PHONY: unit-test-kb unit-test-kb: echo "mode: count" > coverage.out for d in $(TESTFOLDER_KB); do \ $(GO) test -tags $(TESTTAGS) -v -timeout=20m -covermode=count -coverprofile=profile.out -coverpkg=$$(echo $$d | sed "s/\/test$$//g") -skip='TestMemoryLeak|TestIsolateDisposal|TestSearchCleanup' $$d > tmp.out; \ cat tmp.out; \ if grep -q "^--- FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "^FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "^panic:" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "build failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "setup failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "runtime error" tmp.out; then \ rm tmp.out; \ exit 1; \ fi; \ if [ -f profile.out ]; then \ cat profile.out | grep -v "mode:" >> coverage.out; \ rm profile.out; \ fi; \ done # Robot Test (all agent/robot/... packages) - runs ALL tests (unit + E2E) with real LLM calls # These tests require: LLM API keys, database, and longer timeout .PHONY: unit-test-robot unit-test-robot: echo "mode: count" > coverage.out for d in $(TESTFOLDER_ROBOT); do \ $(GO) test -tags $(TESTTAGS) -v -timeout=50m -covermode=count -coverprofile=profile.out -coverpkg=$$(echo $$d | sed "s/\/test$$//g") -skip='TestMemoryLeak|TestIsolateDisposal' $$d > tmp.out; \ cat tmp.out; \ if grep -q "^--- FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "^FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "^panic:" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "build failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "setup failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "runtime error" tmp.out; then \ rm tmp.out; \ exit 1; \ fi; \ if [ -f profile.out ]; then \ cat profile.out | grep -v "mode:" >> coverage.out; \ rm profile.out; \ fi; \ done # Registry Client Test (requires Yao Registry service) .PHONY: unit-test-registry unit-test-registry: echo "mode: count" > coverage.out $(GO) test -v -p 1 -timeout=5m -covermode=count -coverprofile=profile.out ./registry/... > tmp.out; \ cat tmp.out; \ if grep -q "^--- FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "^FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "^panic:" tmp.out; then \ rm tmp.out; \ exit 1; \ fi; \ if [ -f profile.out ]; then \ cat profile.out | grep -v "mode:" >> coverage.out; \ rm profile.out; \ fi # --------------------------------------------------------------------------- # Sandbox V2 CI Test (tai SDK + workspace only) # Full sandbox/v2 integration tests (multi-pool, K8s, etc.) are run locally. # --------------------------------------------------------------------------- .PHONY: unit-test-sandbox-v2 unit-test-sandbox-v2: unit-test-tai unit-test-workspace @echo "" @echo "=============================================" @echo "All Sandbox V2 CI tests passed (tai + workspace)" @echo "=============================================" # Workspace Unit Test (requires Tai for remote mode) .PHONY: unit-test-workspace unit-test-workspace: @echo "" @echo "=============================================" @echo "Running Workspace Tests..." @echo "=============================================" echo "mode: count" > coverage.out for d in $(TESTFOLDER_WORKSPACE); do \ $(GO) test -tags $(TESTTAGS) -v -timeout=10m -covermode=count -coverprofile=profile.out -coverpkg=$$(echo $$d | sed "s/\/test$$//g") $$d > tmp.out; \ cat tmp.out; \ if grep -q "^--- FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "^FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "^panic:" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "build failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "setup failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "runtime error" tmp.out; then \ rm tmp.out; \ exit 1; \ fi; \ if [ -f profile.out ]; then \ cat profile.out | grep -v "mode:" >> coverage.out; \ rm profile.out; \ fi; \ done @echo "" @echo "=============================================" @echo "All workspace tests passed" @echo "=============================================" # Sandbox Unit Test (requires Docker) .PHONY: unit-test-sandbox unit-test-sandbox: @echo "" @echo "=============================================" @echo "Running Sandbox Tests (requires Docker)..." @echo "=============================================" @echo "Pulling sandbox test images..." docker pull alpine:latest || true docker pull yaoapp/sandbox-base:latest || true docker pull yaoapp/sandbox-claude:latest || true @echo "" echo "mode: count" > coverage.out for d in $(TESTFOLDER_SANDBOX); do \ $(GO) test -tags $(TESTTAGS) -v -timeout=10m -covermode=count -coverprofile=profile.out -coverpkg=$$(echo $$d | sed "s/\/test$$//g") -skip='TestMemoryLeak|TestIsolateDisposal' $$d > tmp.out; \ cat tmp.out; \ if grep -q "^--- FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "^FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "^panic:" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "build failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "setup failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "runtime error" tmp.out; then \ rm tmp.out; \ exit 1; \ fi; \ if [ -f profile.out ]; then \ cat profile.out | grep -v "mode:" >> coverage.out; \ rm profile.out; \ fi; \ done @echo "" @echo "=============================================" @echo "✅ All sandbox tests passed" @echo "=============================================" # Tai SDK Test (requires Tai container with Docker socket) .PHONY: unit-test-tai unit-test-tai: @echo "" @echo "=============================================" @echo "Running Tai SDK Tests (requires Tai container)..." @echo "=============================================" echo "mode: count" > coverage.out for d in $(TESTFOLDER_TAI); do \ $(GO) test -tags $(TESTTAGS) -v -timeout=5m -covermode=count -coverprofile=profile.out -coverpkg=$$(echo $$d | sed "s/\/test$$//g") $$d > tmp.out; \ cat tmp.out; \ if grep -q "^--- FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "^FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "^panic:" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "build failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "setup failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "runtime error" tmp.out; then \ rm tmp.out; \ exit 1; \ fi; \ if [ -f profile.out ]; then \ cat profile.out | grep -v "mode:" >> coverage.out; \ rm profile.out; \ fi; \ done @echo "" @echo "=============================================" @echo "All Tai SDK tests passed" @echo "=============================================" # Proto codegen .PHONY: proto proto: protoc --go_out=. --go_opt=paths=source_relative \ --go-grpc_out=. --go-grpc_opt=paths=source_relative \ grpc/pb/yao.proto # gRPC Unit Test .PHONY: unit-test-grpc unit-test-grpc: echo "mode: count" > coverage.out for d in $(TESTFOLDER_GRPC); do \ $(GO) test -tags $(TESTTAGS) -v -timeout=10m \ -covermode=count -coverprofile=profile.out \ -coverpkg=$$(echo $$d | sed "s/\/test$$//g") \ -skip='TestMemoryLeak|TestIsolateDisposal|TestLeak_|TestScenario_' \ $$d > tmp.out; \ cat tmp.out; \ if grep -q "^--- FAIL" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "build failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "setup failed" tmp.out; then \ rm tmp.out; \ exit 1; \ elif grep -q "runtime error" tmp.out; then \ rm tmp.out; \ exit 1; \ fi; \ if [ -f profile.out ]; then \ cat profile.out | grep -v "mode:" >> coverage.out; \ rm profile.out; \ fi; \ done # Benchmark Test .PHONY: benchmark benchmark: @echo "" @echo "=============================================" @echo "Running Benchmark Tests (agent, trace, event)..." @echo "=============================================" @for d in $$($(GO) list ./agent/... ./trace/... ./event/...); do \ if $(GO) test -list=Benchmark $$d 2>/dev/null | grep -q "^Benchmark"; then \ echo ""; \ echo "📊 Benchmarking: $$d"; \ echo "---------------------------------------------"; \ $(GO) test -bench=. -benchmem -benchtime=100x -run='^$$' $$d || true; \ fi; \ done @echo "" @echo "=============================================" @echo "✅ All benchmarks completed" @echo "=============================================" # Memory Leak Detection Test .PHONY: memory-leak memory-leak: @echo "" @echo "=============================================" @echo "Running Memory Leak Detection (agent, trace, event)..." @echo "=============================================" @for d in $$($(GO) list ./agent/... ./trace/... ./event/...); do \ if $(GO) test -list='TestMemoryLeak|TestIsolateDisposal|TestGoroutineLeak|TestLeak_|TestScenario_' $$d 2>/dev/null | grep -qE "^Test(MemoryLeak|IsolateDisposal|GoroutineLeak|Leak_|Scenario_)"; then \ echo ""; \ echo "🔍 Memory Leak Detection: $$d"; \ echo "---------------------------------------------"; \ $(GO) test -run='TestMemoryLeak|TestIsolateDisposal|TestGoroutineLeak|TestLeak_|TestScenario_' -v -timeout=5m $$d || exit 1; \ fi; \ done @echo "" @echo "=============================================" @echo "✅ All memory leak tests passed" @echo "=============================================" # Run all tests (unit + benchmark + memory leak) .PHONY: test test: unit-test benchmark memory-leak .PHONY: fmt fmt: $(GOFMT) -w $(GOFILES) .PHONY: fmt-check fmt-check: @diff=$$($(GOFMT) -d $(GOFILES)); \ if [ -n "$$diff" ]; then \ echo "Please run 'make fmt' and commit the result:"; \ echo "$${diff}"; \ exit 1; \ fi; vet: $(GO) vet $(VETPACKAGES) .PHONY: lint lint: @hash golint > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ $(GO) get -u golang.org/x/lint/golint; \ fi for PKG in $(PACKAGES); do golint -set_exit_status $$PKG || exit 1; done; .PHONY: misspell-check misspell-check: @hash misspell > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ $(GO) get -u github.com/client9/misspell/cmd/misspell; \ fi misspell -error $(GOFILES) .PHONY: misspell misspell: @hash misspell > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ $(GO) get -u github.com/client9/misspell/cmd/misspell; \ fi misspell -w $(GOFILES) .PHONY: tools tools: go install golang.org/x/lint/golint@latest; \ go install github.com/client9/misspell/cmd/misspell@latest; \ go install github.com/go-bindata/go-bindata/...@latest; # make plugin .PHONY: plugin plugin: rm -rf $(HOME)/data/gou-unit/plugins rm -rf $(HOME)/data/gou-unit/logs mkdir -p $(HOME)/data/gou-unit/plugins mkdir -p $(HOME)/data/gou-unit/logs GOOS=linux GOARCH=amd64 go build -o $(HOME)/data/gou-unit/plugins/user.so ./tests/plugins/user chmod +x $(HOME)/data/gou-unit/plugins/user.so ls -l $(HOME)/data/gou-unit/plugins ls -l $(HOME)/data/gou-unit/logs $(HOME)/data/gou-unit/plugins/user.so 2>&1 || true # make plugin-mac .PHONY: plugin-mac plugin-mac: rm -rf ./tests/plugins/user/dist rm -rf ./tests/plugins/user.so go build -o ./tests/plugins/user.so ./tests/plugins/user chmod +x ./tests/plugins/user.so # make pack .PHONY: pack pack: bindata fmt .PHONY: bindata bindata: # Setup Workdir rm -rf .tmp/data rm -rf .tmp/yao-init mkdir -p .tmp/data # Checkout init git clone https://github.com/YaoApp/yao-init.git .tmp/yao-init rm -rf .tmp/yao-init/.git rm -rf .tmp/yao-init/.gitignore rm -rf .tmp/yao-init/LICENSE # rm -rf .tmp/yao-init/README.md # Copy Files cp -r .tmp/yao-init .tmp/data/init cp -r ui .tmp/data/ cp -r ui .tmp/data/public cp -r cui .tmp/data/ cp -r yao .tmp/data/ cp -r sui/libsui .tmp/data/ find .tmp/data -name ".DS_Store" -type f -delete go-bindata -fs -pkg data -o data/bindata.go -prefix ".tmp/data/" .tmp/data/... rm -rf .tmp/data rm -rf .tmp/yao-init # make artifacts-linux .PHONY: artifacts-linux artifacts-linux: clean mkdir -p dist/release # Building CUI v1.0 export NODE_ENV=production # rm -f ../cui-v1.0/pnpm-lock.yaml echo "BASE=__yao_admin_root" > ../cui-v1.0/packages/cui/.env cd ../cui-v1.0 && pnpm install --no-frozen-lockfile && pnpm run build # Init Application cd ../yao-init && rm -rf .git cd ../yao-init && rm -rf .gitignore cd ../yao-init && rm -rf LICENSE # cd ../yao-init rm -rf README.md # Switch .env login URLs from dev mode (__yao_admin_root) to release mode (dashboard) sed -i.bak 's|AFTER_LOGIN_SUCCESS_URL="/__yao_admin_root/|# AFTER_LOGIN_SUCCESS_URL="/__yao_admin_root/|g' ../yao-init/.env sed -i.bak 's|AFTER_LOGIN_FAILURE_URL="/__yao_admin_root/|# AFTER_LOGIN_FAILURE_URL="/__yao_admin_root/|g' ../yao-init/.env sed -i.bak 's|# AFTER_LOGIN_SUCCESS_URL="/dashboard/|AFTER_LOGIN_SUCCESS_URL="/dashboard/|g' ../yao-init/.env sed -i.bak 's|# AFTER_LOGIN_FAILURE_URL="/dashboard/|AFTER_LOGIN_FAILURE_URL="/dashboard/|g' ../yao-init/.env rm -f ../yao-init/.env.bak # Yao Builder # Remove Yao Builder - DUI PageBuilder component will provide online design for pure HTML pages or SUI pages in the future. # mkdir -p .tmp/data/builder # curl -o .tmp/yao-builder-latest.tar.gz https://release-sv.yaoapps.com/archives/yao-builder-latest.tar.gz # tar -zxvf .tmp/yao-builder-latest.tar.gz -C .tmp/data/builder # rm -rf .tmp/yao-builder-latest.tar.gz # Packing # ** CUI will be renamed to CUI in the feature. and move to the new repository. ** # ** new repository: https://github.com/YaoApp/cui.git ** mkdir -p .tmp/data/cui cp -r ./ui .tmp/data/ui cp -r ../cui-v1.0/packages/cui/dist .tmp/data/cui/v1.0 cp -r ../yao-init .tmp/data/init cp -r yao .tmp/data/ cp -r sui/libsui .tmp/data/ go-bindata -fs -pkg data -o data/bindata.go -prefix ".tmp/data/" .tmp/data/... rm -rf .tmp/data # Replace PRVERSION sed -ie "s/const PRVERSION = \"DEV\"/const PRVERSION = \"${COMMIT}-${NOW}\"/g" share/const.go @CUI_COMMIT=$$(cd ../cui-v1.0 && git log | head -n 1 | awk '{print substr($$2, 0, 12)}') && \ sed -ie "s/const PRCUI = \"DEV\"/const PRCUI = \"$$CUI_COMMIT-${NOW}\"/g" share/const.go # Making artifacts - dev builds (full debug symbols, ~158M) mkdir -p dist CGO_ENABLED=1 CGO_LDFLAGS="-static" GOOS=linux GOARCH=amd64 go build -v -o dist/yao-${VERSION}-linux-amd64 CGO_ENABLED=1 CGO_LDFLAGS="-static" LD_LIBRARY_PATH=/usr/lib/gcc-cross/aarch64-linux-gnu/13 GOOS=linux GOARCH=arm64 CC=aarch64-linux-gnu-gcc-13 CXX=aarch64-linux-gnu-g++-13 go build -v -o dist/yao-${VERSION}-linux-arm64 # Making artifacts - prod builds (stripped, ~111M) sed -i.tmp 's/const BUILDOPTIONS = ""/const BUILDOPTIONS = "-s -w (production, stripped)"/g' share/const.go && rm -f share/const.go.tmp CGO_ENABLED=1 CGO_LDFLAGS="-static" GOOS=linux GOARCH=amd64 go build -v -ldflags="-s -w" -o dist/yao-${VERSION}-linux-amd64-prod CGO_ENABLED=1 CGO_LDFLAGS="-static" LD_LIBRARY_PATH=/usr/lib/gcc-cross/aarch64-linux-gnu/13 GOOS=linux GOARCH=arm64 CC=aarch64-linux-gnu-gcc-13 CXX=aarch64-linux-gnu-g++-13 go build -v -ldflags="-s -w" -o dist/yao-${VERSION}-linux-arm64-prod mkdir -p dist/release mv dist/yao-*-* dist/release/ chmod +x dist/release/yao-*-* ls -l dist/release/ dist/release/yao-${VERSION}-linux-amd64 version # Reset const # cp -f share/const.goe share/const.go # rm -f share/const.goe # make artifacts-macos .PHONY: artifacts-macos artifacts-macos: clean mkdir -p dist/release # Building CUI v1.0 export NODE_ENV=production # rm -f ../cui-v1.0/pnpm-lock.yaml echo "BASE=__yao_admin_root" > ../cui-v1.0/packages/cui/.env cd ../cui-v1.0 && pnpm install --no-frozen-lockfile && pnpm run build # Init Application cd ../yao-init && rm -rf .git cd ../yao-init && rm -rf .gitignore cd ../yao-init && rm -rf LICENSE # cd ../yao-init && rm -rf README.md # Switch .env login URLs from dev mode (__yao_admin_root) to release mode (dashboard) sed -i.bak 's|AFTER_LOGIN_SUCCESS_URL="/__yao_admin_root/|# AFTER_LOGIN_SUCCESS_URL="/__yao_admin_root/|g' ../yao-init/.env sed -i.bak 's|AFTER_LOGIN_FAILURE_URL="/__yao_admin_root/|# AFTER_LOGIN_FAILURE_URL="/__yao_admin_root/|g' ../yao-init/.env sed -i.bak 's|# AFTER_LOGIN_SUCCESS_URL="/dashboard/|AFTER_LOGIN_SUCCESS_URL="/dashboard/|g' ../yao-init/.env sed -i.bak 's|# AFTER_LOGIN_FAILURE_URL="/dashboard/|AFTER_LOGIN_FAILURE_URL="/dashboard/|g' ../yao-init/.env rm -f ../yao-init/.env.bak # Packing mkdir -p .tmp/data/cui cp -r ./ui .tmp/data/ui cp -r ../cui-v1.0/packages/cui/dist .tmp/data/cui/v1.0 cp -r ../yao-init .tmp/data/init cp -r yao .tmp/data/ cp -r sui/libsui .tmp/data/ go-bindata -fs -pkg data -o data/bindata.go -prefix ".tmp/data/" .tmp/data/... rm -rf .tmp/data # Replace PRVERSION sed -ie "s/const PRVERSION = \"DEV\"/const PRVERSION = \"${COMMIT}-${NOW}\"/g" share/const.go @CUI_COMMIT=$$(cd ../cui-v1.0 && git log | head -n 1 | awk '{print substr($$2, 0, 12)}') && \ sed -ie "s/const PRCUI = \"DEV\"/const PRCUI = \"$$CUI_COMMIT-${NOW}\"/g" share/const.go # Making artifacts - dev builds (full debug symbols) mkdir -p dist CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -v -o dist/yao-${VERSION}-darwin-amd64 CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -v -o dist/yao-${VERSION}-darwin-arm64 # Making artifacts - prod builds (stripped, no UPX on macOS) sed -i.tmp 's/const BUILDOPTIONS = ""/const BUILDOPTIONS = "-s -w (production, stripped)"/g' share/const.go && rm -f share/const.go.tmp CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -v -ldflags="-s -w" -o dist/yao-${VERSION}-darwin-amd64-prod CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -v -ldflags="-s -w" -o dist/yao-${VERSION}-darwin-arm64-prod mkdir -p dist/release mv dist/yao-*-* dist/release/ chmod +x dist/release/yao-*-* ls -l dist/release/ dist/release/yao-${VERSION}-darwin-amd64 version .PHONY: debug debug: clean mkdir -p dist/release # Packing # mkdir -p .tmp/data # cp -r ui .tmp/data/ui # cp -r yao .tmp/data/ # go-bindata -fs -pkg data -o data/bindata.go -prefix ".tmp/data/" .tmp/data/... # rm -rf .tmp/data # Replace PRVERSION sed -ie "s/const PRVERSION = \"DEV\"/const PRVERSION = \"${COMMIT}-${NOW}-debug\"/g" share/const.go # Making artifacts mkdir -p dist CGO_ENABLED=1 go build -v -o dist/release/yao-debug chmod +x dist/release/yao-debug # Reset const cp -f share/const.goe share/const.go rm -f share/const.goe # make prepare (build CUI, yao-init, bindata - shared by release and prod) .PHONY: prepare prepare: clean mkdir -p dist/release mkdir .tmp # Building CUI v0.9 mkdir -p .tmp/cui/v0.9/dist echo "CUI v0.9" > .tmp/cui/v0.9/dist/index.html # Building CUI v1.0 # ** CUI will be renamed to CUI in the feature. and move to the new repository. ** # ** new repository: https://github.com/YaoApp/cui.git ** export NODE_ENV=production git clone https://github.com/YaoApp/cui.git .tmp/cui/v1.0 # cd .tmp/cui/v1.0 && git checkout 5002c3fded585aaa69a4366135b415ea3234964e echo "BASE=__yao_admin_root" > .tmp/cui/v1.0/packages/cui/.env cd .tmp/cui/v1.0 && pnpm install --no-frozen-lockfile && pnpm run build CUI_COMMIT=$$(cd .tmp/cui/v1.0 && git rev-parse --short HEAD) # Checkout init git clone https://github.com/YaoApp/yao-init.git .tmp/yao-init rm -rf .tmp/yao-init/.git rm -rf .tmp/yao-init/.gitignore rm -rf .tmp/yao-init/LICENSE rm -rf .tmp/yao-init/README.md # Switch .env login URLs from dev mode (__yao_admin_root) to release mode (dashboard) sed -i.bak 's|AFTER_LOGIN_SUCCESS_URL="/__yao_admin_root/|# AFTER_LOGIN_SUCCESS_URL="/__yao_admin_root/|g' .tmp/yao-init/.env sed -i.bak 's|AFTER_LOGIN_FAILURE_URL="/__yao_admin_root/|# AFTER_LOGIN_FAILURE_URL="/__yao_admin_root/|g' .tmp/yao-init/.env sed -i.bak 's|# AFTER_LOGIN_SUCCESS_URL="/dashboard/|AFTER_LOGIN_SUCCESS_URL="/dashboard/|g' .tmp/yao-init/.env sed -i.bak 's|# AFTER_LOGIN_FAILURE_URL="/dashboard/|AFTER_LOGIN_FAILURE_URL="/dashboard/|g' .tmp/yao-init/.env rm -f .tmp/yao-init/.env.bak # Yao Builder # Remove Yao Builder - DUI PageBuilder component will provide online design for pure HTML pages or SUI pages in the future. # mkdir -p .tmp/data/builder # curl -o .tmp/yao-builder-latest.tar.gz https://release-sv.yaoapps.com/archives/yao-builder-latest.tar.gz # tar -zxvf .tmp/yao-builder-latest.tar.gz -C .tmp/data/builder # rm -rf .tmp/yao-builder-latest.tar.gz # Packing cp -f data/bindata.go data/bindata.go.bak mkdir -p .tmp/data/cui cp -r ./ui .tmp/data/ui cp -r ./yao .tmp/data/yao cp -r ./sui/libsui .tmp/data/libsui cp -r .tmp/cui/v0.9/dist .tmp/data/cui/v0.9 cp -r .tmp/cui/v1.0/packages/cui/dist .tmp/data/cui/v1.0 cp -r .tmp/yao-init .tmp/data/init go-bindata -fs -pkg data -o data/bindata.go -prefix ".tmp/data/" .tmp/data/... # Replace PRVERSION cp -f share/const.go share/const.go.bak sed -ie "s/const PRVERSION = \"DEV\"/const PRVERSION = \"${COMMIT}-${NOW}\"/g" share/const.go @CUI_COMMIT=$$(cd .tmp/cui/v1.0 && git log | head -n 1 | awk '{print substr($$2, 0, 12)}') && \ sed -ie "s/const PRCUI = \"DEV\"/const PRCUI = \"$$CUI_COMMIT-${NOW}\"/g" share/const.go # make release (development build only, ~158M) .PHONY: release release: prepare # Making artifacts - dev build mkdir -p dist CGO_ENABLED=1 go build -v -o dist/release/yao chmod +x dist/release/yao # Clean up and restore bindata.go and const.go cp data/bindata.go.bak data/bindata.go cp share/const.go.bak share/const.go rm data/bindata.go.bak rm share/const.go.bak rm -rf .tmp # MacOS Application Signing @if [ "$(OS)" = "Darwin" ]; then \ codesign --deep --force --verbose --timestamp --options runtime \ --entitlements .github/codesign/entitlements.plist \ --sign "${APPLE_SIGN}" dist/release/yao ; \ fi # make prod (production build only, ~111M on macOS) .PHONY: prod prod: prepare # Set BUILDOPTIONS @if [ "$$(uname)" = "Linux" ]; then \ sed -i.tmp 's/const BUILDOPTIONS = ""/const BUILDOPTIONS = "-s -w +upx (production, compressed)"/g' share/const.go && rm -f share/const.go.tmp; \ else \ sed -i.tmp 's/const BUILDOPTIONS = ""/const BUILDOPTIONS = "-s -w (production, stripped)"/g' share/const.go && rm -f share/const.go.tmp; \ fi # Making artifacts - prod build mkdir -p dist CGO_ENABLED=1 go build -v -ldflags="-s -w" -o dist/release/yao-prod chmod +x dist/release/yao-prod # UPX compression (Linux only) @if [ "$$(uname)" = "Linux" ]; then \ echo "Compressing with UPX..."; \ if command -v upx > /dev/null 2>&1; then \ upx --best dist/release/yao-prod; \ else \ echo "WARNING: UPX not found. Install with: apt install upx"; \ echo "Skipping compression."; \ fi; \ else \ echo "Note: UPX compression skipped on macOS (not supported)"; \ fi # Clean up and restore bindata.go and const.go cp data/bindata.go.bak data/bindata.go cp share/const.go.bak share/const.go rm data/bindata.go.bak rm share/const.go.bak rm -rf .tmp # MacOS Application Signing @if [ "$(OS)" = "Darwin" ]; then \ codesign --deep --force --verbose --timestamp --options runtime \ --entitlements .github/codesign/entitlements.plist \ --sign "${APPLE_SIGN}" dist/release/yao-prod ; \ fi @echo "" @echo "Done! Production binary:" @ls -lh dist/release/yao-prod @echo "" @echo "Test with: dist/release/yao-prod version --all" # make release-all (build both dev and prod in one go) .PHONY: release-all release-all: prepare # Making artifacts - dev build (~158M) @echo "Building dev binary..." mkdir -p dist CGO_ENABLED=1 go build -v -o dist/release/yao chmod +x dist/release/yao # Making artifacts - prod build (~111M on macOS) @echo "Building prod binary..." @if [ "$$(uname)" = "Linux" ]; then \ sed -i.tmp 's/const BUILDOPTIONS = ""/const BUILDOPTIONS = "-s -w +upx (production, compressed)"/g' share/const.go && rm -f share/const.go.tmp; \ else \ sed -i.tmp 's/const BUILDOPTIONS = ""/const BUILDOPTIONS = "-s -w (production, stripped)"/g' share/const.go && rm -f share/const.go.tmp; \ fi CGO_ENABLED=1 go build -v -ldflags="-s -w" -o dist/release/yao-prod chmod +x dist/release/yao-prod # UPX compression (Linux only) @if [ "$$(uname)" = "Linux" ]; then \ echo "Compressing with UPX..."; \ if command -v upx > /dev/null 2>&1; then \ upx --best dist/release/yao-prod; \ else \ echo "WARNING: UPX not found. Install with: apt install upx"; \ echo "Skipping compression."; \ fi; \ else \ echo "Note: UPX compression skipped on macOS (not supported)"; \ fi # Clean up and restore bindata.go and const.go cp data/bindata.go.bak data/bindata.go cp share/const.go.bak share/const.go rm data/bindata.go.bak rm share/const.go.bak rm -rf .tmp # MacOS Application Signing @if [ "$(OS)" = "Darwin" ]; then \ codesign --deep --force --verbose --timestamp --options runtime \ --entitlements .github/codesign/entitlements.plist \ --sign "${APPLE_SIGN}" dist/release/yao ; \ codesign --deep --force --verbose --timestamp --options runtime \ --entitlements .github/codesign/entitlements.plist \ --sign "${APPLE_SIGN}" dist/release/yao-prod ; \ fi @echo "" @echo "Done! Binaries:" @ls -lh dist/release/yao dist/release/yao-prod @echo "" @echo "Test with:" @echo " dist/release/yao version --all" @echo " dist/release/yao-prod version --all" .PHONY: linux-release linux-release: clean mkdir -p dist/release mkdir .tmp # Building CUI v1.0 # ** CUI will be renamed to CUI in the feature. and move to the new repository. ** # ** new repository: https://github.com/YaoApp/cui.git ** export NODE_ENV=production git clone https://github.com/YaoApp/cui.git .tmp/cui/v1.0 rm -f .tmp/cui/v1.0/pnpm-lock.yaml echo "BASE=__yao_admin_root" > .tmp/cui/v1.0/packages/cui/.env cd .tmp/cui/v1.0 && pnpm install --no-frozen-lockfile && pnpm run build # Setup UI cd .tmp/cui/v1.0/packages/setup && pnpm install --no-frozen-lockfile && pnpm run build # Checkout init git clone https://github.com/YaoApp/yao-init.git .tmp/yao-init rm -rf .tmp/yao-init/.git rm -rf .tmp/yao-init/.gitignore rm -rf .tmp/yao-init/LICENSE rm -rf .tmp/yao-init/README.md # Yao Builder # Remove Yao Builder - DUI PageBuilder component will provide online design for pure HTML pages or SUI pages in the future. # mkdir -p .tmp/data/builder # curl -o .tmp/yao-builder-latest.tar.gz https://release-sv.yaoapps.com/archives/yao-builder-latest.tar.gz # tar -zxvf .tmp/yao-builder-latest.tar.gz -C .tmp/data/builder # rm -rf .tmp/yao-builder-latest.tar.gz # Packing mkdir -p .tmp/data/cui cp -r ./ui .tmp/data/ui cp -r ./yao .tmp/data/yao cp -r .tmp/cui/v0.9/dist .tmp/data/cui/v0.9 cp -r .tmp/cui/v1.0/packages/setup/build .tmp/data/cui/setup cp -r .tmp/cui/v1.0/packages/cui/dist .tmp/data/cui/v1.0 cp -r .tmp/yao-init .tmp/data/init go-bindata -fs -pkg data -o data/bindata.go -prefix ".tmp/data/" .tmp/data/... rm -rf .tmp/data rm -rf .tmp/cui # Making artifacts mkdir -p dist CGO_ENABLED=1 CGO_LDFLAGS="-static" go build -v -o dist/release/yao chmod +x dist/release/yao # make clean .PHONY: clean clean: rm -rf ./tmp rm -rf .tmp rm -rf dist ================================================ FILE: README.md ================================================ # Yao — Build Autonomous Agents. Just Define the Role. Yao is an open-source engine for autonomous agents — event-driven, proactive, and self-scheduling. ![Mission Control](docs/mission-control.png) **Quick Links:** **🏠 Homepage:** [https://yaoapps.com](https://yaoapps.com) **🚀 Quick Start:** [https://yaoapps.com/docs/documentation/en-us/getting-started](https://yaoapps.com/docs/documentation/en-us/getting-started#quickstart) **📚 Documentation:** [https://yaoapps.com/docs](https://yaoapps.com/docs) **✨ Why Yao?** [https://yaoapps.com/docs/why-yao](https://yaoapps.com/docs/documentation/en-us/getting-started/why-yao) **🤖 Yao Agents:** [https://github.com/YaoAgents/awesome](https://github.com/YaoAgents/awesome) ( Preview ) --- ## What Makes Yao Different? | Traditional AI Assistants | Yao Autonomous Agents | | ----------------------------- | ------------------------------------- | | Entry point: Chatbox | Entry point: Email, Events, Schedules | | Passive: You ask, they answer | Proactive: They work autonomously | | Role: Tool | Role: Team member | > The entry point is not a chatbox — it's email, events, and scheduled tasks. --- ## Features ### Autonomous Agent Framework Build agents that work like real team members: - **Three Trigger Modes** — Clock (scheduled), Human (email/message), Event (webhook/database) - **Six-Phase Execution** — Inspiration → Goals → Tasks → Run → Deliver → Learn - **Multi-Agent Orchestration** — Agents delegate, collaborate, and compose dynamically - **Continuous Learning** — Agents accumulate experience in private knowledge bases ### Native MCP Support Integrate tools without writing adapters: - **Process Transport** — Map Yao processes directly to MCP tools - **External Servers** — Connect via SSE or STDIO - **Schema Mapping** — Declarative input/output schemas ### Built-in GraphRAG - **Vector Search** — Embeddings with OpenAI/FastEmbed - **Knowledge Graph** — Entity-relationship retrieval - **Hybrid Search** — Combine vector similarity with graph traversal ### Full-Stack Runtime Everything in a single executable: - **All-in-One** — Data, API, Agent, UI in one engine - **TypeScript Support** — Built-in V8 engine - **Single Binary** — No Node.js, Python, or containers required - **Edge-Ready** — Runs on ARM64/x64 devices ================================================ FILE: README.zh-CN.md ================================================ # Yao [![UnitTest](https://github.com/YaoApp/yao/actions/workflows/unit-test.yml/badge.svg)](https://github.com/YaoApp/yao/actions/workflows/unit-test.yml) [![codecov](https://codecov.io/gh/YaoApp/yao/branch/main/graph/badge.svg?token=294Y05U71J)](https://codecov.io/gh/YaoApp/yao) https://github.com/YaoApp/yao/assets/1842210/6b23ac89-ef6e-4c24-874f-753a98370dec [English](README.md) YAO 是一款开源应用引擎,使用 Golang 编写,以一个命令行工具的形式存在, 下载即用。适合用于开发业务系统、网站/APP API 接口、管理后台、自建低代码平台等。 YAO 采用 flow-based 的编程模式,通过编写 YAO DSL (JSON 格式逻辑描述) 或使用 JavaScript 编写处理器,实现各种功能。 YAO DSL 可以有多种编写方式: 1. 纯手工编写 2. 使用自动化脚本,根据上下文逻辑生成 3. 使用可视化编辑器,通过“拖拉拽”制作 官网: [https://yaoapps.com](https://yaoapps.com) 文档: [https://yaoapps.com/doc](https://yaoapps.com/doc) ## 最新版本下载安装 (推荐) https://github.com/YaoApp/xgen-dev-app ## 演示 ![界面](docs/yao-setup-demo.jpg) 使用 YAO 开发的应用 | 应用 | 简介 | 代码仓库 | | -------------------- | ---------------------------- | --------------------------------------- | | yaoapp/yao-examples | Yao 应用示例 | https://github.com/YaoApp/yao-examples | | yaoapp/yao-knowledge | ChatGPT 驱动的知识管理库应用 | https://github.com/YaoApp/yao-knowledge | | yaoapp/xgen-dev-app | 演示应用 (演示) | https://github.com/YaoApp/xgen-dev-app | | yaoapp/demo-project | 工程项目管理演示应用(演示) | https://github.com/yaoapp/demo-project | | yaoapp/demo-finance | 财务管理演示应用(演示) | https://github.com/yaoapp/demo-finance | | yaoapp/demo-plm | 生产项目管理演示应用(演示) | https://github.com/yaoapp/demo-plm | ## 介绍 Yao 是一个只需使用 JSON 即可创建数据库模型、编写 API 接口、描述管理后台界面的应用引擎,使用 Yao 构建的应用可运行在云端或物联网设备上。 开发者不需要写一行代码,就可以拥有 10 倍生产力。 Yao 基于 **flow-based** 编程思想,采用 **Go** 语言开发,支持多种方式扩展数据流处理器。这使得 Yao 具有极好的**通用性**,大部分场景下可以代替编程语言, 在复用性和编码效率上是传统编程语言的 **10 倍**;应用性能和资源占比上优于 **PHP**, **JAVA** 等语言。 Yao 内置了一套数据管理系统,通过编写 **JSON** 描述界面布局,即可实现 90% 常见界面交互功能,特别适合快速制作各类管理后台、CRM、ERP 等企业内部系统。对于特殊交互功能亦可通过编写扩展组件或 HTML 页面的方式实现。内置管理系统与 Yao 并不耦合,亦可采用 **VUE**, **React** 等任意前端技术实现管理界面。 ## 安装 Yao v0.10.4 使用说明 https://github.com/YaoApp/xgen-dev-app/blob/main/README.zh-CN.md ## 入门指南 详细说明请看[文档](https://yaoapps.com/doc/%E4%BB%8B%E7%BB%8D/%E5%85%A5%E9%97%A8%E6%8C%87%E5%8D%97) ### 创建应用 #### 新建一个空白应用 新建一个应用目录,进入应用目录,运行 `yao start` 命令, 启动安装界面。 ```bash mkdir -p /data/app # 创建应用目录 cd /data/app # 进入应用目录 yao start # 启动安装界面 ``` **默认账号** - 用户名: **xiang@iqka.com** - 密码: **A123456p+** ![安装界面](docs/yao-setup-step2.jpg) ## 关于 Yao Yao 的名字源于汉字**爻(yáo)**,是构成八卦的基本符号。八卦,是上古大神伏羲观测总结自然规律后,创造的一个可以指代万事万物的符号体系。爻,有阴阳两种状态,就像 0 和 1。爻的阴阳转换,驱动八卦更替,以此来总结记录事物的发展规律。 ================================================ FILE: agent/README.md ================================================ # Yao Agent A powerful AI assistant framework for building intelligent conversational agents with tool integration, knowledge base search, and multi-agent orchestration. ## Quick Start ### 1. Create an Assistant ``` assistants/ └── my-assistant/ ├── package.yao # Configuration ├── prompts.yml # System prompts └── locales/ └── en-us.yml # Translations ``` **package.yao** ```json { "name": "{{ name }}", "connector": "gpt-4o", "description": "{{ description }}", "placeholder": { "title": "{{ chat.title }}", "prompts": ["{{ chat.prompts.0 }}"] } } ``` **prompts.yml** ```yaml - role: system content: | You are a helpful assistant. ``` **locales/en-us.yml** ```yaml name: My Assistant description: A helpful AI assistant chat: title: New Chat prompts: - How can I help you today? ``` ### 2. Add Hooks (Optional) Create `src/index.ts` for custom logic: ```typescript import { agent } from "@yao/runtime"; function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { // Preprocess messages before LLM call return { messages }; } function Next(ctx: agent.Context, payload: agent.Payload): agent.Next { // Post-process LLM response return null; } ``` ### 3. Test (Optional) ```bash # Run tests yao agent test -i "Hello, how are you?" # Run tests from JSONL file yao agent test -i tests/inputs.jsonl -v # Extract results for review yao agent extract output-*.jsonl ``` ### 4. Run ```bash yao start ``` Access via API: `POST /v1/chat/completions` ## Examples ### Hook: Route to Specialist ```typescript // src/index.ts function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { const last = messages[messages.length - 1]?.content || ""; if (last.includes("refund")) { return { delegate: { agent_id: "refund-specialist", messages } }; } return null; } ``` ### Database Query ```json // package.yao - Enable auto DB search { "db": { "models": ["orders", "products"] } } ``` ```bash # Test: Agent auto-generates QueryDSL and searches database yao agent test -i "Find orders over $1000 from last month" ``` ### MCP Tools (Process Transport) ```json // mcps/tools.mcp.yao - Define MCP server with Yao Processes { "label": "Tools", "transport": "process", "tools": { "search_orders": "models.order.Paginate", "create_order": "models.order.Create" } } ``` ```json // mcps/mapping/tools/schemes/search_orders.in.yao - Input schema { "type": "object", "properties": { "keyword": { "type": "string" }, "page": { "type": "integer" } }, "x-process-args": [":arguments"] } ``` ```json // package.yao { "mcp": { "servers": [{ "server_id": "tools" }] } } ``` ### Sidebar Page (Display Data) Pages render in the right sidebar during conversation to display structured data: ```html

{{ title }}

{{ row.name }} {{ row.value }}
``` ```bash yao sui build agent # Build pages ``` ```javascript // In hook: send action to open page in sidebar ctx.Send({ type: "action", props: { name: "navigate", payload: { route: "/agents/my-assistant/result", title: "Query Results", query: { id: "123" }, // Passed as $query in page }, }, }); ``` ## Documentation - [Configuration](docs/configuration.md) - Assistant settings, connectors, options - [Prompts](docs/prompts.md) - System prompts and prompt presets - [Hooks](docs/hooks.md) - Create/Next hooks and agent lifecycle - [Context API](docs/context-api.md) - Messaging, memory, trace, MCP - [MCP Integration](docs/mcp.md) - Tool servers and resources - [Models](docs/models.md) - Assistant-scoped data models - [Search](docs/search.md) - Web, knowledge base, and database search - [Pages](docs/pages.md) - Web UI for agents (SUI framework) - [Iframe Integration](docs/iframe.md) - Iframe communication with CUI - [Internationalization](docs/i18n.md) - Multi-language support - [Testing](docs/testing.md) - Agent testing framework ## Architecture ```mermaid flowchart LR subgraph Request A[User Request] end subgraph Create["Create Hook"] B1[Preprocess Messages] B2[Configure LLM] B3[Delegate to Agent] end subgraph LLM["LLM Call"] C1[Load Prompts] C2[Generate Response] end subgraph Tools["Tool Execution"] D1[MCP Tools] D2[Search] D3[Memory] end subgraph Next["Next Hook"] E1[Process Results] E2[Transform Output] E3[Delegate to Agent] end subgraph Response F[Stream Response] end A --> Create Create --> LLM LLM --> Tools Tools --> Next Next --> Response Next -.->|Continue| LLM ``` ## API Endpoints OpenAPI endpoints (base URL: `/v1`): | Endpoint | Method | Description | | -------------------------------------- | ------ | --------------------- | | `/v1/chat/completions` | POST | Chat with assistant | | `/v1/chat/sessions` | GET | List chat sessions | | `/v1/chat/sessions/:chat_id` | GET | Get chat session | | `/v1/chat/sessions/:chat_id/messages` | GET | Get messages | | `/v1/agent/assistants` | GET | List assistants | | `/v1/agent/assistants/:id` | GET | Get assistant details | | `/v1/file/:uploaderID` | POST | Upload files | | `/v1/file/:uploaderID/:fileID` | GET | Get file info | | `/v1/file/:uploaderID/:fileID/content` | GET | Download file | ## License This project is part of the Yao App Engine and follows the [Yao Open Source License](../LICENSE). ================================================ FILE: agent/agent_test.go ================================================ package agent // type customResponseRecorder struct { // *httptest.ResponseRecorder // closeChannel chan bool // } // func (r *customResponseRecorder) CloseNotify() <-chan bool { // return r.closeChannel // } // func newCustomResponseRecorder() *customResponseRecorder { // return &customResponseRecorder{ // ResponseRecorder: httptest.NewRecorder(), // closeChannel: make(chan bool, 1), // } // } // func TestDSL_Prompts(t *testing.T) { // test.Prepare(t, config.Conf) // defer Test_clean(t) // resetDB() // agent := &DSL{ // Prompts: []Prompt{ // {Role: "system", Content: "You are a helpful assistant", Name: "ai"}, // {Role: "user", Content: "Hello", Name: "user"}, // }, // ConversationSetting: conversation.Setting{ // Connector: "default", // Table: "chat_messages", // }, // } // err := agent.newConversation() // assert.NoError(t, err) // prompts := agent.prompts() // assert.Equal(t, 2, len(prompts)) // assert.Equal(t, "system", prompts[0]["role"]) // assert.Equal(t, "You are a helpful assistant", prompts[0]["content"]) // assert.Equal(t, "ai", prompts[0]["name"]) // } // func TestDSL_ChatMessages(t *testing.T) { // test.Prepare(t, config.Conf) // defer Test_clean(t) // resetDB() // agent := &DSL{ // Prompts: []Prompt{ // {Role: "system", Content: "You are a helpful assistant"}, // }, // ConversationSetting: conversation.Setting{ // Connector: "default", // Table: "chat_messages", // }, // } // err := agent.newConversation() // assert.NoError(t, err) // ctx := Context{ // Sid: "test-session", // ChatID: "test-chat", // } // messages, err := agent.chatMessages(ctx, "Hello AI") // assert.NoError(t, err) // assert.Equal(t, 2, len(messages)) // assert.Equal(t, "system", messages[0]["role"]) // assert.Equal(t, "user", messages[1]["role"]) // assert.Equal(t, "Hello AI", messages[1]["content"]) // } // func TestDSL_Answer(t *testing.T) { // test.Prepare(t, config.Conf) // defer Test_clean(t) // gin.SetMode(gin.TestMode) // w := newCustomResponseRecorder() // c, _ := gin.CreateTestContext(w) // ctx := Context{ // Sid: "test-session", // ChatID: "test-chat", // Context: context.Background(), // } // resetDB() // agent := &DSL{ // Connector: "gpt-3_5-turbo", // Option: map[string]interface{}{ // "temperature": 0.7, // "max_tokens": 150, // }, // Prompts: []Prompt{ // {Role: "system", Content: "You are a helpful assistant"}, // }, // ConversationSetting: conversation.Setting{ // Connector: "default", // Table: "chat_messages", // }, // } // err := agent.newAI() // assert.NoError(t, err) // err = agent.newConversation() // assert.NoError(t, err) // c.Request = httptest.NewRequest("POST", "/chat", nil) // agent.AI = &mockAI{} // err = agent.Answer(ctx, "Hello AI", c) // assert.NoError(t, err) // } // // func TestDSL_NewAI(t *testing.T) { // // test.Prepare(t, config.Conf) // // defer Test_clean(t) // // tests := []struct { // // name string // // connector string // // wantErr string // // }{ // // { // // name: "Mock AI", // // connector: "mock", // // wantErr: "", // // }, // // { // // name: "Specific mock model", // // connector: "mock:gpt-4", // // wantErr: "", // // }, // // { // // name: "Invalid connector", // // connector: "invalid-connector", // // wantErr: "AI connector invalid-connector not found", // // }, // // } // // for _, tt := range tests { // // t.Run(tt.name, func(t *testing.T) { // // agent := &DSL{ // // Connector: tt.connector, // // } // // agent.newConversation() // // assert.Panics(t, func() { // // agent.newAI() // // }) // // }) // // } // // } // func TestDSL_Select(t *testing.T) { // test.Prepare(t, config.Conf) // defer Test_clean(t) // resetDB() // agent := &DSL{ // ConversationSetting: conversation.Setting{ // Connector: "default", // Table: "chat_messages", // }, // } // err := agent.newConversation() // assert.NoError(t, err) // err = agent.Select("invalid-model") // assert.Error(t, err) // // err = agent.Select("gpt-3_5-turbo") // // assert.NoError(t, err) // // assert.NotNil(t, agent.AI) // } // // func TestDSL_NewConversation(t *testing.T) { // // test.Prepare(t, config.Conf) // // defer Test_clean(t) // // tests := []struct { // // name string // // connector string // // wantErr bool // // }{ // // { // // name: "Default connector", // // connector: "default", // // wantErr: false, // // }, // // { // // name: "Empty connector", // // connector: "", // // wantErr: false, // // }, // // { // // name: "Invalid connector", // // connector: "invalid-connector", // // wantErr: true, // // }, // // } // // for _, tt := range tests { // // t.Run(tt.name, func(t *testing.T) { // // agent := &DSL{ // // ConversationSetting: conversation.Setting{ // // Connector: tt.connector, // // }, // // } // // assert.Panics(t, func() { // // agent.newConversation() // // }) // // }) // // } // // } // func TestDSL_SaveHistory(t *testing.T) { // test.Prepare(t, config.Conf) // defer Test_clean(t) // agent := &DSL{ // ConversationSetting: conversation.Setting{ // Connector: "default", // Table: "chat_messages", // }, // } // resetDB() // err := agent.newConversation() // assert.NoError(t, err) // messages := []map[string]interface{}{ // { // "role": "user", // "content": "Hello", // "name": "test-user", // }, // } // content := []byte("Hi there!") // agent.saveHistory("test-session", "test-chat", content, messages) // // Verify the history was saved // history, err := agent.Conversation.GetHistory("test-session", "test-chat") // assert.NoError(t, err) // assert.NotEmpty(t, history) // } // func TestDSL_Send(t *testing.T) { // test.Prepare(t, config.Conf) // defer Test_clean(t) // gin.SetMode(gin.TestMode) // w := httptest.NewRecorder() // c, _ := gin.CreateTestContext(w) // resetDB() // agent := &DSL{ // ConversationSetting: conversation.Setting{ // Connector: "default", // Table: "chat_messages", // }, // } // err := agent.newConversation() // assert.NoError(t, err) // ctx := Context{ // Sid: "test-session", // ChatID: "test-chat", // } // msg := &message.JSON{ // Message: &message.Message{Text: "Test message"}, // } // messages := []map[string]interface{}{ // {"role": "user", "content": "Hello"}, // } // content := []byte("Test content") // err = agent.send(ctx, msg, messages, content, c) // assert.NoError(t, err) // } // func Test_clean(t *testing.T) { // defer test.Clean() // } // func resetDB() { // sch := capsule.Global.Schema() // sch.DropTable("chat_messages") // } // type mockAI struct{} // func (m *mockAI) ChatCompletionsWith(ctx context.Context, messages []map[string]interface{}, options map[string]interface{}, callback func([]byte) int) (interface{}, *exception.Exception) { // callback([]byte(`{"choices":[{"delta":{"content":"Mock response"}}]}`)) // callback([]byte(`{"choices":[{"finish_reason":"stop"}]}`)) // return nil, nil // } // func (m *mockAI) ChatCompletions(messages []map[string]interface{}, options map[string]interface{}, callback func([]byte) int) (interface{}, *exception.Exception) { // return nil, nil // } // func (m *mockAI) GetContent(response interface{}) (string, *exception.Exception) { // return "Mock content", nil // } // func (m *mockAI) Embeddings(input interface{}, user string) (interface{}, *exception.Exception) { // return nil, nil // } // func (m *mockAI) Tiktoken(input string) (int, error) { // return 0, nil // } // func (m *mockAI) MaxToken() int { // return 4096 // } ================================================ FILE: agent/assistant/agent.go ================================================ package assistant import ( "fmt" "log" "time" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/connector" goullm "github.com/yaoapp/gou/llm" "github.com/yaoapp/yao/agent/assistant/handlers" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/agent/llm" "github.com/yaoapp/yao/agent/output/message" agentsandbox "github.com/yaoapp/yao/agent/sandbox" sandboxTypes "github.com/yaoapp/yao/agent/sandbox/v2/types" infraV2 "github.com/yaoapp/yao/sandbox/v2" ) // Stream stream the agent // handler is optional, if not provided, a default handler will be used func (ast *Assistant) Stream(ctx *context.Context, inputMessages []context.Message, options ...*context.Options) (*context.Response, error) { // Update logger with assistant ID and start logging ctx.Logger.SetAssistantID(ast.ID) ctx.Logger.Start() // Validate user permissions var err error err = ast.checkPermissions(ctx) if err != nil { return nil, err } // Start stream time streamStartTime := time.Now() // Set up interrupt handler if interrupt controller is available // InterruptController handles user interrupt signals (stop button) for appending messages // HTTP context cancellation is handled naturally by LLM/Agent layers if ctx.Interrupt != nil { ctx.Interrupt.SetHandler(func(c *context.Context, signal *context.InterruptSignal) error { return ast.handleInterrupt(c, signal) }) } // ================================================ // Initialize // ================================================ ctx.Logger.Phase("Initialize") // Get or create options var opts *context.Options if len(options) > 0 && options[0] != nil { opts = options[0] } else { opts = &context.Options{} } // Merge caller-provided metadata into ctx so sub-agent hooks can read it via ctx.metadata ctx.MergeMetadata(opts.Metadata) // Initialize stack and auto-handle completion/failure/restore _, _, done := context.EnterStack(ctx, ast.ID, opts) defer done() // Auto-skip history for forked Agent-to-Agent calls (ctx.agent.Call/All/Any/Race) // This ensures forked A2A messages don't pollute chat history. // Delegate calls (RefererAgent) still save history as they are part of the main conversation flow. // Note: Output is NOT skipped - sub-agents output normally with ThreadID for UI separation. if ctx.IsForkedA2ACall() { if opts == nil { opts = &context.Options{} } opts.ForceA2A() } // ================================================ // Initialize Chat Buffer (for root stack only) // Buffer is flushed in defer block at the end // ================================================ ast.InitBuffer(ctx) // Track final status for buffer flush var finalStatus = context.StepStatusCompleted var finalError error // Defer buffer flush - always executes on exit (success, error, interrupt, panic) defer func() { // Handle panic recovery for status tracking if r := recover(); r != nil { finalStatus = context.ResumeStatusFailed if e, ok := r.(error); ok { finalError = e } else { finalError = fmt.Errorf("panic: %v", r) } ctx.Logger.Error("Panic recovered in Stream: %v", r) // Re-panic after flush to preserve original behavior defer panic(r) } // Flush buffer to database ast.FlushBuffer(ctx, finalStatus, finalError) // Log end of request ctx.Logger.End(finalStatus == context.StepStatusCompleted, finalError) ctx.Logger.RestoreAssistantID() }() // Determine stream handler streamHandler := ast.getStreamHandler(ctx, opts) // Get connector and capabilities early (before sending stream_start) // so that output adapters can use them when converting stream_start event err = ast.initializeCapabilities(ctx, opts) if err != nil { finalStatus = context.ResumeStatusFailed finalError = err ast.sendStreamEndOnError(ctx, streamHandler, streamStartTime, err) return nil, err } // Send ChunkStreamStart only for root stack (agent-level stream start) // Now ctx.Capabilities is set, so output adapters can use it ast.sendAgentStreamStart(ctx, streamHandler, streamStartTime) // Initialize chat, prepare kb collection (optional) etc. // Use async version to not block the main flow ast.InitializeConversationAsync(ctx, opts) ctx.Logger.PhaseComplete("Initialize") // Ensure chat session exists ast.EnsureChat(ctx) // Initialize agent trace node agentNode := ast.initAgentTraceNode(ctx, inputMessages) // ================================================ // Get Full Messages with chat history // ================================================ ctx.Logger.Phase("History") historyResult, err := ast.WithHistory(ctx, inputMessages, agentNode, opts) if err != nil { ast.traceAgentFail(agentNode, err) ast.sendStreamEndOnError(ctx, streamHandler, streamStartTime, err) return nil, err } fullMessages := historyResult.FullMessages // Buffer user input messages (use cleaned input without overlap) // Skip if History is disabled in options (for internal calls like needsearch) // Note: For A2A calls, ForceA2A() sets skip.history = true, so this will be skipped if opts == nil || opts.Skip == nil || !opts.Skip.History { ast.BufferUserInput(ctx, historyResult.InputMessages) } ctx.Logger.PhaseComplete("History") // ================================================ // Initialize Sandbox (if configured) // ================================================ // Sandbox must be created BEFORE hooks so that hooks can access ctx.sandbox var sandboxExecutor agentsandbox.Executor var sandboxCleanup func() var sandboxLoadingMsgID string // V2 sandbox state var v2Runner sandboxTypes.Runner var v2Computer infraV2.Computer var v2LoadingMsgID string if ast.HasSandboxV2() { ctx.Logger.Phase("Sandbox V2") var err error var v2Cleanup func() v2Runner, v2Computer, v2Cleanup, v2LoadingMsgID, err = ast.initSandboxV2(ctx, opts) if err != nil { ast.traceAgentFail(agentNode, err) ast.sendStreamEndOnError(ctx, streamHandler, streamStartTime, err) return nil, err } sandboxCleanup = v2Cleanup ctx.Logger.PhaseComplete("Sandbox V2") } else if ast.HasSandbox() { ctx.Logger.Phase("Sandbox") var err error sandboxExecutor, sandboxCleanup, sandboxLoadingMsgID, err = ast.initSandbox(ctx, opts) if err != nil { ast.traceAgentFail(agentNode, err) ast.sendStreamEndOnError(ctx, streamHandler, streamStartTime, err) return nil, err } // Set sandbox executor in context so hooks can access ctx.sandbox // The executor implements both agentsandbox.Executor and context.SandboxExecutor ctx.SetSandboxExecutor(sandboxExecutor) ctx.Logger.PhaseComplete("Sandbox") } // Ensure sandbox cleanup on exit defer func() { if sandboxCleanup != nil { sandboxCleanup() } }() // ================================================ // Standalone Workspace Loading (no sandbox required) // ================================================ // When no sandbox is configured but the user selected a workspace, // load the workspace FS into context so hooks can access ctx.workspace. if !ctx.HasWorkspace() { ast.initStandaloneWorkspace(ctx) } // ================================================ // Execute Create Hook // ================================================ // Request Create hook ( Optional ) var createResponse *context.HookCreateResponse if ast.HookScript != nil { ctx.Logger.HookStart("Create") // Begin step tracking for hook_create ast.BeginStep(ctx, context.StepTypeHookCreate, map[string]interface{}{ "messages": fullMessages, }) var err error createResponse, opts, err = ast.HookScript.Create(ctx, fullMessages, opts) if err != nil { finalStatus = context.ResumeStatusFailed finalError = err ast.traceAgentFail(agentNode, err) // Send error stream_end for root stack ast.sendStreamEndOnError(ctx, streamHandler, streamStartTime, err) return nil, err } // Complete step ast.CompleteStep(ctx, map[string]interface{}{ "response": createResponse, }) // Log the create response ast.traceCreateHook(agentNode, createResponse) ctx.Logger.HookComplete("Create") // Check if Create hook wants to delegate to another agent // This allows early routing to sub-agents without LLM call if createResponse != nil && createResponse.Delegate != nil { ctx.Logger.Debug("Create hook delegating to agent: %s", createResponse.Delegate.AgentID) // Delegate to target agent (reuse existing delegation logic from next.go) // Note: User input is already buffered by root agent, delegated agent will skip buffering delegateResponse, err := ast.handleDelegation(ctx, createResponse.Delegate, streamHandler) if err != nil { finalStatus = context.ResumeStatusFailed finalError = err ast.traceAgentFail(agentNode, err) ast.sendStreamEndOnError(ctx, streamHandler, streamStartTime, err) return nil, err } // For root stack, send stream_end and close output // (delegated agent handles its own stream events, but root needs to close) if ctx.Stack != nil && ctx.Stack.IsRoot() { ast.sendAgentStreamEnd(ctx, streamHandler, streamStartTime, "completed", nil, nil) if err := ctx.CloseOutput(); err != nil { if trace, _ := ctx.Trace(); trace != nil { trace.Error(i18n.Tr(ast.ID, ctx.Locale, "assistant.agent.stream.close_error"), map[string]any{"error": err.Error()}) } } } // Return delegated response directly (skip LLM call and Next hook) return delegateResponse, nil } } // ================================================ // Execute LLM Call Stream // ================================================ // LLM Call Stream ( Optional ) var completionResponse *context.CompletionResponse var completionMessages []context.Message var completionOptions *context.CompletionOptions if ast.Prompts != nil || ast.MCP != nil { ctx.Logger.Phase("LLM") // Build the LLM request first (use fullMessages which includes history) // Note: completionMessages here are still in original format (with __yao.attachment:// URLs) // Content conversion (BuildContent) happens inside executeLLMStream, right before LLM call // This ensures autoSearch and delegate receive original messages, not converted ones completionMessages, completionOptions, err = ast.BuildRequest(ctx, fullMessages, createResponse) if err != nil { finalStatus = context.ResumeStatusFailed finalError = err ast.traceAgentFail(agentNode, err) // Send error stream_end for root stack ast.sendStreamEndOnError(ctx, streamHandler, streamStartTime, err) return nil, err } // ================================================ // Execute Auto Search (if enabled) // ================================================ if intent := ast.shouldAutoSearch(ctx, completionMessages, createResponse, opts); intent != nil { refCtx := ast.executeAutoSearch(ctx, completionMessages, createResponse, intent, opts) if refCtx != nil && len(refCtx.References) > 0 { completionMessages = ast.injectSearchContext(completionMessages, refCtx) } } // Begin step tracking for LLM call ast.BeginStep(ctx, context.StepTypeLLM, map[string]interface{}{ "messages": completionMessages, }) // Execute the LLM streaming call // Choose between sandbox execution or direct LLM execution if ast.HasSandboxV2() && v2Runner != nil && v2Computer != nil && v2Runner.Name() != "yao" { // V2 Sandbox execution path (non-yao runners replace LLM.Stream) completionResponse, err = ast.executeSandboxV2Stream(ctx, completionMessages, agentNode, streamHandler, v2Runner, v2Computer, v2LoadingMsgID) } else if ast.HasSandboxV2() && v2Runner != nil && v2Runner.Name() == "yao" { // V2 yao runner: Prepare is done, close loading, fall through to LLM if v2LoadingMsgID != "" { closeLoadingV2(ctx, v2LoadingMsgID, "") } completionResponse, err = ast.executeLLMStream(ctx, completionMessages, completionOptions, agentNode, streamHandler, opts) } else if ast.HasSandbox() { // V1 Sandbox execution path (Claude CLI, Cursor CLI, etc.) completionResponse, err = ast.executeSandboxStream(ctx, completionMessages, agentNode, streamHandler, sandboxExecutor, sandboxLoadingMsgID) } else { // Direct LLM execution path completionResponse, err = ast.executeLLMStream(ctx, completionMessages, completionOptions, agentNode, streamHandler, opts) } if err != nil { finalStatus = context.ResumeStatusFailed finalError = err ast.traceAgentFail(agentNode, err) // Send error stream_end for root stack ast.sendStreamEndOnError(ctx, streamHandler, streamStartTime, err) return nil, err } // Complete LLM step ast.CompleteStep(ctx, map[string]interface{}{ "content": completionResponse.Content, "tool_calls": completionResponse.ToolCalls, }) hasToolCalls := completionResponse != nil && completionResponse.ToolCalls != nil && len(completionResponse.ToolCalls) > 0 tokens := 0 if completionResponse != nil && completionResponse.Usage != nil { tokens = completionResponse.Usage.TotalTokens } ctx.Logger.LLMComplete(tokens, hasToolCalls) ctx.Logger.PhaseComplete("LLM") } // ================================================ // Execute tool calls with retry // ================================================ // Note: Skip MCP tool calls execution for sandbox mode - Claude CLI handles them internally var toolCallResponses []context.ToolCallResponse = nil if completionResponse != nil && completionResponse.ToolCalls != nil && !ast.HasSandbox() { maxToolRetries := 3 currentMessages := completionMessages currentResponse := completionResponse for attempt := 0; attempt < maxToolRetries; attempt++ { // Begin step tracking for tool calls ast.BeginStep(ctx, context.StepTypeTool, map[string]interface{}{ "tool_calls": currentResponse.ToolCalls, "attempt": attempt, }) // Execute all tool calls toolResults, hasErrors := ast.executeToolCalls(ctx, currentResponse.ToolCalls, attempt) // Build a map of tool call ID to arguments for quick lookup toolCallArgsMap := make(map[string]interface{}) for _, tc := range currentResponse.ToolCalls { toolCallArgsMap[tc.ID] = tc.Function.Arguments } // Convert toolResults to toolCallResponses toolCallResponses = make([]context.ToolCallResponse, len(toolResults)) for i, result := range toolResults { parsedContent, _ := result.ParsedContent() toolCallResponses[i] = context.ToolCallResponse{ ToolCallID: result.ToolCallID, Server: result.Server(), Tool: result.Tool(), Arguments: toolCallArgsMap[result.ToolCallID], Result: parsedContent, Error: "", } if result.Error != nil { toolCallResponses[i].Error = result.Error.Error() } } // If all successful, complete step and break out if !hasErrors { ast.CompleteStep(ctx, map[string]interface{}{ "results": toolCallResponses, }) ctx.Logger.Debug("All tool calls succeeded (attempt %d)", attempt) break } // Check if any errors are retryable (parameter/validation issues) hasRetryableErrors := false for _, result := range toolResults { if result.Error != nil && result.IsRetryableError { hasRetryableErrors = true break } } // If no retryable errors, don't retry (MCP internal issues) if !hasRetryableErrors { err := fmt.Errorf("tool calls failed with non-retryable errors (MCP internal issues)") finalStatus = context.ResumeStatusFailed finalError = err ctx.Logger.Error("Tool calls failed: %v", err) ast.traceAgentFail(agentNode, err) ast.sendStreamEndOnError(ctx, streamHandler, streamStartTime, err) return nil, err } // If it's the last attempt, return error if attempt == maxToolRetries-1 { err := fmt.Errorf("tool calls failed after %d attempts", maxToolRetries) finalStatus = context.ResumeStatusFailed finalError = err ctx.Logger.Error("Tool calls failed: %v", err) ast.traceAgentFail(agentNode, err) ast.sendStreamEndOnError(ctx, streamHandler, streamStartTime, err) return nil, err } // Complete current step (with partial results) ast.CompleteStep(ctx, map[string]interface{}{ "results": toolCallResponses, "has_errors": true, }) // Build retry messages with tool call results (including errors) retryMessages := ast.buildToolRetryMessages(currentMessages, currentResponse, toolResults) // Begin LLM retry step ast.BeginStep(ctx, context.StepTypeLLM, map[string]interface{}{ "messages": retryMessages, "retry_attempt": attempt + 1, }) // Retry LLM call (streaming to keep user informed) ctx.Logger.Debug("Retrying LLM for tool call correction (attempt %d/%d)", attempt+1, maxToolRetries-1) currentResponse, err = ast.executeLLMForToolRetry(ctx, retryMessages, completionOptions, agentNode, streamHandler, opts) if err != nil { finalStatus = context.ResumeStatusFailed finalError = err ctx.Logger.Error("LLM retry failed: %v", err) ast.traceAgentFail(agentNode, err) ast.sendStreamEndOnError(ctx, streamHandler, streamStartTime, err) return nil, err } // If LLM didn't return tool calls, it might have given up if currentResponse.ToolCalls == nil { err := fmt.Errorf("LLM did not return tool calls in retry attempt %d", attempt+1) finalStatus = context.ResumeStatusFailed finalError = err ctx.Logger.Error("LLM did not return tool calls: %v", err) ast.traceAgentFail(agentNode, err) ast.sendStreamEndOnError(ctx, streamHandler, streamStartTime, err) return nil, err } // Complete LLM retry step ast.CompleteStep(ctx, map[string]interface{}{ "content": currentResponse.Content, "tool_calls": currentResponse.ToolCalls, }) // Update messages for next iteration currentMessages = retryMessages } // Update completionResponse with the final successful response completionResponse = currentResponse } // ================================================ // Execute Next Hook and Process Response // ================================================ var finalResponse *context.Response var nextResponse *context.NextHookResponse = nil if ast.HookScript != nil { ctx.Logger.HookStart("Next") // Begin step tracking for hook_next ast.BeginStep(ctx, context.StepTypeHookNext, map[string]interface{}{ "messages": fullMessages, "completion": completionResponse, "tools": toolCallResponses, }) var err error nextResponse, opts, err = ast.HookScript.Next(ctx, &context.NextHookPayload{ Messages: fullMessages, Completion: completionResponse, Tools: toolCallResponses, }, opts) if err != nil { finalStatus = context.ResumeStatusFailed finalError = err ast.traceAgentFail(agentNode, err) ast.sendStreamEndOnError(ctx, streamHandler, streamStartTime, err) return nil, err } // Complete hook_next step ast.CompleteStep(ctx, map[string]interface{}{ "response": nextResponse, }) ctx.Logger.HookComplete("Next") // Process Next hook response finalResponse, err = ast.processNextResponse(&NextProcessContext{ Context: ctx, NextResponse: nextResponse, CompletionResponse: completionResponse, FullMessages: fullMessages, ToolCallResponses: toolCallResponses, StreamHandler: streamHandler, CreateResponse: createResponse, }) if err != nil { finalStatus = context.ResumeStatusFailed finalError = err ast.traceAgentFail(agentNode, err) ast.sendStreamEndOnError(ctx, streamHandler, streamStartTime, err) return nil, err } } else { // No Next hook: use standard response finalResponse = ast.buildStandardResponse(&NextProcessContext{ Context: ctx, NextResponse: nil, CompletionResponse: completionResponse, FullMessages: fullMessages, ToolCallResponses: toolCallResponses, StreamHandler: streamHandler, CreateResponse: createResponse, }) } // Create completion node to report final output ast.traceAgentCompletion(ctx, createResponse, nextResponse, completionResponse, finalResponse) // Only close output and send stream_end if this is the root call (entry point) // Nested calls (from MCP, hooks, etc.) should not close the output or send stream_end // Note: Flush is already handled by the stream handler (handleStreamEnd) if ctx.Stack != nil && ctx.Stack.IsRoot() { // Log closing output for root call if trace, _ := ctx.Trace(); trace != nil { trace.Debug("Agent: Closing output (root call)", map[string]any{ "stack_id": ctx.Stack.ID, "depth": ctx.Stack.Depth, "assistant_id": ctx.Stack.AssistantID, }) } // Send ChunkStreamEnd (agent-level stream completion) ast.sendAgentStreamEnd(ctx, streamHandler, streamStartTime, "completed", nil, completionResponse) // Close the output writer to send [DONE] marker if err := ctx.CloseOutput(); err != nil { if trace, _ := ctx.Trace(); trace != nil { trace.Error(i18n.Tr(ast.ID, ctx.Locale, "assistant.agent.stream.close_error"), map[string]any{"error": err.Error()}) // "Failed to close output" } } } else { // Log skipping close for nested call if trace, _ := ctx.Trace(); trace != nil && ctx.Stack != nil { trace.Debug("Agent: Skipping output close (nested call)", map[string]any{ "stack_id": ctx.Stack.ID, "depth": ctx.Stack.Depth, "parent_id": ctx.Stack.ParentID, "assistant_id": ctx.Stack.AssistantID, }) } } // Return finalResponse which could be: // 1. Result from delegated agent call (already a Response) // 2. Custom data from Next hook (wrapped in standard Response) // 3. Standard response return finalResponse, nil } // GetConnector get the connector object, capabilities, and error with priority: // opts.Connector > ast.Connector > defaultConnector (fallback) // Note: opts.Connector may be set by Create hook's applyOptionsAdjustments // Returns: (connector, capabilities, error) func (ast *Assistant) GetConnector(ctx *context.Context, opts ...*context.Options) (connector.Connector, *goullm.Capabilities, error) { connectorID := ast.Connector if len(opts) > 0 && opts[0] != nil && opts[0].Connector != "" { connectorID = opts[0].Connector } if connectorID == "" { connectorID = defaultConnector } if connectorID == "" { return nil, nil, fmt.Errorf("connector not specified") } conn, err := connector.Select(connectorID) if err != nil && connectorID != defaultConnector && defaultConnector != "" { log.Printf("[Assistant] connector %q not found, falling back to default %q", connectorID, defaultConnector) conn, err = connector.Select(defaultConnector) } if err != nil { return nil, nil, err } capabilities := llm.GetCapabilitiesFromConn(conn) return conn, capabilities, nil } // Info get the assistant information func (ast *Assistant) Info(locale ...string) *message.AssistantInfo { lc := "en" if len(locale) > 0 { lc = locale[0] } return &message.AssistantInfo{ ID: ast.ID, Type: ast.Type, Name: i18n.Tr(ast.ID, lc, ast.Name), Avatar: ast.Avatar, Description: i18n.Tr(ast.ID, lc, ast.Description), } } // getStreamHandler returns the stream handler from options or a default one func (ast *Assistant) getStreamHandler(ctx *context.Context, opts ...*context.Options) message.StreamFunc { // Check if handler is provided in options if len(opts) > 0 && opts[0] != nil && opts[0].Writer != nil { return handlers.DefaultStreamHandler(ctx) } return handlers.DefaultStreamHandler(ctx) } // sendAgentStreamStart sends ChunkStreamStart for root stack only (agent-level stream start) // This ensures only one stream_start per agent execution, even with multiple LLM calls func (ast *Assistant) sendAgentStreamStart(ctx *context.Context, handler message.StreamFunc, startTime time.Time) { if ctx.Stack == nil || !ctx.Stack.IsRoot() || handler == nil { return } // Build the start data startData := message.EventStreamStartData{ ContextID: ctx.ID, ChatID: ctx.ChatID, TraceID: ctx.TraceID(), RequestID: ctx.RequestID(), Timestamp: startTime.UnixMilli(), Assistant: ast.Info(ctx.Locale), Metadata: ctx.Metadata, } if startJSON, err := jsoniter.Marshal(startData); err == nil { handler(message.ChunkStreamStart, startJSON) } } // sendAgentStreamEnd sends ChunkStreamEnd for root stack only (agent-level stream completion) func (ast *Assistant) sendAgentStreamEnd(ctx *context.Context, handler message.StreamFunc, startTime time.Time, status string, err error, response *context.CompletionResponse) { if ctx.Stack == nil || !ctx.Stack.IsRoot() || handler == nil { return } // Check if context is cancelled - if so, skip handler call to avoid blocking if ctx.Context != nil && ctx.Context.Err() != nil { ctx.Logger.Debug("Context cancelled, skipping sendAgentStreamEnd handler call") return } endData := &message.EventStreamEndData{ RequestID: ctx.RequestID(), ContextID: ctx.ID, Timestamp: time.Now().UnixMilli(), DurationMs: time.Since(startTime).Milliseconds(), Status: status, TraceID: ctx.TraceID(), Metadata: ctx.Metadata, } if err != nil { endData.Error = err.Error() } if response != nil && response.Usage != nil { endData.Usage = response.Usage } if endJSON, marshalErr := jsoniter.Marshal(endData); marshalErr == nil { handler(message.ChunkStreamEnd, endJSON) } } // sendStreamEndOnError sends ChunkStreamEnd with error status for root stack only func (ast *Assistant) sendStreamEndOnError(ctx *context.Context, handler message.StreamFunc, startTime time.Time, err error) { ast.sendAgentStreamEnd(ctx, handler, startTime, "error", err, nil) } // handleInterrupt handles the interrupt signal // This is called by the interrupt listener when a signal is received func (ast *Assistant) handleInterrupt(ctx *context.Context, signal *context.InterruptSignal) error { // Handle based on interrupt type switch signal.Type { case context.InterruptForce: // Force interrupt: context is already cancelled in handleSignal // LLM streaming will detect ctx.Interrupt.Context().Done() and stop ctx.Logger.Debug("Force interrupt: stopping current operations immediately") case context.InterruptGraceful: ctx.Logger.Debug("Graceful interrupt: will process after current step completes") // Graceful interrupt: let current operation complete // The signal is stored in current/pending, can be checked at checkpoints } // TODO: Implement actual interrupt handling logic: // 1. For graceful: wait for current step, then merge messages and restart // 2. For force: immediately stop and restart with new messages // 3. Call Interrupted Hook if configured // 4. Decide whether to continue, restart, or abort based on Hook response return nil } // initializeCapabilities gets connector and capabilities, then sets them in context // This should be called early (before sending stream_start) so that output adapters // can use capabilities when converting stream_start event func (ast *Assistant) initializeCapabilities(ctx *context.Context, opts *context.Options) error { if ast.Prompts == nil && ast.MCP == nil { return nil } _, capabilities, err := ast.GetConnector(ctx, opts) if err != nil { return err } // Set capabilities in context for output adapters to use if capabilities != nil { ctx.Capabilities = capabilities } return nil } // buildToolRetryMessages builds messages for LLM retry with tool call results // Format follows OpenAI's tool call response pattern: // 1. Assistant message with tool calls // 2. Tool messages with results (one per tool call) // 3. System message explaining the retry func (ast *Assistant) buildToolRetryMessages( previousMessages []context.Message, completionResponse *context.CompletionResponse, toolResults []ToolCallResult, ) []context.Message { retryMessages := make([]context.Message, 0, len(previousMessages)+len(toolResults)+2) // Add all previous messages retryMessages = append(retryMessages, previousMessages...) // Add assistant message with tool calls assistantMsg := context.Message{ Role: context.RoleAssistant, Content: completionResponse.Content, ToolCalls: completionResponse.ToolCalls, } retryMessages = append(retryMessages, assistantMsg) // Add tool result messages (one per tool call) for _, result := range toolResults { toolMsg := context.Message{ Role: context.RoleTool, Content: result.Content, ToolCallID: &result.ToolCallID, } // Add tool name if available if result.Name != "" { name := result.Name toolMsg.Name = &name } retryMessages = append(retryMessages, toolMsg) } // Add system message explaining the retry (optional, helps LLM understand context) systemMsg := context.Message{ Role: context.RoleSystem, Content: i18n.Tr(ast.ID, "en", "assistant.agent.tool_retry_prompt"), } retryMessages = append(retryMessages, systemMsg) return retryMessages } ================================================ FILE: agent/assistant/agent_interrupt_test.go ================================================ package assistant_test import ( stdContext "context" "fmt" "testing" "time" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // newTestContextWithInterrupt creates a Context with interrupt controller for testing // Returns the context and a cancel function that should be called before Release() func newTestContextWithInterrupt(chatID, assistantID string) (*context.Context, stdContext.CancelFunc) { authorized := &types.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client-id", UserID: "test-user-123", TeamID: "test-team-456", TenantID: "test-tenant-789", SessionID: "test-session-id", } // Use cancellable context to properly stop goroutines on timeout parentCtx, cancel := stdContext.WithCancel(stdContext.Background()) ctx := context.New(parentCtx, authorized, chatID) ctx.ID = fmt.Sprintf("test_ctx_%d", time.Now().UnixNano()) ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "TestAgent/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.Route = "/test/route" ctx.IDGenerator = message.NewIDGenerator() // Initialize context-scoped ID generator ctx.Metadata = map[string]interface{}{ "test": "interrupt_test", } // Initialize interrupt controller ctx.Interrupt = context.NewInterruptController() // Register context globally if err := context.Register(ctx); err != nil { panic(fmt.Sprintf("Failed to register context: %v", err)) } // Start interrupt listener ctx.Interrupt.Start(ctx.ID) return ctx, cancel } // TestAgentInterruptGraceful tests graceful interrupt during agent stream func TestAgentInterruptGraceful(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.interrupt") if err != nil { t.Skipf("Skipping test: assistant 'tests.interrupt' not found: %v", err) return } t.Run("GracefulInterruptDuringStream", func(t *testing.T) { // Create context with interrupt support ctx, cancel := newTestContextWithInterrupt("chat-interrupt-graceful", "tests.interrupt") defer func() { cancel() // Cancel context first to stop goroutines time.Sleep(100 * time.Millisecond) // Wait for goroutines to exit ctx.Release() }() // Track handler invocations handlerInvoked := false var receivedSignal *context.InterruptSignal // Override the handler to track invocations originalHandler := ctx.Interrupt ctx.Interrupt.SetHandler(func(c *context.Context, signal *context.InterruptSignal) error { handlerInvoked = true receivedSignal = signal t.Logf("✓ Interrupt handler invoked: type=%s, messages=%d", signal.Type, len(signal.Messages)) return nil }) inputMessages := []context.Message{ {Role: context.RoleUser, Content: "Tell me a long story about artificial intelligence"}, } // Start streaming in a goroutine streamDone := make(chan error, 1) go func() { _, err := agent.Stream(ctx, inputMessages) streamDone <- err }() // Wait a bit to ensure stream has started time.Sleep(300 * time.Millisecond) // Send graceful interrupt signal signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{ {Role: context.RoleUser, Content: "Actually, can you make it shorter?"}, }, Timestamp: time.Now().UnixMilli(), } err = context.SendInterrupt(ctx.ID, signal) if err != nil { t.Logf("Warning: Failed to send interrupt (stream may have completed): %v", err) } else { t.Log("✓ Graceful interrupt signal sent") } // Wait for stream to complete (with timeout) select { case err := <-streamDone: if err != nil { t.Logf("Stream completed with error: %v", err) } else { t.Log("✓ Stream completed successfully") } case <-time.After(10 * time.Second): t.Log("Stream timeout (expected for real LLM calls)") cancel() // Cancel to stop the stream goroutine <-streamDone // Wait for goroutine to exit } // Verify handler was invoked if signal was sent if originalHandler != nil { time.Sleep(200 * time.Millisecond) // Wait for async handler if handlerInvoked { t.Log("✓ Interrupt handler was invoked") if receivedSignal != nil && receivedSignal.Type == context.InterruptGraceful { t.Log("✓ Received graceful interrupt signal") } } } }) } // TestAgentInterruptForce tests force interrupt during agent stream func TestAgentInterruptForce(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.interrupt") if err != nil { t.Skipf("Skipping test: assistant 'tests.interrupt' not found: %v", err) return } t.Run("ForceInterruptDuringStream", func(t *testing.T) { // Create context with interrupt support ctx, cancel := newTestContextWithInterrupt("chat-interrupt-force", "tests.interrupt") defer func() { cancel() // Cancel context first to stop goroutines time.Sleep(100 * time.Millisecond) // Wait for goroutines to exit ctx.Release() }() // Track handler invocations handlerInvoked := false streamInterrupted := false ctx.Interrupt.SetHandler(func(c *context.Context, signal *context.InterruptSignal) error { handlerInvoked = true t.Logf("✓ Interrupt handler invoked: type=%s", signal.Type) return nil }) inputMessages := []context.Message{ {Role: context.RoleUser, Content: "Write a very detailed essay about machine learning"}, } // Start streaming in a goroutine streamDone := make(chan error, 1) go func() { _, err := agent.Stream(ctx, inputMessages) streamDone <- err }() // Wait a bit to ensure stream has started time.Sleep(300 * time.Millisecond) // Send force interrupt signal signal := &context.InterruptSignal{ Type: context.InterruptForce, Messages: []context.Message{ {Role: context.RoleUser, Content: "Stop! I need something else now."}, }, Timestamp: time.Now().UnixMilli(), } err = context.SendInterrupt(ctx.ID, signal) if err != nil { t.Logf("Warning: Failed to send interrupt: %v", err) } else { t.Log("✓ Force interrupt signal sent") } // Wait for stream to complete or be interrupted select { case err := <-streamDone: if err != nil { // Check if error is due to interrupt if err.Error() == "force interrupted by user" || err.Error() == "interrupted by user" || err.Error() == "interrupted by user before stream start" { streamInterrupted = true t.Logf("✓ Stream was interrupted: %v", err) } else { t.Logf("Stream completed with error: %v", err) } } else { t.Log("Stream completed without error") } case <-time.After(10 * time.Second): t.Log("Stream timeout") cancel() // Cancel to stop the stream goroutine <-streamDone // Wait for goroutine to exit } // Verify interrupt behavior time.Sleep(200 * time.Millisecond) if handlerInvoked { t.Log("✓ Force interrupt handler was invoked") } if streamInterrupted { t.Log("✓ Stream was interrupted by force signal") } }) } // TestAgentMultipleInterrupts tests multiple interrupts during stream func TestAgentMultipleInterrupts(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.interrupt") if err != nil { t.Skipf("Skipping test: assistant 'tests.interrupt' not found: %v", err) return } t.Run("MultipleGracefulInterrupts", func(t *testing.T) { // Create context with interrupt support ctx, cancel := newTestContextWithInterrupt("chat-interrupt-multiple", "tests.interrupt") defer func() { cancel() // Cancel context first to stop goroutines time.Sleep(100 * time.Millisecond) // Wait for goroutines to exit ctx.Release() }() handlerCallCount := 0 ctx.Interrupt.SetHandler(func(c *context.Context, signal *context.InterruptSignal) error { handlerCallCount++ t.Logf("✓ Interrupt handler invoked (call %d): %d messages", handlerCallCount, len(signal.Messages)) return nil }) inputMessages := []context.Message{ {Role: context.RoleUser, Content: "Explain quantum computing in detail"}, } // Start streaming streamDone := make(chan error, 1) go func() { _, err := agent.Stream(ctx, inputMessages) streamDone <- err }() // Wait for stream to start time.Sleep(300 * time.Millisecond) // Send multiple graceful interrupts for i := 1; i <= 3; i++ { signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{ {Role: context.RoleUser, Content: fmt.Sprintf("Additional question %d", i)}, }, Timestamp: time.Now().UnixMilli(), } err = context.SendInterrupt(ctx.ID, signal) if err != nil { t.Logf("Warning: Failed to send interrupt %d: %v", i, err) } else { t.Logf("✓ Sent interrupt %d", i) } time.Sleep(100 * time.Millisecond) } // Wait for stream to complete select { case err := <-streamDone: if err != nil { t.Logf("Stream completed with error: %v", err) } case <-time.After(10 * time.Second): t.Log("Stream timeout") cancel() // Cancel to stop the stream goroutine <-streamDone // Wait for goroutine to exit } // Check if interrupts were received time.Sleep(300 * time.Millisecond) pendingCount := ctx.Interrupt.GetPendingCount() t.Logf("Handler was called %d times, pending count: %d", handlerCallCount, pendingCount) if handlerCallCount > 0 { t.Log("✓ Multiple interrupts were processed") } }) } // TestAgentInterruptWithoutStream tests interrupt behavior when no stream is active func TestAgentInterruptWithoutStream(t *testing.T) { t.Run("InterruptBeforeStream", func(t *testing.T) { // Create context with interrupt support ctx, cancel := newTestContextWithInterrupt("chat-interrupt-before", "test-assistant") defer func() { cancel() ctx.Release() }() // Send interrupt before starting stream signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{ {Role: context.RoleUser, Content: "Early interrupt"}, }, Timestamp: time.Now().UnixMilli(), } err := context.SendInterrupt(ctx.ID, signal) if err != nil { t.Fatalf("Failed to send interrupt: %v", err) } // Wait for signal to be processed time.Sleep(100 * time.Millisecond) // Check if signal is in queue receivedSignal := ctx.Interrupt.Peek() if receivedSignal == nil { t.Fatal("Expected interrupt signal to be queued") } if receivedSignal.Type != context.InterruptGraceful { t.Errorf("Expected graceful interrupt, got: %s", receivedSignal.Type) } t.Log("✓ Interrupt queued before stream starts") }) } // TestAgentInterruptContextCleanup tests cleanup after interrupt func TestAgentInterruptContextCleanup(t *testing.T) { t.Run("CleanupAfterInterrupt", func(t *testing.T) { ctx, cancel := newTestContextWithInterrupt("chat-interrupt-cleanup", "test-assistant") // Send interrupt signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{{Role: context.RoleUser, Content: "test"}}, Timestamp: time.Now().UnixMilli(), } context.SendInterrupt(ctx.ID, signal) time.Sleep(100 * time.Millisecond) // Cancel and release context cancel() ctx.Release() // Try to send interrupt to released context err := context.SendInterrupt(ctx.ID, signal) if err == nil { t.Error("Expected error when sending to released context") } else { t.Logf("✓ Correctly rejected interrupt to released context: %v", err) } }) } ================================================ FILE: agent/assistant/agent_next_test.go ================================================ package assistant_test import ( stdContext "context" "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // newAgentNextTestContext creates a test context func newAgentNextTestContext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.ID = chatID ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Client = context.Client{ Type: "web", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.IDGenerator = message.NewIDGenerator() // Initialize ID generator ctx.Metadata = make(map[string]interface{}) return ctx } // TestAgentNextStandard tests agent with Next Hook returning nil (standard response) func TestAgentNextStandard(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld-next") assert.NoError(t, err) ctx := newAgentNextTestContext("test-standard", "tests.realworld-next") messages := []context.Message{ {Role: context.RoleUser, Content: "scenario: standard - Hello"}, } response, err := agent.Stream(ctx, messages) assert.NoError(t, err) assert.NotNil(t, response) assert.NotNil(t, response.Completion) assert.Nil(t, response.Next) // Verify response structure assert.Equal(t, "tests.realworld-next", response.AssistantID) assert.NotEmpty(t, response.ContextID) assert.NotEmpty(t, response.RequestID) assert.NotEmpty(t, response.TraceID) assert.NotEmpty(t, response.ChatID) // Verify completion has content assert.NotNil(t, response.Completion.Content) t.Log("✓ Standard response test passed") } // TestAgentNextCustomData tests agent with Next Hook returning custom data func TestAgentNextCustomData(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld-next") assert.NoError(t, err) ctx := newAgentNextTestContext("test-custom", "tests.realworld-next") messages := []context.Message{ {Role: context.RoleUser, Content: "scenario: custom_data - Give me info"}, } response, err := agent.Stream(ctx, messages) assert.NoError(t, err) assert.NotNil(t, response) assert.NotNil(t, response.Completion) assert.NotNil(t, response.Next) // Verify response structure assert.Equal(t, "tests.realworld-next", response.AssistantID) assert.NotEmpty(t, response.ContextID) assert.NotEmpty(t, response.RequestID) assert.NotEmpty(t, response.TraceID) // Verify custom data structure (from scenarioCustomData) // response.Next contains the "data" field value from NextHookResponse nextData, ok := response.Next.(map[string]interface{}) assert.True(t, ok, "Next should be a map") assert.Equal(t, "custom_response", nextData["type"]) assert.Equal(t, "This is a custom response from Next Hook", nextData["message"]) assert.NotEmpty(t, nextData["timestamp"]) assert.NotNil(t, nextData["message_count"]) t.Log("✓ Custom data test passed") } // TestAgentNextDelegate tests agent with Next Hook delegating to another agent func TestAgentNextDelegate(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld-next") assert.NoError(t, err) ctx := newAgentNextTestContext("test-delegate", "tests.realworld-next") messages := []context.Message{ {Role: context.RoleUser, Content: "scenario: delegate - Forward this"}, } response, err := agent.Stream(ctx, messages) assert.NoError(t, err) assert.NotNil(t, response) // Verify response structure assert.NotEmpty(t, response.AssistantID) assert.NotEmpty(t, response.ContextID) assert.NotEmpty(t, response.RequestID) assert.NotEmpty(t, response.TraceID) // Verify completion (delegated agent should have returned completion) assert.NotNil(t, response.Completion) assert.NotNil(t, response.Completion.Content) // Next should be from the delegated agent // If delegated agent also has Next hook, it will be present t.Logf("✓ Delegation test passed (delegated to: %s)", response.AssistantID) } // TestAgentNextConditional tests agent with conditional logic in Next Hook func TestAgentNextConditional(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld-next") assert.NoError(t, err) ctx := newAgentNextTestContext("test-conditional", "tests.realworld-next") messages := []context.Message{ // Use conditional_success sub-scenario for deterministic behavior // This avoids test flakiness caused by LLM response unpredictability {Role: context.RoleUser, Content: "scenario: conditional_success - Task completed"}, } response, err := agent.Stream(ctx, messages) assert.NoError(t, err) assert.NotNil(t, response) assert.NotNil(t, response.Next) // Verify response structure assert.Equal(t, "tests.realworld-next", response.AssistantID) assert.NotEmpty(t, response.ContextID) assert.NotEmpty(t, response.RequestID) assert.NotEmpty(t, response.TraceID) // Verify conditional response structure (from scenarioConditional) // response.Next contains the "data" field value from NextHookResponse nextData, ok := response.Next.(map[string]interface{}) assert.True(t, ok, "Next should be a map") assert.Equal(t, "Conditional analysis complete", nextData["message"]) assert.Contains(t, nextData, "action") assert.Contains(t, nextData, "reason") assert.Contains(t, nextData, "conditions") // Verify action is one of the expected values action, ok := nextData["action"].(string) assert.True(t, ok) assert.Contains(t, []string{"continue", "flag_for_review", "confirm_success", "summarize", "delegate"}, action) t.Log("✓ Conditional logic test passed") } // TestAgentWithoutNextHook tests agent without Next Hook func TestAgentWithoutNextHook(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.create") assert.NoError(t, err) ctx := newAgentNextTestContext("test-no-next", "tests.create") messages := []context.Message{ {Role: context.RoleUser, Content: "Hello"}, } response, err := agent.Stream(ctx, messages) assert.NoError(t, err) assert.NotNil(t, response) assert.Nil(t, response.Next) // Verify response structure assert.Equal(t, "tests.create", response.AssistantID) assert.NotEmpty(t, response.ContextID) assert.NotEmpty(t, response.RequestID) assert.NotEmpty(t, response.TraceID) assert.NotEmpty(t, response.ChatID) // Verify completion assert.NotNil(t, response.Completion) assert.NotNil(t, response.Completion.Content) t.Log("✓ No Next Hook test passed") } ================================================ FILE: agent/assistant/assistant.go ================================================ package assistant import ( "fmt" "path" "github.com/yaoapp/gou/fs" "github.com/yaoapp/yao/agent/caller" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/agent/llm" "github.com/yaoapp/yao/agent/search" searchTypes "github.com/yaoapp/yao/agent/search/types" store "github.com/yaoapp/yao/agent/store/types" sui "github.com/yaoapp/yao/sui/core" ) func init() { // Initialize AgentGetterFunc to allow content and search packages to call agents caller.AgentGetterFunc = func(agentID string) (caller.AgentCaller, error) { ast, err := Get(agentID) if err != nil { return nil, err } // Return a wrapper that implements AgentCaller interface return &agentCallerWrapper{ast: ast}, nil } // Initialize Agent JSAPI factory for ctx.agent.* methods caller.SetJSAPIFactory() // Initialize LLM JSAPI factory for ctx.llm.* methods llm.SetJSAPIFactory() // Initialize Search JSAPI factory with config getter search.SetJSAPIFactory(func(assistantID string) (*searchTypes.Config, *search.Uses) { ast, err := Get(assistantID) if err != nil || ast == nil { return nil, nil } // Convert assistant.Uses to search.Uses var uses *search.Uses if ast.Uses != nil { uses = &search.Uses{ Search: ast.Uses.Search, Web: ast.Uses.Web, Keyword: ast.Uses.Keyword, QueryDSL: ast.Uses.QueryDSL, Rerank: ast.Uses.Rerank, } } return ast.Search, uses }) } // agentCallerWrapper wraps Assistant to implement AgentCaller interface type agentCallerWrapper struct { ast *Assistant } func (w *agentCallerWrapper) Stream(ctx *agentContext.Context, messages []agentContext.Message, options ...*agentContext.Options) (*agentContext.Response, error) { return w.ast.Stream(ctx, messages, options...) } // Get get the assistant by id func Get(id string) (*Assistant, error) { return LoadStore(id) } // GetPlaceholder returns the placeholder of the assistant func (ast *Assistant) GetPlaceholder(locale string) *store.Placeholder { prompts := []string{} if ast.Placeholder.Prompts != nil { prompts = i18n.Translate(ast.ID, locale, ast.Placeholder.Prompts).([]string) } title := i18n.Translate(ast.ID, locale, ast.Placeholder.Title).(string) description := i18n.Translate(ast.ID, locale, ast.Placeholder.Description).(string) return &store.Placeholder{ Title: title, Description: description, Prompts: prompts, } } // GetName returns the name of the assistant func (ast *Assistant) GetName(locale string) string { return i18n.Translate(ast.ID, locale, ast.Name).(string) } // GetDescription returns the description of the assistant func (ast *Assistant) GetDescription(locale string) string { return i18n.Translate(ast.ID, locale, ast.Description).(string) } // Save save the assistant func (ast *Assistant) Save() error { if storage == nil { return fmt.Errorf("storage is not set") } _, err := storage.SaveAssistant(&ast.AssistantModel) if err != nil { return err } return nil } // Map convert the assistant to a map func (ast *Assistant) Map() map[string]interface{} { if ast == nil { return nil } return map[string]interface{}{ "assistant_id": ast.ID, "type": ast.Type, "name": ast.Name, "readonly": ast.Readonly, "public": ast.Public, "share": ast.Share, "avatar": ast.Avatar, "connector": ast.Connector, "connector_options": ast.ConnectorOptions, "path": ast.Path, "built_in": ast.BuiltIn, "sort": ast.Sort, "description": ast.Description, "options": ast.Options, "prompts": ast.Prompts, "prompt_presets": ast.PromptPresets, "disable_global_prompts": ast.DisableGlobalPrompts, "source": ast.Source, "kb": ast.KB, "db": ast.DB, "mcp": ast.MCP, "workflow": ast.Workflow, "tags": ast.Tags, "modes": ast.Modes, "default_mode": ast.DefaultMode, "mentionable": ast.Mentionable, "automated": ast.Automated, "placeholder": ast.Placeholder, "locales": ast.Locales, "uses": ast.Uses, "search": ast.Search, "dependencies": ast.Dependencies, "created_at": store.ToMySQLTime(ast.CreatedAt), "updated_at": store.ToMySQLTime(ast.UpdatedAt), } } // Validate validates the assistant configuration func (ast *Assistant) Validate() error { if ast.ID == "" { return fmt.Errorf("assistant_id is required") } if ast.Name == "" { return fmt.Errorf("name is required") } if ast.Connector == "" { return fmt.Errorf("connector is required") } return nil } // Assets get the assets content func (ast *Assistant) Assets(name string, data sui.Data) (string, error) { app, err := fs.Get("app") if err != nil { return "", err } root := path.Join(ast.Path, "assets", name) raw, err := app.ReadFile(root) if err != nil { return "", err } if data != nil { content, _ := data.Replace(string(raw)) return content, nil } return string(raw), nil } // Clone creates a deep copy of the assistant func (ast *Assistant) Clone() *Assistant { if ast == nil { return nil } clone := &Assistant{ AssistantModel: store.AssistantModel{ ID: ast.ID, Type: ast.Type, Name: ast.Name, Avatar: ast.Avatar, Connector: ast.Connector, Path: ast.Path, BuiltIn: ast.BuiltIn, Sort: ast.Sort, Description: ast.Description, Readonly: ast.Readonly, Public: ast.Public, Share: ast.Share, Mentionable: ast.Mentionable, Automated: ast.Automated, DisableGlobalPrompts: ast.DisableGlobalPrompts, Source: ast.Source, CreatedAt: ast.CreatedAt, UpdatedAt: ast.UpdatedAt, }, HookScript: ast.HookScript, } // Deep copy tags if ast.Tags != nil { clone.Tags = make([]string, len(ast.Tags)) copy(clone.Tags, ast.Tags) } // Deep copy modes if ast.Modes != nil { clone.Modes = make([]string, len(ast.Modes)) copy(clone.Modes, ast.Modes) } // Copy default_mode (simple string) clone.DefaultMode = ast.DefaultMode // Deep copy KB if ast.KB != nil { clone.KB = &store.KnowledgeBase{} if ast.KB.Collections != nil { clone.KB.Collections = make([]string, len(ast.KB.Collections)) copy(clone.KB.Collections, ast.KB.Collections) } if ast.KB.Options != nil { clone.KB.Options = make(map[string]interface{}) for k, v := range ast.KB.Options { clone.KB.Options[k] = v } } } // Deep copy DB if ast.DB != nil { clone.DB = &store.Database{} if ast.DB.Models != nil { clone.DB.Models = make([]string, len(ast.DB.Models)) copy(clone.DB.Models, ast.DB.Models) } if ast.DB.Options != nil { clone.DB.Options = make(map[string]interface{}) for k, v := range ast.DB.Options { clone.DB.Options[k] = v } } } // Deep copy MCP if ast.MCP != nil { clone.MCP = &store.MCPServers{} if ast.MCP.Servers != nil { clone.MCP.Servers = make([]store.MCPServerConfig, len(ast.MCP.Servers)) for i, server := range ast.MCP.Servers { clone.MCP.Servers[i] = store.MCPServerConfig{ ServerID: server.ServerID, } // Deep copy Resources slice if server.Resources != nil { clone.MCP.Servers[i].Resources = make([]string, len(server.Resources)) copy(clone.MCP.Servers[i].Resources, server.Resources) } // Deep copy Tools slice if server.Tools != nil { clone.MCP.Servers[i].Tools = make([]string, len(server.Tools)) copy(clone.MCP.Servers[i].Tools, server.Tools) } } } if ast.MCP.Options != nil { clone.MCP.Options = make(map[string]interface{}) for k, v := range ast.MCP.Options { clone.MCP.Options[k] = v } } } // Deep copy options if ast.Options != nil { clone.Options = make(map[string]interface{}) for k, v := range ast.Options { clone.Options[k] = v } } // Deep copy prompts if ast.Prompts != nil { clone.Prompts = make([]store.Prompt, len(ast.Prompts)) copy(clone.Prompts, ast.Prompts) } // Deep copy prompt presets if ast.PromptPresets != nil { clone.PromptPresets = make(map[string][]store.Prompt) for k, v := range ast.PromptPresets { prompts := make([]store.Prompt, len(v)) copy(prompts, v) clone.PromptPresets[k] = prompts } } // Deep copy connector options if ast.ConnectorOptions != nil { clone.ConnectorOptions = &store.ConnectorOptions{ Optional: ast.ConnectorOptions.Optional, } if ast.ConnectorOptions.Connectors != nil { clone.ConnectorOptions.Connectors = make([]string, len(ast.ConnectorOptions.Connectors)) copy(clone.ConnectorOptions.Connectors, ast.ConnectorOptions.Connectors) } if ast.ConnectorOptions.Filters != nil { clone.ConnectorOptions.Filters = make([]store.ModelCapability, len(ast.ConnectorOptions.Filters)) copy(clone.ConnectorOptions.Filters, ast.ConnectorOptions.Filters) } } // Deep copy workflow if ast.Workflow != nil { clone.Workflow = &store.Workflow{} if ast.Workflow.Workflows != nil { clone.Workflow.Workflows = make([]string, len(ast.Workflow.Workflows)) copy(clone.Workflow.Workflows, ast.Workflow.Workflows) } if ast.Workflow.Options != nil { clone.Workflow.Options = make(map[string]interface{}) for k, v := range ast.Workflow.Options { clone.Workflow.Options[k] = v } } } // Deep copy placeholder if ast.Placeholder != nil { clone.Placeholder = &store.Placeholder{ Title: ast.Placeholder.Title, Description: ast.Placeholder.Description, } if ast.Placeholder.Prompts != nil { clone.Placeholder.Prompts = make([]string, len(ast.Placeholder.Prompts)) copy(clone.Placeholder.Prompts, ast.Placeholder.Prompts) } } // Deep copy locales if ast.Locales != nil { clone.Locales = make(i18n.Map) for k, v := range ast.Locales { // Deep copy messages messages := make(map[string]any) if v.Messages != nil { for mk, mv := range v.Messages { messages[mk] = mv } } clone.Locales[k] = i18n.I18n{ Locale: v.Locale, Messages: messages, } } } // Deep copy uses if ast.Uses != nil { clone.Uses = &agentContext.Uses{ Vision: ast.Uses.Vision, Audio: ast.Uses.Audio, Search: ast.Uses.Search, Fetch: ast.Uses.Fetch, Web: ast.Uses.Web, Keyword: ast.Uses.Keyword, QueryDSL: ast.Uses.QueryDSL, Rerank: ast.Uses.Rerank, } } // Deep copy search config if ast.Search != nil { clone.Search = &searchTypes.Config{} if ast.Search.Web != nil { clone.Search.Web = &searchTypes.WebConfig{ Provider: ast.Search.Web.Provider, APIKeyEnv: ast.Search.Web.APIKeyEnv, MaxResults: ast.Search.Web.MaxResults, } } if ast.Search.KB != nil { clone.Search.KB = &searchTypes.KBConfig{ Threshold: ast.Search.KB.Threshold, Graph: ast.Search.KB.Graph, } if ast.Search.KB.Collections != nil { clone.Search.KB.Collections = make([]string, len(ast.Search.KB.Collections)) copy(clone.Search.KB.Collections, ast.Search.KB.Collections) } } if ast.Search.DB != nil { clone.Search.DB = &searchTypes.DBConfig{ MaxResults: ast.Search.DB.MaxResults, } if ast.Search.DB.Models != nil { clone.Search.DB.Models = make([]string, len(ast.Search.DB.Models)) copy(clone.Search.DB.Models, ast.Search.DB.Models) } } if ast.Search.Keyword != nil { clone.Search.Keyword = &searchTypes.KeywordConfig{ MaxKeywords: ast.Search.Keyword.MaxKeywords, Language: ast.Search.Keyword.Language, } } if ast.Search.QueryDSL != nil { clone.Search.QueryDSL = &searchTypes.QueryDSLConfig{ Strict: ast.Search.QueryDSL.Strict, } } if ast.Search.Rerank != nil { clone.Search.Rerank = &searchTypes.RerankConfig{ TopN: ast.Search.Rerank.TopN, } } if ast.Search.Citation != nil { clone.Search.Citation = &searchTypes.CitationConfig{ Format: ast.Search.Citation.Format, AutoInjectPrompt: ast.Search.Citation.AutoInjectPrompt, CustomPrompt: ast.Search.Citation.CustomPrompt, } } if ast.Search.Weights != nil { clone.Search.Weights = &searchTypes.WeightsConfig{ User: ast.Search.Weights.User, Hook: ast.Search.Weights.Hook, Auto: ast.Search.Weights.Auto, } } if ast.Search.Options != nil { clone.Search.Options = &searchTypes.OptionsConfig{ SkipThreshold: ast.Search.Options.SkipThreshold, } } } // Deep copy dependencies if ast.Dependencies != nil { clone.Dependencies = make(map[string]string, len(ast.Dependencies)) for k, v := range ast.Dependencies { clone.Dependencies[k] = v } } return clone } // GetInfo returns the basic info of the assistant with optional locale func (ast *Assistant) GetInfo(locale ...string) *store.AssistantInfo { if ast == nil { return nil } loc := "" if len(locale) > 0 { loc = locale[0] } info := &store.AssistantInfo{ AssistantID: ast.ID, Avatar: ast.Avatar, Connector: ast.Connector, ConnectorOptions: ast.ConnectorOptions, Modes: ast.Modes, DefaultMode: ast.DefaultMode, Sandbox: ast.IsSandbox, ComputerFilter: ast.ComputerFilter, } if loc != "" { info.Name = ast.GetName(loc) info.Description = ast.GetDescription(loc) } else { info.Name = ast.Name info.Description = ast.Description } return info } // GetInfoByIDs retrieves basic info for multiple assistants by their IDs // Returns a map of assistant_id -> AssistantInfo func GetInfoByIDs(ids []string, locale ...string) map[string]*store.AssistantInfo { result := make(map[string]*store.AssistantInfo) if len(ids) == 0 { return result } for _, id := range ids { ast, err := Get(id) if err != nil || ast == nil { continue } result[id] = ast.GetInfo(locale...) } return result } // Update updates the assistant properties func (ast *Assistant) Update(data map[string]interface{}) error { if ast == nil { return fmt.Errorf("assistant is nil") } if v, ok := data["name"].(string); ok { ast.Name = v } if v, ok := data["avatar"].(string); ok { ast.Avatar = v } if v, ok := data["description"].(string); ok { ast.Description = v } if v, ok := data["connector"].(string); ok { ast.Connector = v } // Note: tools field is deprecated, now handled by MCP if v, ok := data["type"].(string); ok { ast.Type = v } if v, ok := data["sort"].(int); ok { ast.Sort = v } if v, ok := data["mentionable"].(bool); ok { ast.Mentionable = v } if v, ok := data["automated"].(bool); ok { ast.Automated = v } if v, ok := data["disable_global_prompts"].(bool); ok { ast.DisableGlobalPrompts = v } if v, ok := data["readonly"].(bool); ok { ast.Readonly = v } if v, ok := data["public"].(bool); ok { ast.Public = v } if v, ok := data["share"].(string); ok { ast.Share = v } if v, ok := data["tags"].([]string); ok { ast.Tags = v } if v, ok := data["modes"].([]string); ok { ast.Modes = v } if v, ok := data["default_mode"].(string); ok { ast.DefaultMode = v } if v, ok := data["options"].(map[string]interface{}); ok { ast.Options = v } if v, ok := data["source"].(string); ok { ast.Source = v } // ConnectorOptions if v, has := data["connector_options"]; has { connOpts, err := store.ToConnectorOptions(v) if err != nil { return err } ast.ConnectorOptions = connOpts } // PromptPresets if v, has := data["prompt_presets"]; has { presets, err := store.ToPromptPresets(v) if err != nil { return err } ast.PromptPresets = presets } // KB if v, has := data["kb"]; has { kb, err := store.ToKnowledgeBase(v) if err != nil { return err } ast.KB = kb } // DB if v, has := data["db"]; has { db, err := store.ToDatabase(v) if err != nil { return err } ast.DB = db } // MCP if v, has := data["mcp"]; has { mcp, err := store.ToMCPServers(v) if err != nil { return err } ast.MCP = mcp } // Workflow if v, has := data["workflow"]; has { workflow, err := store.ToWorkflow(v) if err != nil { return err } ast.Workflow = workflow } // Uses if v, has := data["uses"]; has { uses, err := store.ToUses(v) if err != nil { return err } ast.Uses = uses } // Search if v, has := data["search"]; has { search, err := store.ToSearchConfig(v) if err != nil { return err } ast.Search = search } // Dependencies if v, has := data["dependencies"]; has { if v == nil { ast.Dependencies = nil } else { switch d := v.(type) { case map[string]string: ast.Dependencies = d case map[string]interface{}: deps := make(map[string]string, len(d)) for k, val := range d { if s, ok := val.(string); ok { deps[k] = s } } ast.Dependencies = deps } } } return ast.Validate() } // GetMergedSearchConfig returns the search config for this assistant // Note: The config is already merged with global config during loading (loadMap) func (ast *Assistant) GetMergedSearchConfig() *searchTypes.Config { return ast.Search } ================================================ FILE: agent/assistant/build.go ================================================ package assistant import ( "fmt" "github.com/spf13/cast" "github.com/yaoapp/gou/json" "github.com/yaoapp/yao/agent/context" store "github.com/yaoapp/yao/agent/store/types" ) // BuildRequest build the LLM request func (ast *Assistant) BuildRequest(ctx *context.Context, messages []context.Message, createResponse *context.HookCreateResponse) ([]context.Message, *context.CompletionOptions, error) { // Build completion options from createResponse and ctx (includes MCP tools) options, mcpSamplesPrompt, err := ast.buildCompletionOptions(ctx, createResponse) if err != nil { return nil, nil, err } // Build final messages with proper priority (includes MCP samples if available) finalMessages, err := ast.buildMessages(ctx, messages, createResponse, mcpSamplesPrompt) if err != nil { return nil, nil, err } return finalMessages, options, nil } // buildMessages builds the final message list with proper priority // Priority: Prompts > MCP Samples > createResponse.Messages > input messages // If createResponse is nil or has no messages, use input messages func (ast *Assistant) buildMessages(ctx *context.Context, messages []context.Message, createResponse *context.HookCreateResponse, mcpSamplesPrompt string) ([]context.Message, error) { var finalMessages []context.Message // If createResponse is nil or has no messages, use input messages if createResponse == nil || len(createResponse.Messages) == 0 { finalMessages = messages } else { // createResponse.Messages takes priority over input messages finalMessages = createResponse.Messages } // Add MCP samples prompt as a system message (if available) if mcpSamplesPrompt != "" { mcpSamplesMsg := context.Message{ Role: context.RoleSystem, Content: mcpSamplesPrompt, } // Prepend MCP samples before other messages finalMessages = append([]context.Message{mcpSamplesMsg}, finalMessages...) } // Build and prepend system prompts (global + assistant prompts) promptMessages := ast.buildSystemPrompts(ctx, createResponse) if len(promptMessages) > 0 { finalMessages = append(promptMessages, finalMessages...) } return finalMessages, nil } // buildSystemPrompts builds system prompt messages from global prompts and assistant prompts // Order: Global prompts (if not disabled) -> Assistant prompts (or preset) // Variables are parsed with context information // // Priority for prompt preset selection: // 1. createResponse.PromptPreset (highest) // 2. ctx.Metadata["__prompt_preset"] // 3. ast.Prompts (default) // // Priority for disable global prompts: // 1. createResponse.DisableGlobalPrompts (highest) // 2. ctx.Metadata["__disable_global_prompts"] // 3. ast.DisableGlobalPrompts (default) func (ast *Assistant) buildSystemPrompts(ctx *context.Context, createResponse *context.HookCreateResponse) []context.Message { // Build context variables from ctx and ast ctxVars := ast.buildContextVariables(ctx) // Determine if global prompts should be disabled disableGlobal := ast.shouldDisableGlobalPrompts(ctx, createResponse) // Get assistant prompts (default or preset) assistantPrompts := ast.getAssistantPrompts(ctx, createResponse) var allPrompts []store.Prompt // 1. Add global prompts (if not disabled) if !disableGlobal && len(globalPrompts) > 0 { // Parse global prompts with context variables parsedGlobal := store.Prompts(globalPrompts).Parse(ctxVars) allPrompts = append(allPrompts, parsedGlobal...) } // 2. Add assistant prompts (default or preset) if len(assistantPrompts) > 0 { // Parse assistant prompts with context variables parsedAssistant := store.Prompts(assistantPrompts).Parse(ctxVars) allPrompts = append(allPrompts, parsedAssistant...) } // Convert to context.Message slice if len(allPrompts) == 0 { return nil } messages := make([]context.Message, 0, len(allPrompts)) for _, prompt := range allPrompts { msg := context.Message{ Role: context.MessageRole(prompt.Role), Content: prompt.Content, } if prompt.Name != "" { name := prompt.Name msg.Name = &name } messages = append(messages, msg) } return messages } // shouldDisableGlobalPrompts determines if global prompts should be disabled // Priority: createResponse > ctx.Metadata > ast.DisableGlobalPrompts func (ast *Assistant) shouldDisableGlobalPrompts(ctx *context.Context, createResponse *context.HookCreateResponse) bool { // Priority 1: Hook response (highest) if createResponse != nil && createResponse.DisableGlobalPrompts != nil { return *createResponse.DisableGlobalPrompts } // Priority 2: ctx.Metadata["__disable_global_prompts"] if ctx != nil && ctx.Metadata != nil { if disable, ok := ctx.Metadata["__disable_global_prompts"].(bool); ok { return disable } } // Priority 3: Assistant configuration (default) return ast.DisableGlobalPrompts } // getAssistantPrompts returns the assistant prompts based on preset selection // Priority: createResponse.PromptPreset > ctx.Metadata["__prompt_preset"] > ast.Prompts func (ast *Assistant) getAssistantPrompts(ctx *context.Context, createResponse *context.HookCreateResponse) []store.Prompt { // Get preset key presetKey := ast.getPromptPresetKey(ctx, createResponse) // If preset key is specified and exists, use it if presetKey != "" && ast.PromptPresets != nil { if presets, ok := ast.PromptPresets[presetKey]; ok && len(presets) > 0 { return presets } } // Fallback to default prompts return ast.Prompts } // getPromptPresetKey returns the prompt preset key // Priority: createResponse.PromptPreset > ctx.Metadata["__prompt_preset"] func (ast *Assistant) getPromptPresetKey(ctx *context.Context, createResponse *context.HookCreateResponse) string { // Priority 1: Hook response (highest) if createResponse != nil && createResponse.PromptPreset != "" { return createResponse.PromptPreset } // Priority 2: ctx.Metadata["__prompt_preset"] if ctx != nil && ctx.Metadata != nil { if preset, ok := ctx.Metadata["__prompt_preset"].(string); ok && preset != "" { return preset } } // No preset specified return "" } // buildContextVariables extracts context variables from Context and Assistant for prompt parsing func (ast *Assistant) buildContextVariables(ctx *context.Context) map[string]string { vars := make(map[string]string) // Get locale from ctx (default to empty) locale := "" if ctx != nil && ctx.Locale != "" { locale = ctx.Locale } // Assistant info (with locale support) if ast != nil { if ast.ID != "" { vars["ASSISTANT_ID"] = ast.ID } // Use localized name and description name := ast.GetName(locale) if name != "" { vars["ASSISTANT_NAME"] = name } description := ast.GetDescription(locale) if description != "" { vars["ASSISTANT_DESCRIPTION"] = description } if ast.Type != "" { vars["ASSISTANT_TYPE"] = ast.Type } } if ctx == nil { return vars } // Basic context info if ctx.ChatID != "" { vars["CHAT_ID"] = ctx.ChatID } if ctx.Locale != "" { vars["LOCALE"] = ctx.Locale } if ctx.Theme != "" { vars["THEME"] = ctx.Theme } if ctx.Route != "" { vars["ROUTE"] = ctx.Route } if ctx.Referer != "" { vars["REFERER"] = ctx.Referer } // Client info (only non-sensitive fields) if ctx.Client.Type != "" { vars["CLIENT_TYPE"] = ctx.Client.Type } // Authorized info (only internal IDs, no PII) // Note: USER_SUBJECT and CLIENT_IP are excluded for privacy/GDPR compliance if ctx.Authorized != nil { if ctx.Authorized.UserID != "" { vars["USER_ID"] = ctx.Authorized.UserID } if ctx.Authorized.TeamID != "" { vars["TEAM_ID"] = ctx.Authorized.TeamID } if ctx.Authorized.TenantID != "" { vars["TENANT_ID"] = ctx.Authorized.TenantID } } // Metadata - custom variables from ctx.Metadata // All metadata keys are exposed as $CTX.{KEY} // Supports string, int, uint, float, bool types if ctx.Metadata != nil { for key, value := range ctx.Metadata { if value == nil { continue } strVal := cast.ToString(value) if strVal != "" { vars[key] = strVal } } } return vars } // buildCompletionOptions builds completion options from multiple sources // Priority (lowest to highest, later overrides earlier): ast > ctx > createResponse // The priority means: if createResponse has a value, use it; else use ctx; else use ast // Returns (options, mcpSamplesPrompt, error) func (ast *Assistant) buildCompletionOptions(ctx *context.Context, createResponse *context.HookCreateResponse) (*context.CompletionOptions, string, error) { options := &context.CompletionOptions{} // Layer 1 (base): Apply ast - Assistant configuration if err := ast.applyAssistantOptions(options); err != nil { return nil, "", err } // Layer 2 (middle): Apply ctx - Context configuration (overrides ast) ast.applyContextOptions(options, ctx) // Layer 3 (highest): Apply createResponse - Hook configuration (overrides all) if createResponse != nil { ast.applyCreateResponseOptions(options, createResponse) } // Add MCP tools if configured and get samples prompt mcpSamplesPrompt, err := ast.applyMCPTools(ctx, options, createResponse) if err != nil { return nil, "", fmt.Errorf("failed to apply MCP tools: %w", err) } return options, mcpSamplesPrompt, nil } // applyAssistantOptions applies options from ast.Options to CompletionOptions // ast.Options can contain any OpenAI API parameters (temperature, top_p, stop, etc.) // Returns error if any option validation fails (e.g., invalid JSON Schema) func (ast *Assistant) applyAssistantOptions(options *context.CompletionOptions) error { if ast.Options == nil { return nil } // Temperature if v, ok := ast.Options["temperature"].(float64); ok { options.Temperature = &v } // MaxTokens if v, ok := ast.Options["max_tokens"].(float64); ok { intVal := int(v) options.MaxTokens = &intVal } else if v, ok := ast.Options["max_tokens"].(int); ok { options.MaxTokens = &v } // MaxCompletionTokens if v, ok := ast.Options["max_completion_tokens"].(float64); ok { intVal := int(v) options.MaxCompletionTokens = &intVal } else if v, ok := ast.Options["max_completion_tokens"].(int); ok { options.MaxCompletionTokens = &v } // TopP if v, ok := ast.Options["top_p"].(float64); ok { options.TopP = &v } // N (number of choices) if v, ok := ast.Options["n"].(float64); ok { intVal := int(v) options.N = &intVal } else if v, ok := ast.Options["n"].(int); ok { options.N = &v } // Stop sequences (can be string or []string) if v, ok := ast.Options["stop"]; ok { options.Stop = v } // PresencePenalty if v, ok := ast.Options["presence_penalty"].(float64); ok { options.PresencePenalty = &v } // FrequencyPenalty if v, ok := ast.Options["frequency_penalty"].(float64); ok { options.FrequencyPenalty = &v } // LogitBias if v, ok := ast.Options["logit_bias"].(map[string]interface{}); ok { logitBias := make(map[string]float64) for key, val := range v { if fval, ok := val.(float64); ok { logitBias[key] = fval } } if len(logitBias) > 0 { options.LogitBias = logitBias } } // User if v, ok := ast.Options["user"].(string); ok { options.User = v } // ResponseFormat // @todo: Assistant should have a default response format if v, ok := ast.Options["response_format"]; ok { // Try to convert to *context.ResponseFormat if rf, ok := v.(*context.ResponseFormat); ok { // Validate JSONSchema if present - reject if invalid if rf.JSONSchema != nil && rf.JSONSchema.Schema != nil { if err := json.ValidateSchema(rf.JSONSchema.Schema); err != nil { return fmt.Errorf("invalid JSON Schema in response_format: %w", err) } } options.ResponseFormat = rf } else if rfMap, ok := v.(map[string]interface{}); ok { // Handle legacy map[string]interface{} format // Try to parse into ResponseFormat struct rf := &context.ResponseFormat{} // Parse type if typeStr, ok := rfMap["type"].(string); ok { rf.Type = context.ResponseFormatType(typeStr) } // Parse json_schema if present if jsonSchemaMap, ok := rfMap["json_schema"].(map[string]interface{}); ok { jsonSchema := &context.JSONSchema{} if name, ok := jsonSchemaMap["name"].(string); ok { jsonSchema.Name = name } if desc, ok := jsonSchemaMap["description"].(string); ok { jsonSchema.Description = desc } if schema, ok := jsonSchemaMap["schema"]; ok { // Validate schema format - reject if invalid if err := json.ValidateSchema(schema); err != nil { return fmt.Errorf("invalid JSON Schema in response_format: %w", err) } jsonSchema.Schema = schema } if strict, ok := jsonSchemaMap["strict"].(bool); ok { jsonSchema.Strict = &strict } rf.JSONSchema = jsonSchema } options.ResponseFormat = rf } } // Seed if v, ok := ast.Options["seed"].(float64); ok { intVal := int(v) options.Seed = &intVal } else if v, ok := ast.Options["seed"].(int); ok { options.Seed = &v } // Tools if v, ok := ast.Options["tools"].([]interface{}); ok { tools := make([]map[string]interface{}, 0, len(v)) for _, tool := range v { if toolMap, ok := tool.(map[string]interface{}); ok { tools = append(tools, toolMap) } } if len(tools) > 0 { options.Tools = tools } } // ToolChoice if v, ok := ast.Options["tool_choice"]; ok { options.ToolChoice = v } // Stream if v, ok := ast.Options["stream"].(bool); ok { options.Stream = &v } return nil } // applyContextOptions applies options from ctx to CompletionOptions // ctx provides Route and Metadata for CUI context func (ast *Assistant) applyContextOptions(options *context.CompletionOptions, ctx *context.Context) { // Set Route and Metadata from ctx options.Route = ctx.Route options.Metadata = ctx.Metadata // Set Uses configurations (assistant.Uses has priority over global settings) // These can be overridden by createResponse options.Uses = ast.getUses() } // applyCreateResponseOptions applies options from createResponse to CompletionOptions // createResponse takes highest priority and overrides any previous settings func (ast *Assistant) applyCreateResponseOptions(options *context.CompletionOptions, createResponse *context.HookCreateResponse) { // Audio configuration if createResponse.Audio != nil { options.Audio = createResponse.Audio } // Temperature if createResponse.Temperature != nil { options.Temperature = createResponse.Temperature } // MaxTokens if createResponse.MaxTokens != nil { options.MaxTokens = createResponse.MaxTokens } // MaxCompletionTokens if createResponse.MaxCompletionTokens != nil { options.MaxCompletionTokens = createResponse.MaxCompletionTokens } // Route if createResponse.Route != "" { options.Route = createResponse.Route } // Metadata (merge with existing) if createResponse.Metadata != nil { if options.Metadata == nil { options.Metadata = createResponse.Metadata } else { // Merge: createResponse.Metadata overrides existing for key, value := range createResponse.Metadata { options.Metadata[key] = value } } } // Uses configuration (merge with existing) // createResponse.Uses has highest priority and overrides existing Uses if createResponse.Uses != nil { if options.Uses == nil { options.Uses = createResponse.Uses } else { // Merge: createResponse.Uses overrides existing (only non-empty fields) if createResponse.Uses.Vision != "" { options.Uses.Vision = createResponse.Uses.Vision } if createResponse.Uses.Audio != "" { options.Uses.Audio = createResponse.Uses.Audio } if createResponse.Uses.Search != "" { options.Uses.Search = createResponse.Uses.Search } if createResponse.Uses.Fetch != "" { options.Uses.Fetch = createResponse.Uses.Fetch } } } // ForceUses configuration // If hook specifies ForceUses, it takes priority if createResponse.ForceUses != nil { options.ForceUses = *createResponse.ForceUses } } // getUses get the Uses configuration with priority: assistant.Uses > global settings // Note: createResponse.Uses (applied in applyCreateResponseOptions) has even higher priority // getUses returns the Uses config for this assistant // Note: The config is already merged with global config during loading (loadMap) func (ast *Assistant) getUses() *context.Uses { return ast.Uses } // applyMCPTools adds MCP tools to completion options and returns samples prompt // Returns (samplesPrompt, error) func (ast *Assistant) applyMCPTools(ctx *context.Context, options *context.CompletionOptions, createResponse *context.HookCreateResponse) (string, error) { // Priority 1: Check if hook provides MCP servers if createResponse != nil && len(createResponse.MCPServers) > 0 { return ast.buildAndApplyMCPTools(ctx, options, createResponse) } // Priority 2: Check if assistant has MCP config if ast.MCP != nil && len(ast.MCP.Servers) > 0 { return ast.buildAndApplyMCPTools(ctx, options, nil) } // No MCP config return "", nil } // buildAndApplyMCPTools builds MCP tools and applies them to options func (ast *Assistant) buildAndApplyMCPTools(ctx *context.Context, options *context.CompletionOptions, createResponse *context.HookCreateResponse) (string, error) { // Build MCP tools and get samples prompt mcpTools, samplesPrompt, err := ast.buildMCPTools(ctx, createResponse) if err != nil { return "", fmt.Errorf("failed to build MCP tools: %w", err) } // Convert mcpTools to map format for CompletionOptions.Tools if len(mcpTools) > 0 { toolMaps := make([]map[string]interface{}, len(mcpTools)) for i, tool := range mcpTools { toolMaps[i] = map[string]interface{}{ "type": "function", "function": map[string]interface{}{ "name": tool.Name, "description": tool.Description, "parameters": tool.Parameters, }, } } // Add MCP tools to existing tools (append to preserve existing tools) if options.Tools == nil { options.Tools = toolMaps } else { options.Tools = append(options.Tools, toolMaps...) } } return samplesPrompt, nil } ================================================ FILE: agent/assistant/build_content.go ================================================ package assistant import ( "fmt" "github.com/yaoapp/yao/agent/content" "github.com/yaoapp/yao/agent/content/text" contentTypes "github.com/yaoapp/yao/agent/content/types" "github.com/yaoapp/yao/agent/context" ) // BuildContent processes messages through Vision function to convert extended content types // (file, data) to standard LLM-compatible types (text, image_url, input_audio) // // This should be called after BuildRequest and before executing LLM call func (ast *Assistant) BuildContent(ctx *context.Context, messages []context.Message, options *context.CompletionOptions, opts *context.Options) ([]context.Message, error) { // Skip complex content parsing if requested (for internal calls like needsearch) // Still convert file attachments to raw text if opts != nil && opts.Skip != nil && opts.Skip.ContentParsing { return convertFilesToText(ctx, messages), nil } // Set AssistantID in context for file info tracking in Space // This ensures hooks can access file information using the correct namespace if ctx.AssistantID == "" { ctx.AssistantID = ast.ID } // Get connector and capabilities connector, capabilities, err := ast.GetConnector(ctx, opts) if err != nil { return nil, fmt.Errorf("failed to get connector: %w", err) } // Build parse options parseOptions := &contentTypes.Options{ Capabilities: capabilities, CompletionOptions: options, Connector: connector, StreamOptions: options.StreamOptions, } contentMessages, referenceContext, err := content.ParseUserInput(ctx, messages, parseOptions) if err != nil { return nil, fmt.Errorf("failed to parse content: %w", err) } // Inject reference context into messages if referenceContext != nil { contentMessages = ast.injectSearchContext(contentMessages, referenceContext) } return contentMessages, nil } // convertFilesToText converts file attachments in messages to raw text // Used when SkipContentParsing is enabled - simple text extraction without vision/PDF processing func convertFilesToText(ctx *context.Context, messages []context.Message) []context.Message { result := make([]context.Message, 0, len(messages)) textHandler := text.New(nil) for _, msg := range messages { // Only process user messages if msg.Role != context.RoleUser { result = append(result, msg) continue } // Handle content parts parts, ok := msg.Content.([]context.ContentPart) if !ok { // Try []interface{} (from history/JSON) if iparts, ok := msg.Content.([]interface{}); ok { parts = convertInterfaceToParts(iparts) } } if len(parts) == 0 { result = append(result, msg) continue } // Convert file parts to text newParts := make([]context.ContentPart, 0, len(parts)) for _, part := range parts { switch part.Type { case context.ContentFile: // Convert file to raw text if part.File != nil && part.File.URL != "" { textPart, _, err := textHandler.ParseRaw(ctx, part) if err == nil { newParts = append(newParts, textPart) continue } } newParts = append(newParts, part) case context.ContentImageURL: // Skip images - cannot convert to text without vision continue default: newParts = append(newParts, part) } } newMsg := msg newMsg.Content = newParts result = append(result, newMsg) } return result } // convertInterfaceToParts converts []interface{} to []ContentPart for file extraction func convertInterfaceToParts(items []interface{}) []context.ContentPart { parts := make([]context.ContentPart, 0, len(items)) for _, item := range items { m, ok := item.(map[string]interface{}) if !ok { continue } typeStr, _ := m["type"].(string) part := context.ContentPart{ Type: context.ContentPartType(typeStr), } switch typeStr { case "text": if t, ok := m["text"].(string); ok { part.Text = t } case "file": if fileData, ok := m["file"].(map[string]interface{}); ok { part.File = &context.FileAttachment{} if url, ok := fileData["url"].(string); ok { part.File.URL = url } if filename, ok := fileData["filename"].(string); ok { part.File.Filename = filename } } case "image_url": part.Type = context.ContentImageURL default: continue } parts = append(parts, part) } return parts } ================================================ FILE: agent/assistant/build_mcp_test.go ================================================ package assistant_test import ( "testing" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" ) // TestBuildRequest_MCP tests MCP tool integration in BuildRequest func TestBuildRequest_MCP(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.mcptest") if err != nil { t.Fatalf("Failed to get tests.mcptest assistant: %s", err.Error()) } ctx := newTestContext("chat-test-mcp", "tests.mcptest") t.Run("MCPToolsLoaded", func(t *testing.T) { inputMessages := []context.Message{{Role: context.RoleUser, Content: "test mcp tools"}} // Build LLM request _, options, err := agent.BuildRequest(ctx, inputMessages, nil) if err != nil { t.Fatalf("Failed to build LLM request: %s", err.Error()) } // Verify that tools are loaded if options.Tools == nil { t.Fatal("Expected tools to be loaded, got nil") } if len(options.Tools) == 0 { t.Fatal("Expected at least some MCP tools, got empty list") } // Count MCP tools (should be filtered to only ping and echo) mcpToolCount := 0 var toolNames []string for _, toolMap := range options.Tools { fn, ok := toolMap["function"].(map[string]interface{}) if !ok { continue } name, ok := fn["name"].(string) if ok { toolNames = append(toolNames, name) mcpToolCount++ } } t.Logf("Found %d MCP tools: %v", mcpToolCount, toolNames) // Verify tool count (should be exactly 2: ping and echo) if mcpToolCount != 2 { t.Errorf("Expected 2 MCP tools (ping, echo), got %d: %v", mcpToolCount, toolNames) } // Verify specific tools exist hasEchoPing := false hasEchoEcho := false for _, name := range toolNames { if name == "echo__ping" { hasEchoPing = true } if name == "echo__echo" { hasEchoEcho = true } } if !hasEchoPing { t.Error("Expected 'echo__ping' tool to be present") } if !hasEchoEcho { t.Error("Expected 'echo__echo' tool to be present") } // Verify that 'status' tool is NOT included (filtered out) for _, name := range toolNames { if name == "echo__status" { t.Error("Tool 'echo__status' should be filtered out but was found") } } t.Log("✓ MCP tools loaded and filtered correctly") }) t.Run("MCPSamplesPrompt", func(t *testing.T) { inputMessages := []context.Message{{Role: context.RoleUser, Content: "test mcp samples"}} // Build LLM request finalMessages, _, err := agent.BuildRequest(ctx, inputMessages, nil) if err != nil { t.Fatalf("Failed to build LLM request: %s", err.Error()) } // Check if messages contain MCP samples prompt // The samples prompt should be added as a system message hasMCPSamples := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem { if content, ok := msg.Content.(string); ok { if len(content) > 50 && (contains(content, "MCP Tool Usage Examples") || contains(content, "echo.ping") || contains(content, "echo.echo")) { hasMCPSamples = true t.Logf("Found MCP samples prompt (length: %d chars)", len(content)) break } } } } // Note: samples may not exist for echo tools, so this is informational if hasMCPSamples { t.Log("✓ MCP samples prompt included in messages") } else { t.Log("ℹ No MCP samples prompt found (may not have sample files)") } }) t.Run("MCPToolNameFormat", func(t *testing.T) { inputMessages := []context.Message{{Role: context.RoleUser, Content: "test tool format"}} // Build LLM request _, options, err := agent.BuildRequest(ctx, inputMessages, nil) if err != nil { t.Fatalf("Failed to build LLM request: %s", err.Error()) } // Verify tool name format: server_id.tool_name for _, toolMap := range options.Tools { fn, ok := toolMap["function"].(map[string]interface{}) if !ok { continue } name, ok := fn["name"].(string) if ok { // Parse tool name serverID, toolName, ok := assistant.ParseMCPToolName(name) if !ok { t.Errorf("Tool name '%s' is not in correct format (server_id.tool_name)", name) continue } // Verify server ID if serverID != "echo" { t.Errorf("Expected server_id 'echo', got '%s' for tool '%s'", serverID, name) } // Verify tool name is either ping or echo if toolName != "ping" && toolName != "echo" { t.Errorf("Expected tool name 'ping' or 'echo', got '%s'", toolName) } t.Logf("✓ Tool name format correct: %s → (%s, %s)", name, serverID, toolName) } } }) t.Run("MCPToolSchema", func(t *testing.T) { inputMessages := []context.Message{{Role: context.RoleUser, Content: "test tool schema"}} // Build LLM request _, options, err := agent.BuildRequest(ctx, inputMessages, nil) if err != nil { t.Fatalf("Failed to build LLM request: %s", err.Error()) } // Verify tool schema structure for _, toolMap := range options.Tools { // Verify type field if toolType, ok := toolMap["type"].(string); !ok || toolType != "function" { t.Errorf("Expected tool type 'function', got: %v", toolMap["type"]) } // Verify function field exists fn, ok := toolMap["function"].(map[string]interface{}) if !ok { t.Error("Tool missing 'function' field or wrong type") continue } // Verify required fields if _, hasName := fn["name"]; !hasName { t.Error("Tool function missing 'name' field") } if _, hasDesc := fn["description"]; !hasDesc { t.Error("Tool function missing 'description' field") } if _, hasParams := fn["parameters"]; !hasParams { t.Error("Tool function missing 'parameters' field") } t.Logf("✓ Tool schema valid: %v", fn["name"]) } }) t.Run("MCPHookOverride", func(t *testing.T) { // Test that hook can override MCP servers // Use tests.mcptest-hook which has a create hook that returns only ["ping"] hookAgent, err := assistant.Get("tests.mcptest-hook") if err != nil { t.Fatalf("Failed to get tests.mcptest-hook assistant: %s", err.Error()) } hookCtx := newTestContext("chat-test-mcp-hook", "tests.mcptest-hook") inputMessages := []context.Message{{Role: context.RoleUser, Content: "test hook override"}} // Call create hook to get createResponse var createResponse *context.HookCreateResponse if hookAgent.HookScript != nil { createResponse, _, err = hookAgent.HookScript.Create(hookCtx, inputMessages, &context.Options{}) if err != nil { t.Fatalf("Failed to call create hook: %s", err.Error()) } t.Logf("Create hook response: %+v", createResponse) if createResponse != nil && len(createResponse.MCPServers) > 0 { t.Logf("Hook MCP servers: %+v", createResponse.MCPServers) } } else { t.Fatal("Expected hookAgent to have Script/hook configured") } // Build LLM request with create hook response _, options, err := hookAgent.BuildRequest(hookCtx, inputMessages, createResponse) if err != nil { t.Fatalf("Failed to build LLM request: %s", err.Error()) } // Verify that tools are loaded if options.Tools == nil { t.Fatal("Expected tools to be loaded, got nil") } // Count MCP tools mcpToolCount := 0 var toolNames []string for _, toolMap := range options.Tools { fn, ok := toolMap["function"].(map[string]interface{}) if !ok { continue } name, ok := fn["name"].(string) if ok { toolNames = append(toolNames, name) mcpToolCount++ } } t.Logf("Found %d MCP tools after hook override: %v", mcpToolCount, toolNames) // Verify tool count (hook should override to only 1: ping) if mcpToolCount != 1 { t.Errorf("Expected 1 MCP tool (ping only), got %d: %v", mcpToolCount, toolNames) } // Verify only ping tool exists hasEchoPing := false hasEchoEcho := false for _, name := range toolNames { if name == "echo__ping" { hasEchoPing = true } if name == "echo__echo" { hasEchoEcho = true } } if !hasEchoPing { t.Error("Expected 'echo__ping' tool to be present") } if hasEchoEcho { t.Error("Tool 'echo__echo' should be filtered out by hook override but was found") } t.Log("✓ Hook successfully overrode MCP servers configuration") }) } // Helper function to check if string contains substring func contains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || findSubstring(s, substr))) } func findSubstring(s, substr string) bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false } ================================================ FILE: agent/assistant/build_prompts_test.go ================================================ package assistant_test import ( stdContext "context" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" store "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // containsString is a helper to check if a content (string or interface{}) contains a substring func containsString(content interface{}, substr string) bool { switch v := content.(type) { case string: return strings.Contains(v, substr) default: return false } } // newPromptTestContext creates a context suitable for prompt testing with Create Hook func newPromptTestContext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client-id", UserID: "test-user-123", TeamID: "test-team-456", TenantID: "test-tenant-789", SessionID: "test-session-id", } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "TestAgent/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.Metadata = make(map[string]interface{}) return ctx } // newMinimalTestContext creates a minimal context for testing // Use this when you only need specific fields set func newMinimalTestContext() *context.Context { return context.New(stdContext.Background(), nil, "test-chat") } func TestBuildSystemPromptsIntegration(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) t.Run("AssistantWithLocale", func(t *testing.T) { // Load an assistant with locales ast, err := assistant.Get("tests.fullfields") require.NoError(t, err) ctx := newMinimalTestContext() ctx.Locale = "zh-cn" ctx.Authorized = &types.AuthorizedInfo{ UserID: "test-user-123", TeamID: "test-team-456", } ctx.Metadata = map[string]interface{}{ "CUSTOM_VAR": "custom-value", "INT_VAR": 42, "BOOL_VAR": true, } // Build request to test the full flow messages := []context.Message{ {Role: context.RoleUser, Content: "Hello"}, } finalMessages, options, err := ast.BuildRequest(ctx, messages, nil) require.NoError(t, err) require.NotNil(t, options) // Should have system prompts prepended assert.Greater(t, len(finalMessages), 1) // First messages should be system prompts hasSystemPrompt := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem { hasSystemPrompt = true break } } assert.True(t, hasSystemPrompt, "Should have system prompts") }) t.Run("DisableGlobalPrompts", func(t *testing.T) { // Load fullfields assistant which has disable_global_prompts: true ast, err := assistant.Get("tests.fullfields") require.NoError(t, err) require.True(t, ast.DisableGlobalPrompts) ctx := newMinimalTestContext() ctx.Locale = "en-us" messages := []context.Message{ {Role: context.RoleUser, Content: "Hello"}, } finalMessages, _, err := ast.BuildRequest(ctx, messages, nil) require.NoError(t, err) // Should still have assistant prompts hasSystemPrompt := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem { hasSystemPrompt = true break } } assert.True(t, hasSystemPrompt, "Should have assistant prompts even with global disabled") }) t.Run("MetadataTypeConversion", func(t *testing.T) { ast, err := assistant.Get("yaobots") require.NoError(t, err) ctx := newMinimalTestContext() ctx.Metadata = map[string]interface{}{ "STRING_VAL": "hello", "INT_VAL": 123, "INT64_VAL": int64(456), "FLOAT_VAL": 3.14, "BOOL_TRUE": true, "BOOL_FALSE": false, "UINT_VAL": uint(789), "NIL_VAL": nil, "EMPTY_VAL": "", "ZERO_INT": 0, "ZERO_FLOAT": 0.0, } messages := []context.Message{ {Role: context.RoleUser, Content: "Test metadata"}, } // This should not panic _, _, err = ast.BuildRequest(ctx, messages, nil) require.NoError(t, err) }) t.Run("AuthorizedInfoPrivacy", func(t *testing.T) { ast, err := assistant.Get("yaobots") require.NoError(t, err) ctx := newMinimalTestContext() ctx.Authorized = &types.AuthorizedInfo{ UserID: "user-123", Subject: "user@example.com", // PII - should not be exposed TeamID: "team-456", TenantID: "tenant-789", } ctx.Client = context.Client{ Type: "web", IP: "192.168.1.1", // Should not be exposed } messages := []context.Message{ {Role: context.RoleUser, Content: "Test privacy"}, } finalMessages, _, err := ast.BuildRequest(ctx, messages, nil) require.NoError(t, err) // Check that sensitive info is not in any system prompts for _, msg := range finalMessages { if msg.Role == context.RoleSystem { assert.NotContains(t, msg.Content, "user@example.com", "Subject should not be in prompts") assert.NotContains(t, msg.Content, "192.168.1.1", "IP should not be in prompts") } } }) t.Run("ContextVariablesInPrompts", func(t *testing.T) { // Set up global prompts with variables assistant.SetGlobalPrompts([]store.Prompt{ {Role: "system", Content: "User ID: $CTX.USER_ID, Team: $CTX.TEAM_ID, Custom: $CTX.MY_VAR"}, }) defer assistant.SetGlobalPrompts(nil) ast, err := assistant.Get("yaobots") require.NoError(t, err) ctx := newMinimalTestContext() ctx.Authorized = &types.AuthorizedInfo{ UserID: "user-abc", TeamID: "team-xyz", } ctx.Metadata = map[string]interface{}{ "MY_VAR": "my-value", } messages := []context.Message{ {Role: context.RoleUser, Content: "Test variables"}, } finalMessages, _, err := ast.BuildRequest(ctx, messages, nil) require.NoError(t, err) // Find the global prompt and verify variables are replaced found := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem && !found { if assert.Contains(t, msg.Content, "User ID: user-abc") { found = true assert.Contains(t, msg.Content, "Team: team-xyz") assert.Contains(t, msg.Content, "Custom: my-value") } } } assert.True(t, found, "Should find global prompt with replaced variables") }) t.Run("SystemVariablesReplacement", func(t *testing.T) { // Set up global prompts with $SYS.* variables assistant.SetGlobalPrompts([]store.Prompt{ {Role: "system", Content: "Time: $SYS.TIME, Date: $SYS.DATE, Datetime: $SYS.DATETIME, Weekday: $SYS.WEEKDAY"}, }) defer assistant.SetGlobalPrompts(nil) ast, err := assistant.Get("yaobots") require.NoError(t, err) ctx := newMinimalTestContext() messages := []context.Message{ {Role: context.RoleUser, Content: "Test system variables"}, } finalMessages, _, err := ast.BuildRequest(ctx, messages, nil) require.NoError(t, err) // Find the global prompt and verify $SYS.* variables are replaced found := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem { // Should NOT contain $SYS. prefix (variables should be replaced) if !assert.NotContains(t, msg.Content, "$SYS.TIME") { continue } if !assert.NotContains(t, msg.Content, "$SYS.DATE") { continue } if !assert.NotContains(t, msg.Content, "$SYS.DATETIME") { continue } if !assert.NotContains(t, msg.Content, "$SYS.WEEKDAY") { continue } // Should contain "Time:", "Date:", etc. with actual values assert.Contains(t, msg.Content, "Time:") assert.Contains(t, msg.Content, "Date:") assert.Contains(t, msg.Content, "Datetime:") assert.Contains(t, msg.Content, "Weekday:") found = true break } } assert.True(t, found, "Should find global prompt with replaced $SYS.* variables") }) t.Run("EnvVariablesReplacement", func(t *testing.T) { // Set test environment variable t.Setenv("TEST_PROMPT_VAR", "env-test-value") // Set up global prompts with $ENV.* variables assistant.SetGlobalPrompts([]store.Prompt{ {Role: "system", Content: "Env Value: $ENV.TEST_PROMPT_VAR, Not Exist: $ENV.NOT_EXIST_VAR_XYZ"}, }) defer assistant.SetGlobalPrompts(nil) ast, err := assistant.Get("yaobots") require.NoError(t, err) ctx := newMinimalTestContext() messages := []context.Message{ {Role: context.RoleUser, Content: "Test env variables"}, } finalMessages, _, err := ast.BuildRequest(ctx, messages, nil) require.NoError(t, err) // Find the global prompt and verify $ENV.* variables are replaced found := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem { // Should NOT contain $ENV. prefix for existing vars if !assert.NotContains(t, msg.Content, "$ENV.TEST_PROMPT_VAR") { continue } // Should contain the actual env value assert.Contains(t, msg.Content, "Env Value: env-test-value") // Non-existent env var should be replaced with empty string assert.Contains(t, msg.Content, "Not Exist: ") assert.NotContains(t, msg.Content, "$ENV.NOT_EXIST_VAR_XYZ") found = true break } } assert.True(t, found, "Should find global prompt with replaced $ENV.* variables") }) t.Run("AllVariableTypesReplacement", func(t *testing.T) { // Set test environment variable t.Setenv("TEST_APP_NAME", "MyTestApp") // Set up global prompts with all variable types assistant.SetGlobalPrompts([]store.Prompt{ {Role: "system", Content: `System Info: - Time: $SYS.TIME - Date: $SYS.DATE - App: $ENV.TEST_APP_NAME - User: $CTX.USER_ID - Custom: $CTX.CUSTOM_KEY - Assistant: $CTX.ASSISTANT_NAME`}, }) defer assistant.SetGlobalPrompts(nil) ast, err := assistant.Get("yaobots") require.NoError(t, err) ctx := newMinimalTestContext() ctx.Authorized = &types.AuthorizedInfo{ UserID: "all-vars-user", } ctx.Metadata = map[string]interface{}{ "CUSTOM_KEY": "custom-value-123", } messages := []context.Message{ {Role: context.RoleUser, Content: "Test all variables"}, } finalMessages, _, err := ast.BuildRequest(ctx, messages, nil) require.NoError(t, err) // Find the global prompt and verify ALL variable types are replaced found := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem && !found { content := msg.Content // Check $SYS.* replaced if assert.NotContains(t, content, "$SYS.TIME") && assert.NotContains(t, content, "$SYS.DATE") { // Check $ENV.* replaced assert.NotContains(t, content, "$ENV.TEST_APP_NAME") assert.Contains(t, content, "App: MyTestApp") // Check $CTX.* replaced assert.NotContains(t, content, "$CTX.USER_ID") assert.Contains(t, content, "User: all-vars-user") assert.NotContains(t, content, "$CTX.CUSTOM_KEY") assert.Contains(t, content, "Custom: custom-value-123") // Check assistant name from $CTX.ASSISTANT_NAME assert.NotContains(t, content, "$CTX.ASSISTANT_NAME") found = true } } } assert.True(t, found, "Should find global prompt with all variable types replaced") }) t.Run("PromptPresetFromHook", func(t *testing.T) { // Load fullfields assistant which has prompt_presets ast, err := assistant.Get("tests.fullfields") require.NoError(t, err) require.NotNil(t, ast.PromptPresets) require.Contains(t, ast.PromptPresets, "chat.friendly") ctx := newMinimalTestContext() messages := []context.Message{ {Role: context.RoleUser, Content: "Test preset from hook"}, } // Hook returns prompt_preset createResponse := &context.HookCreateResponse{ PromptPreset: "chat.friendly", } finalMessages, _, err := ast.BuildRequest(ctx, messages, createResponse) require.NoError(t, err) // Should have system prompts from the preset hasSystemPrompt := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem { hasSystemPrompt = true // Verify it's from the friendly preset (check content) assert.Contains(t, msg.Content, "friendly", "Should use friendly preset prompts") break } } assert.True(t, hasSystemPrompt, "Should have system prompts from preset") }) t.Run("PromptPresetFromMetadata", func(t *testing.T) { // Load fullfields assistant which has prompt_presets ast, err := assistant.Get("tests.fullfields") require.NoError(t, err) ctx := newMinimalTestContext() ctx.Metadata = map[string]interface{}{ "__prompt_preset": "chat.professional", } messages := []context.Message{ {Role: context.RoleUser, Content: "Test preset from metadata"}, } finalMessages, _, err := ast.BuildRequest(ctx, messages, nil) require.NoError(t, err) // Should have system prompts from the preset hasSystemPrompt := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem { hasSystemPrompt = true // Verify it's from the professional preset assert.Contains(t, msg.Content, "professional", "Should use professional preset prompts") break } } assert.True(t, hasSystemPrompt, "Should have system prompts from preset") }) t.Run("PromptPresetHookOverridesMetadata", func(t *testing.T) { // Load fullfields assistant ast, err := assistant.Get("tests.fullfields") require.NoError(t, err) ctx := newMinimalTestContext() ctx.Metadata = map[string]interface{}{ "__prompt_preset": "chat.professional", // Lower priority } messages := []context.Message{ {Role: context.RoleUser, Content: "Test hook overrides metadata"}, } // Hook returns different preset (higher priority) createResponse := &context.HookCreateResponse{ PromptPreset: "chat.friendly", } finalMessages, _, err := ast.BuildRequest(ctx, messages, createResponse) require.NoError(t, err) // Should use hook's preset, not metadata's for _, msg := range finalMessages { if msg.Role == context.RoleSystem { assert.Contains(t, msg.Content, "friendly", "Hook preset should override metadata preset") break } } }) t.Run("PromptPresetNotFound", func(t *testing.T) { // Load fullfields assistant ast, err := assistant.Get("tests.fullfields") require.NoError(t, err) ctx := newMinimalTestContext() ctx.Metadata = map[string]interface{}{ "__prompt_preset": "non.existent.preset", } messages := []context.Message{ {Role: context.RoleUser, Content: "Test non-existent preset"}, } finalMessages, _, err := ast.BuildRequest(ctx, messages, nil) require.NoError(t, err) // Should fallback to default prompts (not crash) hasSystemPrompt := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem { hasSystemPrompt = true break } } assert.True(t, hasSystemPrompt, "Should fallback to default prompts when preset not found") }) t.Run("DisableGlobalPromptsFromHook", func(t *testing.T) { // Set global prompts assistant.SetGlobalPrompts([]store.Prompt{ {Role: "system", Content: "GLOBAL_PROMPT_MARKER"}, }) defer assistant.SetGlobalPrompts(nil) // Load an assistant that does NOT disable global prompts ast, err := assistant.Get("yaobots") require.NoError(t, err) require.False(t, ast.DisableGlobalPrompts) ctx := newMinimalTestContext() messages := []context.Message{ {Role: context.RoleUser, Content: "Test disable from hook"}, } // Hook disables global prompts disableTrue := true createResponse := &context.HookCreateResponse{ DisableGlobalPrompts: &disableTrue, } finalMessages, _, err := ast.BuildRequest(ctx, messages, createResponse) require.NoError(t, err) // Should NOT have global prompt for _, msg := range finalMessages { if msg.Role == context.RoleSystem { assert.NotContains(t, msg.Content, "GLOBAL_PROMPT_MARKER", "Global prompts should be disabled by hook") } } }) t.Run("DisableGlobalPromptsFromMetadata", func(t *testing.T) { // Set global prompts assistant.SetGlobalPrompts([]store.Prompt{ {Role: "system", Content: "GLOBAL_PROMPT_MARKER_2"}, }) defer assistant.SetGlobalPrompts(nil) // Load an assistant that does NOT disable global prompts ast, err := assistant.Get("yaobots") require.NoError(t, err) ctx := newMinimalTestContext() ctx.Metadata = map[string]interface{}{ "__disable_global_prompts": true, } messages := []context.Message{ {Role: context.RoleUser, Content: "Test disable from metadata"}, } finalMessages, _, err := ast.BuildRequest(ctx, messages, nil) require.NoError(t, err) // Should NOT have global prompt for _, msg := range finalMessages { if msg.Role == context.RoleSystem { assert.NotContains(t, msg.Content, "GLOBAL_PROMPT_MARKER_2", "Global prompts should be disabled by metadata") } } }) t.Run("EnableGlobalPromptsOverrideAssistant", func(t *testing.T) { // Set global prompts assistant.SetGlobalPrompts([]store.Prompt{ {Role: "system", Content: "GLOBAL_ENABLED_MARKER"}, }) defer assistant.SetGlobalPrompts(nil) // Load fullfields assistant which has disable_global_prompts: true ast, err := assistant.Get("tests.fullfields") require.NoError(t, err) require.True(t, ast.DisableGlobalPrompts) ctx := newMinimalTestContext() messages := []context.Message{ {Role: context.RoleUser, Content: "Test enable override"}, } // Hook enables global prompts (overrides assistant's disable) disableFalse := false createResponse := &context.HookCreateResponse{ DisableGlobalPrompts: &disableFalse, } finalMessages, _, err := ast.BuildRequest(ctx, messages, createResponse) require.NoError(t, err) // Should have global prompt (hook enabled it) found := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem && msg.Content == "GLOBAL_ENABLED_MARKER" { found = true break } } assert.True(t, found, "Global prompts should be enabled by hook override") }) } // TestPromptPresetAssistant tests the tests.promptpreset assistant with Create Hook func TestPromptPresetAssistant(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) t.Run("LoadPromptPresetAssistant", func(t *testing.T) { ast, err := assistant.Get("tests.promptpreset") require.NoError(t, err) require.NotNil(t, ast) assert.Equal(t, "tests.promptpreset", ast.ID) assert.Equal(t, "Prompt Preset Test", ast.Name) assert.False(t, ast.DisableGlobalPrompts) // Should have prompt presets loaded require.NotNil(t, ast.PromptPresets) assert.Contains(t, ast.PromptPresets, "mode.friendly") assert.Contains(t, ast.PromptPresets, "mode.professional") // Should have script assert.NotNil(t, ast.HookScript) }) t.Run("CreateHookSelectsFriendlyPreset", func(t *testing.T) { ast, err := assistant.Get("tests.promptpreset") require.NoError(t, err) ctx := newPromptTestContext("chat-friendly-test", "tests.promptpreset") messages := []context.Message{ {Role: context.RoleUser, Content: "use friendly mode please"}, } // Call Create hook createResponse, _, err := ast.HookScript.Create(ctx, messages, &context.Options{}) require.NoError(t, err) require.NotNil(t, createResponse) assert.Equal(t, "mode.friendly", createResponse.PromptPreset) // Build request finalMessages, _, err := ast.BuildRequest(ctx, messages, createResponse) require.NoError(t, err) // Should have friendly preset marker in one of the system messages found := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem && containsString(msg.Content, "FRIENDLY_PRESET_MARKER") { found = true break } } assert.True(t, found, "Should use friendly preset from Create Hook") }) t.Run("CreateHookSelectsProfessionalPreset", func(t *testing.T) { ast, err := assistant.Get("tests.promptpreset") require.NoError(t, err) ctx := newPromptTestContext("chat-professional-test", "tests.promptpreset") messages := []context.Message{ {Role: context.RoleUser, Content: "use professional tone"}, } // Call Create hook createResponse, _, err := ast.HookScript.Create(ctx, messages, &context.Options{}) require.NoError(t, err) require.NotNil(t, createResponse) assert.Equal(t, "mode.professional", createResponse.PromptPreset) // Build request finalMessages, _, err := ast.BuildRequest(ctx, messages, createResponse) require.NoError(t, err) // Should have professional preset marker in one of the system messages found := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem && containsString(msg.Content, "PROFESSIONAL_PRESET_MARKER") { found = true break } } assert.True(t, found, "Should use professional preset from Create Hook") }) t.Run("CreateHookDisablesGlobalPrompts", func(t *testing.T) { // Set global prompts assistant.SetGlobalPrompts([]store.Prompt{ {Role: "system", Content: "GLOBAL_MARKER_FOR_DISABLE_TEST"}, }) defer assistant.SetGlobalPrompts(nil) ast, err := assistant.Get("tests.promptpreset") require.NoError(t, err) ctx := newPromptTestContext("chat-disable-global-test", "tests.promptpreset") messages := []context.Message{ {Role: context.RoleUser, Content: "disable global prompts"}, } // Call Create hook createResponse, _, err := ast.HookScript.Create(ctx, messages, &context.Options{}) require.NoError(t, err) require.NotNil(t, createResponse) require.NotNil(t, createResponse.DisableGlobalPrompts) assert.True(t, *createResponse.DisableGlobalPrompts) // Build request finalMessages, _, err := ast.BuildRequest(ctx, messages, createResponse) require.NoError(t, err) // Should NOT have global prompt for _, msg := range finalMessages { if msg.Role == context.RoleSystem { assert.NotContains(t, msg.Content, "GLOBAL_MARKER_FOR_DISABLE_TEST") } } }) t.Run("CreateHookPresetAndDisableGlobal", func(t *testing.T) { // Set global prompts assistant.SetGlobalPrompts([]store.Prompt{ {Role: "system", Content: "GLOBAL_MARKER_COMBINED_TEST"}, }) defer assistant.SetGlobalPrompts(nil) ast, err := assistant.Get("tests.promptpreset") require.NoError(t, err) ctx := newPromptTestContext("chat-combined-test", "tests.promptpreset") messages := []context.Message{ {Role: context.RoleUser, Content: "friendly no global"}, } // Call Create hook createResponse, _, err := ast.HookScript.Create(ctx, messages, &context.Options{}) require.NoError(t, err) require.NotNil(t, createResponse) assert.Equal(t, "mode.friendly", createResponse.PromptPreset) require.NotNil(t, createResponse.DisableGlobalPrompts) assert.True(t, *createResponse.DisableGlobalPrompts) // Build request finalMessages, _, err := ast.BuildRequest(ctx, messages, createResponse) require.NoError(t, err) // Should have friendly preset but NOT global hasFriendly := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem { assert.NotContains(t, msg.Content, "GLOBAL_MARKER_COMBINED_TEST") if containsString(msg.Content, "FRIENDLY_PRESET_MARKER") { hasFriendly = true } } } assert.True(t, hasFriendly, "Should have friendly preset") }) t.Run("CreateHookUnknownPresetFallback", func(t *testing.T) { ast, err := assistant.Get("tests.promptpreset") require.NoError(t, err) ctx := newPromptTestContext("chat-unknown-preset-test", "tests.promptpreset") messages := []context.Message{ {Role: context.RoleUser, Content: "unknown preset test"}, } // Call Create hook createResponse, _, err := ast.HookScript.Create(ctx, messages, &context.Options{}) require.NoError(t, err) require.NotNil(t, createResponse) assert.Equal(t, "non.existent.preset", createResponse.PromptPreset) // Build request - should not error, fallback to default finalMessages, _, err := ast.BuildRequest(ctx, messages, createResponse) require.NoError(t, err) // Should fallback to default prompts found := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem && containsString(msg.Content, "DEFAULT_PROMPT_MARKER") { found = true break } } assert.True(t, found, "Should fallback to default prompts when preset not found") }) t.Run("CreateHookReturnsNull", func(t *testing.T) { ast, err := assistant.Get("tests.promptpreset") require.NoError(t, err) ctx := newPromptTestContext("chat-null-test", "tests.promptpreset") messages := []context.Message{ {Role: context.RoleUser, Content: "just a normal message"}, } // Call Create hook - should return nil createResponse, _, err := ast.HookScript.Create(ctx, messages, &context.Options{}) require.NoError(t, err) assert.Nil(t, createResponse) // Build request with nil createResponse finalMessages, _, err := ast.BuildRequest(ctx, messages, nil) require.NoError(t, err) // Should use default prompts found := false for _, msg := range finalMessages { if msg.Role == context.RoleSystem && containsString(msg.Content, "DEFAULT_PROMPT_MARKER") { found = true break } } assert.True(t, found, "Should use default prompts when hook returns null") }) } ================================================ FILE: agent/assistant/build_test.go ================================================ package assistant_test import ( stdContext "context" "testing" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // newTestContext creates a Context for testing with commonly used fields pre-populated func newTestContext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client-id", UserID: "test-user-123", TeamID: "test-team-456", TenantID: "test-tenant-789", SessionID: "test-session-id", } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "TestAgent/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.Route = "/test/route" ctx.Metadata = map[string]interface{}{ "test": "context_metadata", } return ctx } // TestBuildRequest tests the BuildRequest function func TestBuildRequest(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.buildrequest") if err != nil { t.Fatalf("Failed to get tests.buildrequest assistant: %s", err.Error()) } if agent.HookScript == nil { t.Fatalf("The tests.buildrequest assistant has no script") } ctx := newTestContext("chat-test-buildrequest", "tests.buildrequest") // Test 1: No override from hook - should use ast.Options and ctx values t.Run("NoOverride", func(t *testing.T) { inputMessages := []context.Message{{Role: "user", Content: "no_override"}} // Call Create hook createResponse, _, err := agent.HookScript.Create(ctx, inputMessages, &context.Options{}) if err != nil { t.Fatalf("Failed to call Create hook: %s", err.Error()) } // Build LLM request _, options, err := agent.BuildRequest(ctx, inputMessages, createResponse) if err != nil { t.Fatalf("Failed to build LLM request: %s", err.Error()) } // Verify options - should use ast.Options values if options.Temperature == nil { t.Error("Expected temperature from ast.Options, got nil") } else if *options.Temperature != 0.5 { t.Errorf("Expected temperature 0.5 from ast.Options, got: %f", *options.Temperature) } if options.MaxTokens == nil { t.Error("Expected max_tokens from ast.Options, got nil") } else if *options.MaxTokens != 1000 { t.Errorf("Expected max_tokens 1000 from ast.Options, got: %d", *options.MaxTokens) } if options.TopP == nil { t.Error("Expected top_p from ast.Options, got nil") } else if *options.TopP != 0.9 { t.Errorf("Expected top_p 0.9 from ast.Options, got: %f", *options.TopP) } // Verify ctx values if options.Route != "/test/route" { t.Errorf("Expected route '/test/route' from ctx, got: %s", options.Route) } if options.Metadata == nil { t.Error("Expected metadata from ctx, got nil") } else if options.Metadata["test"] != "context_metadata" { t.Errorf("Expected metadata from ctx, got: %v", options.Metadata) } t.Log("✓ No override: ast.Options and ctx values used correctly") }) // Test 2: Override temperature - hook value should take priority t.Run("OverrideTemperature", func(t *testing.T) { inputMessages := []context.Message{{Role: "user", Content: "override_temperature"}} createResponse, _, err := agent.HookScript.Create(ctx, inputMessages, &context.Options{}) if err != nil { t.Fatalf("Failed to call Create hook: %s", err.Error()) } _, options, err := agent.BuildRequest(ctx, inputMessages, createResponse) if err != nil { t.Fatalf("Failed to build LLM request: %s", err.Error()) } // Verify temperature override if options.Temperature == nil { t.Error("Expected temperature, got nil") } else if *options.Temperature != 0.9 { t.Errorf("Expected temperature 0.9 from hook, got: %f", *options.Temperature) } // Other values should still come from ast.Options if options.MaxTokens == nil { t.Error("Expected max_tokens from ast.Options, got nil") } else if *options.MaxTokens != 1000 { t.Errorf("Expected max_tokens 1000 from ast.Options, got: %d", *options.MaxTokens) } t.Log("✓ Temperature override: hook value takes priority over ast.Options") }) // Test 3: Override all - all hook values should take priority t.Run("OverrideAll", func(t *testing.T) { inputMessages := []context.Message{{Role: "user", Content: "override_all"}} createResponse, _, err := agent.HookScript.Create(ctx, inputMessages, &context.Options{}) if err != nil { t.Fatalf("Failed to call Create hook: %s", err.Error()) } _, options, err := agent.BuildRequest(ctx, inputMessages, createResponse) if err != nil { t.Fatalf("Failed to build LLM request: %s", err.Error()) } // Verify all overrides if options.Temperature == nil || *options.Temperature != 0.8 { t.Errorf("Expected temperature 0.8 from hook, got: %v", options.Temperature) } if options.MaxTokens == nil || *options.MaxTokens != 2000 { t.Errorf("Expected max_tokens 2000 from hook, got: %v", options.MaxTokens) } if options.MaxCompletionTokens == nil || *options.MaxCompletionTokens != 1800 { t.Errorf("Expected max_completion_tokens 1800 from hook, got: %v", options.MaxCompletionTokens) } if options.Audio == nil { t.Error("Expected audio from hook, got nil") } else { if options.Audio.Voice != "alloy" { t.Errorf("Expected voice 'alloy', got: %s", options.Audio.Voice) } if options.Audio.Format != "mp3" { t.Errorf("Expected format 'mp3', got: %s", options.Audio.Format) } } if options.Route != "/hook/route" { t.Errorf("Expected route '/hook/route' from hook, got: %s", options.Route) } if options.Metadata == nil { t.Error("Expected metadata from hook, got nil") } else { if options.Metadata["source"] != "hook" { t.Errorf("Expected metadata['source'] = 'hook', got: %v", options.Metadata["source"]) } } t.Log("✓ Override all: all hook values take priority") }) // Test 4: Override route and metadata - tests CUI context priority t.Run("OverrideRouteMetadata", func(t *testing.T) { inputMessages := []context.Message{{Role: "user", Content: "override_route_metadata"}} createResponse, _, err := agent.HookScript.Create(ctx, inputMessages, &context.Options{}) if err != nil { t.Fatalf("Failed to call Create hook: %s", err.Error()) } _, options, err := agent.BuildRequest(ctx, inputMessages, createResponse) if err != nil { t.Fatalf("Failed to build LLM request: %s", err.Error()) } // Verify route override if options.Route != "/custom/route" { t.Errorf("Expected route '/custom/route' from hook, got: %s", options.Route) } // Verify metadata merge (ctx metadata should be merged with hook metadata) if options.Metadata == nil { t.Error("Expected metadata, got nil") } else { // Hook metadata should be present if options.Metadata["custom"] != true { t.Errorf("Expected metadata['custom'] = true from hook, got: %v", options.Metadata["custom"]) } if options.Metadata["hook_data"] != "test" { t.Errorf("Expected metadata['hook_data'] = 'test' from hook, got: %v", options.Metadata["hook_data"]) } // Original ctx metadata should still be there (merged) if options.Metadata["test"] != "context_metadata" { t.Errorf("Expected original ctx metadata to be preserved, got: %v", options.Metadata) } } // Other values should still come from ast.Options if options.Temperature == nil || *options.Temperature != 0.5 { t.Errorf("Expected temperature 0.5 from ast.Options, got: %v", options.Temperature) } t.Log("✓ Route and metadata override: hook values take priority, metadata merged") }) // Test 5: Nil createResponse - should use ast.Options and ctx values t.Run("NilCreateResponse", func(t *testing.T) { // Create a fresh context for this test freshCtx := newTestContext("chat-test-nil", "tests.buildrequest") inputMessages := []context.Message{{Role: "user", Content: "test message"}} _, options, err := agent.BuildRequest(freshCtx, inputMessages, nil) if err != nil { t.Fatalf("Failed to build LLM request: %s", err.Error()) } // Should use ast.Options values if options.Temperature == nil || *options.Temperature != 0.5 { t.Errorf("Expected temperature 0.5 from ast.Options, got: %v", options.Temperature) } // Should use ctx values if options.Route != "/test/route" { t.Errorf("Expected route '/test/route' from ctx, got: %s", options.Route) } t.Log("✓ Nil createResponse: ast.Options and ctx values used") }) // Test 6: ResponseFormat with *context.ResponseFormat t.Run("ResponseFormatStruct", func(t *testing.T) { freshCtx := newTestContext("chat-test-response-format", "tests.buildrequest") inputMessages := []context.Message{{Role: "user", Content: "test message"}} // Create a test agent with response_format in Options testAgent := *agent strict := true testAgent.Options = map[string]interface{}{ "temperature": 0.7, "response_format": &context.ResponseFormat{ Type: context.ResponseFormatJSONSchema, JSONSchema: &context.JSONSchema{ Name: "test_schema", Description: "Test schema description", Schema: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "name": map[string]interface{}{ "type": "string", }, }, }, Strict: &strict, }, }, } _, options, err := testAgent.BuildRequest(freshCtx, inputMessages, nil) if err != nil { t.Fatalf("Failed to build LLM request: %s", err.Error()) } // Verify ResponseFormat if options.ResponseFormat == nil { t.Fatal("Expected ResponseFormat, got nil") } if options.ResponseFormat.Type != context.ResponseFormatJSONSchema { t.Errorf("Expected type 'json_schema', got: %s", options.ResponseFormat.Type) } if options.ResponseFormat.JSONSchema == nil { t.Fatal("Expected JSONSchema, got nil") } if options.ResponseFormat.JSONSchema.Name != "test_schema" { t.Errorf("Expected schema name 'test_schema', got: %s", options.ResponseFormat.JSONSchema.Name) } if options.ResponseFormat.JSONSchema.Description != "Test schema description" { t.Errorf("Expected schema description 'Test schema description', got: %s", options.ResponseFormat.JSONSchema.Description) } if options.ResponseFormat.JSONSchema.Strict == nil || *options.ResponseFormat.JSONSchema.Strict != true { t.Errorf("Expected strict = true, got: %v", options.ResponseFormat.JSONSchema.Strict) } t.Log("✓ ResponseFormat with *context.ResponseFormat struct works correctly") }) // Test 7: ResponseFormat with legacy map[string]interface{} t.Run("ResponseFormatLegacyMap", func(t *testing.T) { freshCtx := newTestContext("chat-test-response-format-map", "tests.buildrequest") inputMessages := []context.Message{{Role: "user", Content: "test message"}} // Create a test agent with legacy map format testAgent := *agent testAgent.Options = map[string]interface{}{ "temperature": 0.7, "response_format": map[string]interface{}{ "type": "json_schema", "json_schema": map[string]interface{}{ "name": "legacy_schema", "description": "Legacy schema format", "schema": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "email": map[string]interface{}{ "type": "string", }, }, }, "strict": true, }, }, } _, options, err := testAgent.BuildRequest(freshCtx, inputMessages, nil) if err != nil { t.Fatalf("Failed to build LLM request: %s", err.Error()) } // Verify ResponseFormat was converted from map if options.ResponseFormat == nil { t.Fatal("Expected ResponseFormat, got nil") } if options.ResponseFormat.Type != context.ResponseFormatJSONSchema { t.Errorf("Expected type 'json_schema', got: %s", options.ResponseFormat.Type) } if options.ResponseFormat.JSONSchema == nil { t.Fatal("Expected JSONSchema, got nil") } if options.ResponseFormat.JSONSchema.Name != "legacy_schema" { t.Errorf("Expected schema name 'legacy_schema', got: %s", options.ResponseFormat.JSONSchema.Name) } if options.ResponseFormat.JSONSchema.Description != "Legacy schema format" { t.Errorf("Expected schema description 'Legacy schema format', got: %s", options.ResponseFormat.JSONSchema.Description) } t.Log("✓ ResponseFormat with legacy map[string]interface{} format works correctly") }) // Test 8: ResponseFormat with simple type (text or json_object) t.Run("ResponseFormatSimpleType", func(t *testing.T) { freshCtx := newTestContext("chat-test-response-format-simple", "tests.buildrequest") inputMessages := []context.Message{{Role: "user", Content: "test message"}} // Create a test agent with simple response_format testAgent := *agent testAgent.Options = map[string]interface{}{ "temperature": 0.7, "response_format": map[string]interface{}{ "type": "json_object", }, } _, options, err := testAgent.BuildRequest(freshCtx, inputMessages, nil) if err != nil { t.Fatalf("Failed to build LLM request: %s", err.Error()) } // Verify ResponseFormat if options.ResponseFormat == nil { t.Fatal("Expected ResponseFormat, got nil") } if options.ResponseFormat.Type != context.ResponseFormatJSON { t.Errorf("Expected type 'json_object', got: %s", options.ResponseFormat.Type) } if options.ResponseFormat.JSONSchema != nil { t.Errorf("Expected JSONSchema to be nil for simple type, got: %v", options.ResponseFormat.JSONSchema) } t.Log("✓ ResponseFormat with simple type (json_object) works correctly") }) } ================================================ FILE: agent/assistant/cache.go ================================================ package assistant import ( "container/list" "sync" ) // Cache represents a thread-safe LRU cache for Assistant objects type Cache struct { capacity int mu sync.RWMutex list *list.List items map[string]*list.Element } // cacheItem represents an item in the cache type cacheItem struct { key string value *Assistant } // NewCache creates a new LRU cache with the given capacity func NewCache(capacity int) *Cache { return &Cache{ capacity: capacity, list: list.New(), items: make(map[string]*list.Element), } } // Get retrieves an Assistant from the cache by its ID func (c *Cache) Get(id string) (*Assistant, bool) { c.mu.Lock() defer c.mu.Unlock() if element, exists := c.items[id]; exists { c.list.MoveToFront(element) return element.Value.(*cacheItem).value, true } return nil, false } // Put adds or updates an Assistant in the cache func (c *Cache) Put(assistant *Assistant) { if assistant == nil || assistant.ID == "" { return } c.mu.Lock() defer c.mu.Unlock() // If item exists, update it and move to front if element, exists := c.items[assistant.ID]; exists { c.list.MoveToFront(element) element.Value.(*cacheItem).value = assistant return } // If cache is at capacity, remove oldest item before adding new one if c.list.Len() >= c.capacity { c.removeOldest() } // Add new item element := c.list.PushFront(&cacheItem{ key: assistant.ID, value: assistant, }) c.items[assistant.ID] = element } // Remove removes an Assistant from the cache func (c *Cache) Remove(id string) { c.mu.Lock() defer c.mu.Unlock() if element, exists := c.items[id]; exists { item := element.Value.(*cacheItem) // Unregister scripts before removing from cache if item.value != nil && len(item.value.Scripts) > 0 { item.value.UnregisterScripts() } c.list.Remove(element) delete(c.items, id) } } // Len returns the current number of items in the cache func (c *Cache) Len() int { c.mu.RLock() defer c.mu.RUnlock() return c.list.Len() } // All returns all assistants in the cache func (c *Cache) All() []*Assistant { c.mu.RLock() defer c.mu.RUnlock() assistants := make([]*Assistant, 0, c.list.Len()) for element := c.list.Front(); element != nil; element = element.Next() { item := element.Value.(*cacheItem) assistants = append(assistants, item.value) } return assistants } // Clear removes all items from the cache func (c *Cache) Clear() { c.mu.Lock() defer c.mu.Unlock() // Unregister all scripts before clearing cache for element := c.list.Front(); element != nil; element = element.Next() { item := element.Value.(*cacheItem) if item.value != nil && len(item.value.Scripts) > 0 { item.value.UnregisterScripts() } } c.list.Init() c.items = make(map[string]*list.Element) } // ClearExcept removes items from the cache except those matching the keep function // keep function returns true for items that should be preserved func (c *Cache) ClearExcept(keep func(id string) bool) { c.mu.Lock() defer c.mu.Unlock() // Collect items to remove var toRemove []*list.Element for element := c.list.Front(); element != nil; element = element.Next() { item := element.Value.(*cacheItem) if !keep(item.key) { toRemove = append(toRemove, element) } } // Remove collected items for _, element := range toRemove { item := element.Value.(*cacheItem) // Unregister scripts before removing if item.value != nil && len(item.value.Scripts) > 0 { item.value.UnregisterScripts() } c.list.Remove(element) delete(c.items, item.key) } } // removeOldest removes the least recently used item from the cache func (c *Cache) removeOldest() { if element := c.list.Back(); element != nil { item := element.Value.(*cacheItem) // Unregister scripts before removing from cache if item.value != nil && len(item.value.Scripts) > 0 { item.value.UnregisterScripts() } c.list.Remove(element) delete(c.items, item.key) } } ================================================ FILE: agent/assistant/cache_test.go ================================================ package assistant_test import ( "sync" "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/process" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/testutils" ) func TestCacheBasic(t *testing.T) { cache := assistant.NewCache(2) // Test empty cache assert.Equal(t, 0, cache.Len(), "Expected empty cache") // Create test assistants testutils.Prepare(t) defer testutils.Clean(t) ast1, err := assistant.Get("tests.mcpload") assert.NoError(t, err) ast2, err := assistant.Get("tests.create") assert.NoError(t, err) // Test adding items cache.Put(ast1) cache.Put(ast2) assert.Equal(t, 2, cache.Len(), "Expected cache length 2") // Test getting items cached1, exists := cache.Get("tests.mcpload") assert.True(t, exists, "Should find tests.mcpload") assert.Equal(t, "tests.mcpload", cached1.ID) cached2, exists := cache.Get("tests.create") assert.True(t, exists, "Should find tests.create") assert.Equal(t, "tests.create", cached2.ID) } func TestCacheLRU(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) cache := assistant.NewCache(2) ast1, _ := assistant.Get("tests.mcpload") ast2, _ := assistant.Get("tests.create") ast3, _ := assistant.Get("tests.next") // Add first two items cache.Put(ast1) cache.Put(ast2) // Access ast1 to make it most recently used cache.Get("tests.mcpload") // Add third item, should evict ast2 cache.Put(ast3) // Check ast2 was evicted _, exists := cache.Get("tests.create") assert.False(t, exists, "tests.create should have been evicted") // Check ast1 and ast3 are still present _, exists = cache.Get("tests.mcpload") assert.True(t, exists, "tests.mcpload should still be in cache") _, exists = cache.Get("tests.next") assert.True(t, exists, "tests.next should be in cache") } func TestCacheRemove(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) cache := assistant.NewCache(2) ast1, _ := assistant.Get("tests.mcpload") cache.Put(ast1) // Verify scripts are registered _, exists := process.Handlers["agents.tests.mcpload.tools"] assert.True(t, exists, "Handler should be registered before removal") // Test remove existing item cache.Remove("tests.mcpload") assert.Equal(t, 0, cache.Len(), "Cache should be empty after removing item") // Verify scripts are unregistered _, exists = process.Handlers["agents.tests.mcpload.tools"] assert.False(t, exists, "Handler should be unregistered after removal") // Test remove non-existing item (should not panic) cache.Remove("nonexistent") assert.Equal(t, 0, cache.Len(), "Cache length should not change") } func TestCacheClear(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) cache := assistant.NewCache(3) ast1, _ := assistant.Get("tests.mcpload") ast2, _ := assistant.Get("tests.create") ast3, _ := assistant.Get("tests.next") cache.Put(ast1) cache.Put(ast2) cache.Put(ast3) assert.Equal(t, 3, cache.Len(), "Cache should have 3 items") // Verify scripts are registered _, exists := process.Handlers["agents.tests.mcpload.tools"] assert.True(t, exists, "Handler should be registered before clear") // Clear cache cache.Clear() assert.Equal(t, 0, cache.Len(), "Cache should be empty after clear") // Verify all scripts are unregistered _, exists = process.Handlers["agents.tests.mcpload.tools"] assert.False(t, exists, "Handler should be unregistered after clear") } func TestCacheLRUEviction(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) cache := assistant.NewCache(2) ast1, _ := assistant.Get("tests.mcpload") ast2, _ := assistant.Get("tests.create") ast3, _ := assistant.Get("tests.next") cache.Put(ast1) cache.Put(ast2) // Verify both are registered _, exists1 := process.Handlers["agents.tests.mcpload.tools"] assert.True(t, exists1, "Handler 1 should be registered") // Add third item to trigger LRU eviction of oldest (ast1) cache.Put(ast3) // Verify ast1's handler was unregistered due to eviction _, exists := process.Handlers["agents.tests.mcpload.tools"] assert.False(t, exists, "Handler should be unregistered after LRU eviction") // Verify ast2 and ast3 are still in cache _, exists = cache.Get("tests.create") assert.True(t, exists, "tests.create should still be in cache") _, exists = cache.Get("tests.next") assert.True(t, exists, "tests.next should be in cache") } func TestCacheConcurrent(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) cache := assistant.NewCache(10) var wg sync.WaitGroup workers := 5 iterations := 20 // Load some assistants for concurrent testing assistants := []string{ "tests.mcpload", "tests.create", "tests.next", } // Concurrent writes for i := 0; i < workers; i++ { wg.Add(1) go func(workerID int) { defer wg.Done() for j := 0; j < iterations; j++ { astID := assistants[j%len(assistants)] ast, _ := assistant.Get(astID) if ast != nil { cache.Put(ast) } } }(i) } // Concurrent reads for i := 0; i < workers; i++ { wg.Add(1) go func(workerID int) { defer wg.Done() for j := 0; j < iterations; j++ { astID := assistants[j%len(assistants)] cache.Get(astID) } }(i) } wg.Wait() // Verify cache is in valid state assert.True(t, cache.Len() >= 0, "Cache should have valid length") assert.True(t, cache.Len() <= 10, "Cache should not exceed capacity") } func TestCacheNilInput(t *testing.T) { cache := assistant.NewCache(2) // Test putting nil assistant cache.Put(nil) assert.Equal(t, 0, cache.Len(), "Cache should not store nil assistant") // Test putting assistant with empty ID emptyAST := &assistant.Assistant{} cache.Put(emptyAST) assert.Equal(t, 0, cache.Len(), "Cache should not store assistant with empty ID") } func TestCacheAll(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) cache := assistant.NewCache(5) ast1, _ := assistant.Get("tests.mcpload") ast2, _ := assistant.Get("tests.create") ast3, _ := assistant.Get("tests.next") cache.Put(ast1) cache.Put(ast2) cache.Put(ast3) all := cache.All() assert.Equal(t, 3, len(all), "All() should return 3 assistants") // Verify all expected assistants are present ids := make(map[string]bool) for _, ast := range all { ids[ast.ID] = true } assert.True(t, ids["tests.mcpload"], "Should contain tests.mcpload") assert.True(t, ids["tests.create"], "Should contain tests.create") assert.True(t, ids["tests.next"], "Should contain tests.next") } ================================================ FILE: agent/assistant/chat.go ================================================ package assistant import ( "fmt" "strings" "time" "github.com/google/uuid" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" storetypes "github.com/yaoapp/yao/agent/store/types" ) // InitializeConversation prepares conversation context (synchronous) // KB collection is now initialized when user logs in (see openapi/user/login.go) func (ast *Assistant) InitializeConversation(ctx *agentcontext.Context, options ...*agentcontext.Options) error { // Reserved for future conversation initialization logic return nil } // InitializeConversationAsync prepares conversation context asynchronously func (ast *Assistant) InitializeConversationAsync(ctx *agentcontext.Context, options ...*agentcontext.Options) { go ast.InitializeConversation(ctx, options...) } // GetChatKBID returns the KB collection ID for a chat session // Same team + user always returns the same ID (deterministic) // Format: chat_{team}_{user} or chat_user_{user} if no team func GetChatKBID(teamID, userID string) string { // Sanitize IDs: replace invalid chars with underscores cleanTeamID := sanitizeCollectionID(teamID) cleanUserID := sanitizeCollectionID(userID) if cleanTeamID != "" { return fmt.Sprintf("chat_%s_%s", cleanTeamID, cleanUserID) } return fmt.Sprintf("chat_user_%s", cleanUserID) } // sanitizeCollectionID replaces invalid characters with underscores // Collection IDs only allow: a-z, A-Z, 0-9, and underscore func sanitizeCollectionID(id string) string { if id == "" { return "" } // Replace any character that is not alphanumeric or underscore with underscore result := make([]byte, len(id)) for i := 0; i < len(id); i++ { c := id[i] if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' { result[i] = c } else { result[i] = '_' } } return string(result) } // mergeChatMetadata merges default metadata with chat context information func mergeChatMetadata(defaultMetadata map[string]interface{}, ctx *agentcontext.Context) map[string]interface{} { metadata := make(map[string]interface{}) // Copy default metadata for k, v := range defaultMetadata { metadata[k] = v } // Add chat-specific metadata (only for internal tracking, not displayed) metadata["chat_id"] = ctx.ChatID metadata["team_id"] = ctx.Authorized.TeamID metadata["user_id"] = ctx.Authorized.UserID // Get locale from context, default to zh-CN if not set locale := ctx.Locale if locale == "" { locale = "zh-CN" } locale = strings.ToLower(locale) // Use i18n for name and description (fixed, not showing user/team IDs) if _, exists := metadata["name"]; !exists { metadata["name"] = i18n.T(locale, "kb.chat.name") } if _, exists := metadata["description"]; !exists { metadata["description"] = i18n.T(locale, "kb.chat.description") } return metadata } // ============================================================================= // Chat Buffer Integration // ============================================================================= // InitBuffer initializes the chat buffer for the context // Should be called at the start of Stream() for root stack only func (ast *Assistant) InitBuffer(ctx *agentcontext.Context) { // Only initialize for root stack if ctx.Stack == nil || !ctx.Stack.IsRoot() { return } // Skip if buffer already exists if ctx.Buffer != nil { return } // Skip if History is disabled in options if ctx.Stack.Options != nil && ctx.Stack.Options.Skip != nil && ctx.Stack.Options.Skip.History { ctx.Logger.Debug("Buffer skipped: Skip.History is true") return } // Generate request ID if not set requestID := ctx.RequestID() if requestID == "" { requestID = uuid.New().String() } // Get connector and mode from options connector := "" mode := "" if ctx.Stack.Options != nil { connector = ctx.Stack.Options.Connector mode = ctx.Stack.Options.Mode } ctx.Buffer = agentcontext.NewChatBuffer(ctx.ChatID, requestID, ast.ID, connector, mode) ctx.Logger.Debug("Buffer initialized: chatID=%s, requestID=%s, assistantID=%s", ctx.ChatID, requestID, ast.ID) } // BufferUserInput adds user input messages to the buffer // Should be called after InitBuffer func (ast *Assistant) BufferUserInput(ctx *agentcontext.Context, inputMessages []agentcontext.Message) { if ctx.Buffer == nil { return } // Only root stack should buffer user input // Delegated agents share the same buffer but should not duplicate user input if ctx.Stack != nil && !ctx.Stack.IsRoot() { return } // Convert input messages to buffer format for _, msg := range inputMessages { // Extract content from message var content interface{} var name string content = msg.Content if msg.Name != nil { name = *msg.Name } ctx.Buffer.AddUserInput(content, name) } } // UpdateSpaceSnapshot updates the context memory snapshot in the buffer // Only captures Context-level memory (request-scoped temporary data) for recovery func (ast *Assistant) UpdateSpaceSnapshot(ctx *agentcontext.Context) { if ctx.Buffer == nil || ctx.Memory == nil || ctx.Memory.Context == nil { return } snapshot := ctx.Memory.Context.Snapshot() ctx.Buffer.SetSpaceSnapshot(snapshot) } // BeginStep starts tracking an execution step // Returns the step for further updates func (ast *Assistant) BeginStep(ctx *agentcontext.Context, stepType string, input map[string]interface{}) *agentcontext.BufferedStep { if ctx.Buffer == nil { return nil } // Update space snapshot before beginning step ast.UpdateSpaceSnapshot(ctx) return ctx.Buffer.BeginStep(stepType, input, ctx.Stack) } // CompleteStep marks the current step as completed func (ast *Assistant) CompleteStep(ctx *agentcontext.Context, output map[string]interface{}) { if ctx.Buffer == nil { return } ctx.Buffer.CompleteStep(output) } // FlushBuffer saves all buffered data to the database // Should be called in defer block at the end of Stream() func (ast *Assistant) FlushBuffer(ctx *agentcontext.Context, finalStatus string, err error) { if ctx.Buffer == nil { return } // Only flush for root stack if ctx.Stack == nil || !ctx.Stack.IsRoot() { return } // Get chat store chatStore := GetChatStore() if chatStore == nil { ctx.Logger.Error("Chat store not available, cannot flush buffer") return } // Mark current step as failed/interrupted if needed if finalStatus != agentcontext.StepStatusCompleted && err != nil { ctx.Buffer.FailCurrentStep(finalStatus, err) } // 1. Save all messages (user input + assistant responses) messages := ast.convertBufferedMessages(ctx.Buffer.GetMessages()) if len(messages) > 0 { if saveErr := chatStore.SaveMessages(ctx.ChatID, messages); saveErr != nil { ctx.Logger.Error("Failed to save messages: %v", saveErr) } else { ctx.Logger.Debug("Saved %d messages for chat=%s", len(messages), ctx.ChatID) } } // 2. Update chat last_message_at, last_connector, and last_mode if len(messages) > 0 { now := time.Now() updates := map[string]interface{}{ "last_message_at": now, } // Also update last_connector if available if connector := ctx.Buffer.Connector(); connector != "" { updates["last_connector"] = connector } // Also update last_mode if available if mode := ctx.Buffer.Mode(); mode != "" { updates["last_mode"] = mode } if updateErr := chatStore.UpdateChat(ctx.ChatID, updates); updateErr != nil { ctx.Logger.Debug("Failed to update chat: %v", updateErr) } } // 3. Only save resume steps on error/interrupt (not on success) if finalStatus != agentcontext.StepStatusCompleted { steps := ast.convertBufferedSteps(ctx.Buffer.GetStepsForResume(finalStatus)) if len(steps) > 0 { if saveErr := chatStore.SaveResume(steps); saveErr != nil { ctx.Logger.Error("Failed to save resume steps: %v", saveErr) } else { ctx.Logger.Debug("Saved %d resume steps for chat=%s (status=%s)", len(steps), ctx.ChatID, finalStatus) } } } // 4. Close SafeWriter to flush remaining writes (root stack only) // This ensures all pending SSE messages are sent before the response completes ctx.CloseSafeWriter() } // convertBufferedMessages converts BufferedMessage slice to store Message slice func (ast *Assistant) convertBufferedMessages(buffered []*agentcontext.BufferedMessage) []*storetypes.Message { if len(buffered) == 0 { return nil } messages := make([]*storetypes.Message, len(buffered)) for i, msg := range buffered { messages[i] = &storetypes.Message{ MessageID: msg.MessageID, ChatID: msg.ChatID, RequestID: msg.RequestID, Role: msg.Role, Type: msg.Type, Props: msg.Props, BlockID: msg.BlockID, ThreadID: msg.ThreadID, AssistantID: msg.AssistantID, Connector: msg.Connector, Mode: msg.Mode, Sequence: msg.Sequence, Metadata: msg.Metadata, CreatedAt: msg.CreatedAt, UpdatedAt: msg.CreatedAt, } } return messages } // convertBufferedSteps converts BufferedStep slice to store Resume slice func (ast *Assistant) convertBufferedSteps(buffered []*agentcontext.BufferedStep) []*storetypes.Resume { if len(buffered) == 0 { return nil } steps := make([]*storetypes.Resume, len(buffered)) for i, step := range buffered { steps[i] = &storetypes.Resume{ ResumeID: step.ResumeID, ChatID: step.ChatID, RequestID: step.RequestID, AssistantID: step.AssistantID, StackID: step.StackID, StackParentID: step.StackParentID, StackDepth: step.StackDepth, Type: step.Type, Status: step.Status, Input: step.Input, Output: step.Output, SpaceSnapshot: step.SpaceSnapshot, Error: step.Error, Sequence: step.Sequence, Metadata: step.Metadata, CreatedAt: step.CreatedAt, UpdatedAt: step.CreatedAt, } } return steps } // EnsureChat ensures a chat session exists, creates if not func (ast *Assistant) EnsureChat(ctx *agentcontext.Context) error { if ctx.ChatID == "" { return nil // No chat ID, skip } // Skip if history is disabled if ctx.Stack != nil && ctx.Stack.Options != nil && ctx.Stack.Options.Skip != nil && ctx.Stack.Options.Skip.History { return nil // Skip.History is true, don't create chat session } chatStore := GetChatStore() if chatStore == nil { return nil // No store, skip } // Check if chat exists _, err := chatStore.GetChat(ctx.ChatID) if err == nil { return nil // Chat exists } // Create new chat with permission fields chat := &storetypes.Chat{ ChatID: ctx.ChatID, AssistantID: ast.ID, Status: "active", Share: "private", Sort: 0, Metadata: ctx.Metadata, CreatedAt: time.Now(), UpdatedAt: time.Now(), } // Set last_connector from options (user selected connector) if ctx.Stack != nil && ctx.Stack.Options != nil && ctx.Stack.Options.Connector != "" { chat.LastConnector = ctx.Stack.Options.Connector } // Set permission fields from authorized info if ctx.Authorized != nil { chat.CreatedBy = ctx.Authorized.UserID chat.UpdatedBy = ctx.Authorized.UserID chat.TeamID = ctx.Authorized.TeamID chat.TenantID = ctx.Authorized.TenantID } return chatStore.CreateChat(chat) } // GetChatStore returns the chat store instance // Returns nil if storage is not configured func GetChatStore() storetypes.ChatStore { if storage == nil { return nil } return storage } // GetStore returns the full store instance (implements both ChatStore and AssistantStore) // Returns nil if storage is not configured func GetStore() storetypes.Store { if storage == nil { return nil } return storage } // ============================================================================= // Deprecated methods (kept for compatibility) // ============================================================================= func (ast *Assistant) saveChat(ctx *agentcontext.Context, input []agentcontext.Message, opts *agentcontext.Options) error { _ = ctx _ = input _ = opts return nil } ================================================ FILE: agent/assistant/chat_test.go ================================================ package assistant_test import ( "context" "fmt" "sync" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" agentcontext "github.com/yaoapp/yao/agent/context" storetypes "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/agent/testutils" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) func TestGetChatKBID(t *testing.T) { t.Run("WithTeamAndUser", func(t *testing.T) { teamID := "5659-5504-2879" userID := "4287-9400-2030-0504" collectionID := assistant.GetChatKBID(teamID, userID) // Should sanitize dashes to underscores expected := "chat_5659_5504_2879_4287_9400_2030_0504" assert.Equal(t, expected, collectionID) t.Logf("✓ Collection ID with team: %s", collectionID) }) t.Run("WithoutTeam", func(t *testing.T) { teamID := "" userID := "4287-9400-2030-0504" collectionID := assistant.GetChatKBID(teamID, userID) // Should use chat_user_ prefix expected := "chat_user_4287_9400_2030_0504" assert.Equal(t, expected, collectionID) t.Logf("✓ Collection ID without team: %s", collectionID) }) t.Run("Idempotent", func(t *testing.T) { teamID := "test-team-123" userID := "test-user-456" id1 := assistant.GetChatKBID(teamID, userID) id2 := assistant.GetChatKBID(teamID, userID) id3 := assistant.GetChatKBID(teamID, userID) // Same input should always produce same output assert.Equal(t, id1, id2) assert.Equal(t, id2, id3) t.Logf("✓ Idempotent: %s", id1) }) t.Run("SanitizeSpecialChars", func(t *testing.T) { teamID := "team-with-dashes@123" userID := "user.with.dots!" collectionID := assistant.GetChatKBID(teamID, userID) // Should only contain alphanumeric and underscores assert.Regexp(t, "^[a-zA-Z0-9_]+$", collectionID) t.Logf("✓ Sanitized ID: %s", collectionID) }) t.Run("EmptyUserID", func(t *testing.T) { teamID := "test-team" userID := "" collectionID := assistant.GetChatKBID(teamID, userID) // Should handle empty user ID gracefully expected := "chat_test_team_" assert.Equal(t, expected, collectionID) t.Logf("✓ Empty user ID handled: %s", collectionID) }) } func TestPrepareKBCollection(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) // Get assistant ast, err := assistant.Get("mohe") require.NoError(t, err) require.NotNil(t, ast) // Note: KB collection is now created during user login (see openapi/user/login.go) // These tests verify that InitializeConversation handles various scenarios gracefully t.Run("InitializeWithAuthorizedInfo", func(t *testing.T) { // Use unique IDs based on timestamp to avoid conflicts timestamp := fmt.Sprintf("%d", time.Now().UnixNano()) teamID := fmt.Sprintf("test_team_%s", timestamp) userID := fmt.Sprintf("test_user_%s", timestamp) ctx := agentcontext.New(context.Background(), &oauthtypes.AuthorizedInfo{ TeamID: teamID, UserID: userID, }, "test_chat_prepare_001") opts := &agentcontext.Options{} // InitializeConversation should succeed (KB collection created at login time) err := ast.InitializeConversation(ctx, opts) assert.NoError(t, err) t.Logf("✓ InitializeConversation completed successfully") }) t.Run("IdempotentInitialization", func(t *testing.T) { // Use unique IDs based on timestamp to avoid conflicts timestamp := fmt.Sprintf("%d", time.Now().UnixNano()) teamID := fmt.Sprintf("idem_team_%s", timestamp) userID := fmt.Sprintf("idem_user_%s", timestamp) ctx := agentcontext.New(context.Background(), &oauthtypes.AuthorizedInfo{ TeamID: teamID, UserID: userID, }, "test_chat_idempotent") opts := &agentcontext.Options{} // Multiple calls should all succeed err1 := ast.InitializeConversation(ctx, opts) assert.NoError(t, err1) err2 := ast.InitializeConversation(ctx, opts) assert.NoError(t, err2) err3 := ast.InitializeConversation(ctx, opts) assert.NoError(t, err3) t.Logf("✓ Idempotent initialization works correctly") }) t.Run("HandleMissingAuthorizedInfo", func(t *testing.T) { ctx := agentcontext.New(context.Background(), nil, "test_chat_no_auth") // Missing authorized info opts := &agentcontext.Options{} // Should not error, just return nil err := ast.InitializeConversation(ctx, opts) assert.NoError(t, err) t.Logf("✓ Correctly handled missing authorized info") }) t.Run("ConcurrentInitialization", func(t *testing.T) { // Use unique IDs based on timestamp to avoid conflicts timestamp := fmt.Sprintf("%d", time.Now().UnixNano()) teamID := fmt.Sprintf("concurrent_team_%s", timestamp) userID := fmt.Sprintf("concurrent_user_%s", timestamp) ctx := agentcontext.New(context.Background(), &oauthtypes.AuthorizedInfo{ TeamID: teamID, UserID: userID, }, "test_chat_concurrent") opts := &agentcontext.Options{} // Launch 5 concurrent calls var wg sync.WaitGroup errors := make([]error, 5) for i := 0; i < 5; i++ { wg.Add(1) go func(idx int) { defer wg.Done() errors[idx] = ast.InitializeConversation(ctx, opts) }(i) } // Wait for all goroutines to complete wg.Wait() // All calls should succeed for i, err := range errors { assert.NoError(t, err, "Goroutine %d should not error", i) } t.Logf("✓ Concurrent initialization handled correctly") }) } func TestInitializeConversation(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.Get("mohe") require.NoError(t, err) require.NotNil(t, ast) t.Run("FullInitialization", func(t *testing.T) { // Use unique IDs based on timestamp to avoid conflicts timestamp := fmt.Sprintf("%d", time.Now().UnixNano()) teamID := fmt.Sprintf("init_team_%s", timestamp) userID := fmt.Sprintf("init_user_%s", timestamp) ctx := agentcontext.New(context.Background(), &oauthtypes.AuthorizedInfo{ TeamID: teamID, UserID: userID, }, "test_init_chat_001") opts := &agentcontext.Options{} // Should initialize conversation without error // Note: KB collection is now created during user login, not here err := ast.InitializeConversation(ctx, opts) assert.NoError(t, err) t.Logf("✓ Conversation initialized successfully (KB collection created at login time)") }) t.Run("SkipHistoryFlag", func(t *testing.T) { ctx := agentcontext.New(context.Background(), &oauthtypes.AuthorizedInfo{ TeamID: "skip_team", UserID: "skip_user", }, "test_skip_history") opts := &agentcontext.Options{ Skip: &agentcontext.Skip{ History: true, }, } // Should skip initialization when history flag is set err := ast.InitializeConversation(ctx, opts) assert.NoError(t, err) t.Logf("✓ Correctly skipped with history flag") }) } // ============================================================================= // Buffer Integration Tests // ============================================================================= func TestBufferInitialization(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.Get("mohe") require.NoError(t, err) require.NotNil(t, ast) t.Run("InitBufferForRootStack", func(t *testing.T) { ctx := agentcontext.New(context.Background(), nil, "test_chat_buffer_001") // Enter stack to simulate root stack _, _, done := agentcontext.EnterStack(ctx, ast.ID, nil) defer done() // Initialize buffer ast.InitBuffer(ctx) // Verify buffer was created assert.NotNil(t, ctx.Buffer, "Buffer should be initialized for root stack") assert.Equal(t, "test_chat_buffer_001", ctx.Buffer.ChatID()) assert.Equal(t, ast.ID, ctx.Buffer.AssistantID()) t.Logf("✓ Buffer initialized: chatID=%s, assistantID=%s", ctx.Buffer.ChatID(), ctx.Buffer.AssistantID()) }) t.Run("SkipBufferForNestedStack", func(t *testing.T) { ctx := agentcontext.New(context.Background(), nil, "test_chat_buffer_nested") // Enter root stack _, _, doneRoot := agentcontext.EnterStack(ctx, "root_assistant", nil) defer doneRoot() // Enter nested stack _, _, doneNested := agentcontext.EnterStack(ctx, "nested_assistant", nil) defer doneNested() // Try to initialize buffer (should be skipped for nested stack) ast.InitBuffer(ctx) // Buffer should be nil because we're not at root assert.Nil(t, ctx.Buffer, "Buffer should not be initialized for nested stack") t.Logf("✓ Buffer correctly skipped for nested stack") }) t.Run("IdempotentBufferInit", func(t *testing.T) { ctx := agentcontext.New(context.Background(), nil, "test_chat_buffer_idem") // Enter stack _, _, done := agentcontext.EnterStack(ctx, ast.ID, nil) defer done() // Initialize buffer twice ast.InitBuffer(ctx) firstBuffer := ctx.Buffer ast.InitBuffer(ctx) secondBuffer := ctx.Buffer // Should be the same buffer instance assert.Same(t, firstBuffer, secondBuffer, "Buffer should be idempotent") t.Logf("✓ Buffer initialization is idempotent") }) } func TestBufferUserInput(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.Get("mohe") require.NoError(t, err) t.Run("BufferSimpleTextInput", func(t *testing.T) { ctx := agentcontext.New(context.Background(), nil, "test_chat_input_001") // Enter stack and init buffer _, _, done := agentcontext.EnterStack(ctx, ast.ID, nil) defer done() ast.InitBuffer(ctx) // Create input messages inputMessages := []agentcontext.Message{ { Role: agentcontext.RoleUser, Content: "Hello, how are you?", }, } // Buffer user input ast.BufferUserInput(ctx, inputMessages) // Verify buffer contains the message messages := ctx.Buffer.GetMessages() assert.Len(t, messages, 1, "Should have 1 buffered message") assert.Equal(t, "user", messages[0].Role) assert.Equal(t, "user_input", messages[0].Type) assert.Equal(t, "Hello, how are you?", messages[0].Props["content"]) t.Logf("✓ User input buffered: %v", messages[0].Props) }) t.Run("BufferMultipleMessages", func(t *testing.T) { ctx := agentcontext.New(context.Background(), nil, "test_chat_input_multi") // Enter stack and init buffer _, _, done := agentcontext.EnterStack(ctx, ast.ID, nil) defer done() ast.InitBuffer(ctx) // Create multiple input messages inputMessages := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "First message"}, {Role: agentcontext.RoleUser, Content: "Second message"}, } // Buffer user input ast.BufferUserInput(ctx, inputMessages) // Verify buffer contains all messages messages := ctx.Buffer.GetMessages() assert.Len(t, messages, 2, "Should have 2 buffered messages") assert.Equal(t, 1, messages[0].Sequence) assert.Equal(t, 2, messages[1].Sequence) t.Logf("✓ Multiple messages buffered with correct sequence") }) t.Run("BufferWithNilBuffer", func(t *testing.T) { ctx := agentcontext.New(context.Background(), nil, "test_chat_input_nil") // Don't initialize buffer inputMessages := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test"}, } // Should not panic ast.BufferUserInput(ctx, inputMessages) t.Logf("✓ BufferUserInput handles nil buffer gracefully") }) } func TestBufferStepTracking(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.Get("mohe") require.NoError(t, err) t.Run("BeginAndCompleteStep", func(t *testing.T) { ctx := agentcontext.New(context.Background(), nil, "test_chat_step_001") // Enter stack and init buffer _, _, done := agentcontext.EnterStack(ctx, ast.ID, nil) defer done() ast.InitBuffer(ctx) // Set some context memory data if ctx.Memory != nil && ctx.Memory.Context != nil { ctx.Memory.Context.Set("test_key", "test_value", 0) } // Begin a step step := ast.BeginStep(ctx, agentcontext.StepTypeLLM, map[string]interface{}{ "messages": []string{"Hello"}, }) assert.NotNil(t, step, "Step should be created") assert.Equal(t, agentcontext.StepTypeLLM, step.Type) assert.Equal(t, agentcontext.StepStatusRunning, step.Status) assert.NotEmpty(t, step.StackID) // Complete the step ast.CompleteStep(ctx, map[string]interface{}{ "content": "Response", }) // Verify step is completed steps := ctx.Buffer.GetAllSteps() assert.Len(t, steps, 1) assert.Equal(t, agentcontext.StepStatusCompleted, steps[0].Status) assert.Equal(t, "Response", steps[0].Output["content"]) t.Logf("✓ Step tracking works correctly") }) t.Run("ContextMemorySnapshotCapture", func(t *testing.T) { ctx := agentcontext.New(context.Background(), nil, "test_chat_memory_001") // Enter stack and init buffer _, _, done := agentcontext.EnterStack(ctx, ast.ID, nil) defer done() ast.InitBuffer(ctx) // Set context memory data before step require.NotNil(t, ctx.Memory) require.NotNil(t, ctx.Memory.Context) ctx.Memory.Context.Set("key1", "value1", 0) ctx.Memory.Context.Set("key2", 123, 0) // Begin step (should capture context memory snapshot) ast.BeginStep(ctx, agentcontext.StepTypeHookCreate, nil) // Verify context memory snapshot was captured steps := ctx.Buffer.GetAllSteps() require.Len(t, steps, 1) assert.NotNil(t, steps[0].SpaceSnapshot) assert.Equal(t, "value1", steps[0].SpaceSnapshot["key1"]) assert.Equal(t, 123, steps[0].SpaceSnapshot["key2"]) t.Logf("✓ Context memory snapshot captured: %v", steps[0].SpaceSnapshot) }) t.Run("MultipleSteps", func(t *testing.T) { ctx := agentcontext.New(context.Background(), nil, "test_chat_multi_step") // Enter stack and init buffer _, _, done := agentcontext.EnterStack(ctx, ast.ID, nil) defer done() ast.InitBuffer(ctx) // Step 1: hook_create ast.BeginStep(ctx, agentcontext.StepTypeHookCreate, map[string]interface{}{"phase": "create"}) ast.CompleteStep(ctx, map[string]interface{}{"result": "created"}) // Step 2: llm ast.BeginStep(ctx, agentcontext.StepTypeLLM, map[string]interface{}{"phase": "llm"}) ast.CompleteStep(ctx, map[string]interface{}{"result": "completed"}) // Step 3: hook_next ast.BeginStep(ctx, agentcontext.StepTypeHookNext, map[string]interface{}{"phase": "next"}) ast.CompleteStep(ctx, map[string]interface{}{"result": "done"}) // Verify all steps steps := ctx.Buffer.GetAllSteps() assert.Len(t, steps, 3) assert.Equal(t, agentcontext.StepTypeHookCreate, steps[0].Type) assert.Equal(t, agentcontext.StepTypeLLM, steps[1].Type) assert.Equal(t, agentcontext.StepTypeHookNext, steps[2].Type) t.Logf("✓ Multiple steps tracked correctly") }) } func TestFlushBuffer(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.Get("mohe") require.NoError(t, err) // Skip if chat store not available chatStore := assistant.GetChatStore() if chatStore == nil { t.Skip("Chat store not configured, skipping flush tests") } t.Run("FlushOnSuccess", func(t *testing.T) { chatID := fmt.Sprintf("test_flush_success_%s", uuid.New().String()[:8]) ctx := agentcontext.New(context.Background(), nil, chatID) // Enter stack and init buffer _, _, done := agentcontext.EnterStack(ctx, ast.ID, nil) defer done() ast.InitBuffer(ctx) // Ensure chat exists err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) // Add some messages to buffer require.NotNil(t, ctx.Buffer, "Buffer should be initialized") ctx.Buffer.AddUserInput("Test question", "") ctx.Buffer.AddAssistantMessage("M1", "text", map[string]interface{}{"content": "Test answer"}, "", "", ast.ID, nil) // Add a step ast.BeginStep(ctx, agentcontext.StepTypeLLM, nil) ast.CompleteStep(ctx, nil) // Flush buffer (success case) ast.FlushBuffer(ctx, agentcontext.StepStatusCompleted, nil) // Verify messages were saved messages, err := chatStore.GetMessages(chatID, storetypes.MessageFilter{}) assert.NoError(t, err) assert.Len(t, messages, 2, "Should have 2 messages saved") // Verify no resume records (success case) resumes, err := chatStore.GetResume(chatID) assert.NoError(t, err) assert.Len(t, resumes, 0, "Should have no resume records on success") // Cleanup chatStore.DeleteChat(chatID) t.Logf("✓ Buffer flushed on success: %d messages saved, no resume records", len(messages)) }) t.Run("FlushOnFailure", func(t *testing.T) { chatID := fmt.Sprintf("test_flush_fail_%s", uuid.New().String()[:8]) ctx := agentcontext.New(context.Background(), nil, chatID) // Enter stack and init buffer _, _, done := agentcontext.EnterStack(ctx, ast.ID, nil) defer done() ast.InitBuffer(ctx) // Ensure chat exists err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) // Add messages ctx.Buffer.AddUserInput("Test question", "") // Add a step that will "fail" ast.BeginStep(ctx, agentcontext.StepTypeLLM, map[string]interface{}{"test": "data"}) // Don't complete - simulate failure // Flush buffer (failure case) testErr := fmt.Errorf("simulated error") ast.FlushBuffer(ctx, agentcontext.ResumeStatusFailed, testErr) // Verify messages were saved messages, err := chatStore.GetMessages(chatID, storetypes.MessageFilter{}) assert.NoError(t, err) assert.Len(t, messages, 1, "Should have 1 message saved") // Verify resume records were saved resumes, err := chatStore.GetResume(chatID) assert.NoError(t, err) assert.Len(t, resumes, 1, "Should have 1 resume record on failure") assert.Equal(t, agentcontext.ResumeStatusFailed, resumes[0].Status) // Cleanup chatStore.DeleteResume(chatID) chatStore.DeleteChat(chatID) t.Logf("✓ Buffer flushed on failure: messages and resume records saved") }) t.Run("FlushOnInterrupt", func(t *testing.T) { chatID := fmt.Sprintf("test_flush_interrupt_%s", uuid.New().String()[:8]) ctx := agentcontext.New(context.Background(), nil, chatID) // Enter stack and init buffer _, _, done := agentcontext.EnterStack(ctx, ast.ID, nil) defer done() ast.InitBuffer(ctx) // Ensure chat exists err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) // Add messages and steps ctx.Buffer.AddUserInput("Test question", "") ast.BeginStep(ctx, agentcontext.StepTypeLLM, nil) // Flush buffer (interrupt case) ast.FlushBuffer(ctx, agentcontext.ResumeStatusInterrupted, nil) // Verify resume records were saved with interrupted status resumes, err := chatStore.GetResume(chatID) assert.NoError(t, err) assert.Len(t, resumes, 1, "Should have 1 resume record on interrupt") assert.Equal(t, agentcontext.ResumeStatusInterrupted, resumes[0].Status) // Cleanup chatStore.DeleteResume(chatID) chatStore.DeleteChat(chatID) t.Logf("✓ Buffer flushed on interrupt: resume records saved with interrupted status") }) t.Run("FlushWithModeAndConnector", func(t *testing.T) { chatID := fmt.Sprintf("test_flush_mode_%s", uuid.New().String()[:8]) ctx := agentcontext.New(context.Background(), nil, chatID) // Enter stack with connector and mode options opts := &agentcontext.Options{ Connector: "deepseek.v3", Mode: "task", } _, _, done := agentcontext.EnterStack(ctx, ast.ID, opts) defer done() ast.InitBuffer(ctx) // Verify buffer has correct connector and mode require.NotNil(t, ctx.Buffer, "Buffer should be initialized") assert.Equal(t, "deepseek.v3", ctx.Buffer.Connector(), "Buffer should have connector set") assert.Equal(t, "task", ctx.Buffer.Mode(), "Buffer should have mode set") // Ensure chat exists err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) // Add some messages to buffer ctx.Buffer.AddUserInput("Test question for mode", "") ctx.Buffer.AddAssistantMessage("M1", "text", map[string]interface{}{"content": "Test answer with mode"}, "", "", ast.ID, nil) // Flush buffer ast.FlushBuffer(ctx, agentcontext.StepStatusCompleted, nil) // Verify messages were saved with connector and mode messages, err := chatStore.GetMessages(chatID, storetypes.MessageFilter{}) assert.NoError(t, err) assert.Len(t, messages, 2, "Should have 2 messages saved") // Assistant message should have connector and mode var assistantMsg *storetypes.Message for _, msg := range messages { if msg.Role == "assistant" { assistantMsg = msg break } } require.NotNil(t, assistantMsg, "Should find assistant message") assert.Equal(t, "deepseek.v3", assistantMsg.Connector, "Message should have connector") assert.Equal(t, "task", assistantMsg.Mode, "Message should have mode") // Verify chat was updated with last_connector and last_mode chat, err := chatStore.GetChat(chatID) assert.NoError(t, err) assert.Equal(t, "deepseek.v3", chat.LastConnector, "Chat should have last_connector updated") assert.Equal(t, "task", chat.LastMode, "Chat should have last_mode updated") // Cleanup chatStore.DeleteChat(chatID) t.Logf("✓ Buffer flushed with mode and connector: connector=%s, mode=%s", chat.LastConnector, chat.LastMode) }) } func TestEnsureChat(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.Get("mohe") require.NoError(t, err) // Skip if chat store not available chatStore := assistant.GetChatStore() if chatStore == nil { t.Skip("Chat store not configured, skipping EnsureChat tests") } t.Run("CreateNewChat", func(t *testing.T) { chatID := fmt.Sprintf("test_ensure_new_%s", uuid.New().String()[:8]) ctx := agentcontext.New(context.Background(), nil, chatID) // Ensure chat creates it err := ast.EnsureChat(ctx) assert.NoError(t, err) // Verify chat was created chat, err := chatStore.GetChat(chatID) assert.NoError(t, err) assert.NotNil(t, chat) assert.Equal(t, chatID, chat.ChatID) assert.Equal(t, ast.ID, chat.AssistantID) assert.Equal(t, "active", chat.Status) // Cleanup chatStore.DeleteChat(chatID) t.Logf("✓ New chat created: %s", chatID) }) t.Run("SkipExistingChat", func(t *testing.T) { chatID := fmt.Sprintf("test_ensure_exist_%s", uuid.New().String()[:8]) ctx := agentcontext.New(context.Background(), nil, chatID) // Create chat first err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Title: "Existing Chat", Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) // EnsureChat should not error err = ast.EnsureChat(ctx) assert.NoError(t, err) // Verify chat still has original title chat, err := chatStore.GetChat(chatID) assert.NoError(t, err) assert.Equal(t, "Existing Chat", chat.Title) // Cleanup chatStore.DeleteChat(chatID) t.Logf("✓ Existing chat preserved") }) t.Run("SkipEmptyChatID", func(t *testing.T) { ctx := agentcontext.New(context.Background(), nil, "") // Should not error with empty chat ID err := ast.EnsureChat(ctx) assert.NoError(t, err) t.Logf("✓ Empty chat ID handled gracefully") }) t.Run("CreateChatWithPermissions", func(t *testing.T) { chatID := fmt.Sprintf("test_ensure_perm_%s", uuid.New().String()[:8]) // Create context with authorized info ctx := agentcontext.New(context.Background(), &oauthtypes.AuthorizedInfo{ UserID: "test_user_001", TeamID: "test_team_001", TenantID: "test_tenant_001", }, chatID) // EnsureChat should create with permission fields err := ast.EnsureChat(ctx) assert.NoError(t, err) // Verify permission fields were saved chat, err := chatStore.GetChat(chatID) assert.NoError(t, err) assert.NotNil(t, chat) assert.Equal(t, "test_user_001", chat.CreatedBy, "CreatedBy should be set") assert.Equal(t, "test_user_001", chat.UpdatedBy, "UpdatedBy should be set") assert.Equal(t, "test_team_001", chat.TeamID, "TeamID should be set") assert.Equal(t, "test_tenant_001", chat.TenantID, "TenantID should be set") // Cleanup chatStore.DeleteChat(chatID) t.Logf("✓ Chat created with permission fields: user=%s, team=%s, tenant=%s", chat.CreatedBy, chat.TeamID, chat.TenantID) }) t.Run("SkipHistoryEnabled", func(t *testing.T) { chatID := fmt.Sprintf("test_ensure_skip_%s", uuid.New().String()[:8]) // Create context ctx := agentcontext.New(context.Background(), nil, chatID) // Set up stack with Skip.History = true ctx.Stack = &agentcontext.Stack{ ID: "test_stack", AssistantID: ast.ID, Depth: 0, Options: &agentcontext.Options{ Skip: &agentcontext.Skip{ History: true, }, }, } // EnsureChat should NOT create chat when Skip.History is true err := ast.EnsureChat(ctx) assert.NoError(t, err) // Verify chat was NOT created _, err = chatStore.GetChat(chatID) assert.Error(t, err, "Chat should not be created when Skip.History is true") t.Logf("✓ Chat not created when Skip.History is true") }) } // TestEnsureChatMetadata verifies that ctx.Metadata is persisted to the chat record. // This is required for Host Agent: robot_id is passed in metadata so that // ListChats with chat_id_prefix=robot_{id}_ can filter by robot. func TestEnsureChatMetadata(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.Get("mohe") require.NoError(t, err) chatStore := assistant.GetChatStore() if chatStore == nil { t.Skip("Chat store not configured, skipping metadata tests") } t.Run("MetadataPersisted", func(t *testing.T) { chatID := fmt.Sprintf("robot_test_meta_%s", uuid.New().String()[:8]) ctx := agentcontext.New(context.Background(), &oauthtypes.AuthorizedInfo{ UserID: "test_user_meta", TeamID: "test_team_meta", }, chatID) ctx.Metadata = map[string]interface{}{ "robot_id": "robot_member_001", } err := ast.EnsureChat(ctx) require.NoError(t, err) chat, err := chatStore.GetChat(chatID) require.NoError(t, err) require.NotNil(t, chat) require.NotNil(t, chat.Metadata, "Metadata should be persisted") assert.Equal(t, "robot_member_001", chat.Metadata["robot_id"], "robot_id should be stored in chat metadata") // Cleanup chatStore.DeleteChat(chatID) t.Logf("✓ Chat metadata persisted: robot_id=%v", chat.Metadata["robot_id"]) }) t.Run("MetadataPersistedWithRobotChatIDPrefix", func(t *testing.T) { // Simulate robot host chat_id format: robot_{member_id}_{timestamp} memberID := "120004485525" chatID := fmt.Sprintf("robot_%s_%d", memberID, time.Now().UnixMilli()) ctx := agentcontext.New(context.Background(), &oauthtypes.AuthorizedInfo{ UserID: "test_user_robot", TeamID: "test_team_robot", }, chatID) ctx.Metadata = map[string]interface{}{ "robot_id": memberID, } err := ast.EnsureChat(ctx) require.NoError(t, err) chat, err := chatStore.GetChat(chatID) require.NoError(t, err) require.NotNil(t, chat) require.NotNil(t, chat.Metadata) assert.Equal(t, memberID, chat.Metadata["robot_id"]) // Cleanup chatStore.DeleteChat(chatID) t.Logf("✓ Robot-prefix chat persisted with metadata: chat_id=%s", chatID) }) t.Run("NilMetadataHandled", func(t *testing.T) { chatID := fmt.Sprintf("test_meta_nil_%s", uuid.New().String()[:8]) ctx := agentcontext.New(context.Background(), nil, chatID) ctx.Metadata = nil err := ast.EnsureChat(ctx) assert.NoError(t, err) chat, err := chatStore.GetChat(chatID) require.NoError(t, err) require.NotNil(t, chat) // Metadata nil is acceptable t.Logf("✓ Nil metadata handled gracefully") // Cleanup chatStore.DeleteChat(chatID) }) t.Run("MetadataMultipleFields", func(t *testing.T) { chatID := fmt.Sprintf("test_meta_multi_%s", uuid.New().String()[:8]) ctx := agentcontext.New(context.Background(), &oauthtypes.AuthorizedInfo{ UserID: "test_user_multi", TeamID: "test_team_multi", }, chatID) ctx.Metadata = map[string]interface{}{ "robot_id": "robot_multi_001", "source": "mission_control", } err := ast.EnsureChat(ctx) require.NoError(t, err) chat, err := chatStore.GetChat(chatID) require.NoError(t, err) require.NotNil(t, chat) require.NotNil(t, chat.Metadata) assert.Equal(t, "robot_multi_001", chat.Metadata["robot_id"]) assert.Equal(t, "mission_control", chat.Metadata["source"]) // Cleanup chatStore.DeleteChat(chatID) t.Logf("✓ Multiple metadata fields persisted correctly") }) } func TestConvertBufferedTypes(t *testing.T) { t.Run("ConvertBufferedMessages", func(t *testing.T) { // Create buffered messages buffered := []*agentcontext.BufferedMessage{ { MessageID: "msg_001", ChatID: "chat_001", RequestID: "req_001", Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Hello"}, Sequence: 1, CreatedAt: time.Now(), }, { MessageID: "msg_002", ChatID: "chat_001", RequestID: "req_001", Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "Hi there!"}, BlockID: "block_001", AssistantID: "test_assistant", Sequence: 2, CreatedAt: time.Now(), }, } // Verify structure matches store types assert.Len(t, buffered, 2) assert.Equal(t, "user", buffered[0].Role) assert.Equal(t, "assistant", buffered[1].Role) assert.Equal(t, "block_001", buffered[1].BlockID) t.Logf("✓ Buffered messages have correct structure") }) t.Run("ConvertBufferedSteps", func(t *testing.T) { // Create buffered steps buffered := []*agentcontext.BufferedStep{ { ResumeID: "resume_001", ChatID: "chat_001", RequestID: "req_001", AssistantID: "test_assistant", StackID: "stack_001", StackDepth: 0, Type: agentcontext.StepTypeLLM, Status: agentcontext.ResumeStatusFailed, Input: map[string]interface{}{"messages": []string{"Hello"}}, SpaceSnapshot: map[string]interface{}{"key": "value"}, Error: "Test error", Sequence: 1, CreatedAt: time.Now(), }, } // Verify structure assert.Len(t, buffered, 1) assert.Equal(t, agentcontext.StepTypeLLM, buffered[0].Type) assert.Equal(t, agentcontext.ResumeStatusFailed, buffered[0].Status) assert.Equal(t, "Test error", buffered[0].Error) assert.Equal(t, "value", buffered[0].SpaceSnapshot["key"]) t.Logf("✓ Buffered steps have correct structure") }) } ================================================ FILE: agent/assistant/handlers/stream.go ================================================ package handlers import ( "time" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/agent/output" "github.com/yaoapp/yao/agent/output/message" ) // DefaultStreamHandler creates a default stream handler that sends messages via context // This handler is used when no custom handler is provided func DefaultStreamHandler(ctx *context.Context) message.StreamFunc { // Create stream state manager state := &streamState{ ctx: ctx, inGroup: false, currentGroupID: "", messageSeq: 0, } return func(chunkType message.StreamChunkType, data []byte) int { trace, _ := ctx.Trace() if trace != nil { trace.Info(i18n.T(ctx.Locale, "llm.handlers.stream.info"), map[string]any{"data": string(data)}) } // Handle different chunk types switch chunkType { case message.ChunkStreamStart: return state.handleStreamStart(data) case message.ChunkMessageStart: return state.handleMessageStart(data) case message.ChunkText: return state.handleText(data) case message.ChunkThinking: return state.handleThinking(data) case message.ChunkToolCall: return state.handleToolCall(data) case message.ChunkMetadata: return state.handleMetadata(data) case message.ChunkError: return state.handleError(data) case message.ChunkMessageEnd: return state.handleMessageEnd(data) case message.ChunkStreamEnd: return state.handleStreamEnd(data) default: // Unknown chunk type, continue return 0 } } } // streamState manages the state of the streaming process type streamState struct { ctx *context.Context inGroup bool currentGroupID string // Current group ID (shared by all chunks in the group) currentType string // Track the current message type (text, thinking, tool_call) buffer []byte chunkCount int // Track number of chunks in current group messageSeq int // Message sequence number (for generating readable IDs) groupStartTime time.Time // Track when group started } // handleStreamStart handles stream start event func (s *streamState) handleStreamStart(data []byte) int { // Send event message to indicate stream has started // This is a lifecycle event, CUI clients can show it, OpenAI clients will ignore it var startData message.EventStreamStartData err := jsoniter.Unmarshal(data, &startData) if err != nil { log.Error("Failed to unmarshal stream start data: %v", err) } msg := output.NewEventMessage("stream_start", "Stream started", startData) s.ctx.Send(msg) return 0 } // handleMessageStart handles message start event func (s *streamState) handleMessageStart(data []byte) int { // Parse message start data first to get the message ID var startData message.EventMessageStartData if err := jsoniter.Unmarshal(data, &startData); err != nil { log.Error("Failed to unmarshal message start data: %v", err) return 0 } // Use the message ID from the start data, or generate one if not provided messageID := startData.MessageID if messageID == "" { messageID = s.ctx.IDGenerator.GenerateMessageID() startData.MessageID = messageID } // Auto-set ThreadID from Stack for nested agent calls if startData.ThreadID == "" && s.ctx.Stack != nil && !s.ctx.Stack.IsRoot() { startData.ThreadID = s.ctx.Stack.ID } // Initialize message state with the correct message ID s.inGroup = true s.currentGroupID = messageID s.buffer = []byte{} s.chunkCount = 0 s.messageSeq = 0 // Reset message sequence for each message s.groupStartTime = time.Now() // Send message_start event msg := output.NewEventMessage(message.EventMessageStart, "Message started", startData) s.ctx.Send(msg) return 0 // Continue } // handleText handles text content chunks func (s *streamState) handleText(data []byte) int { if len(data) == 0 { return 0 } // Track current message type s.currentType = message.TypeText // Append to buffer s.buffer = append(s.buffer, data...) s.chunkCount++ s.messageSeq++ // Send delta message // - ChunkID: Unique chunk ID (C1, C2, C3...) for this fragment // - MessageID: Same for all chunks of this logical message (frontend merges by message_id) msg := &message.Message{ ChunkID: s.ctx.IDGenerator.GenerateChunkID(), // Unique chunk ID MessageID: s.currentGroupID, // Message ID for merging (all chunks share this) Type: message.TypeText, Delta: true, Props: map[string]interface{}{ "content": string(data), }, } if err := s.ctx.Send(msg); err != nil { // Log error but continue streaming return 0 } return 0 // Continue } // handleThinking handles thinking/reasoning chunks func (s *streamState) handleThinking(data []byte) int { if len(data) == 0 { return 0 } // Track current message type s.currentType = message.TypeThinking // Append to buffer s.buffer = append(s.buffer, data...) s.chunkCount++ s.messageSeq++ // Send delta message // - ChunkID: Unique chunk ID (C1, C2, C3...) for this fragment // - MessageID: Same for all chunks of this logical message (frontend merges by message_id) msg := &message.Message{ ChunkID: s.ctx.IDGenerator.GenerateChunkID(), // Unique chunk ID MessageID: s.currentGroupID, // Message ID for merging (all chunks share this) Type: message.TypeThinking, Delta: true, Props: map[string]interface{}{ "content": string(data), }, } if err := s.ctx.Send(msg); err != nil { return 0 } return 0 // Continue } // handleToolCall handles tool call chunks func (s *streamState) handleToolCall(data []byte) int { if len(data) == 0 { return 0 } // Track current message type s.currentType = message.TypeToolCall // Append to buffer for message_end event s.buffer = append(s.buffer, data...) s.chunkCount++ s.messageSeq++ // Parse the tool call delta data (JSON array from OpenAI) var toolCallArray []map[string]interface{} if err := jsoniter.Unmarshal(data, &toolCallArray); err != nil { // If parse fails, log and skip this chunk return 0 } // Extract tool call fields from delta // OpenAI delta typically has one element, but we handle arrays safely var props map[string]interface{} var deltaAction string var deltaPath string if len(toolCallArray) == 1 { tc := toolCallArray[0] props = map[string]interface{}{} hasIdentity := false if id, ok := tc["id"].(string); ok { props["id"] = id hasIdentity = true } if typ, ok := tc["type"].(string); ok { props["type"] = typ hasIdentity = true } if index, ok := tc["index"].(float64); ok { props["index"] = int(index) } if fn, ok := tc["function"].(map[string]interface{}); ok { if name, ok := fn["name"].(string); ok { props["name"] = name hasIdentity = true } if args, ok := fn["arguments"].(string); ok { props["arguments"] = args } } if hasIdentity { // First chunk with id/name/type: merge so all fields are applied. deltaAction = "merge" } else if _, ok := props["arguments"]; ok { // Subsequent chunk with only arguments fragment: append to arguments. deltaAction = "append" deltaPath = "arguments" } else { deltaAction = "merge" } } else { // Multiple tool calls in delta (rare) - keep as array props = map[string]interface{}{ "calls": toolCallArray, } deltaAction = "merge" } // Send delta message // - ChunkID: Unique chunk ID (C1, C2, C3...) for this fragment // - MessageID: Same for all chunks of this logical message (frontend merges by message_id) // - DeltaAction: "append" for arguments chunks, "merge" for id/type/name chunks // - DeltaPath: "arguments" when appending arguments field // OpenAI sends: first chunk has id/type/name, subsequent chunks only have arguments fragments msg := &message.Message{ ChunkID: s.ctx.IDGenerator.GenerateChunkID(), // Unique chunk ID MessageID: s.currentGroupID, // Message ID for merging (all chunks share this) Type: message.TypeToolCall, Delta: true, DeltaAction: deltaAction, // "append" for arguments, "merge" for static fields DeltaPath: deltaPath, // "arguments" when appending Props: props, // Flattened tool call fields } if err := s.ctx.Send(msg); err != nil { return 0 } return 0 // Continue } // handleMetadata handles metadata chunks (usage, finish_reason, etc.) func (s *streamState) handleMetadata(data []byte) int { // Metadata is usually not displayed to users // Could be logged or stored for analytics return 0 // Continue } // handleError handles error chunks func (s *streamState) handleError(data []byte) int { // Send error message msg := output.NewErrorMessage(string(data), "stream_error") s.ctx.Send(msg) return 1 // Stop streaming on error } // handleMessageEnd handles message end event func (s *streamState) handleMessageEnd(data []byte) int { if !s.inGroup { return 0 } // Calculate duration durationMs := time.Since(s.groupStartTime).Milliseconds() // Use the tracked message type (thinking, text, tool_call, etc.) msgType := s.currentType if msgType == "" { msgType = message.TypeText // Fallback to text if type not set } // Get ThreadID from Stack for nested agent calls var threadID string if s.ctx.Stack != nil && !s.ctx.Stack.IsRoot() { threadID = s.ctx.Stack.ID } // Get BlockID from metadata if available var blockID string if s.ctx != nil { if metadata := s.ctx.GetMessageMetadata(s.currentGroupID); metadata != nil { blockID = metadata.BlockID } } // Buffer the complete LLM message for storage // Delta chunks are not stored, but we need to save the final complete content // Skip if History is disabled in options shouldSkipHistory := s.ctx.Stack != nil && s.ctx.Stack.Options != nil && s.ctx.Stack.Options.Skip != nil && s.ctx.Stack.Options.Skip.History if s.ctx.Buffer != nil && len(s.buffer) > 0 && !shouldSkipHistory { assistantID := "" if s.ctx.Stack != nil { assistantID = s.ctx.Stack.AssistantID } // Build props based on message type var props map[string]interface{} if msgType == message.TypeToolCall { // For tool calls, try to parse the accumulated buffer as JSON var toolCallData interface{} if err := jsoniter.Unmarshal(s.buffer, &toolCallData); err == nil { props = map[string]interface{}{ "calls": toolCallData, } } else { props = map[string]interface{}{ "content": string(s.buffer), } } } else { // For text/thinking, content is the accumulated text props = map[string]interface{}{ "content": string(s.buffer), } } s.ctx.Buffer.AddAssistantMessage( s.currentGroupID, // Use the message ID msgType, props, blockID, threadID, assistantID, nil, ) } // Build EventMessageEndData with complete content endData := message.EventMessageEndData{ MessageID: s.currentGroupID, // Use the message ID Type: msgType, Timestamp: time.Now().UnixMilli(), ThreadID: threadID, // Include ThreadID for concurrent stream identification DurationMs: durationMs, ChunkCount: s.chunkCount, Status: "completed", Extra: map[string]interface{}{ "content": string(s.buffer), // Include complete content in the event }, } // Send message_end event msg := output.NewEventMessage(message.EventMessageEnd, "Message completed", endData) s.ctx.Send(msg) // Reset state s.inGroup = false s.currentGroupID = "" s.currentType = "" s.buffer = []byte{} s.chunkCount = 0 return 0 // Continue } // handleStreamEnd handles stream end event func (s *streamState) handleStreamEnd(data []byte) int { // Parse the stream end data var endData message.EventStreamEndData if err := jsoniter.Unmarshal(data, &endData); err != nil { log.Error("Failed to parse stream_end data: %v", err) s.ctx.Flush() return 0 } // Send stream_end event as a message to frontend msg := output.NewEventMessage("stream_end", "Stream completed", endData) s.ctx.Send(msg) // Flush any remaining data s.ctx.Flush() return 0 // Continue (stream will end naturally) } ================================================ FILE: agent/assistant/history.go ================================================ package assistant import ( "fmt" "reflect" jsoniter "github.com/json-iterator/go" agentcontext "github.com/yaoapp/yao/agent/context" storetypes "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/trace/types" ) // ============================================================================= // Chat History Management // ============================================================================= // HistoryResult represents the result of history processing type HistoryResult struct { InputMessages []agentcontext.Message // Clean input messages (without overlap) FullMessages []agentcontext.Message // Full messages (history + clean input) } // getHistorySize returns the history size with priority: opts.HistorySize > storeSetting.MaxSize > default (20) func getHistorySize(opts *agentcontext.Options) int { const defaultHistorySize = 20 if opts != nil && opts.HistorySize > 0 { return opts.HistorySize } if setting := GetStoreSetting(); setting != nil && setting.MaxSize > 0 { return setting.MaxSize } return defaultHistorySize } // WithHistory merges the input messages with chat history and traces it // Returns HistoryResult containing: // - InputMessages: cleaned input (overlap removed) // - FullMessages: history + clean input merged func (ast *Assistant) WithHistory(ctx *agentcontext.Context, input []agentcontext.Message, agentNode types.Node, options ...*agentcontext.Options) (*HistoryResult, error) { // Get options var opts *agentcontext.Options if len(options) > 0 && options[0] != nil { opts = options[0] } // SKIP: History (for internal calls like title/prompt etc.) if opts != nil && opts.Skip != nil && opts.Skip.History { result := &HistoryResult{ InputMessages: input, FullMessages: input, } ast.traceAgentHistory(ctx, agentNode, result.FullMessages) return result, nil } // Resolve history size: opts.HistorySize > storeSetting.MaxSize > default (20) maxSize := getHistorySize(opts) // Load history from store historyMessages, err := ast.loadHistory(ctx, maxSize) if err != nil { // Log warning but continue without history ctx.Logger.Warn("Failed to load history for chat=%s: %v", ctx.ChatID, err) result := &HistoryResult{ InputMessages: input, FullMessages: input, } ast.traceAgentHistory(ctx, agentNode, result.FullMessages) return result, nil } // If no history, return input as is if len(historyMessages) == 0 { ctx.Logger.HistoryLoad(0, maxSize) result := &HistoryResult{ InputMessages: input, FullMessages: input, } ast.traceAgentHistory(ctx, agentNode, result.FullMessages) return result, nil } // Log history loaded ctx.Logger.HistoryLoad(len(historyMessages), maxSize) // Find overlap between history and input // Some external clients may include history in their requests overlapIndex := ast.findOverlapIndex(historyMessages, input) // Remove overlap from input cleanInput := input if overlapIndex > 0 { cleanInput = input[overlapIndex:] ctx.Logger.HistoryOverlap(overlapIndex) } // Merge history with clean input fullMessages := make([]agentcontext.Message, 0, len(historyMessages)+len(cleanInput)) fullMessages = append(fullMessages, historyMessages...) fullMessages = append(fullMessages, cleanInput...) result := &HistoryResult{ InputMessages: cleanInput, FullMessages: fullMessages, } // Log the chat history ast.traceAgentHistory(ctx, agentNode, result.FullMessages) return result, nil } // loadHistory loads chat history from the store // Returns the most recent maxSize messages, ordered by time (oldest first) func (ast *Assistant) loadHistory(ctx *agentcontext.Context, maxSize int) ([]agentcontext.Message, error) { // Check if chat ID is available if ctx.ChatID == "" { return nil, nil } // Get chat store chatStore := GetChatStore() if chatStore == nil { return nil, nil } // Load messages from store with limit filter := storetypes.MessageFilter{ Limit: maxSize, } storeMessages, err := chatStore.GetMessages(ctx.ChatID, filter) if err != nil { return nil, fmt.Errorf("failed to get messages: %w", err) } if len(storeMessages) == 0 { return nil, nil } // Convert store messages to context messages messages := make([]agentcontext.Message, 0, len(storeMessages)) for _, msg := range storeMessages { // Only include user and assistant messages for LLM context // Skip internal types like loading, event, etc. if msg.Role != "user" && msg.Role != "assistant" { continue } // Convert store message to context message ctxMsg := ast.convertStoreMessageToContext(msg) if ctxMsg != nil { messages = append(messages, *ctxMsg) } } return messages, nil } // convertStoreMessageToContext converts a store message to a context message func (ast *Assistant) convertStoreMessageToContext(msg *storetypes.Message) *agentcontext.Message { if msg == nil { return nil } // Handle special message types: // - tool_call/action: convert to historical summary text for LLM context // - loading/event: skip (pure UI/lifecycle signals, no semantic value) // - error: kept as-is so LLM can help troubleshoot issues switch msg.Type { case "tool_call": return ast.convertToolCallToContext(msg) case "action": return ast.convertActionToContext(msg) case "loading", "event": return nil } // Extract content from Props content := ast.extractContentFromProps(msg.Props, msg.Type) if content == nil { return nil } // Build context message ctxMsg := &agentcontext.Message{ Role: agentcontext.MessageRole(msg.Role), Content: content, } // Handle name field if msg.Props != nil { if name, ok := msg.Props["name"].(string); ok && name != "" { ctxMsg.Name = &name } } return ctxMsg } // extractContentFromProps extracts the content from message Props based on message type func (ast *Assistant) extractContentFromProps(props map[string]interface{}, msgType string) interface{} { if props == nil { return nil } // For user input, content is stored directly in props["content"] if msgType == "user_input" { return props["content"] } // For text type messages if msgType == "text" { if text, ok := props["text"].(string); ok { return text } // Also try content field if content, ok := props["content"].(string); ok { return content } } // For other types, try to extract content or text if content, ok := props["content"]; ok { return content } if text, ok := props["text"]; ok { return text } return nil } // convertToolCallToContext converts a tool_call store message to a historical summary text message. // This allows the LLM to understand what tools were previously called without re-invoking them. // // Supports two Props formats: // - Standard ToolCallProps: {"name": "tool_name", "arguments": "{...}"} // - Raw stream chunks: {"content": "[{\"index\":0,\"id\":\"call_...\",\"function\":{\"name\":\"tool\"}}][...]"} func (ast *Assistant) convertToolCallToContext(msg *storetypes.Message) *agentcontext.Message { if msg.Props == nil { return nil } // Try standard ToolCallProps format first if name, ok := msg.Props["name"].(string); ok && name != "" { args, _ := msg.Props["arguments"].(string) const maxArgsLen = 500 if len(args) > maxArgsLen { args = args[:maxArgsLen] + "..." } return &agentcontext.Message{ Role: agentcontext.RoleAssistant, Content: fmt.Sprintf("[Historical Tool Call Summary] Called tool \"%s\" with arguments: %s", name, args), } } // Try raw stream chunk format: {"content": "[...][...]..."} // Each chunk is a JSON array like [{"index":0,"id":"call_...","function":{"name":"echo__ping"}}] // Subsequent chunks append arguments: [{"index":0,"function":{"arguments":"..."}}] if raw, ok := msg.Props["content"].(string); ok && raw != "" { name, args := parseToolCallRawChunks(raw) if name == "" { return nil } const maxArgsLen = 500 if len(args) > maxArgsLen { args = args[:maxArgsLen] + "..." } return &agentcontext.Message{ Role: agentcontext.RoleAssistant, Content: fmt.Sprintf("[Historical Tool Call Summary] Called tool \"%s\" with arguments: %s", name, args), } } return nil } // parseToolCallRawChunks parses concatenated raw stream chunks to extract tool name and arguments. // Input format: "[{...}][{...}][{...}]" — multiple JSON arrays concatenated without separator. func parseToolCallRawChunks(raw string) (name, args string) { // Split concatenated JSON arrays: "][" is the boundary // e.g. "[{...}][{...}]" → ["[{...}]", "[{...}]"] chunks := splitJSONArrays(raw) var argParts []string for _, chunk := range chunks { var items []map[string]interface{} if err := jsoniter.UnmarshalFromString(chunk, &items); err != nil || len(items) == 0 { continue } item := items[0] if fn, ok := item["function"].(map[string]interface{}); ok { if n, ok := fn["name"].(string); ok && n != "" && name == "" { name = n } if a, ok := fn["arguments"].(string); ok && a != "" { argParts = append(argParts, a) } } } args = "" for _, part := range argParts { args += part } return name, args } // splitJSONArrays splits a string of concatenated JSON arrays "[...][...][...]" into individual arrays. func splitJSONArrays(s string) []string { var result []string depth := 0 start := -1 for i, ch := range s { switch ch { case '[': if depth == 0 { start = i } depth++ case ']': depth-- if depth == 0 && start >= 0 { result = append(result, s[start:i+1]) start = -1 } } } return result } // convertActionToContext converts an action store message to a historical summary text message. // This allows the LLM to understand what system actions were previously executed. func (ast *Assistant) convertActionToContext(msg *storetypes.Message) *agentcontext.Message { if msg.Props == nil { return nil } name, _ := msg.Props["name"].(string) if name == "" { return nil } payload := "" if msg.Props["payload"] != nil { if payloadStr, err := jsoniter.MarshalToString(msg.Props["payload"]); err == nil { const maxPayloadLen = 500 if len(payloadStr) > maxPayloadLen { payloadStr = payloadStr[:maxPayloadLen] + "..." } payload = payloadStr } } if payload != "" { return &agentcontext.Message{ Role: agentcontext.RoleAssistant, Content: fmt.Sprintf("[Historical Action Summary] Executed action \"%s\" with payload: %s", name, payload), } } return &agentcontext.Message{ Role: agentcontext.RoleAssistant, Content: fmt.Sprintf("[Historical Action Summary] Executed action \"%s\"", name), } } // findOverlapIndex finds the index in input where history messages end // Returns the number of input messages that overlap with history func (ast *Assistant) findOverlapIndex(history, input []agentcontext.Message) int { if len(history) == 0 || len(input) == 0 { return 0 } // We need to find the longest suffix of history that matches a prefix of input // Start from the end of history and try to match with the beginning of input maxOverlap := len(history) if maxOverlap > len(input) { maxOverlap = len(input) } // Try different overlap lengths, starting from the largest possible for overlapLen := maxOverlap; overlapLen > 0; overlapLen-- { // Check if the last 'overlapLen' messages of history match the first 'overlapLen' of input historyStart := len(history) - overlapLen matched := true for i := 0; i < overlapLen; i++ { if !ast.messagesMatch(history[historyStart+i], input[i]) { matched = false break } } if matched { return overlapLen } } return 0 } // messagesMatch checks if two messages are equivalent func (ast *Assistant) messagesMatch(a, b agentcontext.Message) bool { // Must have same role if a.Role != b.Role { return false } // Compare content return ast.contentMatches(a.Content, b.Content) } // contentMatches compares two content values for equality func (ast *Assistant) contentMatches(a, b interface{}) bool { // Handle nil cases if a == nil && b == nil { return true } if a == nil || b == nil { return false } // If both are strings, compare directly aStr, aIsStr := a.(string) bStr, bIsStr := b.(string) if aIsStr && bIsStr { return aStr == bStr } // For complex content (arrays, etc.), use deep equal return reflect.DeepEqual(a, b) } ================================================ FILE: agent/assistant/history_test.go ================================================ package assistant_test import ( "context" "fmt" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" storetypes "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/agent/testutils" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) // ============================================================================= // Helper Functions // ============================================================================= // newHistoryTestContext creates a test context for history tests func newHistoryTestContext(chatID string) *agentcontext.Context { authorized := &oauthtypes.AuthorizedInfo{ Subject: "test-user", UserID: "history-test-user", TeamID: "history-test-team", TenantID: "history-test-tenant", } ctx := agentcontext.New(context.Background(), authorized, chatID) ctx.AssistantID = "tests.history" ctx.Locale = "en-us" ctx.Client = agentcontext.Client{ Type: "web", IP: "127.0.0.1", } ctx.Referer = agentcontext.RefererAPI ctx.Accept = agentcontext.AcceptWebCUI ctx.IDGenerator = message.NewIDGenerator() ctx.Metadata = make(map[string]interface{}) return ctx } // ============================================================================= // WithHistory Tests // ============================================================================= func TestWithHistory(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) // Get assistant ast, err := assistant.Get("tests.history") require.NoError(t, err) require.NotNil(t, ast) // Get chat store for setup/cleanup chatStore := assistant.GetChatStore() if chatStore == nil { t.Skip("Chat store not configured, skipping history tests") } t.Run("NoHistory", func(t *testing.T) { chatID := fmt.Sprintf("test_history_none_%s", uuid.New().String()[:8]) ctx := newHistoryTestContext(chatID) // Create chat without any messages err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer chatStore.DeleteChat(chatID) input := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Hello, this is my first message"}, } result, err := ast.WithHistory(ctx, input, nil) require.NoError(t, err) require.NotNil(t, result) // With no history, InputMessages and FullMessages should be the same as input assert.Equal(t, input, result.InputMessages) assert.Equal(t, input, result.FullMessages) t.Log("✓ No history: input returned as is") }) t.Run("WithExistingHistory", func(t *testing.T) { chatID := fmt.Sprintf("test_history_exist_%s", uuid.New().String()[:8]) ctx := newHistoryTestContext(chatID) reqID := uuid.New().String()[:8] // Create chat err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer func() { chatStore.DeleteMessages(chatID, nil) chatStore.DeleteChat(chatID) }() // Add history messages historyMessages := []*storetypes.Message{ { MessageID: fmt.Sprintf("hist_msg_1_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_%s", reqID), Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Previous question"}, Sequence: 1, AssistantID: ast.ID, CreatedAt: time.Now().Add(-2 * time.Minute), }, { MessageID: fmt.Sprintf("hist_msg_2_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_%s", reqID), Role: "assistant", Type: "text", Props: map[string]interface{}{"text": "Previous answer"}, Sequence: 2, AssistantID: ast.ID, CreatedAt: time.Now().Add(-1 * time.Minute), }, } err = chatStore.SaveMessages(chatID, historyMessages) require.NoError(t, err) // New input message input := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "New question"}, } result, err := ast.WithHistory(ctx, input, nil) require.NoError(t, err) require.NotNil(t, result) // InputMessages should be unchanged (no overlap) assert.Equal(t, input, result.InputMessages) // FullMessages should have history + input assert.Len(t, result.FullMessages, 3) // 2 history + 1 new // Verify order: history first, then input assert.Equal(t, agentcontext.RoleUser, result.FullMessages[0].Role) assert.Equal(t, "Previous question", result.FullMessages[0].Content) assert.Equal(t, agentcontext.RoleAssistant, result.FullMessages[1].Role) assert.Equal(t, "Previous answer", result.FullMessages[1].Content) assert.Equal(t, agentcontext.RoleUser, result.FullMessages[2].Role) assert.Equal(t, "New question", result.FullMessages[2].Content) t.Log("✓ History merged correctly with new input") }) t.Run("SkipHistoryOption", func(t *testing.T) { chatID := fmt.Sprintf("test_history_skip_%s", uuid.New().String()[:8]) ctx := newHistoryTestContext(chatID) reqID := uuid.New().String()[:8] // Create chat with history err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer func() { chatStore.DeleteMessages(chatID, nil) chatStore.DeleteChat(chatID) }() // Add history message err = chatStore.SaveMessages(chatID, []*storetypes.Message{ { MessageID: fmt.Sprintf("skip_hist_1_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_skip_%s", reqID), Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Should be skipped"}, Sequence: 1, AssistantID: ast.ID, CreatedAt: time.Now(), }, }) require.NoError(t, err) input := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Only this should appear"}, } // Use Skip.History option opts := &agentcontext.Options{ Skip: &agentcontext.Skip{ History: true, }, } result, err := ast.WithHistory(ctx, input, nil, opts) require.NoError(t, err) require.NotNil(t, result) // Both should be same as input (history skipped) assert.Equal(t, input, result.InputMessages) assert.Equal(t, input, result.FullMessages) assert.Len(t, result.FullMessages, 1) t.Log("✓ History skipped when Skip.History=true") }) t.Run("OverlapDetection", func(t *testing.T) { chatID := fmt.Sprintf("test_history_overlap_%s", uuid.New().String()[:8]) ctx := newHistoryTestContext(chatID) reqID := uuid.New().String()[:8] // Create chat err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer func() { chatStore.DeleteMessages(chatID, nil) chatStore.DeleteChat(chatID) }() // Add history messages err = chatStore.SaveMessages(chatID, []*storetypes.Message{ { MessageID: fmt.Sprintf("overlap_1_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_overlap_%s", reqID), Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Message one"}, Sequence: 1, AssistantID: ast.ID, CreatedAt: time.Now().Add(-3 * time.Minute), }, { MessageID: fmt.Sprintf("overlap_2_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_overlap_%s", reqID), Role: "assistant", Type: "text", Props: map[string]interface{}{"text": "Response one"}, Sequence: 2, AssistantID: ast.ID, CreatedAt: time.Now().Add(-2 * time.Minute), }, { MessageID: fmt.Sprintf("overlap_3_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_overlap_2_%s", reqID), Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Message two"}, Sequence: 3, AssistantID: ast.ID, CreatedAt: time.Now().Add(-1 * time.Minute), }, }) require.NoError(t, err) // Input that overlaps with history (includes last messages) // Some clients send full history + new message input := []agentcontext.Message{ {Role: agentcontext.RoleAssistant, Content: "Response one"}, // Overlap {Role: agentcontext.RoleUser, Content: "Message two"}, // Overlap {Role: agentcontext.RoleUser, Content: "New message"}, // New } result, err := ast.WithHistory(ctx, input, nil) require.NoError(t, err) require.NotNil(t, result) // InputMessages should have overlap removed assert.Len(t, result.InputMessages, 1, "Should remove 2 overlapping messages") assert.Equal(t, "New message", result.InputMessages[0].Content) // FullMessages should be history + clean input assert.Len(t, result.FullMessages, 4) // 3 history + 1 new t.Log("✓ Overlap detected and removed from input") }) t.Run("EmptyChatID", func(t *testing.T) { ctx := newHistoryTestContext("") input := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "No chat ID"}, } result, err := ast.WithHistory(ctx, input, nil) require.NoError(t, err) require.NotNil(t, result) // With empty chat ID, should return input as is assert.Equal(t, input, result.InputMessages) assert.Equal(t, input, result.FullMessages) t.Log("✓ Empty chat ID handled gracefully") }) t.Run("MultipleUserMessages", func(t *testing.T) { chatID := fmt.Sprintf("test_history_multi_%s", uuid.New().String()[:8]) ctx := newHistoryTestContext(chatID) reqID := uuid.New().String()[:8] // Create chat with history err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer func() { chatStore.DeleteMessages(chatID, nil) chatStore.DeleteChat(chatID) }() // Add history err = chatStore.SaveMessages(chatID, []*storetypes.Message{ { MessageID: fmt.Sprintf("multi_1_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_multi_%s", reqID), Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "First"}, Sequence: 1, AssistantID: ast.ID, CreatedAt: time.Now().Add(-1 * time.Minute), }, }) require.NoError(t, err) // Multiple input messages input := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Second"}, {Role: agentcontext.RoleUser, Content: "Third"}, } result, err := ast.WithHistory(ctx, input, nil) require.NoError(t, err) require.NotNil(t, result) assert.Len(t, result.InputMessages, 2) assert.Len(t, result.FullMessages, 3) // 1 history + 2 new t.Log("✓ Multiple input messages handled correctly") }) } // ============================================================================= // History Load Tests // ============================================================================= func TestHistoryLoading(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.Get("tests.history") require.NoError(t, err) chatStore := assistant.GetChatStore() if chatStore == nil { t.Skip("Chat store not configured") } t.Run("FilterNonConversationTypes", func(t *testing.T) { chatID := fmt.Sprintf("test_filter_%s", uuid.New().String()[:8]) ctx := newHistoryTestContext(chatID) reqID := uuid.New().String()[:8] // Create chat err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer func() { chatStore.DeleteMessages(chatID, nil) chatStore.DeleteChat(chatID) }() // Add various message types (only user/assistant roles allowed by DB constraint) // loadHistory filters by role (user/assistant only) and converts based on type: // - loading/event: skipped (no semantic value) // - tool_call/action: converted to historical summary text // - text/user_input/error: kept as-is err = chatStore.SaveMessages(chatID, []*storetypes.Message{ { MessageID: fmt.Sprintf("filter_1_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_filter_%s", reqID), Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "User message"}, Sequence: 1, AssistantID: ast.ID, CreatedAt: time.Now().Add(-3 * time.Minute), }, { MessageID: fmt.Sprintf("filter_2_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_filter_%s", reqID), Role: "assistant", Type: "loading", Props: map[string]interface{}{"text": "Loading..."}, Sequence: 2, AssistantID: ast.ID, CreatedAt: time.Now().Add(-2 * time.Minute), }, { MessageID: fmt.Sprintf("filter_3_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_filter_%s", reqID), Role: "assistant", Type: "text", Props: map[string]interface{}{"text": "Assistant response"}, Sequence: 3, AssistantID: ast.ID, CreatedAt: time.Now().Add(-1 * time.Minute), }, }) require.NoError(t, err) input := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "New input"}, } result, err := ast.WithHistory(ctx, input, nil) require.NoError(t, err) require.NotNil(t, result) // loading type is skipped (no semantic value) // History: user_input + text = 2 messages; plus 1 new input = 3 total assert.Len(t, result.FullMessages, 3) // Verify only user and assistant roles for _, msg := range result.FullMessages { assert.True(t, msg.Role == agentcontext.RoleUser || msg.Role == agentcontext.RoleAssistant, "Expected user or assistant role, got: %s", msg.Role) } t.Log("✓ Loading type filtered, user/assistant roles kept") }) t.Run("ToolCallConvertedToSummary", func(t *testing.T) { chatID := fmt.Sprintf("test_toolcall_%s", uuid.New().String()[:8]) ctx := newHistoryTestContext(chatID) reqID := uuid.New().String()[:8] err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer func() { chatStore.DeleteMessages(chatID, nil) chatStore.DeleteChat(chatID) }() // Add tool_call messages in both formats err = chatStore.SaveMessages(chatID, []*storetypes.Message{ { MessageID: fmt.Sprintf("tc_1_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_tc_%s", reqID), Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "echo 3 ping 4"}, Sequence: 1, AssistantID: ast.ID, CreatedAt: time.Now().Add(-3 * time.Minute), }, // Raw stream chunk format (actual DB format) { MessageID: fmt.Sprintf("tc_2_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_tc_%s", reqID), Role: "assistant", Type: "tool_call", Props: map[string]interface{}{"content": `[{"index":0,"id":"call_abc","type":"function","function":{"name":"echo__ping"}}][{"index":0,"function":{"arguments":"{\"count\":3}"}}]`}, Sequence: 2, AssistantID: ast.ID, CreatedAt: time.Now().Add(-2 * time.Minute), }, // Standard ToolCallProps format { MessageID: fmt.Sprintf("tc_3_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_tc_%s", reqID), Role: "assistant", Type: "tool_call", Props: map[string]interface{}{"name": "echo__echo", "arguments": `{"message":"hello"}`}, Sequence: 3, AssistantID: ast.ID, CreatedAt: time.Now().Add(-1 * time.Minute), }, }) require.NoError(t, err) input := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "echo 5 ping 6"}, } result, err := ast.WithHistory(ctx, input, nil) require.NoError(t, err) require.NotNil(t, result) // 1 user_input + 2 tool_call summaries + 1 new input = 4 assert.Len(t, result.FullMessages, 4) // Verify tool_call messages are converted to summary text tcMsg1 := result.FullMessages[1] assert.Equal(t, agentcontext.RoleAssistant, tcMsg1.Role) assert.Contains(t, tcMsg1.Content, "[Historical Tool Call Summary]") assert.Contains(t, tcMsg1.Content, "echo__ping") assert.Contains(t, tcMsg1.Content, `{"count":3}`) tcMsg2 := result.FullMessages[2] assert.Equal(t, agentcontext.RoleAssistant, tcMsg2.Role) assert.Contains(t, tcMsg2.Content, "[Historical Tool Call Summary]") assert.Contains(t, tcMsg2.Content, "echo__echo") assert.Contains(t, tcMsg2.Content, `{"message":"hello"}`) t.Log("✓ Tool call messages converted to historical summaries") }) t.Run("ActionConvertedToSummary", func(t *testing.T) { chatID := fmt.Sprintf("test_action_%s", uuid.New().String()[:8]) ctx := newHistoryTestContext(chatID) reqID := uuid.New().String()[:8] err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer func() { chatStore.DeleteMessages(chatID, nil) chatStore.DeleteChat(chatID) }() err = chatStore.SaveMessages(chatID, []*storetypes.Message{ { MessageID: fmt.Sprintf("act_1_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_act_%s", reqID), Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Do something"}, Sequence: 1, AssistantID: ast.ID, CreatedAt: time.Now().Add(-2 * time.Minute), }, // Action with payload { MessageID: fmt.Sprintf("act_2_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_act_%s", reqID), Role: "assistant", Type: "action", Props: map[string]interface{}{ "name": "robot.execute", "payload": map[string]interface{}{"goals": "test goal", "robot_id": "12345"}, }, Sequence: 2, AssistantID: ast.ID, CreatedAt: time.Now().Add(-1 * time.Minute), }, }) require.NoError(t, err) input := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "What happened?"}, } result, err := ast.WithHistory(ctx, input, nil) require.NoError(t, err) require.NotNil(t, result) // 1 user_input + 1 action summary + 1 new input = 3 assert.Len(t, result.FullMessages, 3) actMsg := result.FullMessages[1] assert.Equal(t, agentcontext.RoleAssistant, actMsg.Role) assert.Contains(t, actMsg.Content, "[Historical Action Summary]") assert.Contains(t, actMsg.Content, "robot.execute") assert.Contains(t, actMsg.Content, "test goal") assert.Contains(t, actMsg.Content, "12345") t.Log("✓ Action messages converted to historical summaries with payload") }) t.Run("ActionWithoutPayload", func(t *testing.T) { chatID := fmt.Sprintf("test_action_nopay_%s", uuid.New().String()[:8]) ctx := newHistoryTestContext(chatID) reqID := uuid.New().String()[:8] err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer func() { chatStore.DeleteMessages(chatID, nil) chatStore.DeleteChat(chatID) }() err = chatStore.SaveMessages(chatID, []*storetypes.Message{ { MessageID: fmt.Sprintf("actnp_1_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_actnp_%s", reqID), Role: "assistant", Type: "action", Props: map[string]interface{}{"name": "navigate"}, Sequence: 1, AssistantID: ast.ID, CreatedAt: time.Now(), }, }) require.NoError(t, err) input := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "What happened?"}, } result, err := ast.WithHistory(ctx, input, nil) require.NoError(t, err) require.NotNil(t, result) // 1 action summary + 1 new input = 2 assert.Len(t, result.FullMessages, 2) actMsg := result.FullMessages[0] assert.Equal(t, agentcontext.RoleAssistant, actMsg.Role) assert.Contains(t, actMsg.Content, "[Historical Action Summary]") assert.Contains(t, actMsg.Content, "navigate") assert.NotContains(t, actMsg.Content, "payload") t.Log("✓ Action without payload handled correctly") }) t.Run("ContentExtraction", func(t *testing.T) { chatID := fmt.Sprintf("test_extract_%s", uuid.New().String()[:8]) ctx := newHistoryTestContext(chatID) reqID := uuid.New().String()[:8] // Create chat err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer func() { chatStore.DeleteMessages(chatID, nil) chatStore.DeleteChat(chatID) }() // Add messages with different content formats err = chatStore.SaveMessages(chatID, []*storetypes.Message{ { MessageID: fmt.Sprintf("extract_1_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_extract_%s", reqID), Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "User content from props.content"}, Sequence: 1, AssistantID: ast.ID, CreatedAt: time.Now().Add(-2 * time.Minute), }, { MessageID: fmt.Sprintf("extract_2_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_extract_%s", reqID), Role: "assistant", Type: "text", Props: map[string]interface{}{"text": "Assistant content from props.text"}, Sequence: 2, AssistantID: ast.ID, CreatedAt: time.Now().Add(-1 * time.Minute), }, }) require.NoError(t, err) input := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "New message"}, } result, err := ast.WithHistory(ctx, input, nil) require.NoError(t, err) require.NotNil(t, result) // Verify content was extracted correctly assert.Len(t, result.FullMessages, 3) assert.Equal(t, "User content from props.content", result.FullMessages[0].Content) assert.Equal(t, "Assistant content from props.text", result.FullMessages[1].Content) t.Log("✓ Content extracted correctly from different formats") }) } // ============================================================================= // Edge Cases Tests // ============================================================================= func TestHistoryEdgeCases(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.Get("tests.history") require.NoError(t, err) chatStore := assistant.GetChatStore() if chatStore == nil { t.Skip("Chat store not configured") } t.Run("EmptyInput", func(t *testing.T) { chatID := fmt.Sprintf("test_empty_input_%s", uuid.New().String()[:8]) ctx := newHistoryTestContext(chatID) reqID := uuid.New().String()[:8] // Create chat with history err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer func() { chatStore.DeleteMessages(chatID, nil) chatStore.DeleteChat(chatID) }() err = chatStore.SaveMessages(chatID, []*storetypes.Message{ { MessageID: fmt.Sprintf("empty_input_1_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_empty_%s", reqID), Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Previous"}, Sequence: 1, AssistantID: ast.ID, CreatedAt: time.Now(), }, }) require.NoError(t, err) // Empty input input := []agentcontext.Message{} result, err := ast.WithHistory(ctx, input, nil) require.NoError(t, err) require.NotNil(t, result) // Should return history only assert.Empty(t, result.InputMessages) assert.Len(t, result.FullMessages, 1) t.Log("✓ Empty input handled correctly") }) t.Run("FullOverlap", func(t *testing.T) { chatID := fmt.Sprintf("test_full_overlap_%s", uuid.New().String()[:8]) ctx := newHistoryTestContext(chatID) reqID := uuid.New().String()[:8] // Create chat err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer func() { chatStore.DeleteMessages(chatID, nil) chatStore.DeleteChat(chatID) }() // Add history err = chatStore.SaveMessages(chatID, []*storetypes.Message{ { MessageID: fmt.Sprintf("full_overlap_1_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_full_%s", reqID), Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Exact same message"}, Sequence: 1, AssistantID: ast.ID, CreatedAt: time.Now(), }, }) require.NoError(t, err) // Input is exactly the same as history input := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Exact same message"}, } result, err := ast.WithHistory(ctx, input, nil) require.NoError(t, err) require.NotNil(t, result) // Full overlap: clean input should be empty assert.Empty(t, result.InputMessages) // FullMessages should be just history (no duplicates) assert.Len(t, result.FullMessages, 1) t.Log("✓ Full overlap handled correctly") }) t.Run("NonExistentChat", func(t *testing.T) { chatID := "non_existent_chat_12345" ctx := newHistoryTestContext(chatID) input := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Message to non-existent chat"}, } result, err := ast.WithHistory(ctx, input, nil) require.NoError(t, err) require.NotNil(t, result) // Should return input as is (no history found) assert.Equal(t, input, result.InputMessages) assert.Equal(t, input, result.FullMessages) t.Log("✓ Non-existent chat handled gracefully") }) t.Run("MessageWithName", func(t *testing.T) { chatID := fmt.Sprintf("test_name_%s", uuid.New().String()[:8]) ctx := newHistoryTestContext(chatID) reqID := uuid.New().String()[:8] // Create chat err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer func() { chatStore.DeleteMessages(chatID, nil) chatStore.DeleteChat(chatID) }() // Add message with name err = chatStore.SaveMessages(chatID, []*storetypes.Message{ { MessageID: fmt.Sprintf("name_msg_1_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_name_%s", reqID), Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Message with name", "name": "John"}, Sequence: 1, AssistantID: ast.ID, CreatedAt: time.Now(), }, }) require.NoError(t, err) input := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "New message"}, } result, err := ast.WithHistory(ctx, input, nil) require.NoError(t, err) require.NotNil(t, result) // First message should have name assert.Len(t, result.FullMessages, 2) assert.NotNil(t, result.FullMessages[0].Name) assert.Equal(t, "John", *result.FullMessages[0].Name) t.Log("✓ Message name field preserved") }) t.Run("EmptyContent", func(t *testing.T) { chatID := fmt.Sprintf("test_empty_content_%s", uuid.New().String()[:8]) ctx := newHistoryTestContext(chatID) reqID := uuid.New().String()[:8] // Create chat err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: ast.ID, Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer func() { chatStore.DeleteMessages(chatID, nil) chatStore.DeleteChat(chatID) }() // Add message with empty content in props err = chatStore.SaveMessages(chatID, []*storetypes.Message{ { MessageID: fmt.Sprintf("empty_content_1_%s", reqID), ChatID: chatID, RequestID: fmt.Sprintf("req_empty_content_%s", reqID), Role: "user", Type: "user_input", Props: map[string]interface{}{}, // empty props (no content) Sequence: 1, AssistantID: ast.ID, CreatedAt: time.Now(), }, }) require.NoError(t, err) input := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "New message"}, } result, err := ast.WithHistory(ctx, input, nil) require.NoError(t, err) require.NotNil(t, result) // Message with empty props should be skipped (no content extractable) // Only new input should be present assert.Len(t, result.FullMessages, 1) assert.Equal(t, "New message", result.FullMessages[0].Content) t.Log("✓ Empty content handled gracefully (message skipped)") }) } ================================================ FILE: agent/assistant/hook/REALWORLD_PERFORMANCE_REPORT.md ================================================ # Performance Test Report **Test Date**: November 28, 2025 **System**: Yao Agent Assistant - Create Hook **Hardware**: Apple M2 Max, ARM64, macOS 25.1.0 --- ## Executive Summary All tests passed with 100% success rate. The system demonstrates production-ready performance with stable memory usage and predictable response times. **Key Metrics:** - ✅ **Concurrent Capacity**: 1,000 operations @ 100 goroutines - ✅ **Response Time**: 1.57ms average (hook execution only) - ✅ **Memory Stable**: ≤1 MB growth under load - ✅ **Success Rate**: 100% (1,000/1,000 validated) --- ## Performance Benchmarks ### Single Request Performance | Scenario | Mode | Time/op | Memory/op | Allocs/op | | -------- | ----------- | ------- | --------- | --------- | | Simple | Standard | 1.44 ms | 45 KB | 827 | | Simple | Performance | 0.33 ms | 33 KB | 789 | | Business | Standard | 3.33 ms | 95 KB | 1,570 | | Business | Performance | 0.35 ms | 33 KB | 805 | **Note**: Standard mode creates/disposes V8 isolate per request. Performance mode reuses isolates from pool. ### Concurrent Performance | Scenario | Mode | Time/op | Memory/op | Allocs/op | | ------------------- | ----------- | ------- | --------- | --------- | | Simple Concurrent | Standard | 0.42 ms | 46 KB | 829 | | Simple Concurrent | Performance | 0.35 ms | 33 KB | 789 | | Business Concurrent | Standard | 0.64 ms | 89 KB | 1,457 | | Business Concurrent | Performance | 0.35 ms | 33 KB | 786 | **Observation**: Concurrent execution shows better performance than sequential in standard mode due to parallel isolate creation. --- ## Stress Test Results ### Basic Tests **Simple Scenario** (100 iterations): - Duration: 0.34s - Memory: 470 MB → 471 MB (0 MB growth) - Result: ✅ Stable **MCP Integration** (50 iterations): - Duration: 0.40s - Memory: 472 MB → 471 MB (0 MB growth) - Result: ✅ No leaks **Full Workflow** (30 iterations, MCP + DB + Trace): - Duration: 0.39s - Average: 12.90 ms/op - Memory: 472 MB → 471 MB (0 MB growth) - Result: ✅ All components working ### Concurrent Stress Test ⭐ **Configuration:** - Goroutines: 100 - Iterations: 10 per goroutine - Total operations: 1,000 - Scenarios: Mixed (simple, mcp_health, mcp_tools, full_workflow) **Results:** - Duration: 1.57 seconds - Average: 1.57 ms/op - Throughput: ~636 ops/second - Success: 1,000/1,000 (100%) - Memory: 472 MB → 473 MB (1 MB growth) - Validation: All responses correct **Scenario Distribution:** - simple: 250 ops (25%) - mcp_health: 250 ops (25%) - mcp_tools: 250 ops (25%) - full_workflow: 250 ops (25%) --- ## Memory Analysis ### Memory Leak Tests All memory leak tests passed with acceptable thresholds: **Standard Mode** (1,000 iterations): - Growth: 11.65 MB (12.2 KB/iteration) - Threshold: <15 KB/iteration - Status: ✅ Pass **Performance Mode** (1,000 iterations): - Growth: -0.15 MB (negative = GC working) - Status: ✅ Pass **Business Scenarios** (200 iterations each): - Growth: 12-15 KB/iteration - Status: ✅ All pass **Concurrent Load** (1,000 iterations): - Growth: 1.73 MB (1.8 KB/iteration) - Status: ✅ Excellent ### Goroutine Behavior **Observation**: Each request creates 2 goroutines (trace pubsub + state worker) that exit asynchronously after `Release()`. **Measured Growth**: 2.0 goroutines/iteration - Initial: 106 → Final: 122 (after 10 iterations) - Threshold: <5 goroutines/iteration - Status: ✅ Expected behavior (not a leak) **Root Cause**: Asynchronous cleanup - goroutines exit when channels close, but scheduling takes time. This is normal Go concurrency behavior. --- ## Capacity Planning ### Single Instance Capacity **Hook Execution Only** (measured): ``` Response Time: 1.57ms Goroutines: 100 tested, stable Throughput: ~636 ops/second actual ``` **Complete Request Flow** (estimated): ``` Hook Execution: 1.57ms LLM API Call: 500-2000ms (typical) Network + Parsing: 50-100ms Total: ~1000ms per request ``` ### Production Estimates **Conservative Capacity** (50% safety factor): | User Activity | Requests/Min | Concurrent Online Users | | ------------------- | ------------ | ----------------------- | | Light (3 req/min) | 3,000 total | 1,000 online | | Normal (6 req/min) | 3,000 total | 500 online | | Active (15 req/min) | 3,000 total | 200 online | | Heavy (30 req/min) | 3,000 total | 100 online | **Calculation Basis:** - 100 goroutines proven stable - ~1 request/second per goroutine - Base: 100 req/s = 6,000 req/min - With 50% safety: 3,000 req/min sustained **Recommendation**: Start with 500-1,000 concurrent online users per instance, monitor and scale horizontally as needed. **Note**: "Concurrent online users" means users actively using the system at the same time, not total registered users. ### Horizontal Scaling ``` 1 instance → 500-1,000 concurrent online users 2 instances → 1,000-2,000 concurrent online users 5 instances → 2,500-5,000 concurrent online users 10 instances → 5,000-10,000 concurrent online users ``` --- ## Component Verification ### MCP Integration ✅ - ListTools: Working - CallTool: Working (ping, status) - Resource operations: Working - Prompt operations: Working - Performance: <3ms per operation ### Trace Management ✅ - Node creation: <1ms - 20+ nodes per operation: No issues - Memory cleanup: Effective - Goroutine cleanup: Asynchronous (expected) ### Context Management ✅ - Creation: Fast - Release: Working (cascading cleanup) - Memory: No leaks detected - Thread-safe: Yes ### Database Integration ✅ - Query execution: Working - Connection pooling: Efficient - Error handling: Robust --- ## Reliability Metrics **Test Coverage:** - Total tests: 21 - Tests passed: 21 (100%) - Tests failed: 0 - Flaky tests: 0 **Error Rate:** - Operations: 1,200+ - Errors: 0 - Rate: 0.00% **Data Integrity:** - Message validation: 100% - Metadata validation: 100% - Scenario matching: 100% --- ## Known Behaviors ### Goroutine Accumulation **Observation**: ~2 goroutines created per request that exit asynchronously. **Root Cause**: - Trace creates 2 background goroutines: `pubsub.forward()` + `stateWorker()` - These exit when channels close (via `Release()`) - Exit is asynchronous - takes 5-15ms after `Release()` - In rapid iterations, new goroutines start before old ones finish exiting **Impact**: - Temporary accumulation during high load - No unbounded growth (goroutines eventually exit) - Go runtime handles this efficiently - Not a memory leak **Status**: ✅ Expected behavior, no action needed --- ## Recommendations ### Production Deployment **Ready to Deploy**: Yes **Suggested Configuration:** - Start with 1-2 instances - Target: 500-1,000 concurrent users per instance - V8 Mode: Standard (safer) or Performance (faster) - Health check: Monitor goroutine count (<10,000) ### Monitoring **Key Metrics to Track:** 1. Response time (alert if >100ms sustained) 2. Goroutine count (alert if >10,000) 3. Memory usage (alert if >1GB growth/hour) 4. Error rate (alert if >1%) ### Scaling Triggers **Scale Up When:** - Response time >50ms average (sustained 5 min) - Goroutine count >5,000 (approaching limits) - CPU >70% (need more capacity) **Scale Out When:** - Need >1,000 concurrent users - Multi-region deployment required - Geographic latency optimization needed --- ## Conclusions ### System Status: **Production Ready** ✅ **Strengths:** - Fast response times (1-3ms for hook execution) - Stable memory usage (no leaks detected) - Excellent concurrent performance (100+ goroutines stable) - 100% test success rate with validation - Clean resource management with proper cleanup **Suitable For:** - SaaS platforms (500-1,000 concurrent online users per instance) - Enterprise applications requiring high reliability - Systems with 100-1,000 concurrent online users - Mission-critical AI agent deployments **Performance Rating**: A (Excellent) **Capacity Rating**: Mid-stage SaaS (Series A/B ready) --- ## Test Execution Summary ``` Platform: darwin/arm64 CPU: Apple M2 Max Go Version: 1.25.0 Test Duration: 19.8 seconds Unit Tests: 21 passed Benchmarks: 8 completed Stress Tests: 5 passed (1,000 ops validated) Memory Tests: 7 passed Goroutine Tests: 4 passed (behavior documented) Overall: 100% PASS ✅ ``` --- **Report Generated**: November 28, 2025 **Test Framework**: Go testing + testify **Validation**: Complete (all responses verified) **Status**: PRODUCTION READY ================================================ FILE: agent/assistant/hook/create.go ================================================ package hook import ( "encoding/json" "fmt" "github.com/yaoapp/gou/runtime/v8/bridge" "github.com/yaoapp/yao/agent/context" ) // Create create a new assistant // opts is optional - if provided, will be adjusted based on hook response func (s *Script) Create(ctx *context.Context, messages []context.Message, opts ...*context.Options) (*context.HookCreateResponse, *context.Options, error) { // Get or create options var options *context.Options if len(opts) > 0 && opts[0] != nil { options = opts[0] } else { options = &context.Options{} } // Execute hook with ctx, messages, and options (convert options to map for JS) optionsMap := options.ToMap() res, err := s.Execute(ctx, "Create", messages, optionsMap) if err != nil { return nil, nil, err } response, err := s.getHookCreateResponse(res) if err != nil { return nil, nil, err } // Apply adjustments from the response if response != nil { s.applyContextAdjustments(ctx, response) s.applyOptionsAdjustments(options, response) } return response, options, nil } // applyContextAdjustments applies session-level field overrides from the hook response back to the context func (s *Script) applyContextAdjustments(ctx *context.Context, response *context.HookCreateResponse) { // Note: AssistantID cannot be overridden - it's set at initialization and immutable // Override locale if provided (session-level) if response.Locale != "" { ctx.Locale = response.Locale } // Override theme if provided (session-level) if response.Theme != "" { ctx.Theme = response.Theme } // Override route if provided (session-level) if response.Route != "" { ctx.Route = response.Route } // Merge or override metadata if provided (session-level) if len(response.Metadata) > 0 { if ctx.Metadata == nil { ctx.Metadata = make(map[string]interface{}) } // Merge metadata - response metadata takes precedence for key, value := range response.Metadata { ctx.Metadata[key] = value } } } // applyOptionsAdjustments applies call-level field overrides from the hook response to options func (s *Script) applyOptionsAdjustments(opts *context.Options, response *context.HookCreateResponse) { // Override connector if provided (call-level parameter) if response.Connector != "" { opts.Connector = response.Connector } } // getHookCreateResponse convert the result to a HookCreateResponse func (s *Script) getHookCreateResponse(res interface{}) (*context.HookCreateResponse, error) { // Handle nil result if res == nil { return nil, nil } // Handle undefined result (treat as nil) if _, ok := res.(bridge.UndefinedT); ok { return nil, nil } // Marshal to JSON and unmarshal to HookCreateResponse raw, err := json.Marshal(res) if err != nil { return nil, fmt.Errorf("failed to marshal result: %w", err) } var response context.HookCreateResponse if err := json.Unmarshal(raw, &response); err != nil { return nil, fmt.Errorf("failed to unmarshal to HookCreateResponse: %w", err) } return &response, nil } ================================================ FILE: agent/assistant/hook/create_bench_test.go ================================================ package hook_test import ( stdContext "context" "testing" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" "github.com/yaoapp/yao/test" ) // ============================================================================ // Simple Scenario Benchmarks // ============================================================================ // BenchmarkSimpleStandardMode benchmarks simple scenario in standard V8 mode // Run with: go test -bench=BenchmarkSimpleStandardMode -benchmem -benchtime=100x func BenchmarkSimpleStandardMode(b *testing.B) { testutils.Prepare(&testing.T{}) defer testutils.Clean(&testing.T{}) agent, err := assistant.Get("tests.create") if err != nil { b.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { b.Fatalf("Assistant has no script") } b.ResetTimer() for i := 0; i < b.N; i++ { ctx := newBenchContext("bench-simple-standard", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "Hello"}, }) if err != nil { b.Fatalf("Create failed: %s", err.Error()) } } } // BenchmarkSimplePerformanceMode benchmarks simple scenario in performance V8 mode // Run with: go test -bench=BenchmarkSimplePerformanceMode -benchmem -benchtime=100x func BenchmarkSimplePerformanceMode(b *testing.B) { testutils.Prepare(&testing.T{}, test.PrepareOption{V8Mode: "performance"}) defer testutils.Clean(&testing.T{}) agent, err := assistant.Get("tests.create") if err != nil { b.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { b.Fatalf("Assistant has no script") } b.ResetTimer() for i := 0; i < b.N; i++ { ctx := newBenchContext("bench-simple-performance", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "Hello"}, }) if err != nil { b.Fatalf("Create failed: %s", err.Error()) } } } // ============================================================================ // Business Scenario Benchmarks (with Process calls, DB access, etc.) // ============================================================================ // BenchmarkBusinessStandardMode benchmarks business scenarios in standard V8 mode // Run with: go test -bench=BenchmarkBusinessStandardMode -benchmem -benchtime=100x func BenchmarkBusinessStandardMode(b *testing.B) { testutils.Prepare(&testing.T{}) defer testutils.Clean(&testing.T{}) agent, err := assistant.Get("tests.create") if err != nil { b.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { b.Fatalf("Assistant has no script") } scenarios := getBusinessScenarios() b.ResetTimer() for i := 0; i < b.N; i++ { scenario := scenarios[i%len(scenarios)] ctx := newBenchContext("bench-business-standard", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: scenario.content}, }) if err != nil { b.Errorf("%s failed: %s", scenario.name, err.Error()) } } } // BenchmarkBusinessPerformanceMode benchmarks business scenarios in performance V8 mode // Run with: go test -bench=BenchmarkBusinessPerformanceMode -benchmem -benchtime=100x func BenchmarkBusinessPerformanceMode(b *testing.B) { testutils.Prepare(&testing.T{}, test.PrepareOption{V8Mode: "performance"}) defer testutils.Clean(&testing.T{}) agent, err := assistant.Get("tests.create") if err != nil { b.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { b.Fatalf("Assistant has no script") } scenarios := getBusinessScenarios() b.ResetTimer() for i := 0; i < b.N; i++ { scenario := scenarios[i%len(scenarios)] ctx := newBenchContext("bench-business-performance", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: scenario.content}, }) if err != nil { b.Errorf("%s failed: %s", scenario.name, err.Error()) } } } // ============================================================================ // Concurrent Benchmarks // ============================================================================ // BenchmarkConcurrentSimpleStandardMode benchmarks simple concurrent scenario in standard V8 mode // Simulates concurrent users with isolate creation/disposal per request // Run with: go test -bench=BenchmarkConcurrentSimpleStandardMode -benchmem -benchtime=100x func BenchmarkConcurrentSimpleStandardMode(b *testing.B) { testutils.Prepare(&testing.T{}) defer testutils.Clean(&testing.T{}) agent, err := assistant.Get("tests.create") if err != nil { b.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { b.Fatalf("Assistant has no script") } b.ResetTimer() b.RunParallel(func(pb *testing.PB) { i := 0 for pb.Next() { ctx := newBenchContext("bench-concurrent-simple-standard", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "Hello"}, }) if err != nil { b.Errorf("Create failed (iteration %d): %s", i, err.Error()) } i++ } }) } // BenchmarkConcurrentSimplePerformanceMode benchmarks simple concurrent scenario in performance V8 mode // Simulates 100 users simultaneously using the system with isolate pool // Run with: go test -bench=BenchmarkConcurrentSimplePerformanceMode -benchmem -benchtime=100x func BenchmarkConcurrentSimplePerformanceMode(b *testing.B) { testutils.Prepare(&testing.T{}, test.PrepareOption{V8Mode: "performance"}) defer testutils.Clean(&testing.T{}) agent, err := assistant.Get("tests.create") if err != nil { b.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { b.Fatalf("Assistant has no script") } b.ResetTimer() b.RunParallel(func(pb *testing.PB) { i := 0 for pb.Next() { ctx := newBenchContext("bench-concurrent-simple", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "Hello"}, }) if err != nil { b.Errorf("Create failed (iteration %d): %s", i, err.Error()) } i++ } }) } // BenchmarkConcurrentBusinessStandardMode benchmarks concurrent business scenarios in standard V8 mode // Tests various scenarios with concurrent users and isolate creation/disposal per request // Run with: go test -bench=BenchmarkConcurrentBusinessStandardMode -benchmem -benchtime=100x func BenchmarkConcurrentBusinessStandardMode(b *testing.B) { testutils.Prepare(&testing.T{}) defer testutils.Clean(&testing.T{}) agent, err := assistant.Get("tests.create") if err != nil { b.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { b.Fatalf("Assistant has no script") } scenarios := getBusinessScenarios() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { i := 0 for pb.Next() { scenario := scenarios[i%len(scenarios)] ctx := newBenchContext("bench-concurrent-business-standard", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: scenario.content}, }) if err != nil { b.Errorf("%s failed (iteration %d): %s", scenario.name, i, err.Error()) } i++ } }) } // BenchmarkConcurrentBusinessPerformanceMode benchmarks concurrent business scenarios in performance V8 mode // Tests various scenarios with 100 concurrent users with isolate pool // Run with: go test -bench=BenchmarkConcurrentBusinessPerformanceMode -benchmem -benchtime=100x func BenchmarkConcurrentBusinessPerformanceMode(b *testing.B) { testutils.Prepare(&testing.T{}, test.PrepareOption{V8Mode: "performance"}) defer testutils.Clean(&testing.T{}) agent, err := assistant.Get("tests.create") if err != nil { b.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { b.Fatalf("Assistant has no script") } scenarios := getBusinessScenarios() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { i := 0 for pb.Next() { scenario := scenarios[i%len(scenarios)] ctx := newBenchContext("bench-concurrent-business", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: scenario.content}, }) if err != nil { b.Errorf("%s failed (iteration %d): %s", scenario.name, i, err.Error()) } i++ } }) } // ============================================================================ // Helper Functions // ============================================================================ // getBusinessScenarios returns the business test scenarios func getBusinessScenarios() []struct { name string content string } { return []struct { name string content string }{ {name: "FullResponse", content: "return_full"}, {name: "PartialResponse", content: "return_partial"}, {name: "ProcessCall", content: "return_process"}, {name: "ContextAdjustment", content: "adjust_context"}, {name: "NestedScriptCall", content: "nested_script_call"}, {name: "DeepNestedCall", content: "deep_nested_call"}, } } // newBenchContext creates a minimal context for benchmarking func newBenchContext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "bench-user", ClientID: "bench-client", UserID: "bench-user-123", TeamID: "bench-team-456", TenantID: "bench-tenant-789", Constraints: types.DataConstraints{ TeamOnly: true, Extra: map[string]interface{}{ "department": "engineering", }, }, } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "BenchAgent/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.Route = "" ctx.Metadata = make(map[string]interface{}) return ctx } ================================================ FILE: agent/assistant/hook/create_mem_test.go ================================================ package hook_test import ( stdContext "context" "runtime" "testing" "time" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" "github.com/yaoapp/yao/test" ) // ============================================================================ // Memory Leak Detection Tests // ============================================================================ // TestMemoryLeakStandardMode checks for memory leaks in standard V8 mode // Run with: go test -run=TestMemoryLeakStandardMode -v func TestMemoryLeakStandardMode(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.create") if err != nil { t.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { t.Fatalf("Assistant has no script") } // Warm up - execute a few times to stabilize memory for i := 0; i < 10; i++ { ctx := newMemTestContext("warmup", "tests.create") _, _, _ = agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "Hello"}, }) ctx.Release() } // Force GC and get baseline memory runtime.GC() time.Sleep(100 * time.Millisecond) var baseline runtime.MemStats runtime.ReadMemStats(&baseline) // Execute many iterations iterations := 1000 for i := 0; i < iterations; i++ { ctx := newMemTestContext("mem-test-standard", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "Hello"}, }) if err != nil { t.Errorf("Create failed at iteration %d: %s", i, err.Error()) } // Release context resources ctx.Release() // Periodic GC to help detect leaks faster if i%100 == 0 { runtime.GC() } } // Force GC and check final memory runtime.GC() time.Sleep(100 * time.Millisecond) var final runtime.MemStats runtime.ReadMemStats(&final) // Calculate memory growth baselineHeap := baseline.HeapAlloc finalHeap := final.HeapAlloc growth := int64(finalHeap) - int64(baselineHeap) growthPerIteration := float64(growth) / float64(iterations) t.Logf("Memory Statistics (Standard Mode):") t.Logf(" Iterations: %d", iterations) t.Logf(" Baseline HeapAlloc: %d bytes (%.2f MB)", baselineHeap, float64(baselineHeap)/1024/1024) t.Logf(" Final HeapAlloc: %d bytes (%.2f MB)", finalHeap, float64(finalHeap)/1024/1024) t.Logf(" Total Growth: %d bytes (%.2f MB)", growth, float64(growth)/1024/1024) t.Logf(" Growth per iteration: %.2f bytes", growthPerIteration) t.Logf(" Total Alloc: %d bytes (%.2f MB)", final.TotalAlloc, float64(final.TotalAlloc)/1024/1024) t.Logf(" Mallocs: %d", final.Mallocs) t.Logf(" Frees: %d", final.Frees) t.Logf(" Live Objects: %d", final.Mallocs-final.Frees) t.Logf(" GC Runs: %d", final.NumGC-baseline.NumGC) // Check for memory leak // Standard mode creates/disposes isolates per request, so some overhead is expected // Allow up to 20KB growth per iteration as threshold // This accounts for V8 isolate creation/disposal overhead and bridge management // Significant leaks would show much higher growth rates (50KB+) maxGrowthPerIteration := 20480.0 // 20 KB if growthPerIteration > maxGrowthPerIteration { t.Errorf("Possible memory leak detected: %.2f bytes/iteration (threshold: %.2f bytes/iteration)", growthPerIteration, maxGrowthPerIteration) } else { t.Logf("✓ Memory growth is within acceptable range (%.2f bytes/iteration)", growthPerIteration) } } // TestMemoryLeakPerformanceMode checks for memory leaks in performance V8 mode // Run with: go test -run=TestMemoryLeakPerformanceMode -v func TestMemoryLeakPerformanceMode(t *testing.T) { testutils.Prepare(t, test.PrepareOption{V8Mode: "performance"}) defer testutils.Clean(t) agent, err := assistant.Get("tests.create") if err != nil { t.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { t.Fatalf("Assistant has no script") } // Warm up - execute a few times to stabilize memory and fill isolate pool for i := 0; i < 20; i++ { ctx := newMemTestContext("warmup", "tests.create") _, _, _ = agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "Hello"}, }) ctx.Release() } // Force GC and get baseline memory runtime.GC() time.Sleep(100 * time.Millisecond) var baseline runtime.MemStats runtime.ReadMemStats(&baseline) // Execute many iterations iterations := 1000 for i := 0; i < iterations; i++ { ctx := newMemTestContext("mem-test-performance", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "Hello"}, }) if err != nil { t.Errorf("Create failed at iteration %d: %s", i, err.Error()) } // Release context resources ctx.Release() // Periodic GC if i%100 == 0 { runtime.GC() } } // Force GC and check final memory runtime.GC() time.Sleep(100 * time.Millisecond) var final runtime.MemStats runtime.ReadMemStats(&final) // Calculate memory growth baselineHeap := baseline.HeapAlloc finalHeap := final.HeapAlloc growth := int64(finalHeap) - int64(baselineHeap) growthPerIteration := float64(growth) / float64(iterations) t.Logf("Memory Statistics (Performance Mode):") t.Logf(" Iterations: %d", iterations) t.Logf(" Baseline HeapAlloc: %d bytes (%.2f MB)", baselineHeap, float64(baselineHeap)/1024/1024) t.Logf(" Final HeapAlloc: %d bytes (%.2f MB)", finalHeap, float64(finalHeap)/1024/1024) t.Logf(" Total Growth: %d bytes (%.2f MB)", growth, float64(growth)/1024/1024) t.Logf(" Growth per iteration: %.2f bytes", growthPerIteration) t.Logf(" Total Alloc: %d bytes (%.2f MB)", final.TotalAlloc, float64(final.TotalAlloc)/1024/1024) t.Logf(" Mallocs: %d", final.Mallocs) t.Logf(" Frees: %d", final.Frees) t.Logf(" Live Objects: %d", final.Mallocs-final.Frees) t.Logf(" GC Runs: %d", final.NumGC-baseline.NumGC) // Performance mode should have less growth due to isolate reuse // Allow up to 5KB per iteration as threshold maxGrowthPerIteration := 5120.0 if growthPerIteration > maxGrowthPerIteration { t.Errorf("Possible memory leak detected: %.2f bytes/iteration (threshold: %.2f bytes/iteration)", growthPerIteration, maxGrowthPerIteration) } else { t.Logf("✓ Memory growth is within acceptable range") } } // TestMemoryLeakBusinessScenarios checks for memory leaks with business logic // Run with: go test -run=TestMemoryLeakBusinessScenarios -v func TestMemoryLeakBusinessScenarios(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.create") if err != nil { t.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { t.Fatalf("Assistant has no script") } scenarios := []struct { name string content string }{ {name: "FullResponse", content: "return_full"}, {name: "PartialResponse", content: "return_partial"}, {name: "ProcessCall", content: "return_process"}, {name: "ContextAdjustment", content: "adjust_context"}, {name: "NestedScriptCall", content: "nested_script_call"}, {name: "DeepNestedCall", content: "deep_nested_call"}, } // Warm up for i := 0; i < 10; i++ { ctx := newMemTestContext("warmup", "tests.create") _, _, _ = agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "return_full"}, }) ctx.Release() } // Test each scenario for _, scenario := range scenarios { t.Run(scenario.name, func(t *testing.T) { // Get baseline runtime.GC() time.Sleep(50 * time.Millisecond) var baseline runtime.MemStats runtime.ReadMemStats(&baseline) // Execute iterations (reduced to avoid V8 OOM) iterations := 200 for i := 0; i < iterations; i++ { ctx := newMemTestContext("mem-test-business", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: scenario.content}, }) if err != nil { t.Errorf("Create failed at iteration %d: %s", i, err.Error()) } ctx.Release() if i%50 == 0 { runtime.GC() } } // Check final memory runtime.GC() time.Sleep(50 * time.Millisecond) var final runtime.MemStats runtime.ReadMemStats(&final) growth := int64(final.HeapAlloc) - int64(baseline.HeapAlloc) growthPerIteration := float64(growth) / float64(iterations) t.Logf(" Baseline HeapAlloc: %d bytes (%.2f MB)", baseline.HeapAlloc, float64(baseline.HeapAlloc)/1024/1024) t.Logf(" Final HeapAlloc: %d bytes (%.2f MB)", final.HeapAlloc, float64(final.HeapAlloc)/1024/1024) t.Logf(" Growth: %d bytes (%.2f MB)", growth, float64(growth)/1024/1024) t.Logf(" Growth/iteration: %.2f bytes", growthPerIteration) // Business scenarios may have more memory usage due to complex operations // Allow up to 20KB per iteration as threshold // Note: Some scenarios like ContextAdjustment generate dynamic timestamps, // causing slightly higher memory usage. Real leaks would show 50KB+ growth. maxGrowthPerIteration := 20480.0 if growthPerIteration > maxGrowthPerIteration { t.Errorf("Possible memory leak: %.2f bytes/iteration (threshold: %.2f)", growthPerIteration, maxGrowthPerIteration) } else { t.Logf(" ✓ Memory growth is within acceptable range") } }) } } // TestMemoryLeakConcurrent checks for memory leaks under concurrent load // Run with: go test -run=TestMemoryLeakConcurrent -v func TestMemoryLeakConcurrent(t *testing.T) { testutils.Prepare(t, test.PrepareOption{V8Mode: "performance"}) defer testutils.Clean(t) agent, err := assistant.Get("tests.create") if err != nil { t.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { t.Fatalf("Assistant has no script") } // Warm up for i := 0; i < 20; i++ { ctx := newMemTestContext("warmup", "tests.create") _, _, _ = agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "Hello"}, }) ctx.Release() } // Get baseline runtime.GC() time.Sleep(100 * time.Millisecond) var baseline runtime.MemStats runtime.ReadMemStats(&baseline) // Run concurrent load iterations := 1000 concurrency := 10 iterPerGoroutine := iterations / concurrency done := make(chan bool, concurrency) for g := 0; g < concurrency; g++ { go func(id int) { defer func() { done <- true }() for i := 0; i < iterPerGoroutine; i++ { ctx := newMemTestContext("mem-test-concurrent", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "Hello"}, }) if err != nil { t.Errorf("Goroutine %d failed at iteration %d: %s", id, i, err.Error()) } ctx.Release() } }(g) } // Wait for all goroutines for g := 0; g < concurrency; g++ { <-done } // Check final memory runtime.GC() time.Sleep(100 * time.Millisecond) var final runtime.MemStats runtime.ReadMemStats(&final) growth := int64(final.HeapAlloc) - int64(baseline.HeapAlloc) growthPerIteration := float64(growth) / float64(iterations) t.Logf("Memory Statistics (Concurrent Load):") t.Logf(" Iterations: %d", iterations) t.Logf(" Concurrency: %d", concurrency) t.Logf(" Baseline HeapAlloc: %d bytes (%.2f MB)", baseline.HeapAlloc, float64(baseline.HeapAlloc)/1024/1024) t.Logf(" Final HeapAlloc: %d bytes (%.2f MB)", final.HeapAlloc, float64(final.HeapAlloc)/1024/1024) t.Logf(" Growth: %d bytes (%.2f MB)", growth, float64(growth)/1024/1024) t.Logf(" Growth/iteration: %.2f bytes", growthPerIteration) t.Logf(" GC Runs: %d", final.NumGC-baseline.NumGC) // Concurrent scenarios may have slightly more overhead maxGrowthPerIteration := 10240.0 if growthPerIteration > maxGrowthPerIteration { t.Errorf("Possible memory leak: %.2f bytes/iteration (threshold: %.2f)", growthPerIteration, maxGrowthPerIteration) } else { t.Logf("✓ Memory growth is within acceptable range") } } // TestMemoryLeakNestedCalls checks for memory leaks with nested script calls // Run with: go test -run=TestMemoryLeakNestedCalls -v func TestMemoryLeakNestedCalls(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.create") if err != nil { t.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { t.Fatalf("Assistant has no script") } // Warm up for i := 0; i < 10; i++ { ctx := newMemTestContext("warmup", "tests.create") _, _, _ = agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "nested_script_call"}, }) ctx.Release() } // Get baseline runtime.GC() time.Sleep(100 * time.Millisecond) var baseline runtime.MemStats runtime.ReadMemStats(&baseline) // Execute iterations with nested calls // Nested calls: hook -> scripts.tests.create.NestedCall -> GetRoles/GetRole -> models iterations := 200 for i := 0; i < iterations; i++ { ctx := newMemTestContext("mem-test-nested", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "deep_nested_call"}, }) if err != nil { t.Errorf("Nested call failed at iteration %d: %s", i, err.Error()) } ctx.Release() if i%50 == 0 { runtime.GC() } } // Check final memory runtime.GC() time.Sleep(100 * time.Millisecond) var final runtime.MemStats runtime.ReadMemStats(&final) growth := int64(final.HeapAlloc) - int64(baseline.HeapAlloc) growthPerIteration := float64(growth) / float64(iterations) t.Logf("Memory Statistics (Nested Calls):") t.Logf(" Iterations: %d", iterations) t.Logf(" Baseline HeapAlloc: %d bytes (%.2f MB)", baseline.HeapAlloc, float64(baseline.HeapAlloc)/1024/1024) t.Logf(" Final HeapAlloc: %d bytes (%.2f MB)", final.HeapAlloc, float64(final.HeapAlloc)/1024/1024) t.Logf(" Growth: %d bytes (%.2f MB)", growth, float64(growth)/1024/1024) t.Logf(" Growth/iteration: %.2f bytes", growthPerIteration) t.Logf(" GC Runs: %d", final.NumGC-baseline.NumGC) // Nested calls involve database operations, so allow more overhead // Allow up to 20KB per iteration as threshold maxGrowthPerIteration := 20480.0 if growthPerIteration > maxGrowthPerIteration { t.Errorf("Possible memory leak: %.2f bytes/iteration (threshold: %.2f)", growthPerIteration, maxGrowthPerIteration) } else { t.Logf("✓ Memory growth is within acceptable range") } } // TestMemoryLeakNestedConcurrent checks for memory leaks with concurrent nested calls // Run with: go test -run=TestMemoryLeakNestedConcurrent -v func TestMemoryLeakNestedConcurrent(t *testing.T) { testutils.Prepare(t, test.PrepareOption{V8Mode: "performance"}) defer testutils.Clean(t) agent, err := assistant.Get("tests.create") if err != nil { t.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { t.Fatalf("Assistant has no script") } // Warm up for i := 0; i < 20; i++ { ctx := newMemTestContext("warmup", "tests.create") _, _, _ = agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "nested_script_call"}, }) ctx.Release() } // Get baseline runtime.GC() time.Sleep(100 * time.Millisecond) var baseline runtime.MemStats runtime.ReadMemStats(&baseline) // Run concurrent nested calls iterations := 500 concurrency := 10 iterPerGoroutine := iterations / concurrency done := make(chan bool, concurrency) for g := 0; g < concurrency; g++ { go func(id int) { defer func() { done <- true }() for i := 0; i < iterPerGoroutine; i++ { ctx := newMemTestContext("mem-test-nested-concurrent", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "deep_nested_call"}, }) if err != nil { t.Errorf("Goroutine %d nested call failed at iteration %d: %s", id, i, err.Error()) } ctx.Release() } }(g) } // Wait for all goroutines for g := 0; g < concurrency; g++ { <-done } // Check final memory runtime.GC() time.Sleep(100 * time.Millisecond) var final runtime.MemStats runtime.ReadMemStats(&final) growth := int64(final.HeapAlloc) - int64(baseline.HeapAlloc) growthPerIteration := float64(growth) / float64(iterations) t.Logf("Memory Statistics (Concurrent Nested Calls):") t.Logf(" Iterations: %d", iterations) t.Logf(" Concurrency: %d", concurrency) t.Logf(" Baseline HeapAlloc: %d bytes (%.2f MB)", baseline.HeapAlloc, float64(baseline.HeapAlloc)/1024/1024) t.Logf(" Final HeapAlloc: %d bytes (%.2f MB)", final.HeapAlloc, float64(final.HeapAlloc)/1024/1024) t.Logf(" Growth: %d bytes (%.2f MB)", growth, float64(growth)/1024/1024) t.Logf(" Growth/iteration: %.2f bytes", growthPerIteration) t.Logf(" GC Runs: %d", final.NumGC-baseline.NumGC) // Concurrent nested calls with database operations // Allow up to 25KB per iteration as threshold maxGrowthPerIteration := 25600.0 if growthPerIteration > maxGrowthPerIteration { t.Errorf("Possible memory leak: %.2f bytes/iteration (threshold: %.2f)", growthPerIteration, maxGrowthPerIteration) } else { t.Logf("✓ Memory growth is within acceptable range") } } // TestIsolateDisposal verifies that isolates are properly disposed in standard mode // Run with: go test -run=TestIsolateDisposal -v func TestIsolateDisposal(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.create") if err != nil { t.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { t.Fatalf("Assistant has no script") } // Track goroutine count to detect goroutine leaks initialGoroutines := runtime.NumGoroutine() // Execute multiple iterations iterations := 100 for i := 0; i < iterations; i++ { ctx := newMemTestContext("disposal-test", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "Hello"}, }) if err != nil { t.Errorf("Create failed at iteration %d: %s", i, err.Error()) } ctx.Release() } // Give time for cleanup time.Sleep(200 * time.Millisecond) runtime.GC() time.Sleep(200 * time.Millisecond) finalGoroutines := runtime.NumGoroutine() goroutineGrowth := finalGoroutines - initialGoroutines t.Logf("Goroutine Statistics:") t.Logf(" Initial: %d", initialGoroutines) t.Logf(" Final: %d", finalGoroutines) t.Logf(" Growth: %d", goroutineGrowth) // Allow some goroutine growth for runtime internals // // ROOT CAUSE ANALYSIS: // Each Create() call creates a Trace, which starts 2 goroutines: // 1. trace/pubsub.(*PubSub).forward() - PubSub event forwarding // 2. trace.(*manager).startStateWorker() - State machine worker // // These goroutines exit when Release() closes their channels, but: // - Exit is ASYNCHRONOUS (goroutine needs to reach select statement) // - Go runtime needs time to schedule and cleanup // - In rapid iterations, new goroutines are created before old ones fully exit // // This is NOT a true leak: // ✓ Goroutines eventually exit (channels are closed) // ✓ No unbounded growth (they will be GC'd) // ✓ Typical pattern for async cleanup in Go // // Acceptable: ~2 goroutines per iteration (trace pubsub + state worker) // Concerning: >5 goroutines per iteration (indicates goroutines NOT exiting) maxGoroutineGrowthPerIteration := 5.0 growthPerIteration := float64(goroutineGrowth) / float64(iterations) if growthPerIteration > maxGoroutineGrowthPerIteration { t.Errorf("Goroutine leak detected: %.2f goroutines per iteration (threshold: %.2f)", growthPerIteration, maxGoroutineGrowthPerIteration) t.Errorf("This indicates goroutines are NOT being cleaned up properly") } else { t.Logf("✓ Goroutine growth is acceptable: %.2f per iteration", growthPerIteration) t.Logf(" (Trace creates 2 goroutines per call: pubsub.forward + stateWorker)") t.Logf(" (These exit asynchronously after Release(), causing temporary accumulation)") } } // ============================================================================ // Helper Functions // ============================================================================ // newMemTestContext creates a context for memory leak testing func newMemTestContext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "mem-test-user", ClientID: "mem-test-client", UserID: "mem-user-123", TeamID: "mem-team-456", TenantID: "mem-tenant-789", Constraints: types.DataConstraints{ TeamOnly: true, Extra: map[string]interface{}{ "department": "engineering", }, }, } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "MemTestAgent/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.Route = "" ctx.Metadata = make(map[string]interface{}) return ctx } ================================================ FILE: agent/assistant/hook/create_nested_test.go ================================================ package hook_test import ( "sync" "testing" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" ) // TestNestedScriptCall tests nested script calls with V8 context sharing // This test calls: hook -> scripts.tests.create.NestedCall -> GetRoles/GetRole -> models func TestNestedScriptCall(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.create") if err != nil { t.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { t.Fatalf("Assistant has no script") } // Create context ctx := newTestContext("test-nested-call", "tests.create") // Call with deep_nested_call scenario // This will: hook -> scripts.tests.create.NestedCall -> GetRoles -> model res, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "deep_nested_call"}, }) if err != nil { t.Fatalf("Nested call failed: %s", err.Error()) } if res == nil { t.Fatal("Expected non-nil response") } // Verify messages if len(res.Messages) == 0 { t.Fatal("Expected messages in response") } t.Logf("✓ Nested script call completed successfully") t.Logf(" Messages count: %d", len(res.Messages)) if res.Metadata != nil { t.Logf(" Metadata: %+v", res.Metadata) } } // TestNestedScriptCallConcurrent tests nested script calls under high concurrency // Simulates 100 concurrent users making nested script calls func TestNestedScriptCallConcurrent(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.create") if err != nil { t.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { t.Fatalf("Assistant has no script") } // High concurrency test: 100 concurrent users (testing race condition) concurrency := 100 iterations := 1 // Each user makes 1 call var wg sync.WaitGroup errors := make(chan error, concurrency*iterations) t.Logf("Starting concurrent test: %d users × %d iterations = %d total calls", concurrency, iterations, concurrency*iterations) for i := 0; i < concurrency; i++ { wg.Add(1) go func(userID int) { defer wg.Done() for j := 0; j < iterations; j++ { ctx := newTestContext("test-concurrent", "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "deep_nested_call"}, }) if err != nil { errors <- err return } } }(i) } // Wait for all goroutines to complete wg.Wait() close(errors) // Check for errors errorCount := 0 for err := range errors { errorCount++ t.Errorf("Concurrent call failed: %s", err.Error()) } if errorCount > 0 { t.Fatalf("Failed with %d errors out of %d total calls", errorCount, concurrency*iterations) } t.Logf("✓ All %d concurrent nested calls completed successfully", concurrency*iterations) } ================================================ FILE: agent/assistant/hook/create_test.go ================================================ package hook_test import ( stdContext "context" "testing" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // newTestContext creates a Context for testing with commonly used fields pre-populated. // You can override any fields after creation as needed for specific test scenarios. func newTestContext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client-id", Scope: "openid profile email", SessionID: "test-session-id", UserID: "test-user-123", TeamID: "test-team-456", TenantID: "test-tenant-789", RememberMe: true, Constraints: types.DataConstraints{ OwnerOnly: false, CreatorOnly: false, EditorOnly: false, TeamOnly: true, Extra: map[string]interface{}{ "department": "engineering", "region": "us-west", "project": "yao", }, }, } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "TestAgent/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.Route = "" ctx.Metadata = make(map[string]interface{}) return ctx } // TestCreate test the create hook func TestCreate(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.create") if err != nil { t.Fatalf("Failed to get the tests.create assistant: %s", err.Error()) } if agent.HookScript == nil { t.Fatalf("The tests.create assistant has no script") } // Use the helper function to create a test context ctx := newTestContext("chat-test-create-hook", "tests.create") // Test scenario 1: Return null (should get nil response) t.Run("ReturnNull", func(t *testing.T) { res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "return_null"}}) if err != nil { t.Fatalf("Failed to create with null return: %s", err.Error()) } if res != nil { t.Errorf("Expected nil response for null return, got: %v", res) } }) // Test scenario 2: Return undefined (should get nil response) t.Run("ReturnUndefined", func(t *testing.T) { res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "return_undefined"}}) if err != nil { t.Fatalf("Failed to create with undefined return: %s", err.Error()) } if res != nil { t.Errorf("Expected nil response for undefined return, got: %v", res) } }) // Test scenario 3: Return empty object (should get empty HookCreateResponse) t.Run("ReturnEmpty", func(t *testing.T) { res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "return_empty"}}) if err != nil { t.Fatalf("Failed to create with empty return: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response for empty object, got nil") } if len(res.Messages) != 0 { t.Errorf("Expected empty messages, got: %d messages", len(res.Messages)) } }) // Test scenario 4: Return full response with all fields t.Run("ReturnFull", func(t *testing.T) { res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "return_full"}}) if err != nil { t.Fatalf("Failed to create with full return: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response, got nil") } // Verify messages if len(res.Messages) != 2 { t.Errorf("Expected 2 messages, got: %d", len(res.Messages)) } else { if res.Messages[0].Role != context.RoleSystem { t.Errorf("Expected system role for first message, got: %s", res.Messages[0].Role) } if res.Messages[1].Role != context.RoleUser { t.Errorf("Expected user role for second message, got: %s", res.Messages[1].Role) } } // Verify audio config if res.Audio == nil { t.Error("Expected audio config, got nil") } else { if res.Audio.Voice != "alloy" { t.Errorf("Expected voice 'alloy', got: %s", res.Audio.Voice) } if res.Audio.Format != "mp3" { t.Errorf("Expected format 'mp3', got: %s", res.Audio.Format) } } // Verify temperature if res.Temperature == nil { t.Error("Expected temperature, got nil") } else if *res.Temperature != 0.7 { t.Errorf("Expected temperature 0.7, got: %f", *res.Temperature) } // Verify max_tokens if res.MaxTokens == nil { t.Error("Expected max_tokens, got nil") } else if *res.MaxTokens != 2000 { t.Errorf("Expected max_tokens 2000, got: %d", *res.MaxTokens) } // Verify max_completion_tokens if res.MaxCompletionTokens == nil { t.Error("Expected max_completion_tokens, got nil") } else if *res.MaxCompletionTokens != 1500 { t.Errorf("Expected max_completion_tokens 1500, got: %d", *res.MaxCompletionTokens) } }) // Test scenario 5: Return partial response t.Run("ReturnPartial", func(t *testing.T) { res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "return_partial"}}) if err != nil { t.Fatalf("Failed to create with partial return: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response, got nil") } // Verify messages if len(res.Messages) != 1 { t.Errorf("Expected 1 message, got: %d", len(res.Messages)) } // Verify temperature if res.Temperature == nil { t.Error("Expected temperature, got nil") } else if *res.Temperature != 0.5 { t.Errorf("Expected temperature 0.5, got: %f", *res.Temperature) } // Verify optional fields are nil if res.Audio != nil { t.Errorf("Expected audio to be nil, got: %v", res.Audio) } if res.MaxTokens != nil { t.Errorf("Expected max_tokens to be nil, got: %d", *res.MaxTokens) } }) // Test scenario 6: Process call - calls models.__yao.role.Get and adds to messages t.Run("ReturnProcess", func(t *testing.T) { res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "return_process"}}) if err != nil { t.Fatalf("Failed to create with process return: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response, got nil") } // Verify messages - should have at least 1 (system message) if len(res.Messages) < 1 { t.Errorf("Expected at least 1 message, got: %d", len(res.Messages)) } else { // First message should be system role if res.Messages[0].Role != context.RoleSystem { t.Errorf("Expected system role for first message, got: %s", res.Messages[0].Role) } // Check system message content if content, ok := res.Messages[0].Content.(string); ok { if content != "Here are the available roles in the system:" { t.Errorf("Unexpected system message content: %s", content) } } } }) // Test scenario 7: Default response t.Run("ReturnDefault", func(t *testing.T) { testContent := "Hello, how are you?" res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: testContent}}) if err != nil { t.Fatalf("Failed to create with default return: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response, got nil") } // Verify messages if len(res.Messages) != 1 { t.Errorf("Expected 1 message, got: %d", len(res.Messages)) } else { if res.Messages[0].Role != context.RoleUser { t.Errorf("Expected user role, got: %s", res.Messages[0].Role) } if content, ok := res.Messages[0].Content.(string); ok { if content != testContent { t.Errorf("Expected content '%s', got: '%s'", testContent, content) } } else { t.Errorf("Expected string content, got: %T", res.Messages[0].Content) } } }) // Test scenario 8: Verify context fields - validates all context fields in JavaScript t.Run("VerifyContext", func(t *testing.T) { res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "verify_context"}}) if err != nil { t.Fatalf("Failed to create with verify_context: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response, got nil") } // Verify we have messages if len(res.Messages) < 1 { t.Fatalf("Expected at least 1 message, got: %d", len(res.Messages)) } // First message should be system role with success/failure indicator if res.Messages[0].Role != context.RoleSystem { t.Errorf("Expected system role for first message, got: %s", res.Messages[0].Role) } // Check the validation result content, ok := res.Messages[0].Content.(string) if !ok { t.Fatalf("Expected string content for system message, got: %T", res.Messages[0].Content) } // The content should be "success:all_fields_validated" if content != "success:all_fields_validated" { t.Errorf("Context validation failed: %s", content) // Print detailed validation results if available if len(res.Messages) > 1 { if details, ok := res.Messages[1].Content.(string); ok { t.Logf("Validation details:\n%s", details) } } } else { t.Log("✓ All context fields validated successfully in JavaScript") // Optionally print validation details if len(res.Messages) > 1 { if details, ok := res.Messages[1].Content.(string); ok { t.Logf("Validation details:\n%s", details) } } } }) // Test scenario 9: Adjust context fields - tests that context fields can be modified by the hook t.Run("AdjustContext", func(t *testing.T) { // Create a fresh context for this test adjustCtx := newTestContext("chat-test-adjust", "tests.create") // Call the hook which should adjust context fields res, _, err := agent.HookScript.Create(adjustCtx, []context.Message{{Role: "user", Content: "adjust_context"}}) if err != nil { t.Fatalf("Failed to create with adjust_context: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response, got nil") } // Verify the response contains adjusted fields // Note: AssistantID cannot be overridden by hooks, removed from HookCreateResponse if res.Connector != "adjusted-connector" { t.Errorf("Expected adjusted connector 'adjusted-connector', got: %s", res.Connector) } if res.Locale != "zh-cn" { t.Errorf("Expected adjusted locale 'zh-cn', got: %s", res.Locale) } if res.Theme != "dark" { t.Errorf("Expected adjusted theme 'dark', got: %s", res.Theme) } if res.Route != "/adjusted/route" { t.Errorf("Expected adjusted route '/adjusted/route', got: %s", res.Route) } // Verify metadata if res.Metadata == nil { t.Fatalf("Expected metadata, got nil") } if adjusted, ok := res.Metadata["adjusted"].(bool); !ok || !adjusted { t.Errorf("Expected metadata['adjusted'] = true, got: %v", res.Metadata["adjusted"]) } // Verify context fields were actually updated // Note: AssistantID is immutable and cannot be overridden // Note: Connector is now in Options, not in Context if adjustCtx.Locale != "zh-cn" { t.Errorf("Context locale not updated. Expected 'zh-cn', got: %s", adjustCtx.Locale) } if adjustCtx.Theme != "dark" { t.Errorf("Context theme not updated. Expected 'dark', got: %s", adjustCtx.Theme) } if adjustCtx.Route != "/adjusted/route" { t.Errorf("Context route not updated. Expected '/adjusted/route', got: %s", adjustCtx.Route) } if adjustCtx.Metadata["adjusted"] != true { t.Errorf("Context metadata not updated. Expected metadata['adjusted'] = true, got: %v", adjustCtx.Metadata["adjusted"]) } t.Log("✓ Context fields successfully adjusted by hook") }) // Test scenario 10: Adjust uses configuration - tests that uses can be modified by the hook t.Run("AdjustUses", func(t *testing.T) { // Create a fresh context for this test usesCtx := newTestContext("chat-test-uses", "tests.create") // Call the hook which should adjust uses configuration res, _, err := agent.HookScript.Create(usesCtx, []context.Message{{Role: "user", Content: "adjust_uses"}}) if err != nil { t.Fatalf("Failed to create with adjust_uses: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response, got nil") } // Verify the response contains uses configuration if res.Uses == nil { t.Fatalf("Expected uses configuration, got nil") } // Verify each uses field if res.Uses.Vision != "mcp:vision-server" { t.Errorf("Expected vision 'mcp:vision-server', got: %s", res.Uses.Vision) } if res.Uses.Audio != "mcp:audio-server" { t.Errorf("Expected audio 'mcp:audio-server', got: %s", res.Uses.Audio) } if res.Uses.Search != "agent" { t.Errorf("Expected search 'agent', got: %s", res.Uses.Search) } if res.Uses.Fetch != "mcp:fetch-server" { t.Errorf("Expected fetch 'mcp:fetch-server', got: %s", res.Uses.Fetch) } // Verify metadata if res.Metadata == nil { t.Fatalf("Expected metadata, got nil") } if usesAdjusted, ok := res.Metadata["uses_adjusted"].(bool); !ok || !usesAdjusted { t.Errorf("Expected metadata['uses_adjusted'] = true, got: %v", res.Metadata["uses_adjusted"]) } // Now test that BuildRequest properly applies the uses configuration inputMessages := []context.Message{{Role: "user", Content: "test uses"}} _, options, err := agent.BuildRequest(usesCtx, inputMessages, res) if err != nil { t.Fatalf("Failed to build request: %s", err.Error()) } // Verify that options.Uses has the values from createResponse if options.Uses == nil { t.Fatalf("Expected options.Uses to be set, got nil") } if options.Uses.Vision != "mcp:vision-server" { t.Errorf("Expected options.Uses.Vision 'mcp:vision-server', got: %s", options.Uses.Vision) } if options.Uses.Audio != "mcp:audio-server" { t.Errorf("Expected options.Uses.Audio 'mcp:audio-server', got: %s", options.Uses.Audio) } if options.Uses.Search != "agent" { t.Errorf("Expected options.Uses.Search 'agent', got: %s", options.Uses.Search) } if options.Uses.Fetch != "mcp:fetch-server" { t.Errorf("Expected options.Uses.Fetch 'mcp:fetch-server', got: %s", options.Uses.Fetch) } t.Log("✓ Uses configuration successfully adjusted by hook and applied to options") }) // Test scenario 11: Adjust uses configuration with force_uses flag t.Run("AdjustUsesForce", func(t *testing.T) { // Create a fresh context for this test usesCtx := newTestContext("chat-test-uses-force", "tests.create") // Call the hook which should adjust uses configuration and set force_uses res, _, err := agent.HookScript.Create(usesCtx, []context.Message{{Role: "user", Content: "adjust_uses_force"}}) if err != nil { t.Fatalf("Failed to create with adjust_uses_force: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response, got nil") } // Verify the response contains uses configuration if res.Uses == nil { t.Fatalf("Expected uses configuration, got nil") } // Verify uses fields if res.Uses.Vision != "tests.vision-helper" { t.Errorf("Expected vision 'tests.vision-helper', got: %s", res.Uses.Vision) } if res.Uses.Audio != "mcp:audio-server" { t.Errorf("Expected audio 'mcp:audio-server', got: %s", res.Uses.Audio) } // Verify force_uses flag if res.ForceUses == nil { t.Fatalf("Expected force_uses to be set, got nil") } if !*res.ForceUses { t.Errorf("Expected force_uses to be true, got: %v", *res.ForceUses) } // Verify metadata if res.Metadata == nil { t.Fatalf("Expected metadata, got nil") } if usesForced, ok := res.Metadata["uses_forced"].(bool); !ok || !usesForced { t.Errorf("Expected metadata['uses_forced'] = true, got: %v", res.Metadata["uses_forced"]) } // Now test that BuildRequest properly applies the force_uses flag inputMessages := []context.Message{{Role: "user", Content: "test force uses"}} _, options, err := agent.BuildRequest(usesCtx, inputMessages, res) if err != nil { t.Fatalf("Failed to build request: %s", err.Error()) } // Verify that options.ForceUses is true if !options.ForceUses { t.Errorf("Expected options.ForceUses to be true, got: %v", options.ForceUses) } t.Log("✓ Uses configuration with force_uses flag successfully adjusted by hook and applied to options") }) } ================================================ FILE: agent/assistant/hook/goroutine_leak_test.go ================================================ package hook_test import ( stdContext "context" "fmt" "os" "runtime" "runtime/pprof" "strings" "testing" "time" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // TestGoroutineLeakDetailed performs detailed goroutine leak analysis func TestGoroutineLeakDetailed(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.create") if err != nil { t.Fatalf("Failed to get assistant: %s", err.Error()) } if agent.HookScript == nil { t.Fatalf("Assistant has no script") } // Create profile directory os.MkdirAll("/tmp/goroutine_profiles", 0755) // Take initial snapshot runtime.GC() time.Sleep(200 * time.Millisecond) initialGoroutines := runtime.NumGoroutine() // Save initial profile saveGoroutineProfile("/tmp/goroutine_profiles/00_initial.txt") t.Logf("Initial goroutines: %d", initialGoroutines) // Test with just 10 iterations to see the pattern iterations := 10 for i := 0; i < iterations; i++ { ctx := newLeakTestContext(fmt.Sprintf("leak-test-%d", i), "tests.create") _, _, err := agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "Hello"}, }) if err != nil { t.Errorf("Create failed at iteration %d: %s", i, err.Error()) } // Release context ctx.Release() // Check goroutines after each iteration current := runtime.NumGoroutine() growth := current - initialGoroutines t.Logf("After iteration %d: %d goroutines (growth: %d)", i+1, current, growth) // Save profile every 5 iterations if (i+1)%5 == 0 { saveGoroutineProfile(fmt.Sprintf("/tmp/goroutine_profiles/%02d_after_iter_%d.txt", i+1, i+1)) } } // Force cleanup runtime.GC() time.Sleep(500 * time.Millisecond) finalGoroutines := runtime.NumGoroutine() growth := finalGoroutines - initialGoroutines t.Logf("\n=== SUMMARY ===") t.Logf("Initial: %d goroutines", initialGoroutines) t.Logf("Final: %d goroutines", finalGoroutines) t.Logf("Growth: %d goroutines (%.2f per iteration)", growth, float64(growth)/float64(iterations)) // Save final profile saveGoroutineProfile("/tmp/goroutine_profiles/99_final.txt") // Analyze the leak t.Logf("\n=== ANALYSIS ===") analyzeGoroutineProfiles(t, "/tmp/goroutine_profiles") } // TestGoroutineLeakByComponent tests each component separately func TestGoroutineLeakByComponent(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.create") if err != nil { t.Fatalf("Failed to get assistant: %s", err.Error()) } os.MkdirAll("/tmp/component_profiles", 0755) t.Run("ContextCreationOnly", func(t *testing.T) { runtime.GC() time.Sleep(100 * time.Millisecond) initial := runtime.NumGoroutine() for i := 0; i < 10; i++ { ctx := newLeakTestContext(fmt.Sprintf("test-%d", i), "tests.create") _ = ctx ctx.Release() } runtime.GC() time.Sleep(100 * time.Millisecond) final := runtime.NumGoroutine() t.Logf("Context creation: initial=%d, final=%d, growth=%d", initial, final, final-initial) }) t.Run("ScriptExecutionOnly", func(t *testing.T) { runtime.GC() time.Sleep(100 * time.Millisecond) initial := runtime.NumGoroutine() for i := 0; i < 10; i++ { ctx := newLeakTestContext(fmt.Sprintf("test-%d", i), "tests.create") _, _, _ = agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "Hello"}, }) ctx.Release() } runtime.GC() time.Sleep(100 * time.Millisecond) final := runtime.NumGoroutine() t.Logf("Script execution: initial=%d, final=%d, growth=%d", initial, final, final-initial) saveGoroutineProfile("/tmp/component_profiles/script_execution.txt") }) t.Run("TraceOperations", func(t *testing.T) { runtime.GC() time.Sleep(100 * time.Millisecond) initial := runtime.NumGoroutine() for i := 0; i < 10; i++ { ctx := newLeakTestContext(fmt.Sprintf("test-%d", i), "tests.create") // Create trace trace, err := ctx.Trace() if err == nil && trace != nil { // Trace operations _ = trace } ctx.Release() } runtime.GC() time.Sleep(100 * time.Millisecond) final := runtime.NumGoroutine() t.Logf("Trace operations: initial=%d, final=%d, growth=%d", initial, final, final-initial) saveGoroutineProfile("/tmp/component_profiles/trace_operations.txt") }) } // TestGoroutineLeakWithoutRelease tests if Release() fixes the leak func TestGoroutineLeakWithoutRelease(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.create") if err != nil { t.Fatalf("Failed to get assistant: %s", err.Error()) } t.Run("WithoutRelease", func(t *testing.T) { runtime.GC() time.Sleep(100 * time.Millisecond) initial := runtime.NumGoroutine() for i := 0; i < 10; i++ { ctx := newLeakTestContext(fmt.Sprintf("no-release-%d", i), "tests.create") _, _, _ = agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "Hello"}, }) // Intentionally NOT calling ctx.Release() } runtime.GC() time.Sleep(100 * time.Millisecond) final := runtime.NumGoroutine() t.Logf("WITHOUT Release: initial=%d, final=%d, growth=%d (%.1f per iter)", initial, final, final-initial, float64(final-initial)/10.0) }) t.Run("WithRelease", func(t *testing.T) { runtime.GC() time.Sleep(100 * time.Millisecond) initial := runtime.NumGoroutine() for i := 0; i < 10; i++ { ctx := newLeakTestContext(fmt.Sprintf("with-release-%d", i), "tests.create") _, _, _ = agent.HookScript.Create(ctx, []context.Message{ {Role: "user", Content: "Hello"}, }) ctx.Release() // WITH Release } runtime.GC() time.Sleep(100 * time.Millisecond) final := runtime.NumGoroutine() t.Logf("WITH Release: initial=%d, final=%d, growth=%d (%.1f per iter)", initial, final, final-initial, float64(final-initial)/10.0) }) } // Helper functions func saveGoroutineProfile(filename string) { f, err := os.Create(filename) if err != nil { return } defer f.Close() pprof.Lookup("goroutine").WriteTo(f, 2) // detail level 2 } func analyzeGoroutineProfiles(t *testing.T, dir string) { // Read initial and final profiles initialData, err := os.ReadFile(dir + "/00_initial.txt") if err != nil { t.Logf("Could not read initial profile: %v", err) return } finalData, err := os.ReadFile(dir + "/99_final.txt") if err != nil { t.Logf("Could not read final profile: %v", err) return } // Count goroutines by function initialFuncs := countGoroutinesByFunction(string(initialData)) finalFuncs := countGoroutinesByFunction(string(finalData)) t.Logf("\nGoroutine growth by function:") t.Logf("%-60s %8s %8s %8s", "Function", "Initial", "Final", "Growth") t.Logf("%s", strings.Repeat("-", 90)) // Find functions that grew for fn, finalCount := range finalFuncs { initialCount := initialFuncs[fn] growth := finalCount - initialCount if growth > 0 { t.Logf("%-60s %8d %8d %8d", truncate(fn, 60), initialCount, finalCount, growth) } } t.Logf("\nProfiles saved to: %s", dir) t.Logf("To compare: diff %s/00_initial.txt %s/99_final.txt | grep '^>'", dir, dir) } func countGoroutinesByFunction(profile string) map[string]int { counts := make(map[string]int) lines := strings.Split(profile, "\n") for _, line := range lines { line = strings.TrimSpace(line) // Look for function names in goroutine stack traces if strings.Contains(line, "(") && !strings.HasPrefix(line, "#") { // Extract function name if idx := strings.Index(line, "("); idx > 0 { fn := strings.TrimSpace(line[:idx]) counts[fn]++ } } } return counts } func truncate(s string, max int) string { if len(s) <= max { return s } return s[:max-3] + "..." } func newLeakTestContext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "leak-test-user", ClientID: "leak-test-client", UserID: "leak-user-123", TeamID: "leak-team-456", TenantID: "leak-tenant-789", Constraints: types.DataConstraints{ TeamOnly: true, Extra: map[string]interface{}{ "department": "testing", }, }, } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "LeakTestAgent/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.Route = "" ctx.Metadata = make(map[string]interface{}) return ctx } ================================================ FILE: agent/assistant/hook/hook.go ================================================ package hook ================================================ FILE: agent/assistant/hook/next.go ================================================ package hook import ( "encoding/json" "fmt" "github.com/yaoapp/gou/runtime/v8/bridge" "github.com/yaoapp/yao/agent/context" ) // Next next hook for the next action after the completion // opts is optional - if provided, will be passed to the hook func (s *Script) Next(ctx *context.Context, payload *context.NextHookPayload, opts ...*context.Options) (*context.NextHookResponse, *context.Options, error) { // Get or create options var options *context.Options if len(opts) > 0 && opts[0] != nil { options = opts[0] } else { options = &context.Options{} } // Convert payload to map for JS (use JSON tag names) payloadMap := map[string]interface{}{ "messages": payload.Messages, "completion": payload.Completion, "tools": payload.Tools, "error": payload.Error, } // Execute hook with ctx, payload, and options (convert options to map for JS) optionsMap := options.ToMap() res, err := s.Execute(ctx, "Next", payloadMap, optionsMap) if err != nil { return nil, nil, err } response, err := s.getNextHookResponse(res) if err != nil { return nil, nil, err } return response, options, nil } // getNextHookResponse convert the result to a NextHookResponse func (s *Script) getNextHookResponse(res interface{}) (*context.NextHookResponse, error) { // Handle nil result if res == nil { return nil, nil } // Handle undefined result (treat as nil) if _, ok := res.(bridge.UndefinedT); ok { return nil, nil } // Marshal to JSON and unmarshal to NextHookResponse raw, err := json.Marshal(res) if err != nil { return nil, fmt.Errorf("failed to marshal Next hook result: %w", err) } var response context.NextHookResponse if err := json.Unmarshal(raw, &response); err != nil { return nil, fmt.Errorf("failed to unmarshal to NextHookResponse: %w", err) } return &response, nil } ================================================ FILE: agent/assistant/hook/next_test.go ================================================ package hook_test import ( stdContext "context" "testing" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // newTestContextForNext creates a Context for testing Next Hook with commonly used fields pre-populated. // You can override any fields after creation as needed for specific test scenarios. func newTestContextForNext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client-id", Scope: "openid profile email", SessionID: "test-session-id", UserID: "test-user-123", TeamID: "test-team-456", TenantID: "test-tenant-789", RememberMe: true, Constraints: types.DataConstraints{ OwnerOnly: false, CreatorOnly: false, EditorOnly: false, TeamOnly: true, Extra: map[string]interface{}{ "department": "engineering", "region": "us-west", "project": "yao", }, }, } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "TestAgent/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.Route = "" ctx.Metadata = make(map[string]interface{}) return ctx } // TestNext tests the Next hook func TestNext(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.next") if err != nil { t.Fatalf("Failed to get the tests.next assistant: %s", err.Error()) } if agent.HookScript == nil { t.Fatalf("The tests.next assistant has no script") } // Use the helper function to create a test context ctx := newTestContextForNext("chat-test-next-hook", "tests.next") // Test scenario 1: Return null (should get nil response) t.Run("ReturnNull", func(t *testing.T) { payload := &context.NextHookPayload{ Messages: []context.Message{ {Role: context.RoleUser, Content: "return_null"}, }, Completion: &context.CompletionResponse{ Content: "Test completion", }, Tools: nil, Error: "", } res, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Failed to execute Next hook with null return: %s", err.Error()) } if res != nil { t.Errorf("Expected nil response for null return, got: %v", res) } }) // Test scenario 2: Return undefined (should get nil response) t.Run("ReturnUndefined", func(t *testing.T) { payload := &context.NextHookPayload{ Messages: []context.Message{ {Role: context.RoleUser, Content: "return_undefined"}, }, Completion: &context.CompletionResponse{ Content: "Test completion", }, } res, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Failed to execute Next hook with undefined return: %s", err.Error()) } if res != nil { t.Errorf("Expected nil response for undefined return, got: %v", res) } }) // Test scenario 3: Return empty object (should get empty NextHookResponse) t.Run("ReturnEmpty", func(t *testing.T) { payload := &context.NextHookPayload{ Messages: []context.Message{ {Role: context.RoleUser, Content: "return_empty"}, }, Completion: &context.CompletionResponse{ Content: "Test completion", }, } res, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Failed to execute Next hook with empty return: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response for empty object, got nil") } if res.Delegate != nil { t.Errorf("Expected nil Delegate, got: %v", res.Delegate) } if res.Data != nil { t.Errorf("Expected nil Data, got: %v", res.Data) } }) // Test scenario 4: Return custom data t.Run("ReturnCustomData", func(t *testing.T) { payload := &context.NextHookPayload{ Messages: []context.Message{ {Role: context.RoleUser, Content: "return_custom_data"}, }, Completion: &context.CompletionResponse{ Content: "Test completion", }, } res, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Failed to execute Next hook with custom data: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response, got nil") } // Verify Data is present if res.Data == nil { t.Fatalf("Expected Data to be present, got nil") } // Data should be a map dataMap, ok := res.Data.(map[string]interface{}) if !ok { t.Fatalf("Expected Data to be map[string]interface{}, got: %T", res.Data) } // Verify custom data fields if message, ok := dataMap["message"].(string); !ok || message != "Custom response from Next Hook" { t.Errorf("Expected custom message, got: %v", dataMap["message"]) } if test, ok := dataMap["test"].(bool); !ok || !test { t.Errorf("Expected test=true, got: %v", dataMap["test"]) } if _, ok := dataMap["timestamp"]; !ok { t.Errorf("Expected timestamp field") } // Verify Delegate is nil if res.Delegate != nil { t.Errorf("Expected nil Delegate, got: %v", res.Delegate) } }) // Test scenario 5: Return data with metadata t.Run("ReturnDataWithMetadata", func(t *testing.T) { payload := &context.NextHookPayload{ Messages: []context.Message{ {Role: context.RoleUser, Content: "return_data_with_metadata"}, }, Completion: &context.CompletionResponse{ Content: "Test completion", }, } res, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Failed to execute Next hook: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response, got nil") } // Verify Data if res.Data == nil { t.Fatalf("Expected Data to be present, got nil") } dataMap, ok := res.Data.(map[string]interface{}) if !ok { t.Fatalf("Expected Data to be map[string]interface{}, got: %T", res.Data) } if result, ok := dataMap["result"].(string); !ok || result != "success" { t.Errorf("Expected result='success', got: %v", dataMap["result"]) } // Verify Metadata if res.Metadata == nil { t.Fatalf("Expected Metadata to be present, got nil") } if hook, ok := res.Metadata["hook"].(string); !ok || hook != "next" { t.Errorf("Expected hook='next', got: %v", res.Metadata["hook"]) } if processed, ok := res.Metadata["processed"].(bool); !ok || !processed { t.Errorf("Expected processed=true, got: %v", res.Metadata["processed"]) } }) // Test scenario 6: Return delegate t.Run("ReturnDelegate", func(t *testing.T) { payload := &context.NextHookPayload{ Messages: []context.Message{ {Role: context.RoleUser, Content: "return_delegate"}, }, Completion: &context.CompletionResponse{ Content: "Test completion", }, } res, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Failed to execute Next hook with delegate: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response, got nil") } // Verify Delegate is present if res.Delegate == nil { t.Fatalf("Expected Delegate to be present, got nil") } // Verify delegate fields if res.Delegate.AgentID != "tests.create" { t.Errorf("Expected AgentID='tests.create', got: %s", res.Delegate.AgentID) } if len(res.Delegate.Messages) != 1 { t.Errorf("Expected 1 message, got: %d", len(res.Delegate.Messages)) } else { if res.Delegate.Messages[0].Role != context.RoleUser { t.Errorf("Expected user role, got: %s", res.Delegate.Messages[0].Role) } if content, ok := res.Delegate.Messages[0].Content.(string); !ok || content != "Hello from delegated agent" { t.Errorf("Expected specific content, got: %v", res.Delegate.Messages[0].Content) } } // Verify Data is nil (only delegate, no custom data) if res.Data != nil { t.Logf("Note: Data is present alongside Delegate: %v", res.Data) } }) // Test scenario 7: Verify payload structure t.Run("VerifyPayload", func(t *testing.T) { payload := &context.NextHookPayload{ Messages: []context.Message{ {Role: context.RoleSystem, Content: "System message"}, {Role: context.RoleUser, Content: "verify_payload"}, }, Completion: &context.CompletionResponse{ Content: "Test completion content", Usage: &message.UsageInfo{ PromptTokens: 10, CompletionTokens: 20, TotalTokens: 30, }, }, Tools: []context.ToolCallResponse{ { ToolCallID: "call_123", Server: "test-server", Tool: "test-tool", Result: map[string]interface{}{"success": true}, Error: "", }, }, Error: "", } res, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Failed to execute Next hook: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response, got nil") } // Verify Data contains validation results if res.Data == nil { t.Fatalf("Expected Data with validation results, got nil") } dataMap, ok := res.Data.(map[string]interface{}) if !ok { t.Fatalf("Expected Data to be map[string]interface{}, got: %T", res.Data) } if validation, ok := dataMap["validation"].(string); !ok || validation != "success" { t.Errorf("Expected validation='success', got: %v", dataMap["validation"]) } if checks, ok := dataMap["checks"].([]interface{}); !ok { t.Errorf("Expected checks array, got: %T", dataMap["checks"]) } else { t.Logf("✓ Payload validation checks: %d items", len(checks)) for i, check := range checks { t.Logf(" [%d] %v", i, check) } } }) // Test scenario 8: Verify tools processing t.Run("VerifyTools", func(t *testing.T) { payload := &context.NextHookPayload{ Messages: []context.Message{ {Role: context.RoleUser, Content: "verify_tools"}, }, Completion: &context.CompletionResponse{ Content: "Test completion", }, Tools: []context.ToolCallResponse{ { ToolCallID: "call_1", Server: "server1", Tool: "tool1", Result: map[string]interface{}{"value": 42}, Error: "", }, { ToolCallID: "call_2", Server: "server2", Tool: "tool2", Result: nil, Error: "Tool execution failed", }, }, } res, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Failed to execute Next hook: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response, got nil") } // Verify Data if res.Data == nil { t.Fatalf("Expected Data, got nil") } dataMap, ok := res.Data.(map[string]interface{}) if !ok { t.Fatalf("Expected Data to be map, got: %T", res.Data) } // Verify tool statistics if totalTools, ok := dataMap["total_tools"].(float64); !ok || int(totalTools) != 2 { t.Errorf("Expected total_tools=2, got: %v", dataMap["total_tools"]) } if successful, ok := dataMap["successful"].(float64); !ok || int(successful) != 1 { t.Errorf("Expected successful=1, got: %v", dataMap["successful"]) } if failed, ok := dataMap["failed"].(float64); !ok || int(failed) != 1 { t.Errorf("Expected failed=1, got: %v", dataMap["failed"]) } t.Log("✓ Tools processing validated successfully") }) // Test scenario 9: Handle error t.Run("HandleError", func(t *testing.T) { payload := &context.NextHookPayload{ Messages: []context.Message{ {Role: context.RoleUser, Content: "handle_error"}, }, Completion: &context.CompletionResponse{ Content: "Test completion", }, Error: "Tool execution failed: timeout", } res, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Failed to execute Next hook: %s", err.Error()) } if res == nil { t.Fatalf("Expected non-nil response, got nil") } // Verify error handling if res.Data == nil { t.Fatalf("Expected Data, got nil") } dataMap, ok := res.Data.(map[string]interface{}) if !ok { t.Fatalf("Expected Data to be map, got: %T", res.Data) } if errorMsg, ok := dataMap["error"].(string); !ok || errorMsg != "Tool execution failed: timeout" { t.Errorf("Expected error message, got: %v", dataMap["error"]) } if recovered, ok := dataMap["recovered"].(bool); !ok || !recovered { t.Errorf("Expected recovered=true, got: %v", dataMap["recovered"]) } t.Log("✓ Error handling validated successfully") }) } ================================================ FILE: agent/assistant/hook/realworld_next_test.go ================================================ package hook_test import ( stdContext "context" "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // newRealWorldNextContext creates a Context for real world Next Hook testing func newRealWorldNextContext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "realworld-test-user", ClientID: "realworld-test-client", Scope: "openid profile", SessionID: "realworld-test-session", UserID: "realworld-user-123", TeamID: "realworld-team-456", TenantID: "realworld-tenant-789", } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "RealWorldTest/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.Route = "" ctx.Metadata = make(map[string]interface{}) return ctx } // TestRealWorldNextStandard tests standard response (nil return) func TestRealWorldNextStandard(t *testing.T) { if testing.Short() { t.Skip("Skipping real world Next Hook test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld-next") if err != nil { t.Fatalf("Failed to get assistant: %v", err) } ctx := newRealWorldNextContext("test-next-standard", "tests.realworld-next") // Simulate completion with scenario marker messages := []context.Message{ {Role: context.RoleUser, Content: "scenario: standard"}, {Role: context.RoleAssistant, Content: "I'll process your request using standard response."}, } completion := &context.CompletionResponse{ Content: "Processing complete. Standard response will be used.", } payload := &context.NextHookPayload{ Messages: messages, Completion: completion, Tools: nil, Error: "", } response, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Next hook failed: %v", err) } // Should return nil for standard response assert.Nil(t, response, "Standard scenario should return nil") t.Log("✓ Standard response scenario passed") } // TestRealWorldNextCustomData tests custom data response func TestRealWorldNextCustomData(t *testing.T) { if testing.Short() { t.Skip("Skipping real world Next Hook test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld-next") if err != nil { t.Fatalf("Failed to get assistant: %v", err) } ctx := newRealWorldNextContext("test-next-custom", "tests.realworld-next") messages := []context.Message{ {Role: context.RoleUser, Content: "scenario: custom_data"}, {Role: context.RoleAssistant, Content: "Here's some information for you."}, } completion := &context.CompletionResponse{ Content: "This is the LLM completion that will be summarized.", } payload := &context.NextHookPayload{ Messages: messages, Completion: completion, Tools: nil, Error: "", } response, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Next hook failed: %v", err) } assert.NotNil(t, response, "Custom data scenario should return response") assert.NotNil(t, response.Data, "Response should have Data") dataMap, ok := response.Data.(map[string]interface{}) assert.True(t, ok, "Data should be a map") assert.Equal(t, "custom_response", dataMap["type"]) assert.Contains(t, dataMap, "timestamp") t.Log("✓ Custom data response scenario passed") } // TestRealWorldNextDelegate tests agent delegation func TestRealWorldNextDelegate(t *testing.T) { if testing.Short() { t.Skip("Skipping real world Next Hook test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld-next") if err != nil { t.Fatalf("Failed to get assistant: %v", err) } ctx := newRealWorldNextContext("test-next-delegate", "tests.realworld-next") messages := []context.Message{ {Role: context.RoleUser, Content: "scenario: delegate"}, } completion := &context.CompletionResponse{ Content: "I should delegate this request to another agent.", } payload := &context.NextHookPayload{ Messages: messages, Completion: completion, Tools: nil, Error: "", } response, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Next hook failed: %v", err) } assert.NotNil(t, response, "Delegate scenario should return response") assert.NotNil(t, response.Delegate, "Response should have Delegate") assert.Equal(t, "tests.create", response.Delegate.AgentID) assert.NotEmpty(t, response.Delegate.Messages) t.Log("✓ Delegation scenario passed") } // TestRealWorldNextProcessTools tests tool result processing func TestRealWorldNextProcessTools(t *testing.T) { if testing.Short() { t.Skip("Skipping real world Next Hook test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld-next") if err != nil { t.Fatalf("Failed to get assistant: %v", err) } ctx := newRealWorldNextContext("test-next-tools", "tests.realworld-next") messages := []context.Message{ {Role: context.RoleUser, Content: "scenario: process_tools"}, } completion := &context.CompletionResponse{ Content: "Tool calls have been executed.", } // Simulate tool call results tools := []context.ToolCallResponse{ { ToolCallID: "call_1", Server: "test-server", Tool: "test-tool-1", Result: map[string]interface{}{"status": "success"}, Error: "", }, { ToolCallID: "call_2", Server: "test-server", Tool: "test-tool-2", Result: nil, Error: "Tool execution failed", }, } payload := &context.NextHookPayload{ Messages: messages, Completion: completion, Tools: tools, Error: "", } response, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Next hook failed: %v", err) } assert.NotNil(t, response, "Process tools scenario should return response") assert.NotNil(t, response.Data, "Response should have Data") dataMap, ok := response.Data.(map[string]interface{}) assert.True(t, ok, "Data should be a map") assert.Equal(t, "Tool execution summary", dataMap["message"]) // Check summary summary, ok := dataMap["summary"].(map[string]interface{}) assert.True(t, ok, "Should have summary") assert.Equal(t, float64(2), summary["total"]) assert.Equal(t, float64(1), summary["successful"]) assert.Equal(t, float64(1), summary["failed"]) t.Log("✓ Process tools scenario passed") } // TestRealWorldNextErrorRecovery tests error handling and recovery func TestRealWorldNextErrorRecovery(t *testing.T) { if testing.Short() { t.Skip("Skipping real world Next Hook test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld-next") if err != nil { t.Fatalf("Failed to get assistant: %v", err) } ctx := newRealWorldNextContext("test-next-error", "tests.realworld-next") messages := []context.Message{ {Role: context.RoleUser, Content: "scenario: error_recovery"}, } completion := &context.CompletionResponse{ Content: "An error occurred during processing.", } payload := &context.NextHookPayload{ Messages: messages, Completion: completion, Tools: nil, Error: "System error: Database connection timeout", } response, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Next hook failed: %v", err) } assert.NotNil(t, response, "Error recovery scenario should return response") assert.NotNil(t, response.Data, "Response should have Data") dataMap, ok := response.Data.(map[string]interface{}) assert.True(t, ok, "Data should be a map") assert.Equal(t, "Error was handled by Next Hook", dataMap["message"]) assert.Contains(t, dataMap, "error") assert.Contains(t, dataMap, "recovery_action") t.Log("✓ Error recovery scenario passed") } // TestRealWorldNextConditional tests conditional logic based on completion func TestRealWorldNextConditional(t *testing.T) { if testing.Short() { t.Skip("Skipping real world Next Hook test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld-next") if err != nil { t.Fatalf("Failed to get assistant: %v", err) } ctx := newRealWorldNextContext("test-next-conditional", "tests.realworld-next") t.Run("ConditionalSuccess", func(t *testing.T) { messages := []context.Message{ {Role: context.RoleUser, Content: "scenario: conditional"}, } completion := &context.CompletionResponse{ Content: "The operation completed successfully. All tasks are done.", } payload := &context.NextHookPayload{ Messages: messages, Completion: completion, Tools: nil, Error: "", } response, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Next hook failed: %v", err) } assert.NotNil(t, response, "Conditional scenario should return response") assert.NotNil(t, response.Data, "Response should have Data") dataMap, ok := response.Data.(map[string]interface{}) assert.True(t, ok, "Data should be a map") assert.Equal(t, "Conditional analysis complete", dataMap["message"]) assert.Contains(t, dataMap, "action") assert.Contains(t, dataMap, "conditions") t.Log("✓ Conditional (success) scenario passed") }) t.Run("ConditionalDelegate", func(t *testing.T) { messages := []context.Message{ {Role: context.RoleUser, Content: "scenario: conditional"}, } completion := &context.CompletionResponse{ Content: "I should delegate this request to another service for better handling.", } payload := &context.NextHookPayload{ Messages: messages, Completion: completion, Tools: nil, Error: "", } response, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Next hook failed: %v", err) } assert.NotNil(t, response, "Conditional delegate should return response") assert.NotNil(t, response.Delegate, "Should delegate based on condition") assert.Equal(t, "tests.create", response.Delegate.AgentID) t.Log("✓ Conditional (delegate) scenario passed") }) } // TestRealWorldNextDefault tests default behavior func TestRealWorldNextDefault(t *testing.T) { if testing.Short() { t.Skip("Skipping real world Next Hook test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld-next") if err != nil { t.Fatalf("Failed to get assistant: %v", err) } ctx := newRealWorldNextContext("test-next-default", "tests.realworld-next") messages := []context.Message{ {Role: context.RoleUser, Content: "Just a normal request"}, } completion := &context.CompletionResponse{ Content: "Here's the response to your request.", } payload := &context.NextHookPayload{ Messages: messages, Completion: completion, Tools: nil, Error: "", } response, _, err := agent.HookScript.Next(ctx, payload) if err != nil { t.Fatalf("Next hook failed: %v", err) } // Default behavior should return nil assert.Nil(t, response, "Default scenario should return nil for standard response") t.Log("✓ Default scenario passed") } ================================================ FILE: agent/assistant/hook/realworld_stress_test.go ================================================ package hook_test import ( stdContext "context" "fmt" "runtime" "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // ============================================================================ // Real World Stress Tests // These tests simulate actual production usage patterns with Stream() flow // ============================================================================ // TestRealWorldSimpleScenario tests basic Stream() flow with simple Create hook func TestRealWorldSimpleScenario(t *testing.T) { if testing.Short() { t.Skip("Skipping real world test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld") if err != nil { t.Fatalf("Failed to get assistant: %v", err) } ctx := newRealWorldContext("test-simple", "tests.realworld") // Test Create hook with simple scenario messages := []context.Message{ {Role: "user", Content: "simple"}, } response, _, err := agent.HookScript.Create(ctx, messages) if err != nil { t.Fatalf("Create failed: %v", err) } assert.NotNil(t, response) assert.NotEmpty(t, response.Messages) assert.Equal(t, "simple", response.Metadata["scenario"]) } // TestRealWorldMCPScenarios tests MCP integration scenarios func TestRealWorldMCPScenarios(t *testing.T) { if testing.Short() { t.Skip("Skipping real world test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld") if err != nil { t.Fatalf("Failed to get assistant: %v", err) } t.Run("MCP Health", func(t *testing.T) { ctx := newRealWorldContext("test-mcp-health", "tests.realworld") messages := []context.Message{ {Role: "user", Content: "mcp_health"}, } response, _, err := agent.HookScript.Create(ctx, messages) if err != nil { t.Fatalf("Create failed: %v", err) } // Detailed validation assert.NotNil(t, response) assert.NotEmpty(t, response.Messages) // Check if metadata exists if response.Metadata == nil { t.Logf("⚠ Metadata is nil - checking messages content") // Verify messages contain expected content messageContent := "" for _, msg := range response.Messages { if content, ok := msg.Content.(string); ok { messageContent += content + "\n" } } assert.Contains(t, messageContent, "Health", "Message should mention health") assert.Contains(t, messageContent, "Tools", "Message should mention tools") t.Logf("✓ MCP Health executed (verified via message content)") } else { assert.Equal(t, "mcp_health", response.Metadata["scenario"]) // Verify metadata contains MCP results if toolsCount, ok := response.Metadata["tools_count"]; ok { count := int(toolsCount.(float64)) assert.Greater(t, count, 0, "Should have tools from MCP") t.Logf("✓ MCP Health: %d tools, health data: %v", count, response.Metadata["health_data"]) } } ctx.Release() }) t.Run("MCP Tools", func(t *testing.T) { ctx := newRealWorldContext("test-mcp-tools", "tests.realworld") messages := []context.Message{ {Role: "user", Content: "mcp_tools"}, } response, _, err := agent.HookScript.Create(ctx, messages) if err != nil { t.Fatalf("Create failed: %v", err) } // Detailed validation assert.NotNil(t, response) assert.NotEmpty(t, response.Messages) // Check if metadata exists if response.Metadata == nil { t.Logf("⚠ Metadata is nil - checking messages content") // Verify messages contain expected content messageContent := "" for _, msg := range response.Messages { if content, ok := msg.Content.(string); ok { messageContent += content + "\n" } } assert.Contains(t, messageContent, "Tools", "Message should mention tools") assert.Contains(t, messageContent, "Ping", "Message should mention ping") t.Logf("✓ MCP Tools executed (verified via message content)") } else { assert.Equal(t, "mcp_tools", response.Metadata["scenario"]) // Verify tools were called if toolsCount, ok := response.Metadata["tools_count"]; ok { count := int(toolsCount.(float64)) assert.Greater(t, count, 0, "Should have tools from MCP") // Verify operations list if operations, ok := response.Metadata["operations"].([]interface{}); ok { assert.Len(t, operations, 2, "Should execute 2 operations: ping, status") t.Logf("✓ MCP Tools: %d tools, operations: %v", count, operations) } } } ctx.Release() }) t.Run("Full Workflow", func(t *testing.T) { ctx := newRealWorldContext("test-full-workflow", "tests.realworld") // Initialize stack for trace stack, _, done := context.EnterStack(ctx, "tests.realworld", &context.Options{}) defer done() ctx.Stack = stack messages := []context.Message{ {Role: "user", Content: "full_workflow"}, } response, _, err := agent.HookScript.Create(ctx, messages) if err != nil { t.Fatalf("Create failed: %v", err) } // Detailed validation assert.NotNil(t, response) assert.NotEmpty(t, response.Messages) // Check if metadata exists if response.Metadata == nil { t.Logf("⚠ Metadata is nil - checking messages content") // Verify messages contain expected content messageContent := "" for _, msg := range response.Messages { if content, ok := msg.Content.(string); ok { messageContent += content + "\n" } } assert.Contains(t, messageContent, "Workflow", "Message should mention workflow") assert.Contains(t, messageContent, "Tools", "Message should mention tools") assert.Contains(t, messageContent, "Roles", "Message should mention database roles") t.Logf("✓ Full Workflow executed (verified via message content)") } else { assert.Equal(t, "full_workflow", response.Metadata["scenario"]) // Verify all phases completed if phasesCompleted, ok := response.Metadata["phases_completed"]; ok { phases := int(phasesCompleted.(float64)) assert.Equal(t, 4, phases, "Should complete 4 phases") // Verify MCP tools if mcpTools, ok := response.Metadata["mcp_tools"]; ok { tools := int(mcpTools.(float64)) assert.Greater(t, tools, 0, "Should have MCP tools") // Verify DB records if dbRecords, ok := response.Metadata["db_records"]; ok { records := int(dbRecords.(float64)) assert.GreaterOrEqual(t, records, 0, "Should have DB query result") t.Logf("✓ Full Workflow: %d phases, %d MCP tools, %d DB records", phases, tools, records) } } } } ctx.Release() }) } // TestRealWorldTraceIntensive tests trace-heavy scenarios func TestRealWorldTraceIntensive(t *testing.T) { if testing.Short() { t.Skip("Skipping real world test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld") if err != nil { t.Fatalf("Failed to get assistant: %v", err) } ctx := newRealWorldContext("test-trace-intensive", "tests.realworld") stack, _, done := context.EnterStack(ctx, "tests.realworld", &context.Options{}) defer done() ctx.Stack = stack messages := []context.Message{ {Role: "user", Content: "trace_intensive"}, } response, _, err := agent.HookScript.Create(ctx, messages) if err != nil { t.Fatalf("Create failed: %v", err) } assert.NotNil(t, response) assert.Equal(t, "trace_intensive", response.Metadata["scenario"]) assert.NotZero(t, response.Metadata["nodes_created"]) } // TestRealWorldStressSimple tests simple scenario under stress func TestRealWorldStressSimple(t *testing.T) { if testing.Short() { t.Skip("Skipping stress test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld") if err != nil { t.Fatalf("Failed to get assistant: %v", err) } iterations := 100 startMemory := getMemStats() for i := 0; i < iterations; i++ { ctx := newRealWorldContext(fmt.Sprintf("stress-simple-%d", i), "tests.realworld") messages := []context.Message{ {Role: "user", Content: "simple"}, } response, _, err := agent.HookScript.Create(ctx, messages) if err != nil { t.Fatalf("Iteration %d failed: %v", i, err) } // Validate response assert.NotNil(t, response, "Iteration %d: response should not be nil", i) assert.NotEmpty(t, response.Messages, "Iteration %d: messages should not be empty", i) if response.Metadata != nil { assert.Equal(t, "simple", response.Metadata["scenario"], "Iteration %d: scenario mismatch", i) } // Explicit cleanup ctx.Release() if i%20 == 0 { runtime.GC() currentMemory := getMemStats() t.Logf("Iteration %d: Memory: %d MB", i, currentMemory/1024/1024) } } runtime.GC() endMemory := getMemStats() t.Logf("Simple stress: %d iterations", iterations) t.Logf("Start memory: %d MB", startMemory/1024/1024) t.Logf("End memory: %d MB", endMemory/1024/1024) if endMemory > startMemory { t.Logf("Memory growth: %d MB", (endMemory-startMemory)/1024/1024) } else { t.Logf("Memory decreased: %d MB", (startMemory-endMemory)/1024/1024) } } // TestRealWorldStressMCP tests MCP scenarios under stress func TestRealWorldStressMCP(t *testing.T) { if testing.Short() { t.Skip("Skipping stress test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld") if err != nil { t.Fatalf("Failed to get assistant: %v", err) } iterations := 50 scenarios := []string{"mcp_health", "mcp_tools"} startMemory := getMemStats() for i := 0; i < iterations; i++ { scenario := scenarios[i%len(scenarios)] ctx := newRealWorldContext(fmt.Sprintf("stress-mcp-%d", i), "tests.realworld") // Initialize stack for trace stack, _, done := context.EnterStack(ctx, "tests.realworld", &context.Options{}) ctx.Stack = stack messages := []context.Message{ {Role: "user", Content: scenario}, } response, _, err := agent.HookScript.Create(ctx, messages) if err != nil { t.Fatalf("Iteration %d (%s) failed: %v", i, scenario, err) } // Validate response assert.NotNil(t, response, "Iteration %d (%s): response should not be nil", i, scenario) assert.NotEmpty(t, response.Messages, "Iteration %d (%s): messages should not be empty", i, scenario) // Validate metadata if response.Metadata != nil { assert.Equal(t, scenario, response.Metadata["scenario"], "Iteration %d: scenario mismatch", i) // Verify MCP-specific data if scenario == "mcp_health" { assert.NotNil(t, response.Metadata["tools_count"], "Iteration %d: should have tools_count", i) if toolsCount, ok := response.Metadata["tools_count"].(float64); ok { assert.Greater(t, int(toolsCount), 0, "Iteration %d: should have at least 1 tool", i) assert.Equal(t, 3, int(toolsCount), "Iteration %d: echo should have 3 tools", i) } assert.NotNil(t, response.Metadata["health_data"], "Iteration %d: should have health_data", i) } else if scenario == "mcp_tools" { assert.NotNil(t, response.Metadata["tools_count"], "Iteration %d: should have tools_count", i) if toolsCount, ok := response.Metadata["tools_count"].(float64); ok { assert.Equal(t, 3, int(toolsCount), "Iteration %d: echo should have 3 tools", i) } assert.NotNil(t, response.Metadata["operations"], "Iteration %d: should have operations", i) if operations, ok := response.Metadata["operations"].([]interface{}); ok { assert.Len(t, operations, 2, "Iteration %d: should have 2 operations (ping, status)", i) } } } else { t.Errorf("Iteration %d (%s): metadata is nil", i, scenario) } // Cleanup done() ctx.Release() if i%10 == 0 { runtime.GC() currentMemory := getMemStats() t.Logf("Iteration %d (%s): Memory: %d MB", i, scenario, currentMemory/1024/1024) } } runtime.GC() endMemory := getMemStats() t.Logf("MCP stress: %d iterations", iterations) t.Logf("Start memory: %d MB", startMemory/1024/1024) t.Logf("End memory: %d MB", endMemory/1024/1024) if endMemory > startMemory { t.Logf("Memory growth: %d MB", (endMemory-startMemory)/1024/1024) } else { t.Logf("Memory decreased: %d MB", (startMemory-endMemory)/1024/1024) } } // TestRealWorldStressFullWorkflow tests complete workflow under stress func TestRealWorldStressFullWorkflow(t *testing.T) { if testing.Short() { t.Skip("Skipping stress test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld") if err != nil { t.Fatalf("Failed to get assistant: %v", err) } iterations := 30 startMemory := getMemStats() startTime := time.Now() for i := 0; i < iterations; i++ { ctx := newRealWorldContext(fmt.Sprintf("stress-workflow-%d", i), "tests.realworld") // Initialize stack for trace stack, _, done := context.EnterStack(ctx, "tests.realworld", &context.Options{}) ctx.Stack = stack messages := []context.Message{ {Role: "user", Content: "full_workflow"}, } response, _, err := agent.HookScript.Create(ctx, messages) if err != nil { t.Fatalf("Iteration %d failed: %v", i, err) } // Verify response assert.NotNil(t, response, "Iteration %d: response should not be nil", i) assert.NotEmpty(t, response.Messages, "Iteration %d: messages should not be empty", i) if response.Metadata != nil { assert.Equal(t, "full_workflow", response.Metadata["scenario"], "Iteration %d: scenario mismatch", i) // Verify workflow-specific metadata if phasesCompleted, ok := response.Metadata["phases_completed"]; ok { phases := int(phasesCompleted.(float64)) assert.Equal(t, 4, phases, "Iteration %d: should complete 4 phases", i) } if mcpTools, ok := response.Metadata["mcp_tools"]; ok { tools := int(mcpTools.(float64)) assert.Greater(t, tools, 0, "Iteration %d: should have MCP tools", i) } } // Cleanup done() ctx.Release() if i%10 == 0 { runtime.GC() currentMemory := getMemStats() elapsed := time.Since(startTime) t.Logf("Iteration %d: Memory: %d MB, Elapsed: %v", i, currentMemory/1024/1024, elapsed) } } duration := time.Since(startTime) runtime.GC() endMemory := getMemStats() avgTime := duration / time.Duration(iterations) t.Logf("Full workflow stress: %d iterations", iterations) t.Logf("Total time: %v", duration) t.Logf("Average time per iteration: %v", avgTime) t.Logf("Start memory: %d MB", startMemory/1024/1024) t.Logf("End memory: %d MB", endMemory/1024/1024) if endMemory > startMemory { t.Logf("Memory growth: %d MB", (endMemory-startMemory)/1024/1024) } else { t.Logf("Memory decreased: %d MB", (startMemory-endMemory)/1024/1024) } } // TestRealWorldStressConcurrent tests concurrent real-world usage func TestRealWorldStressConcurrent(t *testing.T) { if testing.Short() { t.Skip("Skipping stress test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld") if err != nil { t.Fatalf("Failed to get assistant: %v", err) } goroutines := 100 iterationsPerGoroutine := 10 scenarios := []string{"simple", "mcp_health", "mcp_tools", "full_workflow"} startMemory := getMemStats() startTime := time.Now() var wg sync.WaitGroup errors := make(chan error, goroutines*iterationsPerGoroutine) // Track results for validation type Result struct { goroutineID int iteration int scenario string metadata map[string]interface{} } results := make(chan Result, goroutines*iterationsPerGoroutine) for g := 0; g < goroutines; g++ { wg.Add(1) go func(goroutineID int) { defer wg.Done() for i := 0; i < iterationsPerGoroutine; i++ { scenario := scenarios[(goroutineID+i)%len(scenarios)] ctx := newRealWorldContext( fmt.Sprintf("concurrent-%d-%d", goroutineID, i), "tests.realworld", ) // Initialize stack for trace stack, _, done := context.EnterStack(ctx, "tests.realworld", &context.Options{}) ctx.Stack = stack messages := []context.Message{ {Role: "user", Content: scenario}, } response, _, err := agent.HookScript.Create(ctx, messages) if err != nil { errors <- fmt.Errorf("goroutine %d iteration %d (%s): %v", goroutineID, i, scenario, err) done() ctx.Release() return } // Validate response if response == nil { errors <- fmt.Errorf("goroutine %d iteration %d (%s): nil response", goroutineID, i, scenario) done() ctx.Release() return } if len(response.Messages) == 0 { errors <- fmt.Errorf("goroutine %d iteration %d (%s): empty messages", goroutineID, i, scenario) done() ctx.Release() return } // Collect result results <- Result{ goroutineID: goroutineID, iteration: i, scenario: scenario, metadata: response.Metadata, } // Cleanup done() ctx.Release() } }(g) } wg.Wait() close(errors) close(results) duration := time.Since(startTime) runtime.GC() endMemory := getMemStats() // Check for errors errorCount := 0 for err := range errors { t.Error(err) errorCount++ } assert.Equal(t, 0, errorCount, "No errors should occur in concurrent operations") // Validate results scenarioCounts := make(map[string]int) validResults := 0 for result := range results { validResults++ scenarioCounts[result.scenario]++ // Validate metadata exists and has expected scenario if result.metadata != nil { if scenario, ok := result.metadata["scenario"].(string); ok { if scenario != result.scenario { t.Errorf("Metadata mismatch: expected %s, got %s (goroutine %d, iteration %d)", result.scenario, scenario, result.goroutineID, result.iteration) } } } } totalOperations := goroutines * iterationsPerGoroutine assert.Equal(t, totalOperations, validResults, "All operations should return valid results") avgTime := duration / time.Duration(totalOperations) t.Logf("✓ Concurrent stress: %d operations (goroutines: %d, iterations: %d)", totalOperations, goroutines, iterationsPerGoroutine) t.Logf("✓ Valid results: %d/%d (100%%)", validResults, totalOperations) t.Logf("✓ Scenario distribution:") for scenario, count := range scenarioCounts { t.Logf(" - %s: %d operations", scenario, count) } t.Logf("✓ Total time: %v", duration) t.Logf("✓ Average time per operation: %v", avgTime) t.Logf("✓ Start memory: %d MB", startMemory/1024/1024) t.Logf("✓ End memory: %d MB", endMemory/1024/1024) if endMemory > startMemory { t.Logf("✓ Memory growth: %d MB", (endMemory-startMemory)/1024/1024) } else { t.Logf("✓ Memory decreased: %d MB", (startMemory-endMemory)/1024/1024) } } // TestRealWorldStressResourceHeavy tests resource-intensive scenarios func TestRealWorldStressResourceHeavy(t *testing.T) { if testing.Short() { t.Skip("Skipping stress test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.realworld") if err != nil { t.Fatalf("Failed to get assistant: %v", err) } iterations := 20 startMemory := getMemStats() startTime := time.Now() for i := 0; i < iterations; i++ { ctx := newRealWorldContext(fmt.Sprintf("stress-heavy-%d", i), "tests.realworld") // Initialize stack for trace stack, _, done := context.EnterStack(ctx, "tests.realworld", &context.Options{}) ctx.Stack = stack messages := []context.Message{ {Role: "user", Content: "resource_heavy"}, } response, _, err := agent.HookScript.Create(ctx, messages) if err != nil { t.Fatalf("Iteration %d failed: %v", i, err) } // Validate response assert.NotNil(t, response, "Iteration %d: response should not be nil", i) assert.NotEmpty(t, response.Messages, "Iteration %d: messages should not be empty", i) if response.Metadata != nil { assert.Equal(t, "resource_heavy", response.Metadata["scenario"], "Iteration %d: scenario mismatch", i) // Verify resource-heavy metadata if mcpIterations, ok := response.Metadata["mcp_iterations"]; ok { iterations := int(mcpIterations.(float64)) assert.Equal(t, 5, iterations, "Iteration %d: should have 5 MCP iterations", i) } } // Cleanup done() ctx.Release() if i%5 == 0 { runtime.GC() currentMemory := getMemStats() elapsed := time.Since(startTime) t.Logf("Iteration %d: Memory: %d MB, Elapsed: %v", i, currentMemory/1024/1024, elapsed) } } duration := time.Since(startTime) runtime.GC() endMemory := getMemStats() avgTime := duration / time.Duration(iterations) t.Logf("Resource heavy stress: %d iterations", iterations) t.Logf("Total time: %v", duration) t.Logf("Average time per iteration: %v", avgTime) t.Logf("Start memory: %d MB", startMemory/1024/1024) t.Logf("End memory: %d MB", endMemory/1024/1024) if endMemory > startMemory { memoryGrowth := int64(endMemory - startMemory) t.Logf("Memory growth: %d MB", memoryGrowth/1024/1024) // Allow up to 100MB growth for resource-heavy operations assert.Less(t, memoryGrowth, int64(100*1024*1024), "Memory growth should be reasonable") } else { t.Logf("Memory decreased: %d MB", (startMemory-endMemory)/1024/1024) } } // ============================================================================ // Helper Functions // ============================================================================ // newRealWorldContext creates a Context for real-world testing func newRealWorldContext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "realworld-test-user", ClientID: "realworld-test-client", Scope: "openid profile email", SessionID: "realworld-test-session", UserID: "realworld-user-123", TeamID: "realworld-team-456", TenantID: "realworld-tenant-789", RememberMe: true, Constraints: types.DataConstraints{ OwnerOnly: false, CreatorOnly: false, EditorOnly: false, TeamOnly: true, Extra: map[string]interface{}{ "department": "engineering", "region": "us-west", "project": "yao-realworld-test", }, }, } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "RealWorldTest/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.Route = "" ctx.Metadata = make(map[string]interface{}) return ctx } // getMemStats returns current memory allocation in bytes func getMemStats() uint64 { runtime.GC() var m runtime.MemStats runtime.ReadMemStats(&m) return m.Alloc } ================================================ FILE: agent/assistant/hook/script.go ================================================ package hook import ( "strings" "github.com/yaoapp/yao/agent/context" ) // Execute execute the script func (s *Script) Execute(ctx *context.Context, method string, args ...interface{}) (interface{}, error) { if s == nil || s.Script == nil { return nil, nil } var sid = "" if ctx.Authorized != nil { sid = ctx.Authorized.SessionID } scriptCtx, err := s.NewContext(sid, nil) if err != nil { return nil, err } defer scriptCtx.Close() // Set authorized information if available if ctx.Authorized != nil { scriptCtx.WithAuthorized(ctx.Authorized.AuthorizedToMap()) } // The first argument is the context args = append([]interface{}{ctx}, args...) // Try to call the method result, err := scriptCtx.CallWith(ctx.Context, method, args...) // If method doesn't exist (ReferenceError or similar), return nil without error if err != nil && (strings.Contains(err.Error(), "is not defined") || strings.Contains(err.Error(), "is not a function") || strings.Contains(err.Error(), "is not a Function")) { return nil, nil } return result, err } ================================================ FILE: agent/assistant/hook/types.go ================================================ package hook import ( v8 "github.com/yaoapp/gou/runtime/v8" ) // Script the script hook align type Script struct { *v8.Script } ================================================ FILE: agent/assistant/llm.go ================================================ package assistant import ( "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/llm" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/trace/types" ) // executeLLMStream executes the LLM streaming call with pre-built request // Returns completionResponse and error func (ast *Assistant) executeLLMStream( ctx *context.Context, completionMessages []context.Message, completionOptions *context.CompletionOptions, agentNode types.Node, streamHandler message.StreamFunc, opts *context.Options, ) (*context.CompletionResponse, error) { // Get connector object (capabilities were already set above, before stream_start) conn, capabilities, err := ast.GetConnector(ctx, opts) if err != nil { ast.traceAgentFail(agentNode, err) return nil, err } // Set capabilities in options if not already set if completionOptions.Capabilities == nil && capabilities != nil { completionOptions.Capabilities = capabilities } // Log the capabilities ast.traceConnectorCapabilities(agentNode, capabilities) // Build content - convert extended types (file, data, __yao.attachment://) to standard LLM types // This is done here (right before LLM call) to ensure: // 1. autoSearch receives original messages (not converted) // 2. delegate receives original messages (not converted) // 3. Only the actual LLM call sees converted messages llmMessages, err := ast.BuildContent(ctx, completionMessages, completionOptions, opts) if err != nil { ast.traceAgentFail(agentNode, err) return nil, err } // Trace Add LLM request (use converted messages for trace) ast.traceLLMRequest(ctx, conn.ID(), llmMessages, completionOptions) // Log LLM call start ctx.Logger.LLMStart(conn.ID(), "", len(llmMessages)) // Create LLM instance with connector and options llmInstance, err := llm.New(conn, completionOptions) if err != nil { // Mark LLM Request as failed in trace ast.traceLLMFail(ctx, err) return nil, err } // Call the LLM Completion Stream (streamHandler was set earlier) // Use llmMessages (converted) instead of completionMessages (original) completionResponse, err := llmInstance.Stream(ctx, llmMessages, completionOptions, streamHandler) if err != nil { // Mark LLM Request as failed in trace ast.traceLLMFail(ctx, err) return nil, err } // Mark LLM Request Complete ast.traceLLMComplete(ctx, completionResponse) return completionResponse, nil } // executeLLMForToolRetry executes LLM call for tool retry with streaming output // This is used when retrying tool calls - we still want to show LLM's response to users // Returns completionResponse and error func (ast *Assistant) executeLLMForToolRetry( ctx *context.Context, completionMessages []context.Message, completionOptions *context.CompletionOptions, agentNode types.Node, streamHandler message.StreamFunc, opts *context.Options, ) (*context.CompletionResponse, error) { // Get connector object conn, capabilities, err := ast.GetConnector(ctx, opts) if err != nil { ast.traceAgentFail(agentNode, err) return nil, err } // Set capabilities in options if not already set if completionOptions.Capabilities == nil && capabilities != nil { completionOptions.Capabilities = capabilities } // Build content - convert extended types for LLM call llmMessages, err := ast.BuildContent(ctx, completionMessages, completionOptions, opts) if err != nil { ast.traceAgentFail(agentNode, err) return nil, err } // Trace Add LLM retry request ast.traceLLMRetryRequest(ctx, conn.ID(), llmMessages, completionOptions) // Log LLM call start (retry) ctx.Logger.LLMStart(conn.ID(), "", len(llmMessages)) // Create LLM instance with connector and options llmInstance, err := llm.New(conn, completionOptions) if err != nil { // Mark LLM Retry Request as failed in trace ast.traceLLMFail(ctx, err) return nil, err } // Call the LLM Completion Stream (still streaming for tool retry) // Use llmMessages (converted) instead of completionMessages (original) completionResponse, err := llmInstance.Stream(ctx, llmMessages, completionOptions, streamHandler) if err != nil { // Mark LLM Retry Request as failed in trace ast.traceLLMFail(ctx, err) return nil, err } // Mark LLM Request Complete ast.traceLLMComplete(ctx, completionResponse) return completionResponse, nil } ================================================ FILE: agent/assistant/load.go ================================================ package assistant import ( "fmt" "os" "path/filepath" "strings" jsoniter "github.com/json-iterator/go" "github.com/spf13/cast" "github.com/yaoapp/gou/application" "github.com/yaoapp/gou/fs" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" sandboxTypes "github.com/yaoapp/yao/agent/sandbox/v2/types" searchTypes "github.com/yaoapp/yao/agent/search/types" store "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/config" "gopkg.in/yaml.v3" ) // loaded the loaded assistant var loaded = NewCache(200) // 200 is the default capacity var storage store.Store = nil var storeSetting *store.Setting = nil // store setting from agent.yml var defaultConnector string = "" // default connector var globalUses *context.Uses = nil // global uses configuration from agent.yml var globalPrompts []store.Prompt = nil // global prompts from agent/prompts.yml var globalKBSetting *store.KBSetting = nil // global KB setting from agent/kb.yml var globalSearchConfig *searchTypes.Config = nil // global search config from agent/search.yml // LoadBuiltIn load the built-in assistants func LoadBuiltIn() error { // Clear non-system agents from cache (preserve system agents loaded by LoadSystemAgents) loaded.ClearExcept(func(id string) bool { return strings.HasPrefix(id, "__yao.") // Keep system agents }) root := `/assistants` app, err := fs.Get("app") if err != nil { return err } // Get all existing built-in assistants deletedBuiltIn := map[string]bool{} // Remove the built-in assistants (exclude system agents with __yao. prefix) if storage != nil { builtIn := true res, err := storage.GetAssistants(store.AssistantFilter{BuiltIn: &builtIn, Select: []string{"assistant_id", "id"}}) if err != nil { return err } // Get all existing built-in assistants (exclude system agents) for _, assistant := range res.Data { // Skip system agents (they are managed by LoadSystemAgents) if strings.HasPrefix(assistant.ID, "__yao.") { continue } deletedBuiltIn[assistant.ID] = true } } // Check if the assistant is built-in if exists, _ := app.Exists(root); !exists { return nil } paths, err := app.ReadDir(root, true) if err != nil { return err } sort := 1 for _, path := range paths { pkgfile := filepath.Join(path, "package.yao") if has, _ := app.Exists(pkgfile); !has { continue } assistant, err := LoadPath(path) if err != nil { return err } assistant.Readonly = true assistant.BuiltIn = true if assistant.Sort == 0 { assistant.Sort = sort } if assistant.Tags == nil { assistant.Tags = []string{} } // Save the assistant err = assistant.Save() if err != nil { return err } // Initialize the assistant err = assistant.initialize() if err != nil { return err } sort++ loaded.Put(assistant) // Remove the built-in assistant from the store delete(deletedBuiltIn, assistant.ID) } // Remove deleted built-in assistants if len(deletedBuiltIn) > 0 { assistantIDs := []string{} for assistantID := range deletedBuiltIn { assistantIDs = append(assistantIDs, assistantID) } _, err := storage.DeleteAssistants(store.AssistantFilter{AssistantIDs: assistantIDs}) if err != nil { return err } } return nil } // SetStorage set the storage func SetStorage(s store.Store) { storage = s } // GetStorage returns the storage (for testing purposes) func GetStorage() store.Store { return storage } // SetConnector set the connector func SetConnector(c string) { defaultConnector = c } // SetGlobalUses set the global uses configuration func SetGlobalUses(uses *context.Uses) { globalUses = uses } // SetGlobalPrompts set the global prompts from agent/prompts.yml func SetGlobalPrompts(prompts []store.Prompt) { globalPrompts = prompts } // SetStoreSetting set the store setting from agent.yml func SetStoreSetting(setting *store.Setting) { storeSetting = setting } // GetStoreSetting returns the store setting func GetStoreSetting() *store.Setting { return storeSetting } // GetGlobalPrompts returns the global prompts with variables parsed // ctx: context variables for parsing $CTX.* variables func GetGlobalPrompts(ctx map[string]string) []store.Prompt { if len(globalPrompts) == 0 { return nil } return store.Prompts(globalPrompts).Parse(ctx) } // SetGlobalKBSetting set the global KB setting from agent/kb.yml func SetGlobalKBSetting(kbSetting *store.KBSetting) { globalKBSetting = kbSetting } // GetGlobalKBSetting returns the global KB setting func GetGlobalKBSetting() *store.KBSetting { return globalKBSetting } // SetGlobalSearchConfig set the global search config from agent/search.yml func SetGlobalSearchConfig(config *searchTypes.Config) { globalSearchConfig = config } // GetGlobalSearchConfig returns the global search config func GetGlobalSearchConfig() *searchTypes.Config { return globalSearchConfig } // SetCache set the cache func SetCache(capacity int) { ClearCache() loaded = NewCache(capacity) } // ClearCache clear the cache func ClearCache() { if loaded != nil { loaded.Clear() loaded = nil } } // GetCache returns the loaded cache func GetCache() *Cache { return loaded } // LoadStore create a new assistant from store func LoadStore(id string) (*Assistant, error) { if id == "" { return nil, fmt.Errorf("assistant_id is required") } assistant, exists := loaded.Get(id) if exists { return assistant, nil } if storage == nil { return nil, fmt.Errorf("storage is not set") } // Request all fields when loading assistant from store storeModel, err := storage.GetAssistant(id, store.AssistantFullFields) if err != nil { return nil, err } // Load from path if storeModel.Path != "" { assistant, err = LoadPath(storeModel.Path) if err != nil { return nil, err } loaded.Put(assistant) return assistant, nil } // Create assistant from store model assistant = &Assistant{AssistantModel: *storeModel} // Load script from source field if present if assistant.Source != "" { script, err := loadSource(assistant.Source, assistant.ID) if err != nil { return nil, err } assistant.HookScript = script } // Initialize the assistant err = assistant.initialize() if err != nil { return nil, err } loaded.Put(assistant) return assistant, nil } // loadPackage loads and parses the package.yao file func loadPackage(path string) (map[string]interface{}, error) { app, err := fs.Get("app") if err != nil { return nil, err } pkgfile := filepath.Join(path, "package.yao") if has, _ := app.Exists(pkgfile); !has { return nil, fmt.Errorf("package.yao not found in %s", path) } pkgraw, err := app.ReadFile(pkgfile) if err != nil { return nil, err } var data map[string]interface{} err = application.Parse(pkgfile, pkgraw, &data) if err != nil { return nil, err } // Process connector environment variable if connector, ok := data["connector"].(string); ok { if strings.HasPrefix(connector, "$ENV.") { envKey := strings.TrimPrefix(connector, "$ENV.") if envValue := os.Getenv(envKey); envValue != "" { data["connector"] = envValue } } } return data, nil } // LoadPath load assistant from path func LoadPath(path string) (*Assistant, error) { app, err := fs.Get("app") if err != nil { return nil, err } data, err := loadPackage(path) if err != nil { return nil, err } // assistant_id id := strings.ReplaceAll(strings.TrimPrefix(path, "/assistants/"), "/", ".") data["assistant_id"] = id data["path"] = path if _, has := data["type"]; !has { data["type"] = "assistant" } updatedAt := int64(0) // prompts (default prompts from prompts.yml) promptsfile := filepath.Join(path, "prompts.yml") if has, _ := app.Exists(promptsfile); has { prompts, ts, err := store.LoadPrompts(promptsfile, path) if err != nil { return nil, err } data["prompts"] = prompts data["updated_at"] = ts updatedAt = ts } // prompt_presets (from prompts directory, key is filename without extension) promptsDir := filepath.Join(path, "prompts") if has, _ := app.Exists(promptsDir); has { presets, ts, err := store.LoadPromptPresets(promptsDir, path) if err != nil { return nil, err } if len(presets) > 0 { data["prompt_presets"] = presets updatedAt = max(updatedAt, ts) } } // load scripts (hook script and other scripts) from src directory srcDir := filepath.Join(path, "src") if has, _ := app.Exists(srcDir); has { hookScript, scripts, err := LoadScripts(srcDir) if err != nil { return nil, err } // Set hook script and update timestamp if hookScript != nil { data["script"] = hookScript // Get timestamp from index.ts if exists scriptfile := filepath.Join(srcDir, "index.ts") if ts, err := app.ModTime(scriptfile); err == nil { data["updated_at"] = max(updatedAt, ts.UnixNano()) } } // Set other scripts if len(scripts) > 0 { data["scripts"] = scripts } } // i18ns locales, err := i18n.GetLocales(path) if err != nil { return nil, err } data["locales"] = locales // V2 sandbox: load standalone sandbox.yao if present (Path A). sandboxFile := filepath.Join(path, "sandbox.yao") if has, _ := app.Exists(sandboxFile); has { absFile := filepath.Join(config.Conf.AppSource, sandboxFile) sbCfg, sbErr := store.LoadSandboxConfig(absFile) if sbErr != nil { return nil, fmt.Errorf("load sandbox.yao: %w", sbErr) } data["__sandbox_v2"] = sbCfg } ast, err := loadMap(data) if err != nil { return nil, err } // If V2 sandbox was loaded via Path A, assign it now. if sbCfg, ok := data["__sandbox_v2"].(*sandboxTypes.SandboxConfig); ok && sbCfg != nil { ast.SandboxV2 = sbCfg } // Extract Sandbox flag and ComputerFilter from V2 sandbox config. if ast.SandboxV2 != nil { ast.IsSandbox = true ast.ComputerFilter = ast.SandboxV2.Filter } // Compute config hash for V2 sandbox. if ast.SandboxV2 != nil { var mcpServers []store.MCPServerConfig if ast.MCP != nil { mcpServers = ast.MCP.Servers } skillsDir := "" if ast.Path != "" { dir := filepath.Join(config.Conf.AppSource, ast.Path, "skills") if info, e := os.Stat(dir); e == nil && info.IsDir() { skillsDir = dir } } ast.ConfigHash = store.ComputeConfigHash(ast.SandboxV2, mcpServers, skillsDir) } return ast, nil } func loadMap(data map[string]interface{}) (*Assistant, error) { assistant := &Assistant{} // assistant_id is required id, ok := data["assistant_id"].(string) if !ok { return nil, fmt.Errorf("assistant_id is required") } assistant.ID = id // name is required name, ok := data["name"].(string) if !ok { return nil, fmt.Errorf("name is required") } assistant.Name = name // avatar if avatar, ok := data["avatar"].(string); ok { assistant.Avatar = avatar } // Type if v, ok := data["type"].(string); ok { assistant.Type = v } // Placeholder if v, ok := data["placeholder"]; ok { switch vv := v.(type) { case string: placeholder, err := jsoniter.Marshal(vv) if err != nil { return nil, err } assistant.Placeholder = &store.Placeholder{} err = jsoniter.Unmarshal(placeholder, assistant.Placeholder) if err != nil { return nil, err } case map[string]interface{}: raw, err := jsoniter.Marshal(vv) if err != nil { return nil, err } assistant.Placeholder = &store.Placeholder{} err = jsoniter.Unmarshal(raw, assistant.Placeholder) if err != nil { return nil, err } case *store.Placeholder: assistant.Placeholder = vv case nil: assistant.Placeholder = nil } } // Mentionable if v, ok := data["mentionable"].(bool); ok { assistant.Mentionable = v } // Automated if v, ok := data["automated"].(bool); ok { assistant.Automated = v } // modes if v, has := data["modes"]; has { modes, err := store.ToModes(v) if err != nil { return nil, err } assistant.Modes = modes } // default_mode if v, ok := data["default_mode"].(string); ok { assistant.DefaultMode = v } // DisableGlobalPrompts if v, ok := data["disable_global_prompts"].(bool); ok { assistant.DisableGlobalPrompts = v } // Readonly if v, ok := data["readonly"].(bool); ok { assistant.Readonly = v } // Public if v, ok := data["public"].(bool); ok { assistant.Public = v } // Share if v, ok := data["share"].(string); ok { assistant.Share = v } // built_in if v, ok := data["built_in"].(bool); ok { assistant.BuiltIn = v } // sort if v, has := data["sort"]; has { assistant.Sort = cast.ToInt(v) } // path if v, ok := data["path"].(string); ok { assistant.Path = v } // connector if connector, ok := data["connector"].(string); ok { assistant.Connector = connector } // connector_options if connOpts, has := data["connector_options"]; has { opts, err := store.ToConnectorOptions(connOpts) if err != nil { return nil, err } assistant.ConnectorOptions = opts } // tags if v, has := data["tags"]; has { switch vv := v.(type) { case []string: assistant.Tags = vv case []interface{}: var tags []string for _, tag := range vv { tags = append(tags, cast.ToString(tag)) } assistant.Tags = tags case string: assistant.Tags = []string{vv} case interface{}: raw, err := jsoniter.Marshal(vv) if err != nil { return nil, err } var tags []string err = jsoniter.Unmarshal(raw, &tags) if err != nil { return nil, err } assistant.Tags = tags } } // options if v, ok := data["options"].(map[string]interface{}); ok { assistant.Options = v } // description if v, ok := data["description"].(string); ok { assistant.Description = v } // capabilities if v, ok := data["capabilities"].(string); ok { assistant.Capabilities = v } // locales if locales, ok := data["locales"].(i18n.Map); ok { assistant.Locales = locales flattened := locales.FlattenWithGlobal() // Auto-inject assistant name and description into all locales // so that {{name}} and {{description}} templates can be resolved for locale, i18nObj := range flattened { if i18nObj.Messages == nil { i18nObj.Messages = make(map[string]any) } // Add name, description, and capabilities if not already present if _, exists := i18nObj.Messages["name"]; !exists && assistant.Name != "" { i18nObj.Messages["name"] = assistant.Name } if _, exists := i18nObj.Messages["description"]; !exists && assistant.Description != "" { i18nObj.Messages["description"] = assistant.Description } if _, exists := i18nObj.Messages["capabilities"]; !exists && assistant.Capabilities != "" { i18nObj.Messages["capabilities"] = assistant.Capabilities } flattened[locale] = i18nObj } i18n.Locales[id] = flattened } else { // No locales defined, create default with name, description, and capabilities for all common locales if assistant.Name != "" || assistant.Description != "" || assistant.Capabilities != "" { defaultLocales := make(map[string]i18n.I18n) commonLocales := []string{"en", "en-us", "zh", "zh-cn", "zh-tw"} for _, locale := range commonLocales { messages := map[string]any{} if assistant.Name != "" { messages["name"] = assistant.Name } if assistant.Description != "" { messages["description"] = assistant.Description } if assistant.Capabilities != "" { messages["capabilities"] = assistant.Capabilities } defaultLocales[locale] = i18n.I18n{ Locale: locale, Messages: messages, } } i18n.Locales[id] = defaultLocales } } // Search configuration (from package.yao search block) // This contains search options like web.max_results, kb.threshold, citation.format, etc. // Merge hierarchy: global config < assistant config switch v := data["search"].(type) { case *searchTypes.Config: assistant.Search = v case searchTypes.Config: assistant.Search = &v case map[string]interface{}: var assistantSearch searchTypes.Config raw, err := jsoniter.Marshal(v) if err != nil { return nil, err } err = jsoniter.Unmarshal(raw, &assistantSearch) if err != nil { return nil, err } // Merge with global search config assistant.Search = mergeSearchConfig(globalSearchConfig, &assistantSearch) default: assistant.Search = globalSearchConfig } // prompts if prompts, has := data["prompts"]; has { switch v := prompts.(type) { case []store.Prompt: assistant.Prompts = v case string: var prompts []store.Prompt err := yaml.Unmarshal([]byte(v), &prompts) if err != nil { return nil, err } assistant.Prompts = prompts default: raw, err := jsoniter.Marshal(v) if err != nil { return nil, err } var prompts []store.Prompt err = jsoniter.Unmarshal(raw, &prompts) if err != nil { return nil, err } assistant.Prompts = prompts } } // prompt_presets if presets, has := data["prompt_presets"]; has { promptPresets, err := store.ToPromptPresets(presets) if err != nil { return nil, err } assistant.PromptPresets = promptPresets } // source (hook script code) - store the source code if source, ok := data["source"].(string); ok { assistant.Source = source } // kb if kb, has := data["kb"]; has { knowledgeBase, err := store.ToKnowledgeBase(kb) if err != nil { return nil, err } assistant.KB = knowledgeBase } // db if db, has := data["db"]; has { database, err := store.ToDatabase(db) if err != nil { return nil, err } assistant.DB = database } // mcp if mcp, has := data["mcp"]; has { mcpServers, err := store.ToMCPServers(mcp) if err != nil { return nil, err } assistant.MCP = mcpServers } // workflow if workflow, has := data["workflow"]; has { wf, err := store.ToWorkflow(workflow) if err != nil { return nil, err } assistant.Workflow = wf } // sandbox (for coding agents like Claude CLI, Cursor CLI) // V2 sandbox via independent sandbox.yao is loaded in LoadPath (below). // This block handles the package.yao embedded "sandbox" field with version dispatch. if assistant.SandboxV2 == nil { if sandbox, has := data["sandbox"]; has { version := extractSandboxVersion(sandbox) if version == sandboxTypes.SandboxVersionV2 { sb, err := store.ToSandboxV2(sandbox) if err != nil { return nil, err } assistant.SandboxV2 = sb assistant.IsSandbox = true assistant.ComputerFilter = sb.Filter } else { sb, err := store.ToSandbox(sandbox) if err != nil { return nil, err } assistant.Sandbox = sb } } } // dependencies (name -> version constraint, like npm dependencies) if deps, has := data["dependencies"]; has { switch v := deps.(type) { case map[string]string: assistant.Dependencies = v case map[string]interface{}: d := make(map[string]string, len(v)) for k, val := range v { d[k] = cast.ToString(val) } assistant.Dependencies = d default: raw, err := jsoniter.Marshal(v) if err != nil { return nil, err } var d map[string]string if err := jsoniter.Unmarshal(raw, &d); err != nil { return nil, err } assistant.Dependencies = d } } // uses (wrapper configurations for vision, audio, etc.) // Merge hierarchy: global uses < assistant uses if uses, has := data["uses"]; has { var assistantUses *context.Uses switch v := uses.(type) { case *context.Uses: assistantUses = v case context.Uses: assistantUses = &v default: raw, err := jsoniter.Marshal(v) if err != nil { return nil, err } var usesConfig context.Uses err = jsoniter.Unmarshal(raw, &usesConfig) if err != nil { return nil, err } assistantUses = &usesConfig } // Merge with global uses assistant.Uses = mergeUses(globalUses, assistantUses) } else if globalUses != nil { // No assistant-specific uses, use global assistant.Uses = globalUses } // Load scripts (hook script and other scripts) hookScript, scripts, scriptErr := LoadScriptsFromData(data, assistant.ID) if scriptErr != nil { return nil, scriptErr } assistant.HookScript = hookScript assistant.Scripts = scripts // created_at if v, has := data["created_at"]; has { ts, err := getTimestamp(v) if err != nil { return nil, err } assistant.CreatedAt = ts } // updated_at if v, has := data["updated_at"]; has { ts, err := getTimestamp(v) if err != nil { return nil, err } assistant.UpdatedAt = ts } // Initialize the assistant err := assistant.initialize() if err != nil { return nil, err } return assistant, nil } // Init init the assistant // Choose the connector and initialize the assistant func (ast *Assistant) initialize() error { conn := defaultConnector if ast.Connector != "" { conn = ast.Connector } ast.Connector = conn // Register scripts as process handlers if len(ast.Scripts) > 0 { if err := ast.RegisterScripts(); err != nil { return fmt.Errorf("failed to register scripts: %w", err) } } return nil } // mergeUses merges two Uses configs (base < override) func mergeUses(base, override *context.Uses) *context.Uses { if base == nil { return override } if override == nil { return base } result := *base // Copy base // Override with non-empty values if override.Vision != "" { result.Vision = override.Vision } if override.Audio != "" { result.Audio = override.Audio } if override.Search != "" { result.Search = override.Search } if override.Fetch != "" { result.Fetch = override.Fetch } if override.Web != "" { result.Web = override.Web } if override.Keyword != "" { result.Keyword = override.Keyword } if override.QueryDSL != "" { result.QueryDSL = override.QueryDSL } if override.Rerank != "" { result.Rerank = override.Rerank } return &result } // mergeSearchConfig merges two search configs (base < override) func mergeSearchConfig(base, override *searchTypes.Config) *searchTypes.Config { if base == nil { return override } if override == nil { return base } result := *base // Copy base // Merge Web config if override.Web != nil { if result.Web == nil { result.Web = override.Web } else { merged := *result.Web if override.Web.Provider != "" { merged.Provider = override.Web.Provider } if override.Web.APIKeyEnv != "" { merged.APIKeyEnv = override.Web.APIKeyEnv } if override.Web.MaxResults > 0 { merged.MaxResults = override.Web.MaxResults } result.Web = &merged } } // Merge KB config if override.KB != nil { if result.KB == nil { result.KB = override.KB } else { merged := *result.KB if len(override.KB.Collections) > 0 { merged.Collections = override.KB.Collections } if override.KB.Threshold > 0 { merged.Threshold = override.KB.Threshold } if override.KB.Graph { merged.Graph = override.KB.Graph } result.KB = &merged } } // Merge DB config if override.DB != nil { if result.DB == nil { result.DB = override.DB } else { merged := *result.DB if len(override.DB.Models) > 0 { merged.Models = override.DB.Models } if override.DB.MaxResults > 0 { merged.MaxResults = override.DB.MaxResults } result.DB = &merged } } // Merge Keyword config if override.Keyword != nil { if result.Keyword == nil { result.Keyword = override.Keyword } else { merged := *result.Keyword if override.Keyword.MaxKeywords > 0 { merged.MaxKeywords = override.Keyword.MaxKeywords } if override.Keyword.Language != "" { merged.Language = override.Keyword.Language } result.Keyword = &merged } } // Merge QueryDSL config if override.QueryDSL != nil { if result.QueryDSL == nil { result.QueryDSL = override.QueryDSL } else { merged := *result.QueryDSL if override.QueryDSL.Strict { merged.Strict = override.QueryDSL.Strict } result.QueryDSL = &merged } } // Merge Rerank config if override.Rerank != nil { if result.Rerank == nil { result.Rerank = override.Rerank } else { merged := *result.Rerank if override.Rerank.TopN > 0 { merged.TopN = override.Rerank.TopN } result.Rerank = &merged } } // Merge Citation config if override.Citation != nil { if result.Citation == nil { result.Citation = override.Citation } else { merged := *result.Citation if override.Citation.Format != "" { merged.Format = override.Citation.Format } // AutoInjectPrompt is a bool, so we check if it's explicitly set // by checking if the whole Citation block was provided merged.AutoInjectPrompt = override.Citation.AutoInjectPrompt if override.Citation.CustomPrompt != "" { merged.CustomPrompt = override.Citation.CustomPrompt } result.Citation = &merged } } // Merge Weights config if override.Weights != nil { if result.Weights == nil { result.Weights = override.Weights } else { merged := *result.Weights if override.Weights.User > 0 { merged.User = override.Weights.User } if override.Weights.Hook > 0 { merged.Hook = override.Weights.Hook } if override.Weights.Auto > 0 { merged.Auto = override.Weights.Auto } result.Weights = &merged } } // Merge Options config if override.Options != nil { if result.Options == nil { result.Options = override.Options } else { merged := *result.Options if override.Options.SkipThreshold > 0 { merged.SkipThreshold = override.Options.SkipThreshold } result.Options = &merged } } return &result } // extractSandboxVersion tries to read the "version" field from a sandbox config value. func extractSandboxVersion(v any) string { if m, ok := v.(map[string]any); ok { if ver, ok := m["version"].(string); ok { return ver } } return "" } ================================================ FILE: agent/assistant/load_merge_test.go ================================================ package assistant_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/testutils" ) // TestLoadPathMerge tests loading the merge test assistant // This verifies that global config is properly merged with assistant-specific config func TestLoadPathMerge(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.LoadPath("/assistants/tests/merge") require.NoError(t, err) require.NotNil(t, ast) assert.Equal(t, "tests.merge", ast.ID) assert.Equal(t, "Merge Config Test Assistant", ast.Name) // Uses configuration - should merge global with assistant-specific // Global (from agent/agent.yml): // vision: "workers.system.vision" // search: "workers.system.search" // fetch: "workers.system.fetch" // audio: (not set) // querydsl: (not set) // rerank: (not set) // Assistant: // web: "mcp:custom-web" // keyword: "mcp:custom-keyword" // Result: assistant values override, global values inherited assert.NotNil(t, ast.Uses) // Assistant overrides assert.Equal(t, "mcp:custom-web", ast.Uses.Web) // overridden by assistant assert.Equal(t, "mcp:custom-keyword", ast.Uses.Keyword) // overridden by assistant // Inherited from global (agent/agent.yml) assert.Equal(t, "workers.system.vision", ast.Uses.Vision) // inherited from global assert.Equal(t, "workers.system.search", ast.Uses.Search) // inherited from global assert.Equal(t, "workers.system.fetch", ast.Uses.Fetch) // inherited from global // Not set in either global or assistant (should be empty) assert.Empty(t, ast.Uses.Audio) // not set anywhere assert.Empty(t, ast.Uses.QueryDSL) // not set anywhere assert.Empty(t, ast.Uses.Rerank) // not set anywhere // Search configuration - should merge global with assistant-specific // Global (from agent/search.yml): // web.provider=tavily, web.max_results=10 // kb.threshold=0.7, kb.graph=false // db.max_results=20 // keyword.max_keywords=10, keyword.language=auto // rerank.top_n=10 // citation.format=#ref:{id}, citation.auto_inject_prompt=true // weights: user=1.0, hook=0.8, auto=0.6 // options.skip_threshold=5 // Assistant: // web.provider=custom-provider, web.max_results=25 // kb.collections=[merge-test-kb], kb.threshold=0.85 assert.NotNil(t, ast.Search) // Web config - assistant overrides global assert.NotNil(t, ast.Search.Web) assert.Equal(t, "custom-provider", ast.Search.Web.Provider) // overridden assert.Equal(t, 25, ast.Search.Web.MaxResults) // overridden // KB config - assistant overrides global assert.NotNil(t, ast.Search.KB) assert.Equal(t, []string{"merge-test-kb"}, ast.Search.KB.Collections) // overridden assert.Equal(t, 0.85, ast.Search.KB.Threshold) // overridden assert.False(t, ast.Search.KB.Graph) // inherited from global // DB config - should inherit from global (assistant doesn't define it) assert.NotNil(t, ast.Search.DB) assert.Equal(t, 20, ast.Search.DB.MaxResults) // inherited from global // Keyword config - should inherit from global assert.NotNil(t, ast.Search.Keyword) assert.Equal(t, 10, ast.Search.Keyword.MaxKeywords) // inherited from global assert.Equal(t, "auto", ast.Search.Keyword.Language) // inherited from global // Rerank config - should inherit from global assert.NotNil(t, ast.Search.Rerank) assert.Equal(t, 10, ast.Search.Rerank.TopN) // inherited from global // Citation config - should inherit from global assert.NotNil(t, ast.Search.Citation) assert.Equal(t, "#ref:{id}", ast.Search.Citation.Format) // inherited from global assert.True(t, ast.Search.Citation.AutoInjectPrompt) // inherited from global // Weights config - should inherit from global assert.NotNil(t, ast.Search.Weights) assert.Equal(t, 1.0, ast.Search.Weights.User) // inherited from global assert.Equal(t, 0.8, ast.Search.Weights.Hook) // inherited from global assert.Equal(t, 0.6, ast.Search.Weights.Auto) // inherited from global // Options config - should inherit from global assert.NotNil(t, ast.Search.Options) assert.Equal(t, 5, ast.Search.Options.SkipThreshold) // inherited from global } // TestLoadPathMergeOverride tests loading the merge-override test assistant // This verifies that assistant config completely overrides global config func TestLoadPathMergeOverride(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.LoadPath("/assistants/tests/merge-override") require.NoError(t, err) require.NotNil(t, ast) assert.Equal(t, "tests.merge-override", ast.ID) assert.Equal(t, "Merge Override Test Assistant", ast.Name) // Uses configuration - all fields should be overridden by assistant assert.NotNil(t, ast.Uses) assert.Equal(t, "mcp:custom-vision", ast.Uses.Vision) assert.Equal(t, "mcp:custom-audio", ast.Uses.Audio) assert.Equal(t, "mcp:custom-search", ast.Uses.Search) assert.Equal(t, "mcp:custom-fetch", ast.Uses.Fetch) assert.Equal(t, "mcp:custom-web", ast.Uses.Web) assert.Equal(t, "mcp:custom-keyword", ast.Uses.Keyword) assert.Equal(t, "mcp:custom-querydsl", ast.Uses.QueryDSL) assert.Equal(t, "mcp:custom-rerank", ast.Uses.Rerank) // Search configuration - all fields should be overridden by assistant assert.NotNil(t, ast.Search) // Web config - all overridden assert.NotNil(t, ast.Search.Web) assert.Equal(t, "override-provider", ast.Search.Web.Provider) assert.Equal(t, "$ENV.OVERRIDE_API_KEY", ast.Search.Web.APIKeyEnv) assert.Equal(t, 100, ast.Search.Web.MaxResults) // KB config - all overridden assert.NotNil(t, ast.Search.KB) assert.Equal(t, []string{"override-kb"}, ast.Search.KB.Collections) assert.Equal(t, 0.99, ast.Search.KB.Threshold) assert.True(t, ast.Search.KB.Graph) // DB config - all overridden assert.NotNil(t, ast.Search.DB) assert.Equal(t, []string{"override-model"}, ast.Search.DB.Models) assert.Equal(t, 200, ast.Search.DB.MaxResults) // Keyword config - all overridden assert.NotNil(t, ast.Search.Keyword) assert.Equal(t, 20, ast.Search.Keyword.MaxKeywords) assert.Equal(t, "zh", ast.Search.Keyword.Language) // QueryDSL config - overridden assert.NotNil(t, ast.Search.QueryDSL) assert.True(t, ast.Search.QueryDSL.Strict) // Rerank config - overridden assert.NotNil(t, ast.Search.Rerank) assert.Equal(t, 20, ast.Search.Rerank.TopN) // Citation config - all overridden assert.NotNil(t, ast.Search.Citation) assert.Equal(t, "[override:{id}]", ast.Search.Citation.Format) assert.False(t, ast.Search.Citation.AutoInjectPrompt) assert.Equal(t, "Override citation prompt", ast.Search.Citation.CustomPrompt) // Weights config - all overridden assert.NotNil(t, ast.Search.Weights) assert.Equal(t, 2.0, ast.Search.Weights.User) assert.Equal(t, 1.5, ast.Search.Weights.Hook) assert.Equal(t, 1.0, ast.Search.Weights.Auto) // Options config - overridden assert.NotNil(t, ast.Search.Options) assert.Equal(t, 10, ast.Search.Options.SkipThreshold) } // TestLoadPathMergeEmpty tests loading the merge-empty test assistant // This verifies that assistant with no uses/search config inherits all from global func TestLoadPathMergeEmpty(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.LoadPath("/assistants/tests/merge-empty") require.NoError(t, err) require.NotNil(t, ast) assert.Equal(t, "tests.merge-empty", ast.ID) assert.Equal(t, "Merge Empty Test Assistant", ast.Name) // Uses configuration - all inherited from global (agent/agent.yml) assert.NotNil(t, ast.Uses) assert.Equal(t, "workers.system.vision", ast.Uses.Vision) // from global assert.Equal(t, "workers.system.search", ast.Uses.Search) // from global assert.Equal(t, "workers.system.fetch", ast.Uses.Fetch) // from global assert.Empty(t, ast.Uses.Audio) // not set in global assert.Empty(t, ast.Uses.Web) // not set in global assert.Empty(t, ast.Uses.Keyword) // not set in global assert.Empty(t, ast.Uses.QueryDSL) // not set in global assert.Empty(t, ast.Uses.Rerank) // not set in global // Search configuration - all inherited from global (agent/search.yml) assert.NotNil(t, ast.Search) // Web config - from global assert.NotNil(t, ast.Search.Web) assert.Equal(t, "tavily", ast.Search.Web.Provider) assert.Equal(t, 10, ast.Search.Web.MaxResults) // KB config - from global assert.NotNil(t, ast.Search.KB) assert.Equal(t, 0.7, ast.Search.KB.Threshold) assert.False(t, ast.Search.KB.Graph) // DB config - from global assert.NotNil(t, ast.Search.DB) assert.Equal(t, 20, ast.Search.DB.MaxResults) // Keyword config - from global assert.NotNil(t, ast.Search.Keyword) assert.Equal(t, 10, ast.Search.Keyword.MaxKeywords) assert.Equal(t, "auto", ast.Search.Keyword.Language) // Rerank config - from global assert.NotNil(t, ast.Search.Rerank) assert.Equal(t, 10, ast.Search.Rerank.TopN) // Citation config - from global assert.NotNil(t, ast.Search.Citation) assert.Equal(t, "#ref:{id}", ast.Search.Citation.Format) assert.True(t, ast.Search.Citation.AutoInjectPrompt) // Weights config - from global assert.NotNil(t, ast.Search.Weights) assert.Equal(t, 1.0, ast.Search.Weights.User) assert.Equal(t, 0.8, ast.Search.Weights.Hook) assert.Equal(t, 0.6, ast.Search.Weights.Auto) // Options config - from global assert.NotNil(t, ast.Search.Options) assert.Equal(t, 5, ast.Search.Options.SkipThreshold) } // TestLoadPathUsesAndSearchMerge tests loading fullfields assistant // This verifies that uses and search configs are properly loaded and merged func TestLoadPathUsesAndSearchMerge(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) require.NotNil(t, ast) // Uses configuration - assistant-specific values assert.NotNil(t, ast.Uses) assert.Equal(t, "agent", ast.Uses.Vision) assert.Equal(t, "mcp:audio-server", ast.Uses.Audio) assert.Equal(t, "agent", ast.Uses.Fetch) assert.Equal(t, "builtin", ast.Uses.Web) assert.Equal(t, "builtin", ast.Uses.Keyword) assert.Equal(t, "builtin", ast.Uses.QueryDSL) assert.Equal(t, "builtin", ast.Uses.Rerank) // Search configuration - assistant-specific values assert.NotNil(t, ast.Search) // Web config - from assistant assert.NotNil(t, ast.Search.Web) assert.Equal(t, "tavily", ast.Search.Web.Provider) assert.Equal(t, 15, ast.Search.Web.MaxResults) // KB config - from assistant assert.NotNil(t, ast.Search.KB) assert.Equal(t, []string{"docs", "faq"}, ast.Search.KB.Collections) assert.Equal(t, 0.8, ast.Search.KB.Threshold) assert.True(t, ast.Search.KB.Graph) // DB config - from assistant assert.NotNil(t, ast.Search.DB) assert.Equal(t, []string{"user", "product"}, ast.Search.DB.Models) assert.Equal(t, 50, ast.Search.DB.MaxResults) // Citation config - from assistant assert.NotNil(t, ast.Search.Citation) assert.Equal(t, "#ref:{id}", ast.Search.Citation.Format) assert.True(t, ast.Search.Citation.AutoInjectPrompt) // Weights config - from assistant assert.NotNil(t, ast.Search.Weights) assert.Equal(t, 1.0, ast.Search.Weights.User) assert.Equal(t, 0.9, ast.Search.Weights.Hook) assert.Equal(t, 0.7, ast.Search.Weights.Auto) } // TestLoadPathSearchAssistant tests loading the dedicated search test assistant func TestLoadPathSearchAssistant(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.LoadPath("/assistants/tests/search") require.NoError(t, err) require.NotNil(t, ast) assert.Equal(t, "tests.search", ast.ID) assert.Equal(t, "Search Config Test Assistant", ast.Name) // Uses configuration assert.NotNil(t, ast.Uses) assert.Equal(t, "builtin", ast.Uses.Web) assert.Equal(t, "builtin", ast.Uses.Keyword) assert.Equal(t, "builtin", ast.Uses.QueryDSL) assert.Equal(t, "builtin", ast.Uses.Rerank) // Search configuration assert.NotNil(t, ast.Search) // Web config assert.NotNil(t, ast.Search.Web) assert.Equal(t, "serper", ast.Search.Web.Provider) assert.Equal(t, "$ENV.SERPER_API_KEY", ast.Search.Web.APIKeyEnv) assert.Equal(t, 20, ast.Search.Web.MaxResults) // KB config assert.NotNil(t, ast.Search.KB) assert.Equal(t, []string{"knowledge-base", "documents"}, ast.Search.KB.Collections) assert.Equal(t, 0.75, ast.Search.KB.Threshold) assert.False(t, ast.Search.KB.Graph) // DB config assert.NotNil(t, ast.Search.DB) assert.Equal(t, []string{"article", "comment"}, ast.Search.DB.Models) assert.Equal(t, 30, ast.Search.DB.MaxResults) // Keyword config assert.NotNil(t, ast.Search.Keyword) assert.Equal(t, 8, ast.Search.Keyword.MaxKeywords) assert.Equal(t, "auto", ast.Search.Keyword.Language) // QueryDSL config assert.NotNil(t, ast.Search.QueryDSL) assert.True(t, ast.Search.QueryDSL.Strict) // Rerank config assert.NotNil(t, ast.Search.Rerank) assert.Equal(t, 5, ast.Search.Rerank.TopN) // Citation config assert.NotNil(t, ast.Search.Citation) assert.Equal(t, "#cite:{id}", ast.Search.Citation.Format) assert.False(t, ast.Search.Citation.AutoInjectPrompt) assert.Equal(t, "Please cite sources using #cite:{id} format.", ast.Search.Citation.CustomPrompt) // Weights config assert.NotNil(t, ast.Search.Weights) assert.Equal(t, 1.0, ast.Search.Weights.User) assert.Equal(t, 0.85, ast.Search.Weights.Hook) assert.Equal(t, 0.65, ast.Search.Weights.Auto) // Options config assert.NotNil(t, ast.Search.Options) assert.Equal(t, 3, ast.Search.Options.SkipThreshold) } ================================================ FILE: agent/assistant/load_process_test.go ================================================ package assistant_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/process" "github.com/yaoapp/yao/agent/testutils" ) func TestLoadProcessIntegration(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) // After testutils.Prepare, all assistants should be loaded and scripts registered // Test calling mcpload assistant's tools.Hello function t.Run("CallHelloAfterLoad", func(t *testing.T) { proc := process.New("agents.tests.mcpload.tools.Hello", map[string]interface{}{ "name": "TestUser", }) err := proc.Execute() assert.NoError(t, err) result := proc.Value() assert.NotNil(t, result) resultStr, ok := result.(string) assert.True(t, ok, "Result should be a string") assert.Contains(t, resultStr, "Hello, TestUser") assert.Contains(t, resultStr, "mcpload assistant") }) t.Run("CallPingAfterLoad", func(t *testing.T) { proc := process.New("agents.tests.mcpload.tools.Ping", map[string]interface{}{ "message": "integration test", }) err := proc.Execute() assert.NoError(t, err) result := proc.Value() assert.NotNil(t, result) resultMap, ok := result.(map[string]interface{}) assert.True(t, ok, "Result should be a map") assert.Equal(t, "integration test", resultMap["message"]) assert.Contains(t, resultMap["echo"], "Pong") assert.NotEmpty(t, resultMap["timestamp"]) }) t.Run("CallCalculateAfterLoad", func(t *testing.T) { proc := process.New("agents.tests.mcpload.tools.Calculate", map[string]interface{}{ "operation": "add", "a": float64(100), "b": float64(50), }) err := proc.Execute() assert.NoError(t, err) result := proc.Value() assert.NotNil(t, result) resultMap, ok := result.(map[string]interface{}) assert.True(t, ok, "Result should be a map") assert.Equal(t, float64(150), resultMap["result"]) assert.Equal(t, "add", resultMap["operation"]) assert.Equal(t, float64(100), resultMap["a"]) assert.Equal(t, float64(50), resultMap["b"]) }) t.Run("CallNonExistentScript", func(t *testing.T) { proc := process.New("agents.tests.mcpload.nonexistent.Method") err := proc.Execute() assert.NotNil(t, err, "Should return error for non-existent script") assert.Contains(t, err.Error(), "Exception|404") }) t.Run("CallNonExistentMethod", func(t *testing.T) { proc := process.New("agents.tests.mcpload.tools.NonExistentMethod") err := proc.Execute() assert.NotNil(t, err, "Should return error for non-existent method") assert.Contains(t, err.Error(), "Exception|500") }) } func TestLoadProcessMultipleAssistants(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) // Test that multiple assistants can have their scripts registered // and process calls work correctly for different assistants t.Run("MCPLoadAssistant", func(t *testing.T) { proc := process.New("agents.tests.mcpload.tools.Hello", map[string]interface{}{ "name": "User1", }) err := proc.Execute() assert.NoError(t, err) result := proc.Value() resultStr, ok := result.(string) assert.True(t, ok) assert.Contains(t, resultStr, "mcpload assistant") }) // If there are other test assistants with scripts, they can be tested here // For now, we verify that the handler is properly isolated per assistant t.Run("VerifyIsolation", func(t *testing.T) { // Verify that the mcpload handler is correctly registered handler, exists := process.Handlers["agents.tests.mcpload.tools"] assert.True(t, exists, "Handler should be registered") assert.NotNil(t, handler) }) } ================================================ FILE: agent/assistant/load_store_test.go ================================================ package assistant_test import ( stdContext "context" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" searchTypes "github.com/yaoapp/yao/agent/search/types" store "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // TestLoadStoreWithSource tests loading assistant from database with Source field func TestLoadStoreWithSource(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) // Create assistant with Source assistantID := "test.store-with-source" now := time.Now().UnixNano() ast := &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: assistantID, Name: "Test Assistant With Source", Type: "assistant", Connector: "gpt-4o", Description: "Test assistant loaded from store with source code", Prompts: []store.Prompt{ {Role: "system", Content: "You are a helpful assistant."}, }, Options: map[string]interface{}{ "temperature": 0.7, }, Tags: []string{"Test", "Source"}, // Simple Create hook that returns null Source: ` // @ts-nocheck function Create(ctx, messages) { return null; } `, CreatedAt: now, UpdatedAt: now, }, } // Save to database err := ast.Save() require.NoError(t, err) // Cleanup after test defer func() { storage := assistant.GetStorage() if storage != nil { storage.DeleteAssistant(assistantID) } assistant.GetCache().Clear() }() // Clear cache to ensure fresh load from database assistant.GetCache().Clear() // Load from store loaded, err := assistant.Get(assistantID) require.NoError(t, err) require.NotNil(t, loaded) // Verify basic fields assert.Equal(t, assistantID, loaded.ID) assert.Equal(t, "Test Assistant With Source", loaded.Name) assert.Equal(t, "assistant", loaded.Type) assert.Equal(t, "Test assistant loaded from store with source code", loaded.Description) // Verify prompts require.NotNil(t, loaded.Prompts) assert.Len(t, loaded.Prompts, 1) assert.Equal(t, "system", loaded.Prompts[0].Role) assert.Equal(t, "You are a helpful assistant.", loaded.Prompts[0].Content) // Verify options assert.NotNil(t, loaded.Options) assert.Equal(t, 0.7, loaded.Options["temperature"]) // Verify tags assert.NotNil(t, loaded.Tags) assert.Contains(t, loaded.Tags, "Test") assert.Contains(t, loaded.Tags, "Source") // Verify script was compiled from source assert.NotNil(t, loaded.HookScript, "HookScript should be compiled from Source field") // Verify source is stored assert.NotEmpty(t, loaded.Source) } // TestLoadStoreWithoutSource tests loading assistant from database without Source field func TestLoadStoreWithoutSource(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) // Create assistant without Source assistantID := "test.store-without-source" now := time.Now().UnixNano() ast := &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: assistantID, Name: "Test Assistant Without Source", Type: "assistant", Connector: "gpt-4o", Description: "Test assistant loaded from store without source code", Prompts: []store.Prompt{ {Role: "system", Content: "You are a helpful assistant without hooks."}, }, Options: map[string]interface{}{ "temperature": 0.5, "max_tokens": 1000, }, Tags: []string{"Test", "NoSource"}, CreatedAt: now, UpdatedAt: now, }, } // Save to database err := ast.Save() require.NoError(t, err) // Cleanup after test defer func() { storage := assistant.GetStorage() if storage != nil { storage.DeleteAssistant(assistantID) } assistant.GetCache().Clear() }() // Clear cache to ensure fresh load from database assistant.GetCache().Clear() // Load from store loaded, err := assistant.Get(assistantID) require.NoError(t, err) require.NotNil(t, loaded) // Verify basic fields assert.Equal(t, assistantID, loaded.ID) assert.Equal(t, "Test Assistant Without Source", loaded.Name) assert.Equal(t, "assistant", loaded.Type) assert.Equal(t, "Test assistant loaded from store without source code", loaded.Description) // Verify prompts require.NotNil(t, loaded.Prompts) assert.Len(t, loaded.Prompts, 1) assert.Equal(t, "system", loaded.Prompts[0].Role) // Verify options assert.NotNil(t, loaded.Options) assert.Equal(t, 0.5, loaded.Options["temperature"]) assert.Equal(t, float64(1000), loaded.Options["max_tokens"]) // Verify tags assert.NotNil(t, loaded.Tags) assert.Contains(t, loaded.Tags, "Test") assert.Contains(t, loaded.Tags, "NoSource") // Verify script is nil (no source) assert.Nil(t, loaded.HookScript, "HookScript should be nil when no Source field") assert.Empty(t, loaded.Source) } // newStoreTestContext creates a Context for testing with commonly used fields pre-populated. func newStoreTestContext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client-id", Scope: "openid profile email", SessionID: "test-session-id", UserID: "test-user-123", TeamID: "test-team-456", TenantID: "test-tenant-789", } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "TestAgent/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.Route = "" ctx.Metadata = make(map[string]interface{}) return ctx } // TestLoadStoreWithSourceExecuteHook tests that Source-based script is properly compiled and can execute func TestLoadStoreWithSourceExecuteHook(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) // Create assistant with a working Create hook assistantID := "test.store-source-hook" now := time.Now().UnixNano() ast := &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: assistantID, Name: "Test Source Hook", Type: "assistant", Connector: "gpt-4o", Prompts: []store.Prompt{ {Role: "system", Content: "Default prompt"}, }, // Create hook that modifies temperature and adds metadata Source: ` // @ts-nocheck function Create(ctx: any, messages: any[]): any { return { temperature: 0.9, metadata: { hook_executed: true, chat_id: ctx.chat_id } }; } `, CreatedAt: now, UpdatedAt: now, }, } // Save to database err := ast.Save() require.NoError(t, err) // Cleanup after test defer func() { storage := assistant.GetStorage() if storage != nil { storage.DeleteAssistant(assistantID) } assistant.GetCache().Clear() }() // Clear cache assistant.GetCache().Clear() // Load from store loaded, err := assistant.Get(assistantID) require.NoError(t, err) require.NotNil(t, loaded) require.NotNil(t, loaded.HookScript, "HookScript should be compiled from Source") // Verify the script object exists and is usable assert.NotNil(t, loaded.HookScript.Script) // Execute the Create hook ctx := newStoreTestContext("test-chat-id", assistantID) messages := []context.Message{{Role: "user", Content: "Hello"}} res, _, err := loaded.HookScript.Create(ctx, messages, &context.Options{}) require.NoError(t, err, "Create hook should execute without error") require.NotNil(t, res, "Create hook should return a response") // Verify temperature was set require.NotNil(t, res.Temperature, "Temperature should be set") assert.Equal(t, 0.9, *res.Temperature, "Temperature should be 0.9") // Verify metadata was set require.NotNil(t, res.Metadata, "Metadata should be set") assert.Equal(t, true, res.Metadata["hook_executed"], "hook_executed should be true") assert.Equal(t, "test-chat-id", res.Metadata["chat_id"], "chat_id should match context") } // TestLoadStoreWithPromptPresets tests loading assistant with prompt presets from database func TestLoadStoreWithPromptPresets(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) assistantID := "test.store-with-presets" now := time.Now().UnixNano() ast := &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: assistantID, Name: "Test With Presets", Type: "assistant", Connector: "gpt-4o", Prompts: []store.Prompt{ {Role: "system", Content: "Default prompt"}, }, PromptPresets: map[string][]store.Prompt{ "friendly": { {Role: "system", Content: "You are a friendly assistant."}, }, "professional": { {Role: "system", Content: "You are a professional assistant."}, }, "mode.casual": { {Role: "system", Content: "You are a casual assistant."}, }, }, CreatedAt: now, UpdatedAt: now, }, } err := ast.Save() require.NoError(t, err) defer func() { storage := assistant.GetStorage() if storage != nil { storage.DeleteAssistant(assistantID) } assistant.GetCache().Clear() }() assistant.GetCache().Clear() loaded, err := assistant.Get(assistantID) require.NoError(t, err) require.NotNil(t, loaded) // Verify prompt presets require.NotNil(t, loaded.PromptPresets) assert.Len(t, loaded.PromptPresets, 3) friendlyPreset, ok := loaded.PromptPresets["friendly"] assert.True(t, ok) assert.Len(t, friendlyPreset, 1) assert.Equal(t, "You are a friendly assistant.", friendlyPreset[0].Content) professionalPreset, ok := loaded.PromptPresets["professional"] assert.True(t, ok) assert.Len(t, professionalPreset, 1) assert.Equal(t, "You are a professional assistant.", professionalPreset[0].Content) casualPreset, ok := loaded.PromptPresets["mode.casual"] assert.True(t, ok) assert.Len(t, casualPreset, 1) assert.Equal(t, "You are a casual assistant.", casualPreset[0].Content) } // TestLoadStoreWithDisableGlobalPrompts tests loading assistant with disable_global_prompts flag func TestLoadStoreWithDisableGlobalPrompts(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) assistantID := "test.store-disable-global" now := time.Now().UnixNano() ast := &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: assistantID, Name: "Test Disable Global Prompts", Type: "assistant", Connector: "gpt-4o", DisableGlobalPrompts: true, Prompts: []store.Prompt{ {Role: "system", Content: "Only this prompt should be used."}, }, CreatedAt: now, UpdatedAt: now, }, } err := ast.Save() require.NoError(t, err) defer func() { storage := assistant.GetStorage() if storage != nil { storage.DeleteAssistant(assistantID) } assistant.GetCache().Clear() }() assistant.GetCache().Clear() loaded, err := assistant.Get(assistantID) require.NoError(t, err) require.NotNil(t, loaded) assert.True(t, loaded.DisableGlobalPrompts) } // TestLoadStoreCaching tests that loaded assistants are cached func TestLoadStoreCaching(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) assistantID := "test.store-caching" now := time.Now().UnixNano() ast := &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: assistantID, Name: "Test Caching", Type: "assistant", Connector: "gpt-4o", CreatedAt: now, UpdatedAt: now, }, } err := ast.Save() require.NoError(t, err) defer func() { storage := assistant.GetStorage() if storage != nil { storage.DeleteAssistant(assistantID) } assistant.GetCache().Clear() }() assistant.GetCache().Clear() // First load ast1, err := assistant.Get(assistantID) require.NoError(t, err) require.NotNil(t, ast1) // Second load - should be from cache ast2, err := assistant.Get(assistantID) require.NoError(t, err) require.NotNil(t, ast2) // Should be the same instance (from cache) assert.Same(t, ast1, ast2) } // TestLoadStoreNotFound tests loading non-existent assistant func TestLoadStoreNotFound(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) assistant.GetCache().Clear() _, err := assistant.Get("non-existent-assistant-id-12345") assert.Error(t, err) } // TestLoadStoreWithAllFields tests loading assistant with comprehensive fields func TestLoadStoreWithAllFields(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) assistantID := "test.store-all-fields" now := time.Now().UnixNano() ast := &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: assistantID, Name: "Test All Fields", Type: "assistant", Avatar: "/api/icons/test.png", Connector: "gpt-4o", Description: "Test assistant with all fields", Tags: []string{"Test", "AllFields", "Complete"}, Readonly: true, Public: true, Share: "team", Mentionable: true, Automated: false, Sort: 100, Options: map[string]interface{}{ "temperature": 0.8, "max_tokens": 2000, }, Prompts: []store.Prompt{ {Role: "system", Content: "You are a test assistant."}, {Role: "system", Content: "Follow all instructions carefully."}, }, PromptPresets: map[string][]store.Prompt{ "default": { {Role: "system", Content: "Default mode prompt."}, }, }, DisableGlobalPrompts: true, Placeholder: &store.Placeholder{ Title: "Test Placeholder", Description: "This is a test placeholder", Prompts: []string{"Test prompt 1", "Test prompt 2"}, }, Dependencies: map[string]string{ "echo": "^1.0.0", "customer": ">=2.0.0", }, Source: ` // @ts-nocheck function Create(ctx: any, messages: any[]): any { return { temperature: 0.5, metadata: { assistant_name: "Test All Fields", executed: true } }; } `, CreatedAt: now, UpdatedAt: now, }, } err := ast.Save() require.NoError(t, err) defer func() { storage := assistant.GetStorage() if storage != nil { storage.DeleteAssistant(assistantID) } assistant.GetCache().Clear() }() assistant.GetCache().Clear() loaded, err := assistant.Get(assistantID) require.NoError(t, err) require.NotNil(t, loaded) // Verify all fields assert.Equal(t, assistantID, loaded.ID) assert.Equal(t, "Test All Fields", loaded.Name) assert.Equal(t, "assistant", loaded.Type) assert.Equal(t, "/api/icons/test.png", loaded.Avatar) assert.Equal(t, "Test assistant with all fields", loaded.Description) // Boolean fields assert.True(t, loaded.Readonly) assert.True(t, loaded.Public) assert.Equal(t, "team", loaded.Share) assert.True(t, loaded.Mentionable) assert.False(t, loaded.Automated) assert.True(t, loaded.DisableGlobalPrompts) assert.Equal(t, 100, loaded.Sort) // Tags assert.Len(t, loaded.Tags, 3) assert.Contains(t, loaded.Tags, "Test") assert.Contains(t, loaded.Tags, "AllFields") assert.Contains(t, loaded.Tags, "Complete") // Options assert.Equal(t, 0.8, loaded.Options["temperature"]) assert.Equal(t, float64(2000), loaded.Options["max_tokens"]) // Prompts assert.Len(t, loaded.Prompts, 2) // Prompt presets assert.NotNil(t, loaded.PromptPresets) assert.Contains(t, loaded.PromptPresets, "default") // Placeholder assert.NotNil(t, loaded.Placeholder) assert.Equal(t, "Test Placeholder", loaded.Placeholder.Title) assert.Equal(t, "This is a test placeholder", loaded.Placeholder.Description) assert.Len(t, loaded.Placeholder.Prompts, 2) // Script from source assert.NotNil(t, loaded.HookScript) assert.NotEmpty(t, loaded.Source) // Dependencies require.NotNil(t, loaded.Dependencies) assert.Len(t, loaded.Dependencies, 2) assert.Equal(t, "^1.0.0", loaded.Dependencies["echo"]) assert.Equal(t, ">=2.0.0", loaded.Dependencies["customer"]) // Execute the Create hook to verify it works ctx := newStoreTestContext("test-chat-all-fields", assistantID) messages := []context.Message{{Role: "user", Content: "Test message"}} res, _, err := loaded.HookScript.Create(ctx, messages, &context.Options{}) require.NoError(t, err, "Create hook should execute without error") require.NotNil(t, res, "Create hook should return a response") // Verify hook returned expected values require.NotNil(t, res.Temperature, "Temperature should be set") assert.Equal(t, 0.5, *res.Temperature, "Temperature should be 0.5") require.NotNil(t, res.Metadata, "Metadata should be set") assert.Equal(t, "Test All Fields", res.Metadata["assistant_name"], "assistant_name should match") assert.Equal(t, true, res.Metadata["executed"], "executed should be true") } // TestLoadStoreHookWithTypeScript tests that TypeScript features work in Source field func TestLoadStoreHookWithTypeScript(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) assistantID := "test.store-typescript-hook" now := time.Now().UnixNano() ast := &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: assistantID, Name: "Test TypeScript Hook", Type: "assistant", Connector: "gpt-4o", Prompts: []store.Prompt{ {Role: "system", Content: "Default prompt"}, }, // TypeScript code with type annotations and interfaces Source: ` // TypeScript interfaces interface CreateContext { chat_id: string; assistant_id: string; locale: string; authorized?: { user_id: string; team_id: string; }; } interface Message { role: string; content: string | object; } interface CreateResponse { temperature?: number; messages?: Message[]; metadata?: Record; } // Create hook with full TypeScript syntax function Create(ctx: CreateContext, messages: Message[]): CreateResponse | null { // Type-safe access to context const chatId: string = ctx.chat_id || "unknown"; const locale: string = ctx.locale || "en-us"; const userId: string = ctx.authorized?.user_id || "anonymous"; // Process messages const userMessages: Message[] = messages.filter((m: Message) => m.role === "user"); const messageCount: number = userMessages.length; // Return typed response return { temperature: 0.7, messages: [ { role: "system", content: "TypeScript hook executed successfully" } ], metadata: { chat_id: chatId, locale: locale, user_id: userId, message_count: messageCount, typescript_features: true } }; } `, CreatedAt: now, UpdatedAt: now, }, } err := ast.Save() require.NoError(t, err) defer func() { storage := assistant.GetStorage() if storage != nil { storage.DeleteAssistant(assistantID) } assistant.GetCache().Clear() }() assistant.GetCache().Clear() loaded, err := assistant.Get(assistantID) require.NoError(t, err) require.NotNil(t, loaded) require.NotNil(t, loaded.HookScript, "HookScript should be compiled from TypeScript Source") // Execute the Create hook ctx := newStoreTestContext("ts-test-chat", assistantID) messages := []context.Message{ {Role: "user", Content: "Hello"}, {Role: "assistant", Content: "Hi there"}, {Role: "user", Content: "How are you?"}, } res, _, err := loaded.HookScript.Create(ctx, messages, &context.Options{}) require.NoError(t, err, "TypeScript Create hook should execute without error") require.NotNil(t, res, "Create hook should return a response") // Verify temperature require.NotNil(t, res.Temperature) assert.Equal(t, 0.7, *res.Temperature) // Verify messages require.Len(t, res.Messages, 1) assert.Equal(t, context.RoleSystem, res.Messages[0].Role) assert.Equal(t, "TypeScript hook executed successfully", res.Messages[0].Content) // Verify metadata require.NotNil(t, res.Metadata) assert.Equal(t, "ts-test-chat", res.Metadata["chat_id"]) assert.Equal(t, "en-us", res.Metadata["locale"]) assert.Equal(t, "test-user-123", res.Metadata["user_id"]) assert.Equal(t, float64(2), res.Metadata["message_count"]) // 2 user messages assert.Equal(t, true, res.Metadata["typescript_features"]) } // TestLoadStoreHookReturnNull tests that hook returning null works correctly func TestLoadStoreHookReturnNull(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) assistantID := "test.store-hook-null" now := time.Now().UnixNano() ast := &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: assistantID, Name: "Test Hook Return Null", Type: "assistant", Connector: "gpt-4o", Source: ` function Create(ctx: any, messages: any[]): any { // Return null to indicate no modifications return null; } `, CreatedAt: now, UpdatedAt: now, }, } err := ast.Save() require.NoError(t, err) defer func() { storage := assistant.GetStorage() if storage != nil { storage.DeleteAssistant(assistantID) } assistant.GetCache().Clear() }() assistant.GetCache().Clear() loaded, err := assistant.Get(assistantID) require.NoError(t, err) require.NotNil(t, loaded) require.NotNil(t, loaded.HookScript) ctx := newStoreTestContext("null-test-chat", assistantID) messages := []context.Message{{Role: "user", Content: "Hello"}} res, _, err := loaded.HookScript.Create(ctx, messages, &context.Options{}) require.NoError(t, err, "Hook returning null should not error") assert.Nil(t, res, "Hook returning null should return nil response") } // TestLoadStoreHookWithPromptPreset tests that hook can return prompt_preset func TestLoadStoreHookWithPromptPreset(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) assistantID := "test.store-hook-preset" now := time.Now().UnixNano() ast := &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: assistantID, Name: "Test Hook Prompt Preset", Type: "assistant", Connector: "gpt-4o", Prompts: []store.Prompt{ {Role: "system", Content: "Default prompt"}, }, PromptPresets: map[string][]store.Prompt{ "friendly": { {Role: "system", Content: "You are a friendly assistant."}, }, "professional": { {Role: "system", Content: "You are a professional assistant."}, }, }, Source: ` function Create(ctx: any, messages: any[]): any { // Check first message to determine preset const firstMsg = messages[0]; if (firstMsg && typeof firstMsg.content === "string") { if (firstMsg.content.includes("friendly")) { return { prompt_preset: "friendly" }; } if (firstMsg.content.includes("professional")) { return { prompt_preset: "professional" }; } } return null; } `, CreatedAt: now, UpdatedAt: now, }, } err := ast.Save() require.NoError(t, err) defer func() { storage := assistant.GetStorage() if storage != nil { storage.DeleteAssistant(assistantID) } assistant.GetCache().Clear() }() assistant.GetCache().Clear() loaded, err := assistant.Get(assistantID) require.NoError(t, err) require.NotNil(t, loaded) require.NotNil(t, loaded.HookScript) // Test friendly preset selection t.Run("SelectFriendlyPreset", func(t *testing.T) { ctx := newStoreTestContext("preset-test-1", assistantID) messages := []context.Message{{Role: "user", Content: "Be friendly please"}} res, _, err := loaded.HookScript.Create(ctx, messages, &context.Options{}) require.NoError(t, err) require.NotNil(t, res) assert.Equal(t, "friendly", res.PromptPreset) }) // Test professional preset selection t.Run("SelectProfessionalPreset", func(t *testing.T) { ctx := newStoreTestContext("preset-test-2", assistantID) messages := []context.Message{{Role: "user", Content: "Be professional"}} res, _, err := loaded.HookScript.Create(ctx, messages, &context.Options{}) require.NoError(t, err) require.NotNil(t, res) assert.Equal(t, "professional", res.PromptPreset) }) // Test no preset (returns null) t.Run("NoPreset", func(t *testing.T) { ctx := newStoreTestContext("preset-test-3", assistantID) messages := []context.Message{{Role: "user", Content: "Hello"}} res, _, err := loaded.HookScript.Create(ctx, messages, &context.Options{}) require.NoError(t, err) assert.Nil(t, res) }) } // TestLoadStoreHookDisableGlobalPrompts tests that hook can disable global prompts func TestLoadStoreHookDisableGlobalPrompts(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) assistantID := "test.store-hook-disable-global" now := time.Now().UnixNano() ast := &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: assistantID, Name: "Test Hook Disable Global", Type: "assistant", Connector: "gpt-4o", Source: ` function Create(ctx: any, messages: any[]): any { const firstMsg = messages[0]; if (firstMsg && typeof firstMsg.content === "string") { if (firstMsg.content.includes("disable_global")) { return { disable_global_prompts: true }; } if (firstMsg.content.includes("enable_global")) { return { disable_global_prompts: false }; } } return null; } `, CreatedAt: now, UpdatedAt: now, }, } err := ast.Save() require.NoError(t, err) defer func() { storage := assistant.GetStorage() if storage != nil { storage.DeleteAssistant(assistantID) } assistant.GetCache().Clear() }() assistant.GetCache().Clear() loaded, err := assistant.Get(assistantID) require.NoError(t, err) require.NotNil(t, loaded) require.NotNil(t, loaded.HookScript) // Test disable global prompts t.Run("DisableGlobalPrompts", func(t *testing.T) { ctx := newStoreTestContext("disable-test-1", assistantID) messages := []context.Message{{Role: "user", Content: "disable_global prompts"}} res, _, err := loaded.HookScript.Create(ctx, messages, &context.Options{}) require.NoError(t, err) require.NotNil(t, res) require.NotNil(t, res.DisableGlobalPrompts) assert.True(t, *res.DisableGlobalPrompts) }) // Test enable global prompts t.Run("EnableGlobalPrompts", func(t *testing.T) { ctx := newStoreTestContext("disable-test-2", assistantID) messages := []context.Message{{Role: "user", Content: "enable_global prompts"}} res, _, err := loaded.HookScript.Create(ctx, messages, &context.Options{}) require.NoError(t, err) require.NotNil(t, res) require.NotNil(t, res.DisableGlobalPrompts) assert.False(t, *res.DisableGlobalPrompts) }) } // TestLoadStoreWithSearchConfig tests loading assistant with search configuration from database func TestLoadStoreWithSearchConfig(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) assistantID := "test.store-with-search" now := time.Now().UnixNano() ast := &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: assistantID, Name: "Test With Search Config", Type: "assistant", Connector: "gpt-4o", Uses: &context.Uses{ Vision: "agent", Audio: "mcp:audio-server", Fetch: "agent", Web: "builtin", Keyword: "builtin", QueryDSL: "builtin", Rerank: "builtin", }, Search: &searchTypes.Config{ Web: &searchTypes.WebConfig{ Provider: "tavily", MaxResults: 15, }, KB: &searchTypes.KBConfig{ Collections: []string{"docs", "faq"}, Threshold: 0.8, Graph: true, }, DB: &searchTypes.DBConfig{ Models: []string{"user", "product"}, MaxResults: 50, }, Keyword: &searchTypes.KeywordConfig{ MaxKeywords: 8, Language: "auto", }, QueryDSL: &searchTypes.QueryDSLConfig{ Strict: true, }, Rerank: &searchTypes.RerankConfig{ TopN: 5, }, Citation: &searchTypes.CitationConfig{ Format: "#cite:{id}", AutoInjectPrompt: false, CustomPrompt: "Please cite sources.", }, Weights: &searchTypes.WeightsConfig{ User: 1.0, Hook: 0.85, Auto: 0.65, }, Options: &searchTypes.OptionsConfig{ SkipThreshold: 3, }, }, CreatedAt: now, UpdatedAt: now, }, } err := ast.Save() require.NoError(t, err) defer func() { storage := assistant.GetStorage() if storage != nil { storage.DeleteAssistant(assistantID) } assistant.GetCache().Clear() }() assistant.GetCache().Clear() loaded, err := assistant.Get(assistantID) require.NoError(t, err) require.NotNil(t, loaded) // Verify Uses require.NotNil(t, loaded.Uses) assert.Equal(t, "agent", loaded.Uses.Vision) assert.Equal(t, "mcp:audio-server", loaded.Uses.Audio) assert.Equal(t, "agent", loaded.Uses.Fetch) assert.Equal(t, "builtin", loaded.Uses.Web) assert.Equal(t, "builtin", loaded.Uses.Keyword) assert.Equal(t, "builtin", loaded.Uses.QueryDSL) assert.Equal(t, "builtin", loaded.Uses.Rerank) // Verify Search config require.NotNil(t, loaded.Search) // Web config require.NotNil(t, loaded.Search.Web) assert.Equal(t, "tavily", loaded.Search.Web.Provider) assert.Equal(t, 15, loaded.Search.Web.MaxResults) // KB config require.NotNil(t, loaded.Search.KB) assert.Equal(t, []string{"docs", "faq"}, loaded.Search.KB.Collections) assert.Equal(t, 0.8, loaded.Search.KB.Threshold) assert.True(t, loaded.Search.KB.Graph) // DB config require.NotNil(t, loaded.Search.DB) assert.Equal(t, []string{"user", "product"}, loaded.Search.DB.Models) assert.Equal(t, 50, loaded.Search.DB.MaxResults) // Keyword config require.NotNil(t, loaded.Search.Keyword) assert.Equal(t, 8, loaded.Search.Keyword.MaxKeywords) assert.Equal(t, "auto", loaded.Search.Keyword.Language) // QueryDSL config require.NotNil(t, loaded.Search.QueryDSL) assert.True(t, loaded.Search.QueryDSL.Strict) // Rerank config require.NotNil(t, loaded.Search.Rerank) assert.Equal(t, 5, loaded.Search.Rerank.TopN) // Citation config require.NotNil(t, loaded.Search.Citation) assert.Equal(t, "#cite:{id}", loaded.Search.Citation.Format) assert.False(t, loaded.Search.Citation.AutoInjectPrompt) assert.Equal(t, "Please cite sources.", loaded.Search.Citation.CustomPrompt) // Weights config require.NotNil(t, loaded.Search.Weights) assert.Equal(t, 1.0, loaded.Search.Weights.User) assert.Equal(t, 0.85, loaded.Search.Weights.Hook) assert.Equal(t, 0.65, loaded.Search.Weights.Auto) // Options config require.NotNil(t, loaded.Search.Options) assert.Equal(t, 3, loaded.Search.Options.SkipThreshold) } // TestLoadStoreWithPartialSearchConfig tests loading assistant with partial search configuration func TestLoadStoreWithPartialSearchConfig(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) assistantID := "test.store-partial-search" now := time.Now().UnixNano() ast := &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: assistantID, Name: "Test Partial Search Config", Type: "assistant", Connector: "gpt-4o", Search: &searchTypes.Config{ Web: &searchTypes.WebConfig{ Provider: "serper", }, // Only web config, others are nil }, CreatedAt: now, UpdatedAt: now, }, } err := ast.Save() require.NoError(t, err) defer func() { storage := assistant.GetStorage() if storage != nil { storage.DeleteAssistant(assistantID) } assistant.GetCache().Clear() }() assistant.GetCache().Clear() loaded, err := assistant.Get(assistantID) require.NoError(t, err) require.NotNil(t, loaded) // Verify Search config require.NotNil(t, loaded.Search) // Web config should be set require.NotNil(t, loaded.Search.Web) assert.Equal(t, "serper", loaded.Search.Web.Provider) // Other configs should be nil assert.Nil(t, loaded.Search.KB) assert.Nil(t, loaded.Search.DB) assert.Nil(t, loaded.Search.Keyword) assert.Nil(t, loaded.Search.QueryDSL) assert.Nil(t, loaded.Search.Rerank) assert.Nil(t, loaded.Search.Citation) assert.Nil(t, loaded.Search.Weights) assert.Nil(t, loaded.Search.Options) } // TestLoadStoreWithoutSearchConfig tests loading assistant without search configuration func TestLoadStoreWithoutSearchConfig(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) assistantID := "test.store-no-search" now := time.Now().UnixNano() ast := &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: assistantID, Name: "Test No Search Config", Type: "assistant", Connector: "gpt-4o", // No Search config CreatedAt: now, UpdatedAt: now, }, } err := ast.Save() require.NoError(t, err) defer func() { storage := assistant.GetStorage() if storage != nil { storage.DeleteAssistant(assistantID) } assistant.GetCache().Clear() }() assistant.GetCache().Clear() loaded, err := assistant.Get(assistantID) require.NoError(t, err) require.NotNil(t, loaded) // Search config should be nil assert.Nil(t, loaded.Search) } ================================================ FILE: agent/assistant/load_system.go ================================================ package assistant import ( "fmt" "path/filepath" "strings" "github.com/yaoapp/gou/application" "github.com/yaoapp/gou/connector" gouOpenAI "github.com/yaoapp/gou/connector/openai" "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/agent/i18n" store "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/data" "gopkg.in/yaml.v3" ) // systemAgents defines the system agents loaded from bindata // These are internal agents used by the system (e.g., keyword extraction, querydsl generation) // The directory name is without __yao. prefix, prefix is added during loading // Format: directory name -> bindata path prefix var systemAgents = []string{ "keyword", "querydsl", "title", "prompt", "robot_prompt", "needsearch", "entity", } // SystemConfig holds the system agents connector configuration // This is set from agent.yml system block type SystemConfig struct { Default string // Default connector for all system agents Keyword string // Connector for __yao.keyword agent QueryDSL string // Connector for __yao.querydsl agent Title string // Connector for __yao.title agent Prompt string // Connector for __yao.prompt agent RobotPrompt string // Connector for __yao.robot_prompt agent NeedSearch string // Connector for __yao.needsearch agent Entity string // Connector for __yao.entity agent } // systemConfig holds the system agents configuration (global variable like others in load.go) var systemConfig *SystemConfig = nil // SetSystemConfig sets the system agents configuration func SetSystemConfig(config *SystemConfig) { systemConfig = config } // GetSystemConfig returns the system agents configuration func GetSystemConfig() *SystemConfig { return systemConfig } // LoadSystemAgents loads the system agents from bindata // These are internal agents like __yao.keyword and __yao.querydsl // They are loaded before application assistants // Behavior is same as LoadBuiltIn, just reads from bindata instead of filesystem func LoadSystemAgents() error { // Get all existing system agents (for cleanup) deletedSystem := map[string]bool{} if storage != nil { // System agents have "system" tag tags := []string{"system"} builtIn := true res, err := storage.GetAssistants(store.AssistantFilter{ Tags: tags, BuiltIn: &builtIn, Select: []string{"assistant_id", "id"}, }) if err != nil { log.Warn("Failed to get existing system agents: %v", err) } else { for _, assistant := range res.Data { deletedSystem[assistant.ID] = true } } } sort := 1 for _, name := range systemAgents { // Build agent ID with __yao. prefix id := "__yao." + name pathPrefix := "yao/assistants/" + name assistant, err := loadSystemAgent(id, pathPrefix) if err != nil { log.Warn("Failed to load system agent %s: %v", id, err) continue } // Set sort order if assistant.Sort == 0 { assistant.Sort = sort } // Save to storage if err := assistant.Save(); err != nil { log.Warn("Failed to save system agent %s: %v", id, err) continue } // Initialize the assistant if err := assistant.initialize(); err != nil { log.Warn("Failed to initialize system agent %s: %v", id, err) continue } sort++ loaded.Put(assistant) log.Trace("Loaded system agent: %s", id) // Remove from deleted list delete(deletedSystem, id) } // Remove deleted system agents if len(deletedSystem) > 0 { assistantIDs := []string{} for assistantID := range deletedSystem { assistantIDs = append(assistantIDs, assistantID) } if _, err := storage.DeleteAssistants(store.AssistantFilter{AssistantIDs: assistantIDs}); err != nil { log.Warn("Failed to delete obsolete system agents: %v", err) } } return nil } // loadSystemAgent loads a single system agent from bindata // This follows the same pattern as LoadPath but reads from bindata func loadSystemAgent(id, pathPrefix string) (*Assistant, error) { // Read package.yao from bindata pkgPath := pathPrefix + "/package.yao" pkgContent, err := data.Read(pkgPath) if err != nil { return nil, fmt.Errorf("failed to read %s: %w", pkgPath, err) } // Parse package.yao var pkgData map[string]interface{} if err := application.Parse(pkgPath, pkgContent, &pkgData); err != nil { return nil, fmt.Errorf("failed to parse %s: %w", pkgPath, err) } // Set assistant_id (no path - system agents are loaded from storage, not filesystem) pkgData["assistant_id"] = id // Set type if not specified if _, has := pkgData["type"]; !has { pkgData["type"] = "assistant" } // Resolve connector for this system agent connectorID := resolveSystemConnector(id) if connectorID != "" { pkgData["connector"] = connectorID } // Read prompts.yml from bindata (default prompts) promptsPath := pathPrefix + "/prompts.yml" promptsContent, err := data.Read(promptsPath) if err == nil { var prompts []store.Prompt if err := yaml.Unmarshal(promptsContent, &prompts); err == nil && len(prompts) > 0 { pkgData["prompts"] = prompts } } // Read prompt_presets from prompts directory presets := loadSystemPromptPresets(pathPrefix) if len(presets) > 0 { pkgData["prompt_presets"] = presets } // Load scripts from src directory (hook script source and other scripts sources) // These will be compiled by loadMap -> LoadScriptsFromData hookScriptSource, scriptsSource := loadSystemScripts(pathPrefix) if hookScriptSource != "" { pkgData["script"] = hookScriptSource } if len(scriptsSource) > 0 { pkgData["scripts"] = scriptsSource } // Read locales locales, err := loadSystemLocales(pathPrefix) if err == nil && len(locales) > 0 { pkgData["locales"] = locales } // Mark as system agent pkgData["readonly"] = true pkgData["built_in"] = true pkgData["tags"] = []string{"system"} // Load from map (same as LoadPath, includes initialize()) return loadMap(pkgData) } // resolveSystemConnector resolves the connector for a system agent // Priority: specific agent config > system.default > defaultConnector > fallback to first capable connector func resolveSystemConnector(agentID string) string { // Try specific agent config first if systemConfig != nil { switch agentID { case "__yao.keyword": if systemConfig.Keyword != "" { return systemConfig.Keyword } case "__yao.querydsl": if systemConfig.QueryDSL != "" { return systemConfig.QueryDSL } case "__yao.title": if systemConfig.Title != "" { return systemConfig.Title } case "__yao.prompt": if systemConfig.Prompt != "" { return systemConfig.Prompt } case "__yao.robot_prompt": if systemConfig.RobotPrompt != "" { return systemConfig.RobotPrompt } case "__yao.needsearch": if systemConfig.NeedSearch != "" { return systemConfig.NeedSearch } case "__yao.entity": if systemConfig.Entity != "" { return systemConfig.Entity } } // Try system default if systemConfig.Default != "" { return systemConfig.Default } } // Try global default connector if defaultConnector != "" { return defaultConnector } // Fallback: find first connector that supports tool calling return findCapableConnector() } // findCapableConnector finds the first connector that supports tool calling func findCapableConnector() string { for id, conn := range connector.Connectors { if !conn.Is(connector.OPENAI) { continue } if connOpenAI, ok := conn.(*gouOpenAI.Connector); ok { if connOpenAI.Options.Capabilities != nil && connOpenAI.Options.Capabilities.ToolCalls { return id } } } return "" } // loadSystemPromptPresets loads prompt presets from bindata prompts directory func loadSystemPromptPresets(pathPrefix string) map[string][]store.Prompt { presets := make(map[string][]store.Prompt) promptsDir := pathPrefix + "/prompts" // Try common preset files presetFiles := []string{"chat.yml", "task.yml", "code.yml", "analysis.yml"} for _, filename := range presetFiles { presetPath := promptsDir + "/" + filename content, err := data.Read(presetPath) if err != nil { continue } var prompts []store.Prompt if err := yaml.Unmarshal(content, &prompts); err == nil && len(prompts) > 0 { presetName := strings.TrimSuffix(filename, ".yml") presets[presetName] = prompts } } return presets } // loadSystemScripts loads scripts source from bindata src directory // Returns hook script source and other scripts sources (as strings) // These will be compiled by loadMap -> LoadScriptsFromData func loadSystemScripts(pathPrefix string) (string, map[string]string) { srcDir := pathPrefix + "/src" // Try to load hook script (index.ts) var hookScriptSource string indexPath := srcDir + "/index.ts" indexContent, err := data.Read(indexPath) if err == nil && len(indexContent) > 0 { hookScriptSource = string(indexContent) } // Try to load other scripts scripts := make(map[string]string) scriptFiles := []string{"utils.ts", "helpers.ts", "tools.ts"} for _, filename := range scriptFiles { scriptPath := srcDir + "/" + filename content, err := data.Read(scriptPath) if err != nil { continue } scriptName := strings.TrimSuffix(filename, ".ts") scripts[scriptName] = string(content) } if len(scripts) == 0 { scripts = nil } return hookScriptSource, scripts } // loadSystemLocales loads locales from bindata func loadSystemLocales(pathPrefix string) (i18n.Map, error) { locales := make(i18n.Map) // Try to load common locale files localeFiles := []string{"en-us.yml", "zh-cn.yml", "en.yml", "zh.yml"} localesDir := pathPrefix + "/locales" for _, filename := range localeFiles { localePath := filepath.Join(localesDir, filename) content, err := data.Read(localePath) if err != nil { continue } // Parse locale file locale := strings.TrimSuffix(filename, ".yml") var messages map[string]any if err := yaml.Unmarshal(content, &messages); err != nil { continue } locales[locale] = i18n.I18n{ Locale: locale, Messages: messages, } } return locales, nil } ================================================ FILE: agent/assistant/load_test.go ================================================ package assistant_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent" "github.com/yaoapp/yao/agent/assistant" store "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func prepare(t *testing.T) { test.Prepare(t, config.Conf) } func prepareAgent(t *testing.T) { test.Prepare(t, config.Conf) err := agent.Load(config.Conf) require.NoError(t, err, "agent.Load should succeed") } // TestLoadPath tests loading assistant from path func TestLoadPath(t *testing.T) { prepare(t) defer test.Clean() t.Run("LoadFullFieldsAssistant", func(t *testing.T) { assistant, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) require.NotNil(t, assistant) // Basic fields assert.Equal(t, "tests.fullfields", assistant.ID) assert.Equal(t, "Full Fields Test Assistant", assistant.Name) assert.Equal(t, "assistant", assistant.Type) assert.Equal(t, "/api/__yao/app/icons/app.png", assistant.Avatar) assert.Equal(t, "gpt-4o", assistant.Connector) assert.Equal(t, "/assistants/tests/fullfields", assistant.Path) assert.Equal(t, "Test assistant with all available fields for unit testing", assistant.Description) // Boolean fields assert.True(t, assistant.Public) assert.True(t, assistant.Readonly) assert.True(t, assistant.Mentionable) assert.False(t, assistant.Automated) assert.True(t, assistant.DisableGlobalPrompts) // Share field assert.Equal(t, "team", assistant.Share) // Sort field assert.Equal(t, 100, assistant.Sort) // Tags assert.NotNil(t, assistant.Tags) assert.Contains(t, assistant.Tags, "Test") assert.Contains(t, assistant.Tags, "Development") assert.Contains(t, assistant.Tags, "FullFields") // Options assert.NotNil(t, assistant.Options) assert.Equal(t, 0.7, assistant.Options["temperature"]) assert.Equal(t, float64(2000), assistant.Options["max_tokens"]) // Prompts (default prompts from prompts.yml) assert.NotNil(t, assistant.Prompts) assert.GreaterOrEqual(t, len(assistant.Prompts), 1) assert.Equal(t, "system", assistant.Prompts[0].Role) // Script (from src/index.ts) assert.NotNil(t, assistant.HookScript) }) t.Run("LoadConnectorOptions", func(t *testing.T) { assistant, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) require.NotNil(t, assistant) // ConnectorOptions assert.NotNil(t, assistant.ConnectorOptions) assert.NotNil(t, assistant.ConnectorOptions.Optional) assert.True(t, *assistant.ConnectorOptions.Optional) assert.NotNil(t, assistant.ConnectorOptions.Connectors) assert.Contains(t, assistant.ConnectorOptions.Connectors, "gpt-4o") assert.Contains(t, assistant.ConnectorOptions.Connectors, "gpt-4o-mini") assert.Contains(t, assistant.ConnectorOptions.Connectors, "deepseek") assert.NotNil(t, assistant.ConnectorOptions.Filters) assert.Len(t, assistant.ConnectorOptions.Filters, 2) }) t.Run("LoadPromptPresets", func(t *testing.T) { assistant, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) require.NotNil(t, assistant) // PromptPresets (from prompts directory) assert.NotNil(t, assistant.PromptPresets) // Top-level presets: chat.yml -> "chat", task.yml -> "task" chatPreset, hasChat := assistant.PromptPresets["chat"] assert.True(t, hasChat, "Should have 'chat' preset") assert.NotEmpty(t, chatPreset) taskPreset, hasTask := assistant.PromptPresets["task"] assert.True(t, hasTask, "Should have 'task' preset") assert.NotEmpty(t, taskPreset) // Nested presets: chat/friendly.yml -> "chat.friendly" friendlyPreset, hasFriendly := assistant.PromptPresets["chat.friendly"] assert.True(t, hasFriendly, "Should have 'chat.friendly' preset") assert.NotEmpty(t, friendlyPreset) professionalPreset, hasProfessional := assistant.PromptPresets["chat.professional"] assert.True(t, hasProfessional, "Should have 'chat.professional' preset") assert.NotEmpty(t, professionalPreset) // task/analysis.yml -> "task.analysis" analysisPreset, hasAnalysis := assistant.PromptPresets["task.analysis"] assert.True(t, hasAnalysis, "Should have 'task.analysis' preset") assert.NotEmpty(t, analysisPreset) }) t.Run("LoadKnowledgeBase", func(t *testing.T) { assistant, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) require.NotNil(t, assistant) // KB assert.NotNil(t, assistant.KB) assert.NotNil(t, assistant.KB.Collections) assert.Contains(t, assistant.KB.Collections, "test-collection") assert.NotNil(t, assistant.KB.Options) assert.Equal(t, float64(5), assistant.KB.Options["top_k"]) }) t.Run("LoadMCPServers", func(t *testing.T) { assistant, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) require.NotNil(t, assistant) // MCP assert.NotNil(t, assistant.MCP) assert.NotNil(t, assistant.MCP.Servers) assert.Len(t, assistant.MCP.Servers, 1) assert.Equal(t, "echo", assistant.MCP.Servers[0].ServerID) assert.Contains(t, assistant.MCP.Servers[0].Tools, "ping") assert.Contains(t, assistant.MCP.Servers[0].Tools, "echo") }) t.Run("LoadWorkflow", func(t *testing.T) { assistant, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) require.NotNil(t, assistant) // Workflow assert.NotNil(t, assistant.Workflow) assert.NotNil(t, assistant.Workflow.Workflows) assert.Contains(t, assistant.Workflow.Workflows, "test-workflow") assert.NotNil(t, assistant.Workflow.Options) assert.Equal(t, float64(10), assistant.Workflow.Options["max_steps"]) }) t.Run("LoadPlaceholder", func(t *testing.T) { assistant, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) require.NotNil(t, assistant) // Placeholder assert.NotNil(t, assistant.Placeholder) assert.Equal(t, "Full Fields Test", assistant.Placeholder.Title) assert.Equal(t, "Test assistant with complete field coverage", assistant.Placeholder.Description) assert.NotNil(t, assistant.Placeholder.Prompts) assert.Len(t, assistant.Placeholder.Prompts, 3) }) t.Run("LoadLocales", func(t *testing.T) { assistant, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) require.NotNil(t, assistant) // Locales assert.NotNil(t, assistant.Locales) enLocale, hasEn := assistant.Locales["en-us"] assert.True(t, hasEn, "Should have en-us locale") assert.NotNil(t, enLocale) zhLocale, hasZh := assistant.Locales["zh-cn"] assert.True(t, hasZh, "Should have zh-cn locale") assert.NotNil(t, zhLocale) }) t.Run("LoadDependencies", func(t *testing.T) { assistant, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) require.NotNil(t, assistant) // Dependencies assert.NotNil(t, assistant.Dependencies) assert.Len(t, assistant.Dependencies, 2) assert.Equal(t, "^1.0.0", assistant.Dependencies["echo"]) assert.Equal(t, ">=2.0.0", assistant.Dependencies["customer"]) }) t.Run("LoadNonExistentAssistant", func(t *testing.T) { _, err := assistant.LoadPath("/assistants/non-existent") assert.Error(t, err) }) } // TestLoadPathMCPTest tests loading the MCP test assistant func TestLoadPathMCPTest(t *testing.T) { prepare(t) defer test.Clean() assistant, err := assistant.LoadPath("/assistants/tests/mcptest") require.NoError(t, err) require.NotNil(t, assistant) assert.Equal(t, "tests.mcptest", assistant.ID) assert.Equal(t, "MCP Test Assistant", assistant.Name) assert.Equal(t, "gpt-4o", assistant.Connector) // MCP configuration assert.NotNil(t, assistant.MCP) assert.Len(t, assistant.MCP.Servers, 1) assert.Equal(t, "echo", assistant.MCP.Servers[0].ServerID) // Locales assert.NotNil(t, assistant.Locales) assert.Contains(t, assistant.Locales, "en-us") assert.Contains(t, assistant.Locales, "zh-cn") } // TestLoadPathBuildRequest tests loading the build request test assistant func TestLoadPathBuildRequest(t *testing.T) { prepare(t) defer test.Clean() assistant, err := assistant.LoadPath("/assistants/tests/buildrequest") require.NoError(t, err) require.NotNil(t, assistant) assert.Equal(t, "tests.buildrequest", assistant.ID) assert.Equal(t, "Build Request Test", assistant.Name) // HookScript should be loaded assert.NotNil(t, assistant.HookScript) // Options assert.NotNil(t, assistant.Options) assert.Equal(t, 0.5, assistant.Options["temperature"]) } // TestCache tests the assistant cache functionality func TestCache(t *testing.T) { // Clear any existing cache assistant.ClearCache() // Set small cache for testing assistant.SetCache(3) assert.NotNil(t, assistant.GetCache()) // Create test assistants ast1 := &assistant.Assistant{AssistantModel: store.AssistantModel{ID: "id1", Name: "Assistant 1"}} ast2 := &assistant.Assistant{AssistantModel: store.AssistantModel{ID: "id2", Name: "Assistant 2"}} ast3 := &assistant.Assistant{AssistantModel: store.AssistantModel{ID: "id3", Name: "Assistant 3"}} ast4 := &assistant.Assistant{AssistantModel: store.AssistantModel{ID: "id4", Name: "Assistant 4"}} t.Run("PutAndGet", func(t *testing.T) { assistant.GetCache().Put(ast1) assert.Equal(t, 1, assistant.GetCache().Len()) cached, exists := assistant.GetCache().Get("id1") assert.True(t, exists) assert.Equal(t, ast1, cached) }) t.Run("CacheEviction", func(t *testing.T) { assistant.GetCache().Put(ast2) assistant.GetCache().Put(ast3) assert.Equal(t, 3, assistant.GetCache().Len()) // Access ast1 to make it recently used assistant.GetCache().Get("id1") // Add ast4, should evict ast2 (least recently used) assistant.GetCache().Put(ast4) assert.Equal(t, 3, assistant.GetCache().Len()) _, exists := assistant.GetCache().Get("id2") assert.False(t, exists, "ast2 should be evicted") _, exists = assistant.GetCache().Get("id1") assert.True(t, exists, "ast1 should still exist") _, exists = assistant.GetCache().Get("id4") assert.True(t, exists, "ast4 should exist") }) t.Run("ClearCache", func(t *testing.T) { assistant.ClearCache() assert.Nil(t, assistant.GetCache()) }) t.Run("SetCacheAfterClear", func(t *testing.T) { assistant.SetCache(100) assert.NotNil(t, assistant.GetCache()) }) } // TestClone tests the assistant Clone method func TestClone(t *testing.T) { prepare(t) defer test.Clean() t.Run("CloneFullFieldsAssistant", func(t *testing.T) { original, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) clone := original.Clone() require.NotNil(t, clone) // Basic fields should be equal assert.Equal(t, original.ID, clone.ID) assert.Equal(t, original.Name, clone.Name) assert.Equal(t, original.Type, clone.Type) assert.Equal(t, original.Connector, clone.Connector) assert.Equal(t, original.Description, clone.Description) // Verify deep copy - modifying original should not affect clone if len(original.Tags) > 0 { originalTag := original.Tags[0] original.Tags[0] = "modified" assert.NotEqual(t, original.Tags[0], clone.Tags[0]) original.Tags[0] = originalTag // restore } if original.Options != nil { original.Options["test_key"] = "test_value" _, exists := clone.Options["test_key"] assert.False(t, exists, "Clone should not have modified key") delete(original.Options, "test_key") // cleanup } if original.Dependencies != nil { original.Dependencies["test_dep"] = "^9.9.9" _, exists := clone.Dependencies["test_dep"] assert.False(t, exists, "Clone dependencies should not have modified key") delete(original.Dependencies, "test_dep") // cleanup } }) t.Run("CloneNil", func(t *testing.T) { var nilAssistant *assistant.Assistant assert.Nil(t, nilAssistant.Clone()) }) } // TestUpdate tests the assistant Update method func TestUpdate(t *testing.T) { prepare(t) defer test.Clean() t.Run("UpdateBasicFields", func(t *testing.T) { assistant, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) updates := map[string]interface{}{ "name": "Updated Name", "description": "Updated description", "tags": []string{"updated", "tags"}, } err = assistant.Update(updates) require.NoError(t, err) assert.Equal(t, "Updated Name", assistant.Name) assert.Equal(t, "Updated description", assistant.Description) assert.Equal(t, []string{"updated", "tags"}, assistant.Tags) }) t.Run("UpdateConnectorOptions", func(t *testing.T) { assistant, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) updates := map[string]interface{}{ "connector_options": map[string]interface{}{ "optional": false, "connectors": []string{"new-connector"}, }, } err = assistant.Update(updates) require.NoError(t, err) assert.NotNil(t, assistant.ConnectorOptions) assert.NotNil(t, assistant.ConnectorOptions.Optional) assert.False(t, *assistant.ConnectorOptions.Optional) assert.Contains(t, assistant.ConnectorOptions.Connectors, "new-connector") }) t.Run("UpdatePromptPresets", func(t *testing.T) { assistant, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) updates := map[string]interface{}{ "prompt_presets": map[string]interface{}{ "custom": []map[string]interface{}{ {"role": "system", "content": "Custom preset"}, }, }, } err = assistant.Update(updates) require.NoError(t, err) assert.NotNil(t, assistant.PromptPresets) customPreset, exists := assistant.PromptPresets["custom"] assert.True(t, exists) assert.Len(t, customPreset, 1) }) t.Run("UpdateSource", func(t *testing.T) { assistant, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) updates := map[string]interface{}{ "source": "function Create(ctx, messages) { return { messages: messages }; }", } err = assistant.Update(updates) require.NoError(t, err) assert.Equal(t, "function Create(ctx, messages) { return { messages: messages }; }", assistant.Source) }) t.Run("UpdateNilAssistant", func(t *testing.T) { var nilAssistant *assistant.Assistant err := nilAssistant.Update(map[string]interface{}{"name": "test"}) assert.Error(t, err) }) } // TestMap tests the assistant Map method func TestMap(t *testing.T) { prepare(t) defer test.Clean() assistant, err := assistant.LoadPath("/assistants/tests/fullfields") require.NoError(t, err) m := assistant.Map() require.NotNil(t, m) // Check all fields are present assert.Equal(t, assistant.ID, m["assistant_id"]) assert.Equal(t, assistant.Name, m["name"]) assert.Equal(t, assistant.Type, m["type"]) assert.Equal(t, assistant.Connector, m["connector"]) assert.Equal(t, assistant.Description, m["description"]) assert.Equal(t, assistant.Path, m["path"]) assert.Equal(t, assistant.Tags, m["tags"]) assert.Equal(t, assistant.Options, m["options"]) assert.Equal(t, assistant.Prompts, m["prompts"]) assert.Equal(t, assistant.KB, m["kb"]) assert.Equal(t, assistant.MCP, m["mcp"]) assert.Equal(t, assistant.Workflow, m["workflow"]) assert.Equal(t, assistant.Placeholder, m["placeholder"]) assert.Equal(t, assistant.Locales, m["locales"]) // New fields assert.Equal(t, assistant.ConnectorOptions, m["connector_options"]) assert.Equal(t, assistant.PromptPresets, m["prompt_presets"]) assert.Equal(t, assistant.Source, m["source"]) assert.Equal(t, assistant.Dependencies, m["dependencies"]) } // TestLoadSystemAgents tests loading system agents from bindata func TestLoadSystemAgents(t *testing.T) { prepareAgent(t) defer test.Clean() // Clear cache first assistant.ClearCache() assistant.SetCache(200) t.Run("LoadSystemAgents", func(t *testing.T) { err := assistant.LoadSystemAgents() require.NoError(t, err) // Check __yao.keyword keywordAst, keywordExists := assistant.GetCache().Get("__yao.keyword") require.True(t, keywordExists, "__yao.keyword should be loaded") assert.Equal(t, "__yao.keyword", keywordAst.ID) assert.Equal(t, "Keyword Extractor", keywordAst.Name) assert.True(t, keywordAst.Readonly) assert.True(t, keywordAst.BuiltIn) assert.Contains(t, keywordAst.Tags, "system") assert.NotNil(t, keywordAst.Prompts) assert.Greater(t, len(keywordAst.Prompts), 0) // Check __yao.querydsl querydslAst, querydslExists := assistant.GetCache().Get("__yao.querydsl") require.True(t, querydslExists, "__yao.querydsl should be loaded") assert.Equal(t, "__yao.querydsl", querydslAst.ID) assert.Equal(t, "Query Builder", querydslAst.Name) assert.True(t, querydslAst.Readonly) assert.True(t, querydslAst.BuiltIn) assert.Contains(t, querydslAst.Tags, "system") assert.NotNil(t, querydslAst.Prompts) assert.Greater(t, len(querydslAst.Prompts), 0) // Check __yao.title titleAst, titleExists := assistant.GetCache().Get("__yao.title") require.True(t, titleExists, "__yao.title should be loaded") assert.Equal(t, "__yao.title", titleAst.ID) assert.Equal(t, "Title Generator", titleAst.Name) assert.True(t, titleAst.Readonly) assert.True(t, titleAst.BuiltIn) // Check __yao.prompt promptAst, promptExists := assistant.GetCache().Get("__yao.prompt") require.True(t, promptExists, "__yao.prompt should be loaded") assert.Equal(t, "__yao.prompt", promptAst.ID) assert.Equal(t, "Prompt Optimizer", promptAst.Name) assert.True(t, promptAst.Readonly) assert.True(t, promptAst.BuiltIn) // Check __yao.needsearch needsearchAst, needsearchExists := assistant.GetCache().Get("__yao.needsearch") require.True(t, needsearchExists, "__yao.needsearch should be loaded") assert.Equal(t, "__yao.needsearch", needsearchAst.ID) assert.Equal(t, "Reference Checker", needsearchAst.Name) assert.True(t, needsearchAst.Readonly) assert.True(t, needsearchAst.BuiltIn) }) t.Run("SystemAgentsSavedToStorage", func(t *testing.T) { // System agents should be saved to storage require.NotNil(t, assistant.GetStore(), "storage should be initialized") // Check __yao.keyword in storage builtIn := true tags := []string{"system"} res, err := assistant.GetStore().GetAssistants(store.AssistantFilter{ BuiltIn: &builtIn, Tags: tags, Select: []string{"assistant_id", "name"}, }) require.NoError(t, err) require.Greater(t, len(res.Data), 0, "System agents should be in storage") // Verify at least one system agent exists found := false for _, ast := range res.Data { if ast.ID == "__yao.keyword" || ast.ID == "__yao.querydsl" { found = true break } } assert.True(t, found, "System agents should be found in storage") }) t.Run("SystemAgentsGetFromStorage", func(t *testing.T) { // Clear cache to force loading from storage assistant.GetCache().Clear() // Test Get for each system agent systemAgents := []string{ "__yao.keyword", "__yao.querydsl", "__yao.title", "__yao.prompt", "__yao.needsearch", "__yao.entity", } for _, agentID := range systemAgents { ast, err := assistant.Get(agentID) require.NoError(t, err, "Get(%s) should succeed", agentID) require.NotNil(t, ast, "Get(%s) should return assistant", agentID) assert.Equal(t, agentID, ast.ID) assert.True(t, ast.BuiltIn, "%s should be built-in", agentID) assert.True(t, ast.Readonly, "%s should be readonly", agentID) assert.Contains(t, ast.Tags, "system", "%s should have system tag", agentID) assert.Equal(t, "worker", ast.Type, "%s should be worker type", agentID) assert.NotNil(t, ast.Prompts, "%s should have prompts", agentID) assert.Greater(t, len(ast.Prompts), 0, "%s should have at least one prompt", agentID) } }) } // TestLoadPathSandboxV2 tests loading assistants with V2 sandbox configuration (standalone sandbox.yao) func TestLoadPathSandboxV2(t *testing.T) { prepare(t) defer test.Clean() t.Run("OneshotCLI", func(t *testing.T) { ast, err := assistant.LoadPath("/assistants/tests/sandbox-v2/oneshot-cli") require.NoError(t, err) require.NotNil(t, ast) assert.Equal(t, "Sandbox V2 Oneshot CLI", ast.Name) assert.Contains(t, ast.Tags, "SandboxV2") // V2 sandbox should be loaded from sandbox.yao require.NotNil(t, ast.SandboxV2, "SandboxV2 should be loaded") assert.Equal(t, "2.0", ast.SandboxV2.Version) assert.Equal(t, "yaoapp/tai-sandbox-claude:latest", ast.SandboxV2.Computer.Image) assert.Equal(t, "2GB", ast.SandboxV2.Computer.Memory) assert.Equal(t, float64(2), ast.SandboxV2.Computer.CPUs) assert.Equal(t, "/workspace", ast.SandboxV2.Computer.WorkDir) assert.Equal(t, "claude", ast.SandboxV2.Runner.Name) assert.Equal(t, "cli", ast.SandboxV2.Runner.Mode) assert.Equal(t, "oneshot", ast.SandboxV2.Lifecycle) // Runner options assert.NotNil(t, ast.SandboxV2.Runner.Options) assert.Equal(t, float64(5), ast.SandboxV2.Runner.Options["max_turns"]) // V1 Sandbox should be nil assert.Nil(t, ast.Sandbox, "V1 Sandbox should be nil when V2 is present") // ConfigHash should be computed assert.NotEmpty(t, ast.ConfigHash, "ConfigHash should be computed for V2 sandbox") // HasSandboxV2 helper assert.True(t, ast.HasSandboxV2()) }) t.Run("SessionCLI", func(t *testing.T) { ast, err := assistant.LoadPath("/assistants/tests/sandbox-v2/session-cli") require.NoError(t, err) require.NotNil(t, ast) require.NotNil(t, ast.SandboxV2) assert.Equal(t, "session", ast.SandboxV2.Lifecycle) assert.Equal(t, "10m", ast.SandboxV2.IdleTimeout) // Prepare steps require.Len(t, ast.SandboxV2.Prepare, 1) assert.Equal(t, "exec", ast.SandboxV2.Prepare[0].Action) assert.True(t, ast.SandboxV2.Prepare[0].Once) }) t.Run("LongrunningCLI", func(t *testing.T) { ast, err := assistant.LoadPath("/assistants/tests/sandbox-v2/longrunning-cli") require.NoError(t, err) require.NotNil(t, ast) require.NotNil(t, ast.SandboxV2) assert.Equal(t, "longrunning", ast.SandboxV2.Lifecycle) assert.Equal(t, "15m", ast.SandboxV2.IdleTimeout) assert.Equal(t, "2h", ast.SandboxV2.MaxLifetime) assert.Equal(t, "5s", ast.SandboxV2.StopTimeout) assert.Equal(t, "4GB", ast.SandboxV2.Computer.Memory) assert.Equal(t, "rw", ast.SandboxV2.Computer.MountMode) // Environment assert.Equal(t, "test", ast.SandboxV2.Environment["NODE_ENV"]) assert.Equal(t, "longrunning", ast.SandboxV2.Environment["V2_TEST_MODE"]) // Secrets assert.Equal(t, "sandbox-v2-longrunning-secret", ast.SandboxV2.Secrets["TEST_SECRET"]) // Prepare steps require.Len(t, ast.SandboxV2.Prepare, 3) assert.True(t, ast.SandboxV2.Prepare[2].IgnoreError) // MCP (from package.yao) require.NotNil(t, ast.MCP) require.Len(t, ast.MCP.Servers, 1) assert.Equal(t, "echo", ast.MCP.Servers[0].ServerID) // ConfigHash should include MCP servers hashWithMCP := ast.ConfigHash assert.NotEmpty(t, hashWithMCP) }) t.Run("HooksOnly_YaoRunner", func(t *testing.T) { ast, err := assistant.LoadPath("/assistants/tests/sandbox-v2/hooks-only") require.NoError(t, err) require.NotNil(t, ast) require.NotNil(t, ast.SandboxV2) assert.Equal(t, "yao", ast.SandboxV2.Runner.Name) assert.Equal(t, "oneshot", ast.SandboxV2.Lifecycle) assert.Equal(t, float64(1), ast.SandboxV2.Computer.CPUs) // Runner mode should be empty (yao runner ignores mode) assert.Empty(t, ast.SandboxV2.Runner.Mode) }) t.Run("FullPrepare", func(t *testing.T) { ast, err := assistant.LoadPath("/assistants/tests/sandbox-v2/full-prepare") require.NoError(t, err) require.NotNil(t, ast) require.NotNil(t, ast.SandboxV2) assert.Equal(t, "session", ast.SandboxV2.Lifecycle) assert.Equal(t, "15m", ast.SandboxV2.IdleTimeout) // Prepare: 5 steps with mixed actions require.Len(t, ast.SandboxV2.Prepare, 5) assert.Equal(t, "copy", ast.SandboxV2.Prepare[0].Action) assert.Equal(t, "skills", ast.SandboxV2.Prepare[0].Src) assert.Equal(t, "~/.claude/skills", ast.SandboxV2.Prepare[0].Dst) assert.Equal(t, "exec", ast.SandboxV2.Prepare[1].Action) assert.True(t, ast.SandboxV2.Prepare[1].Once) assert.True(t, ast.SandboxV2.Prepare[3].IgnoreError) // Environment + Secrets assert.Equal(t, "full", ast.SandboxV2.Environment["V2_PREPARE_TEST"]) assert.Equal(t, "v2-full-prepare-key", ast.SandboxV2.Secrets["TEST_API_KEY"]) // Runner options assert.Equal(t, "acceptEdits", ast.SandboxV2.Runner.Options["permission_mode"]) }) t.Run("HostMode", func(t *testing.T) { ast, err := assistant.LoadPath("/assistants/tests/sandbox-v2/host-mode") require.NoError(t, err) require.NotNil(t, ast) require.NotNil(t, ast.SandboxV2) // Host mode: no image assert.Empty(t, ast.SandboxV2.Computer.Image) assert.Equal(t, "/tmp/yao-sandbox-v2-host-test", ast.SandboxV2.Computer.WorkDir) assert.Equal(t, "session", ast.SandboxV2.Lifecycle) }) t.Run("ConfigHashDeterministic", func(t *testing.T) { ast1, err := assistant.LoadPath("/assistants/tests/sandbox-v2/oneshot-cli") require.NoError(t, err) ast2, err := assistant.LoadPath("/assistants/tests/sandbox-v2/oneshot-cli") require.NoError(t, err) assert.Equal(t, ast1.ConfigHash, ast2.ConfigHash, "same config should produce same hash") }) t.Run("ConfigHashDiffers", func(t *testing.T) { ast1, err := assistant.LoadPath("/assistants/tests/sandbox-v2/oneshot-cli") require.NoError(t, err) ast2, err := assistant.LoadPath("/assistants/tests/sandbox-v2/longrunning-cli") require.NoError(t, err) assert.NotEqual(t, ast1.ConfigHash, ast2.ConfigHash, "different configs should produce different hashes") }) } // TestValidate tests the assistant Validate method func TestValidate(t *testing.T) { tests := []struct { name string ast *assistant.Assistant wantErr bool }{ { name: "ValidAssistant", ast: &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: "test-id", Name: "Test Assistant", Connector: "gpt-4o", }, }, wantErr: false, }, { name: "MissingID", ast: &assistant.Assistant{ AssistantModel: store.AssistantModel{ Name: "Test Assistant", Connector: "gpt-4o", }, }, wantErr: true, }, { name: "MissingName", ast: &assistant.Assistant{ AssistantModel: store.AssistantModel{ ID: "test-id", Connector: "gpt-4o", }, }, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.ast.Validate() if tt.wantErr { assert.Error(t, err) } else { assert.NoError(t, err) } }) } } ================================================ FILE: agent/assistant/mcp.go ================================================ package assistant import ( "context" "fmt" "strings" jsoniter "github.com/json-iterator/go" gouJson "github.com/yaoapp/gou/json" "github.com/yaoapp/gou/mcp" mcpTypes "github.com/yaoapp/gou/mcp/types" agentContext "github.com/yaoapp/yao/agent/context" storeTypes "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/trace/types" ) const ( // MaxMCPTools maximum number of MCP tools to include (to avoid overwhelming the LLM) MaxMCPTools = 20 ) // MCPToolName formats a tool name with MCP server prefix // Format: server_id__tool_name (double underscore separator) // Dots in server_id are replaced with single underscores // Examples: // - ("echo", "ping") → "echo__ping" // - ("github.enterprise", "search") → "github_enterprise__search" // // Naming constraint: MCP server_id MUST NOT contain underscores (_) // Only dots (.), letters, numbers, and hyphens (-) are allowed in server_id func MCPToolName(serverID, toolName string) string { if serverID == "" || toolName == "" { return "" } // Replace dots with single underscores in server_id cleanServerID := strings.ReplaceAll(serverID, ".", "_") // Use double underscore as separator return fmt.Sprintf("%s__%s", cleanServerID, toolName) } // ParseMCPToolName parses a formatted MCP tool name into server ID and tool name // Splits by double underscore (__), then restores dots in server_id // Examples: // - "echo__ping" → ("echo", "ping") // - "github_enterprise__search" → ("github.enterprise", "search") // // Returns (serverID, toolName, true) if valid format, ("", "", false) otherwise func ParseMCPToolName(formattedName string) (string, string, bool) { if formattedName == "" { return "", "", false } // Split by double underscore parts := strings.Split(formattedName, "__") if len(parts) != 2 { return "", "", false } cleanServerID := parts[0] toolName := parts[1] // Validate that both parts are non-empty if cleanServerID == "" || toolName == "" { return "", "", false } // Restore dots in server_id (replace single underscores back to dots) serverID := strings.ReplaceAll(cleanServerID, "_", ".") return serverID, toolName, true } // buildMCPTools builds tool definitions and samples system prompt from MCP servers // Returns (tools, samplesPrompt, error) func (ast *Assistant) buildMCPTools(ctx *agentContext.Context, createResponse *agentContext.HookCreateResponse) ([]MCPTool, string, error) { // Determine which MCP servers to use: hook's or assistant's (hook takes precedence) var servers []storeTypes.MCPServerConfig // If hook provides MCP servers, use those (override) if createResponse != nil && len(createResponse.MCPServers) > 0 { servers = make([]storeTypes.MCPServerConfig, len(createResponse.MCPServers)) for i, hookServer := range createResponse.MCPServers { // Convert context.MCPServerConfig to storeTypes.MCPServerConfig servers[i] = storeTypes.MCPServerConfig{ ServerID: hookServer.ServerID, Tools: hookServer.Tools, Resources: hookServer.Resources, } } } else if ast.MCP != nil && len(ast.MCP.Servers) > 0 { // Otherwise, use assistant's configured servers servers = ast.MCP.Servers } else { // No servers configured return nil, "", nil } // Use the agent context for cancellation and timeout control mcpCtx := ctx.Context if mcpCtx == nil { mcpCtx = context.Background() } allTools := make([]MCPTool, 0) samplesBuilder := strings.Builder{} hasSamples := false // Process each MCP server in order for _, serverConfig := range servers { if len(allTools) >= MaxMCPTools { ctx.Logger.Warn("Reached maximum tool limit (%d), skipping remaining servers", MaxMCPTools) break } // Get MCP client client, err := mcp.Select(serverConfig.ServerID) if err != nil { ctx.Logger.Warn("Failed to select MCP client '%s': %v", serverConfig.ServerID, err) continue } // Get tools list (filter by serverConfig.Tools if specified) toolsResponse, err := client.ListTools(mcpCtx, "") if err != nil { ctx.Logger.Warn("Failed to list tools for '%s': %v", serverConfig.ServerID, err) continue } // Build tool filter map if specified toolFilter := make(map[string]bool) if len(serverConfig.Tools) > 0 { for _, toolName := range serverConfig.Tools { toolFilter[toolName] = true } } // Process each tool for _, tool := range toolsResponse.Tools { // Check tool limit if len(allTools) >= MaxMCPTools { break } // Apply tool filter if specified if len(toolFilter) > 0 && !toolFilter[tool.Name] { continue } // Format tool name with server prefix formattedName := MCPToolName(serverConfig.ServerID, tool.Name) // Convert MCP tool to MCPTool format mcpTool := MCPTool{ Name: formattedName, Description: tool.Description, Parameters: tool.InputSchema, } allTools = append(allTools, mcpTool) // Try to get samples for this tool samples, err := client.ListSamples(mcpCtx, mcpTypes.SampleTool, tool.Name) if err == nil && len(samples.Samples) > 0 { if !hasSamples { samplesBuilder.WriteString("\n\n## MCP Tool Usage Examples\n\n") samplesBuilder.WriteString("The following examples demonstrate how to use MCP tools correctly:\n\n") hasSamples = true } samplesBuilder.WriteString(fmt.Sprintf("### %s\n\n", formattedName)) if tool.Description != "" { samplesBuilder.WriteString(fmt.Sprintf("**Description**: %s\n\n", tool.Description)) } for i, sample := range samples.Samples { if i >= 3 { // Limit to 3 examples per tool break } samplesBuilder.WriteString(fmt.Sprintf("**Example %d", i+1)) if sample.Name != "" { samplesBuilder.WriteString(fmt.Sprintf(" - %s", sample.Name)) } samplesBuilder.WriteString("**:\n") // Check metadata for description if sample.Metadata != nil { if desc, ok := sample.Metadata["description"].(string); ok && desc != "" { samplesBuilder.WriteString(fmt.Sprintf("- Description: %s\n", desc)) } } if sample.Input != nil { samplesBuilder.WriteString(fmt.Sprintf("- Input: `%v`\n", sample.Input)) } if sample.Output != nil { samplesBuilder.WriteString(fmt.Sprintf("- Output: `%v`\n", sample.Output)) } samplesBuilder.WriteString("\n") } } } ctx.Logger.Debug("Loaded %d tools from server '%s'", len(toolsResponse.Tools), serverConfig.ServerID) } samplesPrompt := "" if hasSamples { samplesPrompt = samplesBuilder.String() } ctx.Logger.Debug("Total MCP tools loaded: %d", len(allTools)) return allTools, samplesPrompt, nil } // ToolCallResult represents the result of a tool call execution // executeToolCalls executes tool calls with intelligent strategy and trace logging: // - Single tool: use CallTool, single trace node // - Multiple tools: use CallToolsParallel with parallel trace nodes, fallback to sequential on certain errors // Returns (results, hasErrors) func (ast *Assistant) executeToolCalls(ctx *agentContext.Context, toolCalls []agentContext.ToolCall, attempt int) ([]ToolCallResult, bool) { if len(toolCalls) == 0 { return nil, false } ctx.Logger.Debug("Executing %d tool calls (attempt %d)", len(toolCalls), attempt) // Single tool call if len(toolCalls) == 1 { return ast.executeSingleToolCall(ctx, toolCalls[0]) } // Multiple tool calls - try parallel first return ast.executeMultipleToolCallsParallel(ctx, toolCalls) } // executeSingleToolCall executes a single tool call with trace logging func (ast *Assistant) executeSingleToolCall(ctx *agentContext.Context, toolCall agentContext.ToolCall) ([]ToolCallResult, bool) { ctx.Logger.ToolStart(toolCall.Function.Name) trace, _ := ctx.Trace() // Use the agent context for cancellation and timeout control mcpCtx := ctx.Context if mcpCtx == nil { mcpCtx = context.Background() } result := ToolCallResult{ ToolCallID: toolCall.ID, Name: toolCall.Function.Name, } // Parse tool name serverID, toolName, ok := ParseMCPToolName(toolCall.Function.Name) if !ok { result.Error = fmt.Errorf("invalid MCP tool name format: %s", toolCall.Function.Name) result.Content = result.Error.Error() ctx.Logger.Error("Invalid MCP tool name format: %s", toolCall.Function.Name) ctx.Logger.ToolComplete(toolCall.Function.Name, false) return []ToolCallResult{result}, true } // Get MCP client client, err := mcp.Select(serverID) if err != nil { result.Error = fmt.Errorf("failed to select MCP client '%s': %w", serverID, err) result.Content = result.Error.Error() result.IsRetryableError = false // MCP client selection error is not retryable ctx.Logger.Error("Failed to select MCP client '%s': %v", serverID, err) ctx.Logger.ToolComplete(toolCall.Function.Name, false) return []ToolCallResult{result}, true } // Get tool info for description and schema toolsResponse, err := client.ListTools(mcpCtx, "") var toolDescription string var toolSchema interface{} if err == nil { for _, t := range toolsResponse.Tools { if t.Name == toolName { toolDescription = t.Description toolSchema = t.InputSchema break } } } if toolDescription == "" { toolDescription = fmt.Sprintf("MCP tool '%s'", toolName) } // Add trace node for this tool call var toolNode types.Node if trace != nil { toolNode, _ = trace.Add( map[string]any{ "tool_call_id": toolCall.ID, "server": serverID, "tool": toolName, "arguments": toolCall.Function.Arguments, }, types.TraceNodeOption{ Label: toolDescription, Type: "mcp_tool", Icon: "build", Description: fmt.Sprintf("Calling '%s' on server '%s'", toolName, serverID), }, ) } // Parse arguments with repair support for better tolerance var args map[string]interface{} if toolCall.Function.Arguments != "" { parsed, err := gouJson.Parse(toolCall.Function.Arguments) if err != nil { result.Error = fmt.Errorf("failed to parse arguments: %w", err) result.Content = result.Error.Error() result.IsRetryableError = true // Argument parsing error is retryable by LLM ctx.Logger.Error("Failed to parse arguments: %v", err) ctx.Logger.ToolComplete(toolCall.Function.Name, false) if toolNode != nil { toolNode.Fail(result.Error) } return []ToolCallResult{result}, true } // Convert to map if argsMap, ok := parsed.(map[string]interface{}); ok { args = argsMap } else { result.Error = fmt.Errorf("arguments must be an object, got %T", parsed) result.Content = result.Error.Error() result.IsRetryableError = true // Type error is retryable by LLM ctx.Logger.Error("Arguments must be an object, got %T", parsed) ctx.Logger.ToolComplete(toolCall.Function.Name, false) if toolNode != nil { toolNode.Fail(result.Error) } return []ToolCallResult{result}, true } // Validate arguments against tool schema if available if toolSchema != nil { if err := gouJson.Validate(args, toolSchema); err != nil { result.Error = fmt.Errorf("argument validation failed: %w", err) result.Content = result.Error.Error() result.IsRetryableError = true // Validation error is retryable by LLM ctx.Logger.Error("Argument validation failed: %v", err) ctx.Logger.ToolComplete(toolCall.Function.Name, false) if toolNode != nil { toolNode.Fail(result.Error) } return []ToolCallResult{result}, true } } } // Call the tool with agent context as extra argument ctx.Logger.Debug("Calling tool: %s (server: %s)", toolName, serverID) // Pass agent context as extra argument (only used for Process transport) callResult, err := client.CallTool(mcpCtx, toolName, args, ctx) if err != nil { result.Error = fmt.Errorf("tool call failed: %w", err) result.Content = result.Error.Error() // Check if error is retryable (parameter/validation errors) result.IsRetryableError = isRetryableToolError(err) ctx.Logger.Error("Tool call failed: %v (retryable: %v)", err, result.IsRetryableError) ctx.Logger.ToolComplete(toolCall.Function.Name, false) if toolNode != nil { toolNode.Fail(result.Error) } return []ToolCallResult{result}, true } // Check if result is an error if callResult.IsError { result.Error = fmt.Errorf("MCP tool error") result.IsRetryableError = false // MCP internal error is not retryable } // Serialize the Content field only ([]ToolContent) contentBytes, err := jsoniter.Marshal(callResult.Content) if err != nil { result.Error = fmt.Errorf("failed to serialize result: %w", err) result.Content = result.Error.Error() result.IsRetryableError = false ctx.Logger.Error("Failed to serialize result: %v", err) ctx.Logger.ToolComplete(toolCall.Function.Name, false) if toolNode != nil { toolNode.Fail(result.Error) } return []ToolCallResult{result}, true } result.Content = string(contentBytes) ctx.Logger.ToolComplete(toolCall.Function.Name, true) if toolNode != nil { toolNode.Complete(map[string]any{ "result": callResult, }) } return []ToolCallResult{result}, false } // executeMultipleToolCallsParallel executes multiple tool calls in parallel with trace logging func (ast *Assistant) executeMultipleToolCallsParallel(ctx *agentContext.Context, toolCalls []agentContext.ToolCall) ([]ToolCallResult, bool) { trace, _ := ctx.Trace() // Use the agent context for cancellation and timeout control mcpCtx := ctx.Context if mcpCtx == nil { mcpCtx = context.Background() } // Group tool calls by server serverGroups := make(map[string][]agentContext.ToolCall) for _, tc := range toolCalls { serverID, _, ok := ParseMCPToolName(tc.Function.Name) if !ok { ctx.Logger.Warn("Invalid tool name format: %s", tc.Function.Name) continue } serverGroups[serverID] = append(serverGroups[serverID], tc) } results := make([]ToolCallResult, 0, len(toolCalls)) hasErrors := false // Process each server's tools for serverID, calls := range serverGroups { client, err := mcp.Select(serverID) if err != nil { ctx.Logger.Error("Failed to select MCP client '%s': %v", serverID, err) // Add error results for all calls to this server for _, tc := range calls { results = append(results, ToolCallResult{ ToolCallID: tc.ID, Name: tc.Function.Name, Content: fmt.Sprintf("Failed to select MCP client: %v", err), Error: err, }) } hasErrors = true continue } // Try parallel execution serverResults, serverHasErrors := ast.executeServerToolsParallelWithTrace( mcpCtx, ctx, trace, client, serverID, calls, ) // If parallel execution failed with retryable error, try sequential if serverHasErrors && ast.shouldRetrySequential(serverResults) { ctx.Logger.Warn("Parallel execution had parameter errors for server '%s', retrying sequentially", serverID) serverResults, serverHasErrors = ast.executeServerToolsSequentialWithTrace( mcpCtx, ctx, trace, client, serverID, calls, ) } results = append(results, serverResults...) if serverHasErrors { hasErrors = true } } return results, hasErrors } // isRetryableToolError checks if an error is retryable by LLM (parameter/validation errors) // Returns true for errors that LLM can potentially fix by adjusting parameters // Returns false for MCP internal errors (network, auth, service unavailable, etc.) func isRetryableToolError(err error) bool { if err == nil { return false } errMsg := strings.ToLower(err.Error()) // These are NOT retryable (MCP internal issues) nonRetryablePatterns := []string{ "network", "timeout", "connection", "unauthorized", "forbidden", "unavailable", "failed to select", "context canceled", "context deadline", "server error", "internal error", } for _, pattern := range nonRetryablePatterns { if strings.Contains(errMsg, pattern) { return false } } // These ARE retryable (parameter/validation issues LLM can fix) retryablePatterns := []string{ "invalid", "required", "missing", "validation", "schema", "type", "format", "parse", "argument", "parameter", } for _, pattern := range retryablePatterns { if strings.Contains(errMsg, pattern) { return true } } // Default: assume it's retryable unless proven otherwise // This allows LLM to attempt fixes for unknown error types return true } // shouldRetrySequential checks if errors are retryable (parameter issues, not network/service issues) func (ast *Assistant) shouldRetrySequential(results []ToolCallResult) bool { // Check if any result has a retryable error hasRetryable := false for _, result := range results { if result.Error != nil && result.IsRetryableError { hasRetryable = true break } } return hasRetryable } // executeServerToolsParallelWithTrace executes tools for a single server in parallel with trace func (ast *Assistant) executeServerToolsParallelWithTrace(mcpCtx context.Context, ctx *agentContext.Context, trace types.Manager, client mcp.Client, serverID string, toolCalls []agentContext.ToolCall) ([]ToolCallResult, bool) { // Prepare parallel trace inputs var parallelInputs []types.TraceParallelInput mcpCalls := make([]mcpTypes.ToolCall, 0, len(toolCalls)) callMap := make(map[string]agentContext.ToolCall) for _, tc := range toolCalls { _, toolName, ok := ParseMCPToolName(tc.Function.Name) if !ok { continue } var args map[string]interface{} if tc.Function.Arguments != "" { if err := jsoniter.UnmarshalFromString(tc.Function.Arguments, &args); err != nil { ctx.Logger.Error("Failed to parse arguments for %s: %v", toolName, err) continue } } mcpCalls = append(mcpCalls, mcpTypes.ToolCall{ Name: toolName, Arguments: args, }) callMap[toolName] = tc ctx.Logger.ToolStart(tc.Function.Name) // Add trace input for this tool parallelInputs = append(parallelInputs, types.TraceParallelInput{ Input: map[string]any{ "tool_call_id": tc.ID, "server": serverID, "tool": toolName, "arguments": tc.Function.Arguments, }, Option: types.TraceNodeOption{ Label: fmt.Sprintf("Tool: %s", toolName), Type: "mcp_tool", Icon: "build", Description: fmt.Sprintf("Calling MCP tool '%s' on server '%s'", toolName, serverID), }, }) } // Create parallel trace nodes var toolNodes []types.Node if trace != nil && len(parallelInputs) > 0 { var err error toolNodes, err = trace.Parallel(parallelInputs) if err != nil { ctx.Logger.Debug("trace.Parallel() failed: %v", err) } } // Call tools in parallel with agent context as extra argument ctx.Logger.Debug("Calling %d tools in parallel on server '%s'", len(mcpCalls), serverID) // Pass agent context as extra argument (only used for Process transport) mcpResponse, err := client.CallToolsParallel(mcpCtx, mcpCalls, ctx) if err != nil { ctx.Logger.Error("Parallel call failed: %v", err) for i, node := range toolNodes { if node != nil { node.Fail(err) } if i < len(mcpCalls) { if tc, ok := callMap[mcpCalls[i].Name]; ok { ctx.Logger.ToolComplete(tc.Function.Name, false) } } } return nil, true } // Process results results := make([]ToolCallResult, 0, len(mcpResponse.Results)) hasErrors := false for i, mcpResult := range mcpResponse.Results { toolName := mcpCalls[i].Name originalCall := callMap[toolName] var toolNode types.Node if i < len(toolNodes) { toolNode = toolNodes[i] } result := ToolCallResult{ ToolCallID: originalCall.ID, Name: originalCall.Function.Name, } // Serialize content contentBytes, err := jsoniter.Marshal(mcpResult.Content) if err != nil { result.Error = fmt.Errorf("failed to serialize result: %w", err) result.Content = result.Error.Error() result.IsRetryableError = false // Serialization error is not retryable hasErrors = true ctx.Logger.ToolComplete(originalCall.Function.Name, false) if toolNode != nil { toolNode.Fail(result.Error) } } else { result.Content = string(contentBytes) // Check if it's an error result if mcpResult.IsError { result.Error = fmt.Errorf("tool call error: %s", result.Content) result.IsRetryableError = isRetryableToolError(result.Error) hasErrors = true ctx.Logger.Error("Tool call failed: %s - %s (retryable: %v)", toolName, result.Content, result.IsRetryableError) ctx.Logger.ToolComplete(originalCall.Function.Name, false) if toolNode != nil { toolNode.Fail(result.Error) } } else { ctx.Logger.ToolComplete(originalCall.Function.Name, true) if toolNode != nil { toolNode.Complete(map[string]any{ "result": mcpResult.Content, }) } } } results = append(results, result) } return results, hasErrors } // executeServerToolsSequentialWithTrace executes tools for a single server sequentially with trace func (ast *Assistant) executeServerToolsSequentialWithTrace(mcpCtx context.Context, ctx *agentContext.Context, trace types.Manager, client mcp.Client, serverID string, toolCalls []agentContext.ToolCall) ([]ToolCallResult, bool) { results := make([]ToolCallResult, 0, len(toolCalls)) hasErrors := false ctx.Logger.Debug("Calling %d tools sequentially on server '%s'", len(toolCalls), serverID) for _, tc := range toolCalls { ctx.Logger.ToolStart(tc.Function.Name) _, toolName, ok := ParseMCPToolName(tc.Function.Name) if !ok { results = append(results, ToolCallResult{ ToolCallID: tc.ID, Name: tc.Function.Name, Content: fmt.Sprintf("Invalid tool name format: %s", tc.Function.Name), Error: fmt.Errorf("invalid tool name format"), }) ctx.Logger.ToolComplete(tc.Function.Name, false) hasErrors = true continue } // Get tool schema for validation toolsResponse, err := client.ListTools(mcpCtx, "") var toolSchema interface{} if err == nil { for _, t := range toolsResponse.Tools { if t.Name == toolName { toolSchema = t.InputSchema break } } } // Add trace node for this tool call var toolNode types.Node if trace != nil { toolNode, _ = trace.Add( map[string]any{ "tool_call_id": tc.ID, "server": serverID, "tool": toolName, "arguments": tc.Function.Arguments, }, types.TraceNodeOption{ Label: fmt.Sprintf("Tool: %s (sequential retry)", toolName), Type: "mcp_tool", Icon: "build", Description: fmt.Sprintf("Retrying MCP tool '%s' on server '%s' sequentially", toolName, serverID), }, ) } // Parse arguments with repair support var args map[string]interface{} if tc.Function.Arguments != "" { parsed, err := gouJson.Parse(tc.Function.Arguments) if err != nil { result := ToolCallResult{ ToolCallID: tc.ID, Name: tc.Function.Name, Content: fmt.Sprintf("Failed to parse arguments: %v", err), Error: err, IsRetryableError: true, // Parsing error is retryable } results = append(results, result) hasErrors = true ctx.Logger.ToolComplete(tc.Function.Name, false) if toolNode != nil { toolNode.Fail(err) } continue } // Convert to map if argsMap, ok := parsed.(map[string]interface{}); ok { args = argsMap } else { err := fmt.Errorf("arguments must be an object, got %T", parsed) result := ToolCallResult{ ToolCallID: tc.ID, Name: tc.Function.Name, Content: err.Error(), Error: err, IsRetryableError: true, // Type error is retryable } results = append(results, result) hasErrors = true ctx.Logger.ToolComplete(tc.Function.Name, false) if toolNode != nil { toolNode.Fail(err) } continue } // Validate arguments against tool schema if available if toolSchema != nil { if err := gouJson.Validate(args, toolSchema); err != nil { result := ToolCallResult{ ToolCallID: tc.ID, Name: tc.Function.Name, Content: fmt.Sprintf("Argument validation failed: %v", err), Error: err, IsRetryableError: true, // Validation error is retryable } results = append(results, result) hasErrors = true ctx.Logger.ToolComplete(tc.Function.Name, false) if toolNode != nil { toolNode.Fail(err) } continue } } } // Call single tool with agent context as extra argument ctx.Logger.Debug("Calling tool: %s", toolName) mcpResult, err := client.CallTool(mcpCtx, toolName, args, ctx) result := ToolCallResult{ ToolCallID: tc.ID, Name: tc.Function.Name, } if err != nil { result.Error = err result.Content = fmt.Sprintf("Tool call failed: %v", err) result.IsRetryableError = isRetryableToolError(err) hasErrors = true ctx.Logger.Error("Tool call failed: %s - %v (retryable: %v)", toolName, err, result.IsRetryableError) ctx.Logger.ToolComplete(tc.Function.Name, false) if toolNode != nil { toolNode.Fail(err) } } else { // Check if result is an error if mcpResult.IsError { result.Error = fmt.Errorf("MCP tool error") result.IsRetryableError = false // MCP internal error is not retryable hasErrors = true } // Serialize the Content field only ([]ToolContent) contentBytes, err := jsoniter.Marshal(mcpResult.Content) if err != nil { result.Error = err result.Content = fmt.Sprintf("Failed to serialize result: %v", err) result.IsRetryableError = false // Serialization error is not retryable hasErrors = true ctx.Logger.ToolComplete(tc.Function.Name, false) if toolNode != nil { toolNode.Fail(err) } } else { result.Content = string(contentBytes) ctx.Logger.ToolComplete(tc.Function.Name, !mcpResult.IsError) if toolNode != nil { toolNode.Complete(map[string]any{ "result": mcpResult.Content, }) } } } results = append(results, result) } return results, hasErrors } ================================================ FILE: agent/assistant/mcp_test.go ================================================ package assistant_test import ( "context" "testing" jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/mcp" mcpTypes "github.com/yaoapp/gou/mcp/types" "github.com/yaoapp/yao/agent/assistant" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) func TestMCPToolName(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) tests := []struct { name string serverID string toolName string wantResult string }{ { name: "Simple tool name", serverID: "github", toolName: "search_repos", wantResult: "github__search_repos", }, { name: "Server with dots", serverID: "github.enterprise", toolName: "search_repos", wantResult: "github_enterprise__search_repos", }, { name: "Tool with underscores", serverID: "customer-db", toolName: "create_customer", wantResult: "customer-db__create_customer", }, { name: "Complex server with multiple dots", serverID: "com.example.mcp", toolName: "tool_name", wantResult: "com_example_mcp__tool_name", }, { name: "Empty server ID", serverID: "", toolName: "tool", wantResult: "", }, { name: "Empty tool name", serverID: "server", toolName: "", wantResult: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := assistant.MCPToolName(tt.serverID, tt.toolName) if result != tt.wantResult { t.Errorf("MCPToolName() = %v, want %v", result, tt.wantResult) } }) } } func TestParseMCPToolName(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) tests := []struct { name string formattedName string wantServerID string wantToolName string wantOK bool }{ { name: "Valid simple format", formattedName: "github__search_repos", wantServerID: "github", wantToolName: "search_repos", wantOK: true, }, { name: "Server with dots restored", formattedName: "github_enterprise__search_repos", wantServerID: "github.enterprise", wantToolName: "search_repos", wantOK: true, }, { name: "Complex server ID with multiple dots", formattedName: "com_example_mcp_server__tool_name", wantServerID: "com.example.mcp.server", wantToolName: "tool_name", wantOK: true, }, { name: "Tool name with underscores", formattedName: "server__create_new_user", wantServerID: "server", wantToolName: "create_new_user", wantOK: true, }, { name: "Server with hyphens", formattedName: "mcp-server__tool", wantServerID: "mcp-server", wantToolName: "tool", wantOK: true, }, { name: "Invalid format - no double underscore", formattedName: "invalid", wantServerID: "", wantToolName: "", wantOK: false, }, { name: "Invalid format - empty string", formattedName: "", wantServerID: "", wantToolName: "", wantOK: false, }, { name: "Invalid format - only double underscore", formattedName: "__", wantServerID: "", wantToolName: "", wantOK: false, }, { name: "Invalid format - ends with double underscore", formattedName: "server__", wantServerID: "", wantToolName: "", wantOK: false, }, { name: "Invalid format - starts with double underscore", formattedName: "__tool", wantServerID: "", wantToolName: "", wantOK: false, }, { name: "Invalid format - multiple double underscores", formattedName: "server__middle__tool", wantServerID: "", wantToolName: "", wantOK: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { serverID, toolName, ok := assistant.ParseMCPToolName(tt.formattedName) if serverID != tt.wantServerID { t.Errorf("ParseMCPToolName() serverID = %v, want %v", serverID, tt.wantServerID) } if toolName != tt.wantToolName { t.Errorf("ParseMCPToolName() toolName = %v, want %v", toolName, tt.wantToolName) } if ok != tt.wantOK { t.Errorf("ParseMCPToolName() ok = %v, want %v", ok, tt.wantOK) } }) } } func TestMCPToolName_RoundTrip(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) tests := []struct { name string serverID string toolName string }{ { name: "Simple IDs", serverID: "github", toolName: "search_repos", }, { name: "Server with dots", serverID: "github.enterprise", toolName: "search", }, { name: "Complex server ID", serverID: "com.example.mcp.server", toolName: "tool_name", }, { name: "Server with dashes", serverID: "mcp-server-123", toolName: "tool_with_underscores", }, { name: "Mixed dots and dashes", serverID: "github.enterprise-prod", toolName: "api_call", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Format formatted := assistant.MCPToolName(tt.serverID, tt.toolName) if formatted == "" { t.Fatal("MCPToolName() returned empty string") } // Parse serverID, toolName, ok := assistant.ParseMCPToolName(formatted) // Verify round-trip if !ok { t.Fatal("ParseMCPToolName() failed") } if serverID != tt.serverID { t.Errorf("Round-trip failed: serverID = %v, want %v", serverID, tt.serverID) } if toolName != tt.toolName { t.Errorf("Round-trip failed: toolName = %v, want %v", toolName, tt.toolName) } t.Logf("✓ Round-trip successful: (%s, %s) → %s → (%s, %s)", tt.serverID, tt.toolName, formatted, serverID, toolName) }) } } // TestMCPToolContextPassing tests that agent context is correctly passed to MCP tools func TestMCPToolContextPassing(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) // Get the echo MCP client client, err := mcp.Select("echo") assert.NoError(t, err, "Failed to select echo MCP client") assert.NotNil(t, client, "MCP client should not be nil") // Create a test agent context authorized := &types.AuthorizedInfo{ UserID: "test-user-123", TenantID: "test-tenant-456", } ctx := agentContext.New(context.Background(), authorized, "test-chat-789") ctx.AssistantID = "test-assistant-mcptest" ctx.Locale = "en" ctx.Theme = "dark" // Call the echo tool with context args := map[string]interface{}{ "message": "test message from context test", } // Call the tool - the agent context will be passed as extra parameter result, err := client.CallTool(ctx.Context, "echo", args, ctx) assert.NoError(t, err, "CallTool should not return error") assert.NotNil(t, result, "Result should not be nil") assert.False(t, result.IsError, "Result should not be an error") assert.Greater(t, len(result.Content), 0, "Result should have content") // Parse the result content var echoResult map[string]interface{} err = jsoniter.Unmarshal([]byte(result.Content[0].Text), &echoResult) assert.NoError(t, err, "Failed to parse result content") t.Logf("Echo result: %+v", echoResult) // Verify the context was received contextData, ok := echoResult["context"].(map[string]interface{}) assert.True(t, ok, "Result should contain context field") assert.NotNil(t, contextData, "Context data should not be nil") // Verify context has_context flag hasContext, ok := contextData["has_context"].(bool) assert.True(t, ok, "Context should have has_context field") assert.True(t, hasContext, "Context should indicate it has context") // Verify chat_id and assistant_id have values (main verification) chatID, ok := contextData["chat_id"].(string) assert.True(t, ok, "Context should have chat_id field") assert.NotEmpty(t, chatID, "chat_id should have a value") assert.Equal(t, "test-chat-789", chatID, "chat_id should match") assistantID, ok := contextData["assistant_id"].(string) assert.True(t, ok, "Context should have assistant_id field") assert.NotEmpty(t, assistantID, "assistant_id should have a value") assert.Equal(t, "test-assistant-mcptest", assistantID, "assistant_id should match") // Verify authorized information authorizedData, ok := contextData["authorized"].(map[string]interface{}) assert.True(t, ok, "Context should have authorized field") assert.NotNil(t, authorizedData, "Authorized data should not be nil") userID, ok := authorizedData["user_id"].(string) assert.True(t, ok, "Authorized should have user_id field") assert.Equal(t, "test-user-123", userID, "User ID should match") tenantID, ok := authorizedData["tenant_id"].(string) assert.True(t, ok, "Authorized should have tenant_id field") assert.Equal(t, "test-tenant-456", tenantID, "Tenant ID should match") t.Logf("✓ Context successfully passed to MCP tool") t.Logf(" - ChatID: %s", chatID) t.Logf(" - AssistantID: %s", assistantID) t.Logf(" - UserID: %s", userID) t.Logf(" - TenantID: %s", tenantID) } // TestMCPToolContextPassingParallel tests that agent context is correctly passed in parallel calls func TestMCPToolContextPassingParallel(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) // Get the echo MCP client client, err := mcp.Select("echo") assert.NoError(t, err, "Failed to select echo MCP client") assert.NotNil(t, client, "MCP client should not be nil") // Create a test agent context authorized := &types.AuthorizedInfo{ UserID: "parallel-user-123", TenantID: "parallel-tenant-456", } ctx := agentContext.New(context.Background(), authorized, "parallel-chat-789") ctx.AssistantID = "test-assistant-parallel" ctx.Locale = "zh-CN" // Call multiple echo tools in parallel toolCalls := []mcpTypes.ToolCall{ { Name: "echo", Arguments: map[string]interface{}{ "message": "parallel message 1", }, }, { Name: "echo", Arguments: map[string]interface{}{ "message": "parallel message 2", }, }, } // Call tools in parallel - the agent context will be passed as extra parameter results, err := client.CallToolsParallel(ctx.Context, toolCalls, ctx) assert.NoError(t, err, "CallToolsParallel should not return error") assert.NotNil(t, results, "Results should not be nil") assert.Equal(t, 2, len(results.Results), "Should have 2 results") // Verify both results received the context for i, result := range results.Results { assert.False(t, result.IsError, "Result %d should not be an error", i) assert.Greater(t, len(result.Content), 0, "Result %d should have content", i) // Parse the result content var echoResult map[string]interface{} err = jsoniter.Unmarshal([]byte(result.Content[0].Text), &echoResult) assert.NoError(t, err, "Failed to parse result %d content", i) // Verify the context was received contextData, ok := echoResult["context"].(map[string]interface{}) assert.True(t, ok, "Result %d should contain context field", i) assert.NotNil(t, contextData, "Context data %d should not be nil", i) hasContext, ok := contextData["has_context"].(bool) assert.True(t, ok, "Context %d should have has_context field", i) assert.True(t, hasContext, "Context %d should indicate it has context", i) // Verify chat_id in parallel call chatID, ok := contextData["chat_id"].(string) assert.True(t, ok, "Context %d should have chat_id field", i) assert.Equal(t, "parallel-chat-789", chatID, "Chat ID in result %d should match", i) // Verify authorized information in parallel call authorizedData, ok := contextData["authorized"].(map[string]interface{}) assert.True(t, ok, "Context %d should have authorized field", i) if userID, ok := authorizedData["user_id"].(string); ok { assert.Equal(t, "parallel-user-123", userID, "User ID in result %d should match", i) } t.Logf("✓ Result %d successfully received context", i) } t.Log("✓ Context successfully passed to all parallel MCP tool calls") } ================================================ FILE: agent/assistant/next.go ================================================ package assistant import ( "fmt" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" ) // processNextResponse processes the Next hook's response and handles agent delegation or custom data func (ast *Assistant) processNextResponse(npc *NextProcessContext) (*agentContext.Response, error) { // If no Next hook response, return standard response if npc.NextResponse == nil { return ast.buildStandardResponse(npc), nil } // Handle Delegate: call another agent // Note: User input is already buffered by root agent, delegated agent will skip buffering if npc.NextResponse.Delegate != nil { return ast.handleDelegation(npc.Context, npc.NextResponse.Delegate, npc.StreamHandler) } // Handle custom Data: return as-is wrapped in standard Response if npc.NextResponse.Data != nil { return &agentContext.Response{ ContextID: npc.Context.ID, RequestID: npc.Context.RequestID(), TraceID: npc.Context.TraceID(), ChatID: npc.Context.ChatID, AssistantID: ast.ID, Create: npc.CreateResponse, Next: npc.NextResponse.Data, // Put custom data in Next field Completion: npc.CompletionResponse, Tools: npc.ToolCallResponses, }, nil } // No delegate or data, return standard response return ast.buildStandardResponse(npc), nil } // handleDelegation handles calling another agent based on DelegateConfig func (ast *Assistant) handleDelegation( ctx *agentContext.Context, delegate *agentContext.DelegateConfig, streamHandler func(message.StreamChunkType, []byte) int, ) (*agentContext.Response, error) { // Load the target assistant targetAssistant, err := Get(delegate.AgentID) if err != nil { return nil, fmt.Errorf("failed to load delegated assistant '%s': %w", delegate.AgentID, err) } // Mark this as an agent-to-agent call for proper source tracking ctx.Referer = agentContext.RefererAgent // Call the delegated assistant with the same context // The delegated assistant's Stream method will: // 1. Call EnterStack() to push itself onto the Stack (creating parent-child relationship) // 2. Execute with the same Context (preserving ID, Space, Writer, etc.) // 3. Call done() to pop from Stack when finished // This ensures proper Stack tracing: parent assistant -> delegated assistant // Convert options map from delegate config to Options struct delegateOpts := agentContext.OptionsFromMap(delegate.Options) return targetAssistant.Stream(ctx, delegate.Messages, delegateOpts) } // buildStandardResponse builds the standard agent response when no custom Next hook processing is needed func (ast *Assistant) buildStandardResponse(npc *NextProcessContext) *agentContext.Response { var next interface{} = nil if npc.NextResponse != nil { next = npc.NextResponse } return &agentContext.Response{ ContextID: npc.Context.ID, RequestID: npc.Context.RequestID(), TraceID: npc.Context.TraceID(), ChatID: npc.Context.ChatID, AssistantID: ast.ID, Create: npc.CreateResponse, Next: next, Completion: npc.CompletionResponse, Tools: npc.ToolCallResponses, } } ================================================ FILE: agent/assistant/permission.go ================================================ package assistant import ( "fmt" "github.com/yaoapp/yao/agent/context" ) func (ast *Assistant) checkPermissions(ctx *context.Context) error { if ctx.Authorized == nil { return fmt.Errorf("authorized information not found") } return nil } ================================================ FILE: agent/assistant/sandbox.go ================================================ package assistant import ( stdContext "context" "encoding/json" "fmt" "os" "path/filepath" "sync" "time" "github.com/yaoapp/gou/connector" gouMCP "github.com/yaoapp/gou/mcp" mcpProcess "github.com/yaoapp/gou/mcp/process" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/agent/output/message" agentsandbox "github.com/yaoapp/yao/agent/sandbox" "github.com/yaoapp/yao/config" infraSandbox "github.com/yaoapp/yao/sandbox" "github.com/yaoapp/yao/sandbox/ipc" traceTypes "github.com/yaoapp/yao/trace/types" ) var ( sandboxManager *infraSandbox.Manager sandboxManagerOnce sync.Once sandboxManagerErr error ) // GetSandboxManager returns the sandbox manager singleton // Returns nil and error if sandbox is not configured or Docker is unavailable func GetSandboxManager() (*infraSandbox.Manager, error) { sandboxManagerOnce.Do(func() { // Create sandbox config from Yao config cfg := &infraSandbox.Config{} // Use YAO_DATA_ROOT for workspace and IPC paths dataRoot := config.Conf.DataRoot if dataRoot != "" { cfg.Init(dataRoot) } // Create manager (will fail if Docker is not available) sandboxManager, sandboxManagerErr = infraSandbox.NewManager(cfg) }) return sandboxManager, sandboxManagerErr } // HasSandbox returns true if the assistant has sandbox configuration func (ast *Assistant) HasSandbox() bool { return ast.Sandbox != nil && ast.Sandbox.Command != "" } // initSandbox initializes the sandbox executor // Returns the full Executor (for LLM calls), cleanup function, and any error // This is called BEFORE hooks so that hooks can access ctx.sandbox // The executor implements both agentsandbox.Executor and context.SandboxExecutor interfaces func (ast *Assistant) initSandbox(ctx *context.Context, opts *context.Options) (agentsandbox.Executor, func(), string, error) { // Get sandbox manager (singleton) manager, err := GetSandboxManager() if err != nil { ctx.Logger.Error("Sandbox manager initialization failed: %v", err) return nil, nil, "", fmt.Errorf("sandbox manager not available: %w", err) } if manager == nil { return nil, nil, "", fmt.Errorf("sandbox manager not initialized") } // Build executor options from assistant config execOpts, err := ast.buildSandboxOptions(ctx, opts) if err != nil { ctx.Logger.Error("Failed to build sandbox options: %v", err) return nil, nil, "", fmt.Errorf("failed to build sandbox options: %w", err) } // Log sandbox creation ctx.Logger.Info("Creating sandbox container for command: %s", ast.Sandbox.Command) // Add trace for sandbox creation trace, traceErr := ctx.Trace() if traceErr == nil && trace != nil { trace.Info("Creating sandbox container...") } // Send loading message to user loadingMsg := &message.Message{ Type: message.TypeLoading, Props: map[string]interface{}{ "message": i18n.T(ctx.Locale, "sandbox.preparing"), }, } loadingMsgID, _ := ctx.SendStream(loadingMsg) // Create executor (container starts here) executor, err := agentsandbox.New(manager, execOpts) if err != nil { ctx.Logger.Error("Sandbox creation failed: %v", err) if traceErr == nil && trace != nil { trace.Error("Sandbox creation failed: %v", err) } // End loading message with done:true if loadingMsgID != "" { doneMsg := &message.Message{ MessageID: loadingMsgID, Delta: true, DeltaAction: message.DeltaReplace, Type: message.TypeLoading, Props: map[string]interface{}{ "message": i18n.T(ctx.Locale, "sandbox.failed"), "done": true, }, } ctx.Send(doneMsg) } return nil, nil, "", fmt.Errorf("failed to create sandbox executor: %w", err) } // Log sandbox ready ctx.Logger.Info("Sandbox container ready") if traceErr == nil && trace != nil { trace.Info("Sandbox container ready") } // Return cleanup function cleanup := func() { if err := executor.Close(); err != nil { ctx.Logger.Error("Failed to close sandbox executor: %v", err) } } // Keep loadingMsgID open - it will be closed when first output is received // This provides better UX: user sees "Preparing..." until actual content appears return executor, cleanup, loadingMsgID, nil } // executeSandboxStream executes the request using sandbox (Claude CLI, etc.) // This is called when ast.Sandbox is configured // NOTE: The executor is passed directly from initSandbox, no type assertion needed func (ast *Assistant) executeSandboxStream( ctx *context.Context, completionMessages []context.Message, agentNode traceTypes.Node, streamHandler message.StreamFunc, executor agentsandbox.Executor, loadingMsgID string, ) (*context.CompletionResponse, error) { // Mark the agentNode as used to avoid unused variable error _ = agentNode if executor == nil { return nil, fmt.Errorf("sandbox executor not initialized (call initSandbox first)") } // Log sandbox execution ctx.Logger.Info("Executing via sandbox (command: %s)", ast.Sandbox.Command) // Pass the "preparing sandbox" loading message ID to executor // It will be closed when first output (text or tool) is received if loadingMsgID != "" { executor.SetLoadingMsgID(loadingMsgID) } // Execute LLM call via sandbox // The loadingMsgID will be closed when first output is received // Tool calls will create their own loading messages below the text resp, err := executor.Stream(ctx, completionMessages, streamHandler) if err != nil { // Close loading message on error if loadingMsgID != "" { doneMsg := &message.Message{ MessageID: loadingMsgID, Delta: true, DeltaAction: message.DeltaReplace, Type: message.TypeLoading, Props: map[string]interface{}{ "message": i18n.T(ctx.Locale, "sandbox.failed"), "done": true, }, } ctx.Send(doneMsg) } // Send error message to client errMsg := &message.Message{ Type: message.TypeError, Props: map[string]interface{}{ "message": err.Error(), }, } ctx.Send(errMsg) return nil, fmt.Errorf("sandbox execution failed: %w", err) } return resp, nil } // buildSandboxOptions builds executor options from assistant config func (ast *Assistant) buildSandboxOptions(ctx *context.Context, opts *context.Options) (*agentsandbox.Options, error) { if ast.Sandbox == nil { return nil, fmt.Errorf("sandbox configuration is required") } execOpts := &agentsandbox.Options{ Command: ast.Sandbox.Command, Image: ast.Sandbox.Image, MaxMemory: ast.Sandbox.MaxMemory, MaxCPU: ast.Sandbox.MaxCPU, Arguments: ast.Sandbox.Arguments, } // Parse timeout string (e.g., "10m") to duration if ast.Sandbox.Timeout != "" { timeout, err := time.ParseDuration(ast.Sandbox.Timeout) if err != nil { return nil, fmt.Errorf("invalid timeout format: %w", err) } execOpts.Timeout = timeout } // Set user and chat IDs for workspace isolation if ctx.Authorized != nil && ctx.Authorized.UserID != "" { execOpts.UserID = ctx.Authorized.UserID } else { execOpts.UserID = "anonymous" } execOpts.ChatID = ctx.ChatID // Set skills directory (auto-resolved from assistant path) // Only set if the directory actually exists if ast.Path != "" { appRoot := config.Conf.AppSource skillsDir := filepath.Join(appRoot, ast.Path, "skills") if info, err := os.Stat(skillsDir); err == nil && info.IsDir() { execOpts.SkillsDir = skillsDir ctx.Logger.Debug("Skills directory found: %s", skillsDir) } } // Check if assistant has prompts (from prompts.yml) // If prompts are configured, we need to call Claude CLI if len(ast.Prompts) > 0 { // Extract system prompt from prompts for _, prompt := range ast.Prompts { if prompt.Role == "system" && prompt.Content != "" { execOpts.SystemPrompt = prompt.Content break } } } // Resolve connector settings conn, _, err := ast.GetConnector(ctx, opts) if err != nil { return nil, fmt.Errorf("failed to get connector: %w", err) } // Determine connector type for sandbox proxy behavior // Anthropic connectors bypass the proxy (Claude CLI connects directly) if conn.Is(connector.ANTHROPIC) { execOpts.ConnectorType = "anthropic" } else { execOpts.ConnectorType = "openai" } setting := conn.Setting() if host, ok := setting["host"].(string); ok { execOpts.ConnectorHost = host } if key, ok := setting["key"].(string); ok { execOpts.ConnectorKey = key } if model, ok := setting["model"].(string); ok { execOpts.Model = model } // Extract extra connector options (thinking, max_tokens, temperature, etc.) // These are backend-specific parameters that need to be passed through to the proxy connectorOptions := make(map[string]interface{}) for k, v := range setting { // Skip standard fields that are already handled switch k { case "host", "key", "model", "azure", "capabilities": continue default: // Include all other fields as extra options connectorOptions[k] = v } } if len(connectorOptions) > 0 { execOpts.ConnectorOptions = connectorOptions ctx.Logger.Debug("Connector options extracted: %v", connectorOptions) } // Extract secrets from sandbox config (e.g., GITHUB_TOKEN: "$ENV.GITHUB_TOKEN") if ast.Sandbox != nil && len(ast.Sandbox.Secrets) > 0 { secrets := make(map[string]string) for k, v := range ast.Sandbox.Secrets { // Resolve $ENV.XXX references resolved := resolveEnvValue(v) if resolved != "" { secrets[k] = resolved } } if len(secrets) > 0 { execOpts.Secrets = secrets ctx.Logger.Debug("Secrets extracted: %d items", len(secrets)) } } // Build MCP config and load tools if the assistant has MCP servers configured if ast.MCP != nil && len(ast.MCP.Servers) > 0 { // Build MCP config for Claude CLI mcpConfig, err := ast.BuildMCPConfigForSandbox(ctx) if err != nil { ctx.Logger.Warn("Failed to build MCP config for sandbox: %v", err) // Non-fatal: sandbox can work without MCP } else { execOpts.MCPConfig = mcpConfig ctx.Logger.Debug("MCP config built for sandbox (%d bytes)", len(mcpConfig)) } // Load MCP tools for IPC session mcpTools, err := ast.loadMCPToolsForIPC(ctx) if err != nil { ctx.Logger.Warn("Failed to load MCP tools for IPC: %v", err) // Non-fatal: IPC will have no tools } else if len(mcpTools) > 0 { execOpts.MCPTools = mcpTools ctx.Logger.Debug("Loaded %d MCP tools for IPC", len(mcpTools)) } } return execOpts, nil } // loadMCPToolsForIPC loads MCP tools from configured servers and converts them to IPC format func (ast *Assistant) loadMCPToolsForIPC(ctx *context.Context) (map[string]*ipc.MCPTool, error) { if ast.MCP == nil || len(ast.MCP.Servers) == 0 { return nil, nil } tools := make(map[string]*ipc.MCPTool) stdCtx := ctx.Context if stdCtx == nil { stdCtx = stdContext.Background() } for _, serverConfig := range ast.MCP.Servers { if serverConfig.ServerID == "" { continue } // Get MCP client client, err := gouMCP.Select(serverConfig.ServerID) if err != nil { ctx.Logger.Warn("MCP server '%s' not found: %v", serverConfig.ServerID, err) continue } // List tools from the MCP client toolsResp, err := client.ListTools(stdCtx, "") if err != nil { ctx.Logger.Warn("Failed to list tools from MCP server '%s': %v", serverConfig.ServerID, err) continue } // Get tool mapping for process names mapping, ok := mcpProcess.GetMapping(serverConfig.ServerID) if !ok { ctx.Logger.Warn("No mapping found for MCP server '%s'", serverConfig.ServerID) continue } // Filter tools if specified in config toolFilter := make(map[string]bool) if len(serverConfig.Tools) > 0 { for _, t := range serverConfig.Tools { toolFilter[t] = true } } // Convert tools to IPC format // Tool names are prefixed with server ID to avoid conflicts // e.g., "echo" server's "ping" tool becomes "echo__ping" for _, tool := range toolsResp.Tools { // Apply tool filter if specified if len(toolFilter) > 0 && !toolFilter[tool.Name] { continue } // Find the process name from mapping processName := "" if toolSchema, ok := mapping.Tools[tool.Name]; ok { processName = toolSchema.Process } if processName == "" { ctx.Logger.Warn("No process mapping for tool '%s' in server '%s'", tool.Name, serverConfig.ServerID) continue } // Prefixed tool name: serverID__toolName // This matches Claude's MCP naming: mcp__yao__serverID__toolName prefixedName := serverConfig.ServerID + "__" + tool.Name // Create IPC tool entry with prefixed name ipcTool := &ipc.MCPTool{ Name: prefixedName, Description: tool.Description, Process: processName, InputSchema: tool.InputSchema, } tools[prefixedName] = ipcTool } } return tools, nil } // BuildMCPConfigForSandbox builds the MCP configuration JSON for sandbox // This creates a .mcp.json format that Claude CLI can understand // Exported for testing func (ast *Assistant) BuildMCPConfigForSandbox(ctx *context.Context) ([]byte, error) { if ast.MCP == nil || len(ast.MCP.Servers) == 0 { return nil, nil } // Build MCP config in Claude CLI format // Claude CLI expects: { "mcpServers": { "server_id": { "command": "...", "args": [...] } } } // // For Yao's MCP servers, we use yao-bridge to connect to the IPC socket. // yao-bridge bridges stdio to Unix socket, allowing Claude CLI to communicate // with Yao's IPC server running on the host. // // Architecture: // Claude CLI → yao-bridge → Unix Socket → IPC Session → Yao Process config := map[string]interface{}{ "mcpServers": map[string]interface{}{ // Single "yao" server that handles all MCP tools via IPC "yao": map[string]interface{}{ "command": "yao-bridge", "args": []string{"/tmp/yao.sock"}, // ContainerIPCSocket from sandbox config }, }, } return json.Marshal(config) } // resolveEnvValue resolves environment variable references in a string // Supports format: $ENV.VAR_NAME or plain value // Returns empty string if the variable is not set func resolveEnvValue(value string) string { if value == "" { return "" } // Check for $ENV.XXX format if len(value) > 5 && value[:5] == "$ENV." { envName := value[5:] return os.Getenv(envName) } // Return as-is if not an env reference return value } ================================================ FILE: agent/assistant/sandbox_debug_test.go ================================================ package assistant_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/testutils" ) // TestSandboxDebugHasSandbox tests the HasSandbox method directly func TestSandboxDebugHasSandbox(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) testCases := []struct { name string assistantID string expectTrue bool }{ {"BasicSandbox", "tests.sandbox.basic", true}, {"HooksSandbox", "tests.sandbox.hooks", true}, {"FullSandbox", "tests.sandbox.full", true}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { ast, err := assistant.Get(tc.assistantID) require.NoError(t, err, "Failed to get assistant %s", tc.assistantID) // Check Sandbox struct t.Logf("Assistant ID: %s", ast.ID) t.Logf("Sandbox: %+v", ast.Sandbox) if ast.Sandbox != nil { t.Logf("Sandbox.Command: %q", ast.Sandbox.Command) t.Logf("Sandbox.Timeout: %s", ast.Sandbox.Timeout) t.Logf("Sandbox.Image: %s", ast.Sandbox.Image) t.Logf("Sandbox.Arguments: %v", ast.Sandbox.Arguments) } // Check HasSandbox hasSandbox := ast.HasSandbox() t.Logf("HasSandbox() = %v", hasSandbox) if tc.expectTrue { assert.True(t, hasSandbox, "Expected HasSandbox() to be true for %s", tc.assistantID) } else { assert.False(t, hasSandbox, "Expected HasSandbox() to be false for %s", tc.assistantID) } }) } } // TestSandboxDebugPrompts tests if Prompts is set (affects execution path) func TestSandboxDebugPrompts(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.Get("tests.sandbox.basic") require.NoError(t, err) t.Logf("Assistant ID: %s", ast.ID) t.Logf("Prompts: %v", ast.Prompts) t.Logf("MCP: %v", ast.MCP) t.Logf("HasSandbox: %v", ast.HasSandbox()) // The condition in agent.go is: // if ast.Prompts != nil || ast.MCP != nil { // // ... execute LLM // if ast.HasSandbox() { // // sandbox path // } else { // // direct LLM path // } // } // So we need Prompts or MCP to be non-nil if ast.Prompts == nil && ast.MCP == nil { t.Log("WARNING: Neither Prompts nor MCP is set, LLM phase will be skipped entirely!") } } ================================================ FILE: agent/assistant/sandbox_e2e_test.go ================================================ package assistant_test import ( stdContext "context" "fmt" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // newSandboxE2EContext creates a Context for sandbox E2E testing // Uses unique chatID to avoid container name conflicts func newSandboxE2EContext(chatIDPrefix, assistantID string) *context.Context { // Generate unique chatID using timestamp to avoid container conflicts chatID := fmt.Sprintf("%s-%d", chatIDPrefix, time.Now().UnixNano()) authorized := &types.AuthorizedInfo{ Subject: "sandbox-e2e-test-user", ClientID: "sandbox-e2e-test-client", Scope: "openid profile", SessionID: "sandbox-e2e-test-session", UserID: "sandbox-user-123", TeamID: "sandbox-team-456", TenantID: "sandbox-tenant-789", } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "SandboxE2ETest/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.Route = "" ctx.Metadata = make(map[string]interface{}) return ctx } // TestSandboxBasicE2E tests the basic sandbox assistant end-to-end // This test verifies that: // 1. Sandbox is correctly initialized // 2. Claude CLI command is built correctly // 3. Docker container is created and managed func TestSandboxBasicE2E(t *testing.T) { if testing.Short() { t.Skip("Skipping sandbox E2E test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Load the basic sandbox assistant ast, err := assistant.Get("tests.sandbox.basic") if err != nil { t.Skipf("Skipping test: sandbox assistant not available: %v", err) } // Verify sandbox is configured require.NotNil(t, ast.Sandbox, "Sandbox should be configured") assert.Equal(t, "claude", ast.Sandbox.Command) t.Logf("✓ Sandbox configured with command: %s", ast.Sandbox.Command) // Create context ctx := newSandboxE2EContext("sandbox-basic-e2e", "tests.sandbox.basic") // Test messages messages := []context.Message{ {Role: context.RoleUser, Content: "echo hello sandbox"}, } // Execute stream // Note: This will fail if Docker/Claude image is not available, which is expected in CI response, err := ast.Stream(ctx, messages) if err != nil { // Check if it's a Docker/sandbox availability issue errStr := err.Error() if strings.Contains(errStr, "Docker") || strings.Contains(errStr, "sandbox") || strings.Contains(errStr, "container") || strings.Contains(errStr, "image") { t.Skipf("Skipping test: Docker/sandbox not available: %v", err) } t.Fatalf("Stream failed: %v", err) } // Verify response require.NotNil(t, response, "Response should not be nil") // Verify response completion (Claude CLI should return some response) if response.Completion != nil && response.Completion.Content != nil { if contentStr, ok := response.Completion.Content.(string); ok && contentStr != "" { t.Logf("✓ Response content: %s", truncateString(contentStr, 200)) } else { t.Logf("⚠ Response content type: %T", response.Completion.Content) } } else { t.Log("⚠ Response content is empty (might be expected for some commands)") } t.Log("✓ Basic sandbox E2E test passed") } // truncateString truncates a string to maxLen and adds "..." if truncated func truncateString(s string, maxLen int) string { if len(s) <= maxLen { return s } return s[:maxLen] + "..." } // TestSandboxHooksE2E tests the sandbox assistant with hooks func TestSandboxHooksE2E(t *testing.T) { if testing.Short() { t.Skip("Skipping sandbox E2E test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Load the hooks sandbox assistant ast, err := assistant.Get("tests.sandbox.hooks") if err != nil { t.Skipf("Skipping test: sandbox hooks assistant not available: %v", err) } // Verify sandbox and hooks are configured require.NotNil(t, ast.Sandbox, "Sandbox should be configured") require.NotNil(t, ast.HookScript, "HookScript should be loaded") t.Logf("✓ Sandbox and hooks configured") // Create context ctx := newSandboxE2EContext("sandbox-hooks-e2e", "tests.sandbox.hooks") // Test messages messages := []context.Message{ {Role: context.RoleUser, Content: "test hooks integration"}, } // Execute stream response, err := ast.Stream(ctx, messages) if err != nil { errStr := err.Error() if strings.Contains(errStr, "Docker") || strings.Contains(errStr, "sandbox") || strings.Contains(errStr, "container") || strings.Contains(errStr, "image") { t.Skipf("Skipping test: Docker/sandbox not available: %v", err) } t.Fatalf("Stream failed: %v", err) } require.NotNil(t, response, "Response should not be nil") t.Log("✓ Sandbox hooks E2E test passed") } // TestSandboxFullE2E tests the full sandbox assistant with MCPs and Skills func TestSandboxFullE2E(t *testing.T) { if testing.Short() { t.Skip("Skipping sandbox E2E test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Load the full sandbox assistant ast, err := assistant.Get("tests.sandbox.full") if err != nil { t.Skipf("Skipping test: full sandbox assistant not available: %v", err) } // Verify all components are configured require.NotNil(t, ast.Sandbox, "Sandbox should be configured") require.NotNil(t, ast.MCP, "MCP should be configured") require.NotNil(t, ast.HookScript, "HookScript should be loaded") t.Logf("✓ Full sandbox configured: command=%s, MCP servers=%d", ast.Sandbox.Command, len(ast.MCP.Servers)) // Verify MCP configuration assert.Len(t, ast.MCP.Servers, 1) assert.Equal(t, "echo", ast.MCP.Servers[0].ServerID) t.Logf("✓ MCP server: %s with tools %v", ast.MCP.Servers[0].ServerID, ast.MCP.Servers[0].Tools) // Create context ctx := newSandboxE2EContext("sandbox-full-e2e", "tests.sandbox.full") // Test messages messages := []context.Message{ {Role: context.RoleUser, Content: "test full sandbox with MCP and skills"}, } // Execute stream response, err := ast.Stream(ctx, messages) if err != nil { errStr := err.Error() if strings.Contains(errStr, "Docker") || strings.Contains(errStr, "sandbox") || strings.Contains(errStr, "container") || strings.Contains(errStr, "image") { t.Skipf("Skipping test: Docker/sandbox not available: %v", err) } t.Fatalf("Stream failed: %v", err) } require.NotNil(t, response, "Response should not be nil") t.Log("✓ Full sandbox E2E test passed") } // TestSandboxContextAccess tests that sandbox is accessible in hooks via ctx.sandbox func TestSandboxContextAccess(t *testing.T) { if testing.Short() { t.Skip("Skipping sandbox context access test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Load the hooks sandbox assistant ast, err := assistant.Get("tests.sandbox.hooks") if err != nil { t.Skipf("Skipping test: sandbox hooks assistant not available: %v", err) } require.NotNil(t, ast.HookScript, "HookScript should be loaded") // Create context ctx := newSandboxE2EContext("sandbox-ctx-access", "tests.sandbox.hooks") // Test Create Hook - it should have access to ctx.sandbox messages := []context.Message{ {Role: context.RoleUser, Content: "test sandbox context access"}, } // Execute Create hook directly // This tests that the hook runs without error (sandbox operations tested within) opts := &context.Options{} response, _, err := ast.HookScript.Create(ctx, messages, opts) // The hook might fail if sandbox isn't initialized yet (that's done in Stream) // But we can at least verify the hook exists and can be called if err != nil { // If the error is about sandbox not being available, that's expected // because we haven't initialized the sandbox yet if strings.Contains(err.Error(), "sandbox") { t.Logf("Expected error: sandbox not available in direct hook call: %v", err) } else { t.Fatalf("Unexpected error: %v", err) } } // Response might be nil, that's okay t.Logf("Create hook response: %v", response) t.Log("✓ Sandbox context access test passed") } // TestSandboxMCPToolCall tests that Claude actually calls MCP tools via IPC // This test specifically asks Claude to use the echo tool and verifies the result func TestSandboxMCPToolCall(t *testing.T) { if testing.Short() { t.Skip("Skipping sandbox MCP tool call test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Load the full sandbox assistant (has MCP echo tool) ast, err := assistant.Get("tests.sandbox.full") if err != nil { t.Skipf("Skipping test: full sandbox assistant not available: %v", err) } // Verify MCP is configured with echo tools require.NotNil(t, ast.MCP, "MCP should be configured") require.NotEmpty(t, ast.MCP.Servers, "MCP servers should be configured") t.Logf("✓ MCP configured with server: %s, tools: %v", ast.MCP.Servers[0].ServerID, ast.MCP.Servers[0].Tools) // Create context ctx := newSandboxE2EContext("sandbox-mcp-tool", "tests.sandbox.full") // Explicit prompt to use echo tool // This tells Claude to use the MCP tool specifically messages := []context.Message{ { Role: context.RoleUser, Content: `Please use the 'ping' MCP tool to send a ping with message "MCP_TEST_SUCCESS". Just call the tool and show me the result. Do not explain, just use the tool.`, }, } // Collect all response content var responseContent strings.Builder // Execute stream response, err := ast.Stream(ctx, messages) if err != nil { errStr := err.Error() if strings.Contains(errStr, "Docker") || strings.Contains(errStr, "sandbox") || strings.Contains(errStr, "container") || strings.Contains(errStr, "image") { t.Skipf("Skipping test: Docker/sandbox not available: %v", err) } t.Fatalf("Stream failed: %v", err) } require.NotNil(t, response, "Response should not be nil") // Get the response content fullResponse := "" if response.Completion != nil && response.Completion.Content != nil { if contentStr, ok := response.Completion.Content.(string); ok { fullResponse = contentStr responseContent.WriteString(contentStr) } } t.Logf("Claude response: %s", fullResponse) // Check if Claude acknowledged using the tool or returned tool results // The response should contain either: // 1. Evidence of tool call (tool_use block in response) // 2. The ping result "pong" or "MCP_TEST_SUCCESS" // 3. Some indication that it attempted to use the MCP tool hasToolEvidence := strings.Contains(fullResponse, "pong") || strings.Contains(fullResponse, "MCP_TEST_SUCCESS") || strings.Contains(fullResponse, "ping") || strings.Contains(fullResponse, "tool") if hasToolEvidence { t.Log("✓ Claude appears to have used the MCP tool") } else { t.Logf("⚠ Claude response does not clearly show MCP tool usage") t.Logf("Response: %s", fullResponse) } // At minimum, verify we got a response if fullResponse == "" { t.Log("⚠ Response content is empty") } t.Log("✓ Sandbox MCP tool call test completed") } // TestSandboxMCPEchoTool tests the echo MCP tool specifically // This test uses a more explicit prompt to force tool usage func TestSandboxMCPEchoTool(t *testing.T) { if testing.Short() { t.Skip("Skipping sandbox MCP echo test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Load the full sandbox assistant ast, err := assistant.Get("tests.sandbox.full") if err != nil { t.Skipf("Skipping test: full sandbox assistant not available: %v", err) } // Create context ctx := newSandboxE2EContext("sandbox-mcp-echo", "tests.sandbox.full") // Very explicit prompt for echo tool messages := []context.Message{ { Role: context.RoleUser, Content: `Call the 'echo' MCP tool with message "ECHO_VERIFICATION_12345" and uppercase=true. Show me the exact response from the tool.`, }, } response, err := ast.Stream(ctx, messages) if err != nil { errStr := err.Error() if strings.Contains(errStr, "Docker") || strings.Contains(errStr, "sandbox") || strings.Contains(errStr, "container") || strings.Contains(errStr, "image") { t.Skipf("Skipping test: Docker/sandbox not available: %v", err) } t.Fatalf("Stream failed: %v", err) } require.NotNil(t, response) fullResponse := "" if response.Completion != nil && response.Completion.Content != nil { if contentStr, ok := response.Completion.Content.(string); ok { fullResponse = contentStr } } t.Logf("Claude response for echo tool: %s", fullResponse) // The echo tool with uppercase=true should return "ECHO_VERIFICATION_12345" // Check if this appears in the response if strings.Contains(fullResponse, "ECHO_VERIFICATION_12345") { t.Log("✓ MCP echo tool executed successfully - found verification string in response") } else if strings.Contains(fullResponse, "echo") || strings.Contains(fullResponse, "ECHO") { t.Log("✓ MCP echo tool appears to have been used (found 'echo' in response)") } else { t.Logf("⚠ Could not verify echo tool execution. Response: %s", fullResponse) } t.Log("✓ Sandbox MCP echo tool test completed") } // TestSandboxLoadConfiguration verifies that sandbox assistants load correctly func TestSandboxLoadConfiguration(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) testCases := []struct { name string assistantID string expectSandbox bool expectMCP bool expectHooks bool }{ { name: "BasicSandbox", assistantID: "tests.sandbox.basic", expectSandbox: true, expectMCP: false, expectHooks: false, }, { name: "HooksSandbox", assistantID: "tests.sandbox.hooks", expectSandbox: true, expectMCP: false, expectHooks: true, }, { name: "FullSandbox", assistantID: "tests.sandbox.full", expectSandbox: true, expectMCP: true, expectHooks: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { ast, err := assistant.Get(tc.assistantID) if err != nil { t.Skipf("Skipping: assistant %s not available: %v", tc.assistantID, err) } // Check sandbox if tc.expectSandbox { require.NotNil(t, ast.Sandbox, "Expected sandbox to be configured") assert.Equal(t, "claude", ast.Sandbox.Command) t.Logf("✓ %s: Sandbox configured with command=%s", tc.name, ast.Sandbox.Command) } // Check MCP if tc.expectMCP { require.NotNil(t, ast.MCP, "Expected MCP to be configured") assert.True(t, len(ast.MCP.Servers) > 0, "Expected at least one MCP server") t.Logf("✓ %s: MCP configured with %d servers", tc.name, len(ast.MCP.Servers)) } // Check hooks if tc.expectHooks { require.NotNil(t, ast.HookScript, "Expected hooks to be loaded") t.Logf("✓ %s: Hooks loaded", tc.name) } }) } } ================================================ FILE: agent/assistant/sandbox_integration_test.go ================================================ package assistant_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent" "github.com/yaoapp/yao/agent/assistant" agentContext "github.com/yaoapp/yao/agent/context" agentsandbox "github.com/yaoapp/yao/agent/sandbox" "github.com/yaoapp/yao/agent/sandbox/claude" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // TestSandboxOptionsBuilding tests that sandbox options are correctly built from assistant config func TestSandboxOptionsBuilding(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agent to ensure connectors are available err := agent.Load(config.Conf) require.NoError(t, err, "agent.Load should succeed") // Load the full test assistant ast, err := assistant.LoadPath("/assistants/tests/sandbox/full") require.NoError(t, err) require.NotNil(t, ast) // Verify sandbox is configured require.NotNil(t, ast.Sandbox) assert.Equal(t, "claude", ast.Sandbox.Command) assert.Equal(t, "5m", ast.Sandbox.Timeout) // Verify arguments are set require.NotNil(t, ast.Sandbox.Arguments) assert.Equal(t, float64(10), ast.Sandbox.Arguments["max_turns"]) assert.Equal(t, "acceptEdits", ast.Sandbox.Arguments["permission_mode"]) // Verify MCP configuration require.NotNil(t, ast.MCP) assert.Len(t, ast.MCP.Servers, 1) assert.Equal(t, "echo", ast.MCP.Servers[0].ServerID) t.Logf("Sandbox config: command=%s, timeout=%s", ast.Sandbox.Command, ast.Sandbox.Timeout) t.Logf("Sandbox arguments: %v", ast.Sandbox.Arguments) t.Logf("MCP servers: %v", ast.MCP.Servers) } // TestClaudeCommandBuilding tests that Claude CLI commands are correctly built func TestClaudeCommandBuilding(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create test messages messages := []agentContext.Message{ {Role: "system", Content: "You are a helpful coding assistant."}, {Role: "user", Content: "Hello, how are you?"}, } // Create options similar to what buildSandboxOptions would produce opts := &claude.Options{ Command: "claude", UserID: "test-user", ChatID: "test-chat", ConnectorHost: "https://ark.cn-beijing.volces.com/api/v3", ConnectorKey: "test-api-key", Model: "ep-xxxxx", Arguments: map[string]interface{}{ "max_turns": 10, "permission_mode": "acceptEdits", }, } // Build the command cmd, env, err := claude.BuildCommand(messages, opts) require.NoError(t, err) // Verify command structure // Command is now: ["bash", "-c", "cat << 'INPUTEOF' | claude -p ... INPUTEOF"] assert.NotEmpty(t, cmd) assert.Equal(t, "bash", cmd[0], "Command should start with bash") assert.Equal(t, "-c", cmd[1], "Second arg should be -c") assert.Contains(t, cmd[2], "claude -p", "Bash command should contain claude -p") assert.Contains(t, cmd[2], "--permission-mode", "Should include permission mode") assert.Contains(t, cmd[2], "--input-format", "Should include input-format flag") assert.Contains(t, cmd[2], "--output-format", "Should include output-format flag") assert.Contains(t, cmd[2], "--verbose", "Should include verbose flag") assert.Contains(t, cmd[2], "stream-json", "Should use stream-json format") assert.Contains(t, cmd[2], "INPUTEOF", "Should use heredoc for input") t.Logf("Built command: %v", cmd) // Verify environment variables (claude-proxy mode) assert.NotEmpty(t, env) assert.Equal(t, "http://127.0.0.1:3456", env["ANTHROPIC_BASE_URL"], "Should set proxy base URL") assert.Equal(t, "dummy", env["ANTHROPIC_API_KEY"], "Should set dummy API key for proxy") // max_turns is passed via CLI flag // system prompt is written to file via heredoc, then referenced via --append-system-prompt-file assert.Contains(t, cmd[2], "--max-turns", "Should include max-turns flag") assert.Contains(t, cmd[2], "cat << 'PROMPTEOF' > /tmp/.system-prompt.txt", "Should use heredoc for system prompt") assert.Contains(t, cmd[2], "--append-system-prompt-file", "Should include append-system-prompt-file flag") assert.Contains(t, cmd[2], "You are a helpful coding assistant", "Command should contain system prompt") t.Logf("Built environment: %v", env) } // TestClaudeProxyConfigBuilding tests that claude-proxy config is correctly built func TestClaudeProxyConfigBuilding(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() opts := &claude.Options{ ConnectorHost: "https://ark.cn-beijing.volces.com/api/v3", ConnectorKey: "test-api-key", Model: "ep-xxxxx", } configJSON, err := claude.BuildProxyConfig(opts) require.NoError(t, err) require.NotEmpty(t, configJSON) t.Logf("Proxy config: %s", string(configJSON)) // Verify the JSON contains expected fields for claude-proxy assert.Contains(t, string(configJSON), "backend") assert.Contains(t, string(configJSON), "api_key") assert.Contains(t, string(configJSON), "model") assert.Contains(t, string(configJSON), "test-api-key") assert.Contains(t, string(configJSON), "ep-xxxxx") // Backend URL should end with /chat/completions assert.Contains(t, string(configJSON), "/chat/completions") } // TestDefaultImageSelection tests that default images are correctly selected func TestDefaultImageSelection(t *testing.T) { tests := []struct { command string expectedImage string }{ {"claude", "yaoapp/sandbox-claude:latest"}, {"cursor", "yaoapp/sandbox-cursor:latest"}, {"unknown", ""}, } for _, tt := range tests { t.Run(tt.command, func(t *testing.T) { image := agentsandbox.DefaultImage(tt.command) assert.Equal(t, tt.expectedImage, image) }) } } // TestSandboxCommandValidation tests that command validation works correctly func TestSandboxCommandValidation(t *testing.T) { tests := []struct { command string valid bool }{ {"claude", true}, {"cursor", true}, {"invalid", false}, {"", false}, } for _, tt := range tests { t.Run(tt.command, func(t *testing.T) { result := agentsandbox.IsValidCommand(tt.command) assert.Equal(t, tt.valid, result) }) } } // TestHasSandboxMethod tests the HasSandbox method on Assistant func TestHasSandboxMethod(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Test assistant with sandbox astWithSandbox, err := assistant.LoadPath("/assistants/tests/sandbox/basic") require.NoError(t, err) assert.True(t, astWithSandbox.HasSandbox(), "Assistant with sandbox config should return true") // Test assistant without sandbox astWithoutSandbox, err := assistant.LoadPath("/assistants/tests/simple-greeting") require.NoError(t, err) assert.False(t, astWithoutSandbox.HasSandbox(), "Assistant without sandbox config should return false") } ================================================ FILE: agent/assistant/sandbox_test.go ================================================ package assistant_test import ( "context" "encoding/json" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent" "github.com/yaoapp/yao/agent/assistant" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // TestLoadSandboxBasicAssistant tests loading the basic sandbox test assistant func TestLoadSandboxBasicAssistant(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ast, err := assistant.LoadPath("/assistants/tests/sandbox/basic") require.NoError(t, err) require.NotNil(t, ast) // Verify basic fields assert.Equal(t, "tests.sandbox.basic", ast.ID) assert.Equal(t, "Sandbox Basic Test", ast.Name) assert.Equal(t, "deepseek.v3", ast.Connector) // Verify sandbox configuration require.NotNil(t, ast.Sandbox, "Sandbox should be configured") assert.Equal(t, "claude", ast.Sandbox.Command) assert.Equal(t, "5m", ast.Sandbox.Timeout) // Verify HasSandbox returns true assert.True(t, ast.HasSandbox(), "HasSandbox should return true") } // TestLoadSandboxHooksAssistant tests loading the hooks sandbox test assistant func TestLoadSandboxHooksAssistant(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ast, err := assistant.LoadPath("/assistants/tests/sandbox/hooks") require.NoError(t, err) require.NotNil(t, ast) // Verify basic fields assert.Equal(t, "tests.sandbox.hooks", ast.ID) assert.Equal(t, "Sandbox Hooks Test", ast.Name) assert.Equal(t, "deepseek.v3", ast.Connector) // Verify sandbox configuration require.NotNil(t, ast.Sandbox, "Sandbox should be configured") assert.Equal(t, "claude", ast.Sandbox.Command) // Verify hooks are loaded assert.NotNil(t, ast.HookScript, "HookScript should be loaded") } // TestLoadSandboxFullAssistant tests loading the full sandbox test assistant with MCPs and Skills func TestLoadSandboxFullAssistant(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agent to ensure MCPs are available err := agent.Load(config.Conf) require.NoError(t, err, "agent.Load should succeed") ast, err := assistant.LoadPath("/assistants/tests/sandbox/full") require.NoError(t, err) require.NotNil(t, ast) // Verify basic fields assert.Equal(t, "tests.sandbox.full", ast.ID) assert.Equal(t, "Sandbox Full Test", ast.Name) assert.Equal(t, "deepseek.v3", ast.Connector) // Verify sandbox configuration require.NotNil(t, ast.Sandbox, "Sandbox should be configured") assert.Equal(t, "claude", ast.Sandbox.Command) assert.Equal(t, "5m", ast.Sandbox.Timeout) // Verify sandbox arguments (command-specific options) require.NotNil(t, ast.Sandbox.Arguments, "Sandbox arguments should be configured") assert.Equal(t, float64(10), ast.Sandbox.Arguments["max_turns"]) assert.Equal(t, "acceptEdits", ast.Sandbox.Arguments["permission_mode"]) // Verify MCP configuration require.NotNil(t, ast.MCP, "MCP should be configured") require.NotNil(t, ast.MCP.Servers, "MCP.Servers should be configured") assert.Len(t, ast.MCP.Servers, 1, "Should have 1 MCP server configured") assert.Equal(t, "echo", ast.MCP.Servers[0].ServerID, "MCP server ID should be 'echo'") assert.Contains(t, ast.MCP.Servers[0].Tools, "ping", "MCP tools should contain 'ping'") assert.Contains(t, ast.MCP.Servers[0].Tools, "echo", "MCP tools should contain 'echo'") // Verify hooks are loaded assert.NotNil(t, ast.HookScript, "HookScript should be loaded") } // TestSandboxConfigValidation tests sandbox configuration validation func TestSandboxConfigValidation(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() tests := []struct { name string path string hasError bool }{ { name: "Basic sandbox config", path: "/assistants/tests/sandbox/basic", hasError: false, }, { name: "Hooks sandbox config", path: "/assistants/tests/sandbox/hooks", hasError: false, }, { name: "Full sandbox config with MCPs", path: "/assistants/tests/sandbox/full", hasError: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ast, err := assistant.LoadPath(tt.path) if tt.hasError { assert.Error(t, err) return } require.NoError(t, err) require.NotNil(t, ast) require.NotNil(t, ast.Sandbox) assert.NotEmpty(t, ast.Sandbox.Command) }) } } // TestSkillsDirectoryResolution tests that skills directory exists and has correct structure // Note: Skills are auto-discovered from skills/ directory, not stored in AssistantModel func TestSkillsDirectoryResolution(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ast, err := assistant.LoadPath("/assistants/tests/sandbox/full") require.NoError(t, err) require.NotNil(t, ast) // Get app root from environment appRoot := os.Getenv("YAO_ROOT") require.NotEmpty(t, appRoot, "YAO_ROOT should be set") // Verify assistant path is set assert.NotEmpty(t, ast.Path, "Assistant path should be set") // Build expected skills directory path // ast.Path is like "/assistants/tests/sandbox/full" expectedSkillsDir := filepath.Join(appRoot, ast.Path, "skills") // Verify skills directory exists info, err := os.Stat(expectedSkillsDir) require.NoError(t, err, "Skills directory should exist: %s", expectedSkillsDir) assert.True(t, info.IsDir(), "Skills path should be a directory") // Verify skills directory structure entries, err := os.ReadDir(expectedSkillsDir) require.NoError(t, err, "Should be able to read skills directory") // Find echo-test skill var foundEchoTest bool for _, entry := range entries { if entry.IsDir() && entry.Name() == "echo-test" { foundEchoTest = true // Verify SKILL.md exists (required) skillMdPath := filepath.Join(expectedSkillsDir, "echo-test", "SKILL.md") _, err := os.Stat(skillMdPath) assert.NoError(t, err, "SKILL.md should exist") // Verify scripts directory exists (optional but we created it) scriptsDir := filepath.Join(expectedSkillsDir, "echo-test", "scripts") _, err = os.Stat(scriptsDir) assert.NoError(t, err, "scripts directory should exist") // Verify echo.sh exists echoShPath := filepath.Join(scriptsDir, "echo.sh") _, err = os.Stat(echoShPath) assert.NoError(t, err, "echo.sh should exist") break } } assert.True(t, foundEchoTest, "echo-test skill should exist in skills directory") } // TestMCPConfiguration tests that MCP is correctly loaded for sandbox assistant func TestMCPConfiguration(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agent to ensure MCPs are available err := agent.Load(config.Conf) require.NoError(t, err, "agent.Load should succeed") ast, err := assistant.LoadPath("/assistants/tests/sandbox/full") require.NoError(t, err) require.NotNil(t, ast) // Verify MCP configuration structure require.NotNil(t, ast.MCP, "MCP should not be nil") require.NotNil(t, ast.MCP.Servers, "MCP.Servers should not be nil") assert.Len(t, ast.MCP.Servers, 1, "Should have 1 MCP server configured") // Verify echo server configuration echoServer := ast.MCP.Servers[0] assert.Equal(t, "echo", echoServer.ServerID, "Server ID should be 'echo'") assert.Len(t, echoServer.Tools, 3, "Should have 3 tools configured") assert.Contains(t, echoServer.Tools, "ping") assert.Contains(t, echoServer.Tools, "echo") assert.Contains(t, echoServer.Tools, "status") } // TestBuildMCPConfigForSandbox tests that MCP configuration is correctly built for sandbox func TestBuildMCPConfigForSandbox(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agent to ensure MCPs are available err := agent.Load(config.Conf) require.NoError(t, err, "agent.Load should succeed") ast, err := assistant.LoadPath("/assistants/tests/sandbox/full") require.NoError(t, err) require.NotNil(t, ast) require.NotNil(t, ast.MCP, "MCP configuration should exist") // Create a mock context for the test ctx := agentContext.New(context.Background(), nil, "test-mcp-config-build") // Call BuildMCPConfigForSandbox and verify the result mcpConfig, err := ast.BuildMCPConfigForSandbox(ctx) require.NoError(t, err, "BuildMCPConfigForSandbox should not error") require.NotEmpty(t, mcpConfig, "MCP config should not be empty") t.Logf("MCP config JSON: %s", string(mcpConfig)) // Parse and verify the JSON structure var config map[string]interface{} err = json.Unmarshal(mcpConfig, &config) require.NoError(t, err, "MCP config should be valid JSON") // Verify mcpServers key exists mcpServers, ok := config["mcpServers"].(map[string]interface{}) require.True(t, ok, "mcpServers should be a map") require.NotEmpty(t, mcpServers, "mcpServers should not be empty") // Verify "yao" server exists (single server using yao-bridge for IPC) yaoServer, ok := mcpServers["yao"].(map[string]interface{}) require.True(t, ok, "yao server should exist in mcpServers") // Verify server structure - uses yao-bridge to connect to IPC socket assert.Equal(t, "yao-bridge", yaoServer["command"], "command should be yao-bridge") args, ok := yaoServer["args"].([]interface{}) require.True(t, ok, "args should be an array") require.Len(t, args, 1, "args should have 1 element") assert.Equal(t, "/tmp/yao.sock", args[0], "first arg should be IPC socket path") t.Logf("✓ MCP config verified: uses yao-bridge with IPC socket /tmp/yao.sock") } // TestSandboxMCPAndSkillsOptions tests that sandbox options include MCP and Skills func TestSandboxMCPAndSkillsOptions(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agent to ensure MCPs are available err := agent.Load(config.Conf) require.NoError(t, err, "agent.Load should succeed") ast, err := assistant.LoadPath("/assistants/tests/sandbox/full") require.NoError(t, err) require.NotNil(t, ast) // Verify sandbox configuration is present require.NotNil(t, ast.Sandbox, "Sandbox should be configured") assert.Equal(t, "claude", ast.Sandbox.Command) // Verify MCP is configured (will be passed to sandbox) require.NotNil(t, ast.MCP, "MCP should be configured") assert.Len(t, ast.MCP.Servers, 1, "Should have 1 MCP server") // Verify skills directory exists appRoot := os.Getenv("YAO_ROOT") require.NotEmpty(t, appRoot, "YAO_ROOT should be set") skillsDir := filepath.Join(appRoot, ast.Path, "skills") info, err := os.Stat(skillsDir) require.NoError(t, err, "Skills directory should exist") assert.True(t, info.IsDir(), "Skills should be a directory") // Verify echo-test skill exists echoTestDir := filepath.Join(skillsDir, "echo-test") info, err = os.Stat(echoTestDir) require.NoError(t, err, "echo-test skill should exist") assert.True(t, info.IsDir(), "echo-test should be a directory") // Verify SKILL.md exists skillMd := filepath.Join(echoTestDir, "SKILL.md") _, err = os.Stat(skillMd) require.NoError(t, err, "SKILL.md should exist") } ================================================ FILE: agent/assistant/sandbox_v2.go ================================================ package assistant import ( "fmt" "log" "os" "path/filepath" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/agent/output/message" sandboxv2 "github.com/yaoapp/yao/agent/sandbox/v2" sandboxTypes "github.com/yaoapp/yao/agent/sandbox/v2/types" "github.com/yaoapp/yao/config" infraV2 "github.com/yaoapp/yao/sandbox/v2" traceTypes "github.com/yaoapp/yao/trace/types" "github.com/yaoapp/yao/workspace" ) // HasSandboxV2 returns true if the assistant has a V2 sandbox configuration. func (ast *Assistant) HasSandboxV2() bool { return ast.SandboxV2 != nil } // initSandboxV2 initializes the V2 sandbox: obtains a Computer, gets a Runner, // runs Prepare, and returns the runner, computer, cleanup closure, loading // message ID, and any error. func (ast *Assistant) initSandboxV2(ctx *context.Context, opts *context.Options) ( sandboxTypes.Runner, infraV2.Computer, func(), string, error, ) { cfg := ast.SandboxV2 manager := infraV2.M() loadingMsg := &message.Message{ Type: message.TypeLoading, Props: map[string]any{ "message": i18n.T(ctx.Locale, "sandbox.preparing"), }, } loadingMsgID, _ := ctx.SendStream(loadingMsg) stdCtx := ctx.Context // 1. Resolve connector (before Computer so proxy env vars can be injected). conn, _, err := ast.GetConnector(ctx, opts) if err != nil && cfg.Runner.Name != "yao" { closeLoadingV2(ctx, loadingMsgID, "sandbox.failed") return nil, nil, nil, "", fmt.Errorf("get connector: %w", err) } // 2. Build human-readable DisplayName from real Agent name + Workspace name. cfg.DisplayName = buildBoxDisplayName(ctx, ast.ID, ast.Name) // 2.5. Image existence check + pull (for box mode). if cfg.Computer.Image != "" && manager != nil { nodeID, kind, _ := sandboxv2.ResolveNodeID(ctx, cfg, manager) if kind == "box" && nodeID != "" { updateLoadingV2(ctx, loadingMsgID, "sandbox.starting") exists, existsErr := manager.ImageExists(stdCtx, nodeID, cfg.Computer.Image) if existsErr != nil { log.Printf("[sandbox/v2] image exists check failed on node %s: %v", nodeID, existsErr) } if existsErr == nil && !exists { updateLoadingV2(ctx, loadingMsgID, "sandbox.pulling_image") ch, pullErr := manager.PullImage(stdCtx, nodeID, cfg.Computer.Image, infraV2.ImagePullOptions{}) if pullErr != nil { log.Printf("[sandbox/v2] image pull failed on node %s: %v (will retry in Create)", nodeID, pullErr) } else if ch != nil { for p := range ch { if p.Error != "" { log.Printf("[sandbox/v2] image pull progress error: %s", p.Error) break } } } } } } // 3. Obtain Computer (passes connector for OPENAI_PROXY_* env injection). updateLoadingV2(ctx, loadingMsgID, "sandbox.starting") computer, identifier, err := sandboxv2.GetComputer(ctx, cfg, manager, conn) if err != nil { closeLoadingV2(ctx, loadingMsgID, "sandbox.failed") return nil, nil, nil, "", fmt.Errorf("getComputer failed: %w", err) } _ = identifier // 4. Get Runner. runner, err := sandboxv2.Get(cfg.Runner.Name) if err != nil { sandboxv2.LifecycleAction(stdCtx, cfg, computer, manager) closeLoadingV2(ctx, loadingMsgID, "sandbox.failed") return nil, nil, nil, "", fmt.Errorf("get runner %q: %w", cfg.Runner.Name, err) } // 5. Resolve assistant directory and skills subdirectory. assistantDir := "" skillsDir := "" if ast.Path != "" { assistantDir = filepath.Join(config.Conf.AppSource, ast.Path) dir := filepath.Join(assistantDir, "skills") if info, e := os.Stat(dir); e == nil && info.IsDir() { skillsDir = dir } } // 6. Convert MCP servers. var mcpServers []sandboxTypes.MCPServer if ast.MCP != nil { for _, s := range ast.MCP.Servers { mcpServers = append(mcpServers, sandboxTypes.MCPServer{ ServerID: s.ServerID, Resources: s.Resources, Tools: s.Tools, }) } } // 7. Runner.Prepare (standard context). err = runner.Prepare(stdCtx, &sandboxTypes.PrepareRequest{ Computer: computer, Config: cfg, Connector: conn, SkillsDir: skillsDir, AssistantDir: assistantDir, MCPServers: mcpServers, ConfigHash: ast.ConfigHash, RunSteps: sandboxv2.RunPrepareSteps, }) if err != nil { runner.Cleanup(stdCtx, computer) sandboxv2.LifecycleAction(stdCtx, cfg, computer, manager) closeLoadingV2(ctx, loadingMsgID, "sandbox.failed") return nil, nil, nil, "", fmt.Errorf("runner.Prepare: %w", err) } // Inject computer + workspace into context so Create/Next hooks // can access ctx.computer and ctx.workspace. ctx.SetComputer(computer) cleanup := func() { // Defensive fallback — executeSandboxV2Stream defer handles the // normal case; this covers paths that never reach execution. } return runner, computer, cleanup, loadingMsgID, nil } // executeSandboxV2Stream calls the V2 Runner.Stream and wraps it in the // standard completion response. func (ast *Assistant) executeSandboxV2Stream( ctx *context.Context, completionMessages []context.Message, agentNode traceTypes.Node, streamHandler message.StreamFunc, runner sandboxTypes.Runner, computer infraV2.Computer, loadingMsgID string, ) (*context.CompletionResponse, error) { _ = agentNode cfg := ast.SandboxV2 manager := infraV2.M() // Build system prompt. var systemPrompt string if len(ast.Prompts) > 0 { for _, p := range ast.Prompts { if p.Role == "system" && p.Content != "" { systemPrompt = p.Content break } } } // Resolve connector for Stream. conn, _, _ := ast.GetConnector(ctx) var tok *sandboxTypes.SandboxToken if ctx.Authorized != nil { var err error tok, err = sandboxv2.IssueSandboxToken(ctx.Authorized.TeamID, ctx.Authorized.UserID) if err != nil { return nil, fmt.Errorf("issue sandbox token: %w", err) } } streamReq := &sandboxTypes.StreamRequest{ Computer: computer, Config: cfg, Connector: conn, Messages: completionMessages, SystemPrompt: systemPrompt, ChatID: ctx.ChatID, Token: tok, } execReq := &sandboxv2.ExecuteRequest{ Computer: computer, Runner: runner, Config: cfg, StreamReq: streamReq, Manager: manager, LoadingMsgID: loadingMsgID, } return sandboxv2.ExecuteSandboxStream(ctx, execReq, streamHandler) } // initStandaloneWorkspace loads the workspace FS into context when no sandbox // is configured but the user selected a workspace (metadata["workspace_id"]). func (ast *Assistant) initStandaloneWorkspace(ctx *context.Context) { if ctx.Metadata == nil { return } wsID, _ := ctx.Metadata["workspace_id"].(string) if wsID == "" { return } stdCtx := ctx.Context wsFS, err := workspace.M().FS(stdCtx, wsID) if err != nil { log.Printf("[assistant] initStandaloneWorkspace: failed to load workspace %s: %v", wsID, err) return } ctx.SetWorkspace(wsFS) } // buildBoxDisplayName constructs a human-readable display name for a Box // using the locale-resolved Agent name and Workspace name (matching the UI list pages). func buildBoxDisplayName(ctx *context.Context, assistantID, rawName string) string { agentName := i18n.Tr(assistantID, ctx.Locale, rawName) wsName := "" if ctx.Metadata != nil { if wsID, ok := ctx.Metadata["workspace_id"].(string); ok && wsID != "" { if wsm := workspace.M(); wsm != nil { if ws, err := wsm.Get(ctx.Context, wsID); err == nil && ws != nil { wsName = ws.Name } } } } if agentName != "" && wsName != "" { return agentName + " / " + wsName } if agentName != "" { return agentName } if wsName != "" { return wsName } return "" } func updateLoadingV2(ctx *context.Context, loadingMsgID, msgKey string) { if loadingMsgID == "" || ctx == nil || msgKey == "" { return } msg := &message.Message{ MessageID: loadingMsgID, Delta: true, DeltaAction: message.DeltaReplace, Type: message.TypeLoading, Props: map[string]any{ "message": i18n.T(ctx.Locale, msgKey), }, } ctx.Send(msg) } func closeLoadingV2(ctx *context.Context, loadingMsgID, msgKey string) { if loadingMsgID == "" || ctx == nil { return } props := map[string]any{"done": true} if msgKey != "" { props["message"] = i18n.T(ctx.Locale, msgKey) } else { props["message"] = "" } doneMsg := &message.Message{ MessageID: loadingMsgID, Delta: true, DeltaAction: message.DeltaReplace, Type: message.TypeLoading, Props: props, } ctx.Send(doneMsg) } ================================================ FILE: agent/assistant/scripts.go ================================================ package assistant import ( "context" "fmt" "path/filepath" "strings" "sync" "time" "github.com/yaoapp/gou/application" "github.com/yaoapp/gou/process" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/kun/exception" "github.com/yaoapp/yao/agent/assistant/hook" ) // scriptsMutex protects concurrent v8.Load calls and Scripts map access var scriptsMutex sync.Mutex // Execute execute the script func (s *Script) Execute(ctx context.Context, method string, args ...interface{}) (interface{}, error) { return s.ExecuteWithAuthorized(ctx, method, nil, args...) } // ExecuteWithAuthorized execute the script with authorized information func (s *Script) ExecuteWithAuthorized(ctx context.Context, method string, authorized map[string]interface{}, args ...interface{}) (interface{}, error) { if s == nil || s.Script == nil { return nil, nil } scriptCtx, err := s.NewContext("", nil) if err != nil { return nil, err } defer scriptCtx.Close() // Set authorized information if available if authorized != nil { scriptCtx.WithAuthorized(authorized) } // Call the method with provided arguments as-is result, err := scriptCtx.CallWith(ctx, method, args...) // Return error as-is (including "not defined" errors) return result, err } // LoadScripts loads all scripts from a src directory path // It scans for .ts and .js files (excluding index.ts which is the hook script) // Returns the HookScript and a map of other scripts func LoadScripts(srcDir string) (*hook.Script, map[string]*Script, error) { // Check if src directory exists exists, err := application.App.Exists(srcDir) if err != nil { return nil, nil, err } if !exists { return nil, nil, nil // No src directory } var hookScript *hook.Script scripts := make(map[string]*Script) var loadErr error // Walk through src directory to find all script files exts := []string{"*.ts", "*.js"} err = application.App.Walk(srcDir, func(root, file string, isdir bool) error { if isdir { return nil } // file is the full path, root is srcDir // Get relative path for determining if it's index relPath := strings.TrimPrefix(file, root+"/") // Skip test files (*_test.ts, *_test.js) if strings.HasSuffix(relPath, "_test.ts") || strings.HasSuffix(relPath, "_test.js") { return nil } // Check if it's the root index.ts/js (hook script) // Only src/index.ts is the hook script, not src/foo/index.ts isRootIndex := relPath == "index.ts" || relPath == "index.js" if isRootIndex { scriptsMutex.Lock() script, err := loadScriptFile(file) scriptsMutex.Unlock() if err != nil { loadErr = fmt.Errorf("failed to load hook script %s: %w", file, err) return loadErr } hookScript = script } else { // Generate script ID from relative path scriptID := generateScriptID(file, root) // Load the script (v8.Load is not thread-safe) scriptsMutex.Lock() script, err := loadScriptV8(file) if err != nil { scriptsMutex.Unlock() loadErr = fmt.Errorf("failed to load script %s: %w", file, err) return loadErr } scripts[scriptID] = &Script{Script: script} scriptsMutex.Unlock() } return nil }, exts...) if loadErr != nil { return nil, nil, loadErr } if err != nil { return nil, nil, fmt.Errorf("failed to walk src directory: %w", err) } return hookScript, scripts, nil } // generateScriptID generates a script ID from file path // Example: assistants/test/src/foo/bar/test.ts -> foo.bar.test func generateScriptID(filePath string, srcDir string) string { // Normalize path separators filePath = filepath.ToSlash(filePath) srcDir = filepath.ToSlash(srcDir) // Remove src directory prefix relPath := strings.TrimPrefix(filePath, srcDir+"/") relPath = strings.TrimPrefix(relPath, "/") // Remove file extension relPath = strings.TrimSuffix(relPath, filepath.Ext(relPath)) // Replace path separators with dots scriptID := strings.ReplaceAll(relPath, "/", ".") return scriptID } // loadScriptFile loads a hook script from file func loadScriptFile(file string) (*hook.Script, error) { id := makeScriptID(file, "") script, err := v8.Load(file, id) if err != nil { return nil, err } return &hook.Script{Script: script}, nil } // loadScriptFromSource loads a script from source code // Uses MakeScriptInMemory which supports TypeScript syntax without file resolution func loadScriptFromSource(source string, file string) (*v8.Script, error) { script, err := v8.MakeScriptInMemory([]byte(source), file, 5*time.Second, true) if err != nil { return nil, err } return script, nil } // loadScriptV8 loads a v8.Script from file (used for non-hook scripts) func loadScriptV8(file string) (*v8.Script, error) { id := makeScriptID(file, "") script, err := v8.Load(file, id) if err != nil { return nil, err } return script, nil } // makeScriptID generates the script ID for v8.Load // Converts file path to a dot-separated ID // Example: assistants/tests/fullfields/src/index.ts -> assistants.tests.fullfields.src.index func makeScriptID(file string, root string) string { // Remove root prefix if provided id := file if root != "" { id = strings.TrimPrefix(file, root+"/") } // Remove extension id = strings.TrimSuffix(id, filepath.Ext(id)) // Replace path separators with dots id = strings.ReplaceAll(id, "/", ".") id = strings.ReplaceAll(id, string(filepath.Separator), ".") return id } // LoadScriptsFromData loads scripts from data map // Handles script/scripts/source fields with priority: script > scripts > source > file system func LoadScriptsFromData(data map[string]interface{}, assistantID string) (*hook.Script, map[string]*Script, error) { // Priority 1: script field (hook script from string source) if data["script"] != nil { switch v := data["script"].(type) { case string: file := fmt.Sprintf("assistants/%s/src/index.ts", assistantID) script, err := loadScriptFromSource(v, file) if err != nil { return nil, nil, err } hookScript := &hook.Script{Script: script} // Load other scripts if provided scripts, err := loadScriptsField(data["scripts"]) if err != nil { return nil, nil, err } return hookScript, scripts, nil case *hook.Script: scripts, err := loadScriptsField(data["scripts"]) if err != nil { return nil, nil, err } return v, scripts, nil case *v8.Script: scripts, err := loadScriptsField(data["scripts"]) if err != nil { return nil, nil, err } return &hook.Script{Script: v}, scripts, nil } } // Priority 2: scripts field (map of scripts) if data["scripts"] != nil { // First extract index if present var hookScript *hook.Script if scriptsMap, ok := data["scripts"].(map[string]interface{}); ok { if indexSource, hasIndex := scriptsMap["index"]; hasIndex { switch v := indexSource.(type) { case string: file := fmt.Sprintf("assistants/%s/src/index.ts", assistantID) script, err := loadScriptFromSource(v, file) if err != nil { return nil, nil, err } hookScript = &hook.Script{Script: script} case *Script: hookScript = &hook.Script{Script: v.Script} case *v8.Script: hookScript = &hook.Script{Script: v} } } } // Then load other scripts (loadScriptsField automatically filters out index) scripts, err := loadScriptsField(data["scripts"]) if err != nil { return nil, nil, err } return hookScript, scripts, nil } // Priority 3: source field (legacy hook script from source) if source, ok := data["source"].(string); ok && source != "" { script, err := loadSource(source, assistantID) if err != nil { return nil, nil, err } return script, nil, nil } // Priority 4: file system (scan src directory) srcDir := fmt.Sprintf("assistants/%s/src", assistantID) return LoadScripts(srcDir) } // loadScriptsField parses scripts field from data // Note: "index" is always filtered out as it's reserved for HookScript func loadScriptsField(scriptsData interface{}) (map[string]*Script, error) { if scriptsData == nil { return nil, nil } scripts := make(map[string]*Script) switch v := scriptsData.(type) { case map[string]*Script: for id, script := range v { if id == "index" { continue // Skip index } scripts[id] = script } return scripts, nil case map[string]*v8.Script: for id, script := range v { if id == "index" { continue // Skip index } scripts[id] = &Script{Script: script} } return scripts, nil case map[string]interface{}: for id, item := range v { if id == "index" { continue // Skip index } switch s := item.(type) { case *Script: scripts[id] = s case *v8.Script: scripts[id] = &Script{Script: s} case string: // Load script from source code file := fmt.Sprintf("script_%s", id) script, err := loadScriptFromSource(s, file) if err != nil { return nil, fmt.Errorf("failed to load script %s: %w", id, err) } scripts[id] = &Script{Script: script} } } return scripts, nil } return nil, nil } // RegisterScripts registers all scripts as process handlers // Handler naming: agents.. func (ast *Assistant) RegisterScripts() error { if len(ast.Scripts) == 0 { return nil } assistantID := ast.ID handlers := make(map[string]process.Handler) for scriptID, script := range ast.Scripts { // Create handler for this script handlers[scriptID] = makeScriptHandler(script) } // Register the handler group dynamically groupName := fmt.Sprintf("agents.%s", assistantID) process.RegisterDynamicGroup(groupName, handlers) return nil } // UnregisterScripts unregisters all scripts from process handlers func (ast *Assistant) UnregisterScripts() error { if len(ast.Scripts) == 0 { return nil } assistantID := ast.ID for scriptID := range ast.Scripts { handlerID := fmt.Sprintf("agents.%s.%s", strings.ToLower(assistantID), strings.ToLower(scriptID)) delete(process.Handlers, handlerID) } return nil } // makeScriptHandler creates a process handler for a script func makeScriptHandler(script *Script) process.Handler { return func(p *process.Process) interface{} { // Extract method name from process method := p.Method // Get arguments from process args := p.Args // Convert authorized info to map if available var authorized map[string]interface{} if p.Authorized != nil { authorized = p.Authorized.AuthorizedToMap() } // Execute the script with authorized information result, err := script.ExecuteWithAuthorized(p.Context, method, authorized, args...) if err != nil { exception.New(err.Error(), 500).Throw() } return result } } ================================================ FILE: agent/assistant/scripts_process_test.go ================================================ package assistant_test import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/process" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/testutils" ) func TestScriptsProcessFlow(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) // Get the mcpload assistant assistantID := "tests.mcpload" ast, err := assistant.Get(assistantID) assert.NoError(t, err) assert.NotNil(t, ast, "Assistant should be loaded") // Check that scripts were loaded assert.NotNil(t, ast.Scripts) assert.Greater(t, len(ast.Scripts), 0, "Should have loaded at least one script") // Verify tools.ts was loaded toolsScript, hasTools := ast.Scripts["tools"] assert.True(t, hasTools, "Should have loaded tools script") assert.NotNil(t, toolsScript) // Register scripts as process handlers err = ast.RegisterScripts() assert.NoError(t, err) // Test 1: Call Hello function t.Run("CallHelloFunction", func(t *testing.T) { handlerID := "agents.tests.mcpload.tools" handler, exists := process.Handlers[handlerID] assert.True(t, exists, "Handler should be registered") p := &process.Process{ ID: handlerID + ".Hello", Method: "Hello", Args: []interface{}{map[string]interface{}{"name": "Yao"}}, Context: context.Background(), } result := handler(p) assert.NotNil(t, result) resultStr, ok := result.(string) assert.True(t, ok, "Result should be a string") assert.Contains(t, resultStr, "Hello, Yao") }) // Test 2: Call Ping function t.Run("CallPingFunction", func(t *testing.T) { handlerID := "agents.tests.mcpload.tools" handler, exists := process.Handlers[handlerID] assert.True(t, exists, "Handler should be registered") p := &process.Process{ ID: handlerID + ".Ping", Method: "Ping", Args: []interface{}{map[string]interface{}{"message": "test"}}, Context: context.Background(), } result := handler(p) assert.NotNil(t, result) resultMap, ok := result.(map[string]interface{}) assert.True(t, ok, "Result should be a map") assert.Equal(t, "test", resultMap["message"]) assert.Contains(t, resultMap["echo"], "Pong") }) // Test 3: Call Calculate function t.Run("CallCalculateFunction", func(t *testing.T) { handlerID := "agents.tests.mcpload.tools" handler, exists := process.Handlers[handlerID] assert.True(t, exists, "Handler should be registered") p := &process.Process{ ID: handlerID + ".Calculate", Method: "Calculate", Args: []interface{}{map[string]interface{}{ "operation": "add", "a": float64(10), "b": float64(5), }}, Context: context.Background(), } result := handler(p) assert.NotNil(t, result) resultMap, ok := result.(map[string]interface{}) assert.True(t, ok, "Result should be a map") assert.Equal(t, float64(15), resultMap["result"]) }) // Test 4: Unregister scripts t.Run("UnregisterScripts", func(t *testing.T) { err := ast.UnregisterScripts() assert.NoError(t, err) // Verify handlers are removed handlerID := "agents.tests.mcpload.tools" _, exists := process.Handlers[handlerID] assert.False(t, exists, "Handler should be unregistered") }) } func TestScriptsProcessUsing(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) // Get the mcpload assistant assistantID := "tests.mcpload" ast, err := assistant.Get(assistantID) assert.NoError(t, err) assert.NotNil(t, ast) // Register scripts err = ast.RegisterScripts() assert.NoError(t, err) defer ast.UnregisterScripts() // Test 1: Call Hello using process.New().Execute() t.Run("ProcessHello", func(t *testing.T) { proc := process.New("agents.tests.mcpload.tools.Hello", map[string]interface{}{ "name": "Yao", }) err := proc.Execute() assert.NoError(t, err) result := proc.Value() assert.NotNil(t, result) resultStr, ok := result.(string) assert.True(t, ok, "Result should be a string") assert.Contains(t, resultStr, "Hello, Yao") }) // Test 2: Call Ping using process.New().Execute() t.Run("ProcessPing", func(t *testing.T) { proc := process.New("agents.tests.mcpload.tools.Ping", map[string]interface{}{ "message": "test message", }) err := proc.Execute() assert.NoError(t, err) result := proc.Value() assert.NotNil(t, result) resultMap, ok := result.(map[string]interface{}) assert.True(t, ok, "Result should be a map") assert.Equal(t, "test message", resultMap["message"]) assert.Contains(t, resultMap["echo"], "Pong") }) // Test 3: Call Calculate using process.New().Execute() t.Run("ProcessCalculate", func(t *testing.T) { proc := process.New("agents.tests.mcpload.tools.Calculate", map[string]interface{}{ "operation": "multiply", "a": float64(6), "b": float64(7), }) err := proc.Execute() assert.NoError(t, err) result := proc.Value() assert.NotNil(t, result) resultMap, ok := result.(map[string]interface{}) assert.True(t, ok, "Result should be a map") assert.Equal(t, float64(42), resultMap["result"]) assert.Equal(t, "multiply", resultMap["operation"]) }) } func TestScriptsProcessError(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) // Get the mcpload assistant assistantID := "tests.mcpload" ast, err := assistant.Get(assistantID) assert.NoError(t, err) assert.NotNil(t, ast) // Register scripts err = ast.RegisterScripts() assert.NoError(t, err) defer ast.UnregisterScripts() // Test calling non-existent method t.Run("CallNonExistentMethod", func(t *testing.T) { proc := process.New("agents.tests.mcpload.tools.NonExistent") err := proc.Execute() assert.NotNil(t, err, "Should return error when calling non-existent method") assert.Contains(t, err.Error(), "Exception|500") }) } ================================================ FILE: agent/assistant/scripts_test.go ================================================ package assistant import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/process" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // TestLoadScripts tests loading scripts from file system // Note: These tests are commented out due to path format differences // The functionality is tested by existing integration tests in the codebase func TestLoadScriptsFromData(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() t.Run("LoadFromScriptField", func(t *testing.T) { // Use JavaScript instead of TypeScript to avoid compilation path issues data := map[string]interface{}{ "script": `function Create(ctx) { return null; }`, } // Need to provide a real assistant path for compilation data["path"] = "assistants/tests/mcpload" hookScript, scripts, err := LoadScriptsFromData(data, "tests.mcpload") require.NoError(t, err) assert.NotNil(t, hookScript, "HookScript should be loaded from script field") assert.Nil(t, scripts, "Scripts should be nil when only script field is provided") t.Logf("✓ Successfully loaded from script field") }) t.Run("LoadFromScriptsField", func(t *testing.T) { data := map[string]interface{}{ "scripts": map[string]interface{}{ "tool1": `function tool1() { return "tool1"; }`, "tool2": `function tool2() { return "tool2"; }`, }, } hookScript, scripts, err := LoadScriptsFromData(data, "test.assistant") require.NoError(t, err) assert.Nil(t, hookScript, "HookScript should be nil when no index in scripts") require.NotNil(t, scripts, "Scripts should be loaded") assert.Len(t, scripts, 2, "Should have 2 scripts") assert.Contains(t, scripts, "tool1") assert.Contains(t, scripts, "tool2") t.Logf("✓ Successfully loaded from scripts field") }) t.Run("LoadFromScriptsFieldWithIndex", func(t *testing.T) { // Test that index is properly extracted and not present in Scripts map // Note: We skip actual script compilation here to avoid path issues data := map[string]interface{}{ "scripts": map[string]interface{}{ "tool1": `function tool1() { return "tool1"; }`, "tool2": `function tool2() { return "tool2"; }`, }, } hookScript, scripts, err := LoadScriptsFromData(data, "test.assistant") require.NoError(t, err) // Without index in scripts field, hookScript should be nil assert.Nil(t, hookScript, "HookScript should be nil when no index in scripts") require.NotNil(t, scripts, "Scripts should be loaded") assert.Len(t, scripts, 2, "Should have 2 scripts") assert.Contains(t, scripts, "tool1") assert.Contains(t, scripts, "tool2") assert.NotContains(t, scripts, "index", "index should never be in Scripts map") t.Logf("✓ Successfully loaded from scripts field, index properly filtered") }) t.Run("LoadFromSourceField", func(t *testing.T) { data := map[string]interface{}{ "source": `function Create(ctx) { return null; }`, } hookScript, scripts, err := LoadScriptsFromData(data, "test.assistant") require.NoError(t, err) assert.NotNil(t, hookScript, "HookScript should be loaded from source field") assert.Nil(t, scripts, "Scripts should be nil when only source field is provided") t.Logf("✓ Successfully loaded from source field") }) t.Run("PriorityOrder", func(t *testing.T) { // script field should take priority over scripts field data := map[string]interface{}{ "script": `function Create1() { return null; }`, "scripts": map[string]interface{}{ "tool1": `function tool1() { return "tool1"; }`, }, "source": `function Create2() { return null; }`, "path": "assistants/tests/mcpload", } hookScript, scripts, err := LoadScriptsFromData(data, "tests.mcpload") require.NoError(t, err) assert.NotNil(t, hookScript, "HookScript should be loaded") require.NotNil(t, scripts, "Scripts should be loaded") assert.Len(t, scripts, 1, "Should have 1 script from scripts field") t.Logf("✓ Priority order works: script > scripts > source") }) } func TestGenerateScriptID(t *testing.T) { tests := []struct { name string filePath string srcDir string expected string }{ { name: "Simple file", filePath: "assistants/test/src/tools.ts", srcDir: "assistants/test/src", expected: "tools", }, { name: "Nested directory", filePath: "assistants/test/src/foo/bar/test.ts", srcDir: "assistants/test/src", expected: "foo.bar.test", }, { name: "Single level nested", filePath: "assistants/test/src/utils/helper.js", srcDir: "assistants/test/src", expected: "utils.helper", }, { name: "Deep nesting", filePath: "assistants/test/src/a/b/c/d/file.ts", srcDir: "assistants/test/src", expected: "a.b.c.d.file", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := generateScriptID(tt.filePath, tt.srcDir) assert.Equal(t, tt.expected, result, "Script ID should match expected value") t.Logf("✓ %s: %s → %s", tt.name, tt.filePath, result) }) } } // TestLoadScriptsThreadSafety tests concurrent script loading // Note: This test is commented out due to path format differences // Thread safety is ensured by the scriptsMutex in LoadScripts function func TestExecuteWithAuthorized(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() t.Run("ExecuteWithAuthorizedInfo", func(t *testing.T) { // Create a script that returns the authorized info from __yao_data scriptSource := ` function GetAuth() { if (typeof __yao_data !== 'undefined' && __yao_data.AUTHORIZED) { return __yao_data.AUTHORIZED; } return null; } ` data := map[string]interface{}{ "scripts": map[string]interface{}{ "auth_test": scriptSource, }, } _, scripts, err := LoadScriptsFromData(data, "test.authorized") require.NoError(t, err) require.NotNil(t, scripts) require.Contains(t, scripts, "auth_test") script := scripts["auth_test"] // Create authorized info authorized := map[string]interface{}{ "user_id": "user123", "team_id": "team456", "scope": "read write", "constraints": map[string]interface{}{ "team_only": true, }, } // Execute with authorized info ctx := context.Background() result, err := script.ExecuteWithAuthorized(ctx, "GetAuth", authorized) require.NoError(t, err) require.NotNil(t, result) // Verify the authorized info was passed correctly resultMap, ok := result.(map[string]interface{}) require.True(t, ok, "Result should be a map") assert.Equal(t, "user123", resultMap["user_id"]) assert.Equal(t, "team456", resultMap["team_id"]) assert.Equal(t, "read write", resultMap["scope"]) constraints, ok := resultMap["constraints"].(map[string]interface{}) require.True(t, ok, "Constraints should be a map") assert.Equal(t, true, constraints["team_only"]) t.Logf("✓ Authorized info passed correctly to script") }) t.Run("ExecuteWithoutAuthorizedInfo", func(t *testing.T) { // Create a script that checks for authorized info scriptSource := ` function CheckAuth() { if (typeof __yao_data !== 'undefined' && __yao_data.AUTHORIZED) { return { hasAuth: true, data: __yao_data.AUTHORIZED }; } return { hasAuth: false }; } ` data := map[string]interface{}{ "scripts": map[string]interface{}{ "no_auth_test": scriptSource, }, } _, scripts, err := LoadScriptsFromData(data, "test.noauth") require.NoError(t, err) require.NotNil(t, scripts) require.Contains(t, scripts, "no_auth_test") script := scripts["no_auth_test"] // Execute without authorized info ctx := context.Background() result, err := script.Execute(ctx, "CheckAuth") require.NoError(t, err) require.NotNil(t, result) resultMap, ok := result.(map[string]interface{}) require.True(t, ok) assert.Equal(t, false, resultMap["hasAuth"]) t.Logf("✓ Script executed correctly without authorized info") }) t.Run("MakeScriptHandlerWithAuthorized", func(t *testing.T) { // Create a script that returns authorized user_id scriptSource := ` function GetUserID() { if (typeof __yao_data !== 'undefined' && __yao_data.AUTHORIZED) { return __yao_data.AUTHORIZED.user_id || null; } return null; } ` data := map[string]interface{}{ "scripts": map[string]interface{}{ "handler_test": scriptSource, }, } _, scripts, err := LoadScriptsFromData(data, "test.handler") require.NoError(t, err) require.NotNil(t, scripts) require.Contains(t, scripts, "handler_test") script := scripts["handler_test"] // Create a process handler handler := makeScriptHandler(script) require.NotNil(t, handler) // Create a mock process with authorized info ctx := context.Background() p := &process.Process{ Method: "GetUserID", Args: []interface{}{}, Context: ctx, Authorized: &process.AuthorizedInfo{ UserID: "user999", TeamID: "team888", Scope: "admin", }, } // Execute the handler result := handler(p) require.NotNil(t, result) // Verify the result assert.Equal(t, "user999", result) t.Logf("✓ Process handler correctly passed authorized info") }) } ================================================ FILE: agent/assistant/search.go ================================================ package assistant import ( "encoding/json" "fmt" "strings" "time" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/search" "github.com/yaoapp/yao/agent/search/nlp/keyword" searchTypes "github.com/yaoapp/yao/agent/search/types" storeTypes "github.com/yaoapp/yao/agent/store/types" traceTypes "github.com/yaoapp/yao/trace/types" ) // shouldAutoSearch determines if auto search should be executed // Returns nil if search should be skipped, otherwise returns SearchIntent with types to search // Search is skipped if: // - opts.Skip.Search is true // - createResponse.Search is false // - uses.search is "disabled" // - assistant has no search configuration // - needsearch intent detection returns false func (ast *Assistant) shouldAutoSearch(ctx *context.Context, messages []context.Message, createResponse *context.HookCreateResponse, opts *context.Options) *SearchIntent { // Check if search is skipped via options if opts != nil && opts.Skip != nil && opts.Skip.Search { ctx.Logger.Debug("Auto search skipped by opts.Skip.Search") return nil } // Check if search is skipped via ctx.Metadata["__disable_search"] if ctx != nil && ctx.Metadata != nil { disableSearch := getBool(ctx.Metadata, "__disable_search") if disableSearch { ctx.Logger.Debug("Auto search skipped by ctx.Metadata['__disable_search']") return nil } } // Check createResponse.Search field (highest priority from Create hook) // Supports: bool | SearchIntent | nil if createResponse != nil && createResponse.Search != nil { intent := parseSearchField(createResponse.Search) if intent != nil { if !intent.NeedSearch { ctx.Logger.Info("Auto search disabled by createResponse.Search") return nil } ctx.Logger.Info("Auto search controlled by createResponse.Search: types=%v", intent.SearchTypes) return intent } } // Get merged uses configuration uses := ast.getMergedSearchUses(createResponse, opts) // Check if search is explicitly disabled if uses != nil && uses.Search == "disabled" { ctx.Logger.Info("Auto search disabled by uses.search=disabled") return nil } // Check if assistant has search configuration if ast.Search == nil && (uses == nil || uses.Search == "") { return nil } // Check search intent using __yao.needsearch agent intent := ast.checkSearchIntent(ctx, messages) if intent == nil || !intent.NeedSearch { ctx.Logger.Info("Auto search skipped: intent detection returned false") return nil } return intent } // parseSearchField parses the Search field from HookCreateResponse // Supports: bool | SearchIntent | map[string]any | nil func parseSearchField(search any) *SearchIntent { if search == nil { return nil } switch v := search.(type) { case bool: // bool: true = enable all, false = disable all if v { return &SearchIntent{ NeedSearch: true, SearchTypes: []string{"web", "kb", "db"}, Confidence: 1.0, Reason: "enabled by hook", } } return &SearchIntent{ NeedSearch: false, SearchTypes: []string{}, Confidence: 1.0, Reason: "disabled by hook", } case *SearchIntent: // SearchIntent is alias for context.SearchIntent, so this covers both return v case SearchIntent: return &v case map[string]any: // Parse from map (e.g., from JSON) intent := &SearchIntent{ NeedSearch: false, SearchTypes: []string{}, Confidence: 0.5, } if needSearch, ok := v["need_search"].(bool); ok { intent.NeedSearch = needSearch } if types, ok := v["search_types"].([]any); ok { for _, t := range types { if typeStr, ok := t.(string); ok { intent.SearchTypes = append(intent.SearchTypes, typeStr) } } } if confidence, ok := v["confidence"].(float64); ok { intent.Confidence = confidence } if reason, ok := v["reason"].(string); ok { intent.Reason = reason } return intent default: return nil } } // checkSearchIntent uses __yao.needsearch agent to determine if search is needed // Returns SearchIntent with search types and confidence func (ast *Assistant) checkSearchIntent(ctx *context.Context, messages []context.Message) *SearchIntent { // Default intent: no search needed (fallback when agent unavailable or fails) defaultIntent := &SearchIntent{ NeedSearch: false, SearchTypes: []string{}, Confidence: 0, } // Build a single text message with conversation context intentMessages := buildContextMessage(messages) if len(intentMessages) == 0 { return defaultIntent // No messages, skip search } // Try to get __yao.needsearch agent needsearchAst, err := Get("__yao.needsearch") if err != nil { ctx.Logger.Debug("__yao.needsearch agent not available: %v, skipping search", err) return defaultIntent // Agent not available, skip search } // === Output: Send loading message === loadingID := ast.sendIntentLoading(ctx) // Call the needsearch agent (Stack will auto-track) // IMPORTANT: Skip search to prevent infinite loop, skip output to prevent JSON showing in UI opts := &context.Options{ Skip: &context.Skip{ History: true, // Don't save to history Search: true, // Skip search to prevent infinite loop Output: true, // Skip output to prevent JSON showing in UI }, } result, err := needsearchAst.Stream(ctx, intentMessages, opts) if err != nil { ctx.Logger.Debug("__yao.needsearch failed: %v, skipping search", err) // === Output: Send done (error case, skip search) === ast.sendIntentDone(ctx, loadingID, false, "") return defaultIntent // On error, skip search } // Parse the result // Next hook returns {data: {need_search: bool, search_types: [], confidence: float}} // First try to get from Next hook response if result.Next != nil { if nextData, ok := result.Next.(map[string]interface{}); ok { // Check for data field (from Next hook's {data: result}) var intentData map[string]interface{} if data, ok := nextData["data"].(map[string]interface{}); ok { intentData = data } else { intentData = nextData } intent := parseSearchIntent(intentData) if intent != nil { ctx.Logger.Debug("Search intent (from Next): need_search=%v, types=%v, confidence=%.2f, reason=%s", intent.NeedSearch, intent.SearchTypes, intent.Confidence, intent.Reason) ast.sendIntentDone(ctx, loadingID, intent.NeedSearch, intent.Reason) return intent } } } // Fallback: parse from Completion.Content if Next hook didn't process if result.Completion != nil { content, ok := result.Completion.Content.(string) if !ok || content == "" { ast.sendIntentDone(ctx, loadingID, false, "") return defaultIntent } intent := parseSearchIntentFromContent(content) ctx.Logger.Debug("Search intent (from Content): need_search=%v, types=%v, confidence=%.2f, reason=%s", intent.NeedSearch, intent.SearchTypes, intent.Confidence, intent.Reason) ast.sendIntentDone(ctx, loadingID, intent.NeedSearch, intent.Reason) return intent } // Default: skip search if we can't parse the result // === Output: Send done (default case) === ast.sendIntentDone(ctx, loadingID, false, "") return defaultIntent } // parseSearchIntent parses SearchIntent from intent data map func parseSearchIntent(intentData map[string]interface{}) *SearchIntent { if intentData == nil { return nil } needSearch, ok := intentData["need_search"].(bool) if !ok { return nil } intent := &SearchIntent{ NeedSearch: needSearch, SearchTypes: []string{}, Confidence: 0.5, // Default confidence } // Parse search_types if types, ok := intentData["search_types"].([]interface{}); ok { for _, t := range types { if typeStr, ok := t.(string); ok { // Validate type typeStr = strings.ToLower(typeStr) if typeStr == "web" || typeStr == "kb" || typeStr == "db" { intent.SearchTypes = append(intent.SearchTypes, typeStr) } } } } // Parse confidence if confidence, ok := intentData["confidence"].(float64); ok { intent.Confidence = confidence } // Parse reason if reason, ok := intentData["reason"].(string); ok { intent.Reason = reason } return intent } // parseSearchIntentFromContent parses SearchIntent from LLM completion content // Handles JSON wrapped in markdown code blocks func parseSearchIntentFromContent(content string) *SearchIntent { // Default intent: no search needed defaultIntent := &SearchIntent{ NeedSearch: false, SearchTypes: []string{}, Confidence: 0, } // Remove markdown code block if present content = strings.TrimSpace(content) if strings.HasPrefix(content, "```json") { content = strings.TrimPrefix(content, "```json") content = strings.TrimSuffix(content, "```") content = strings.TrimSpace(content) } else if strings.HasPrefix(content, "```") { content = strings.TrimPrefix(content, "```") content = strings.TrimSuffix(content, "```") content = strings.TrimSpace(content) } // Try to parse JSON var result map[string]interface{} if err := json.Unmarshal([]byte(content), &result); err != nil { // Failed to parse, default to no search return defaultIntent } intent := parseSearchIntent(result) if intent == nil { return defaultIntent } return intent } // sendIntentLoading sends the initial intent detection loading message // Returns the message ID for later replacement func (ast *Assistant) sendIntentLoading(ctx *context.Context) string { loadingMsg := i18n.T(ctx.Locale, "search.intent.loading") msg := &message.Message{ Type: "loading", Props: map[string]any{ "message": loadingMsg, }, } // Send and get message ID msgID, err := ctx.SendStream(msg) if err != nil { ctx.Logger.Warn("Failed to send intent loading message: %v", err) return "" } return msgID } // sendIntentDone replaces loading with result // Only marks as done when needSearch is false (no further loading will follow) // When needSearch is true, the search loading will continue func (ast *Assistant) sendIntentDone(ctx *context.Context, loadingID string, needSearch bool, reason string) { if loadingID == "" { return } var resultMsg string if needSearch { resultMsg = i18n.T(ctx.Locale, "search.intent.need_search") } else { resultMsg = i18n.T(ctx.Locale, "search.intent.no_search") } msg := &message.Message{ MessageID: loadingID, Delta: true, DeltaAction: message.DeltaReplace, Type: "loading", Props: map[string]any{ "message": resultMsg, "done": true, // Intent detection loading is independent, always close it }, } if err := ctx.Send(msg); err != nil { ctx.Logger.Warn("Failed to send intent done message: %v", err) } } // getMergedSearchUses returns the merged uses configuration for search // Priority: createResponse > options.Uses > assistant func (ast *Assistant) getMergedSearchUses(createResponse *context.HookCreateResponse, opts ...*context.Options) *context.Uses { // Start with assistant uses var uses *context.Uses if ast.Uses != nil { uses = &context.Uses{ Search: ast.Uses.Search, Web: ast.Uses.Web, Keyword: ast.Uses.Keyword, QueryDSL: ast.Uses.QueryDSL, Rerank: ast.Uses.Rerank, } } // Override with options.Uses if provided (highest priority) if len(opts) > 0 && opts[0] != nil && opts[0].Uses != nil { if uses == nil { uses = &context.Uses{} } if opts[0].Uses.Search != "" { uses.Search = opts[0].Uses.Search } if opts[0].Uses.Web != "" { uses.Web = opts[0].Uses.Web } if opts[0].Uses.Keyword != "" { uses.Keyword = opts[0].Uses.Keyword } if opts[0].Uses.QueryDSL != "" { uses.QueryDSL = opts[0].Uses.QueryDSL } if opts[0].Uses.Rerank != "" { uses.Rerank = opts[0].Uses.Rerank } } // Override with createResponse.Uses if provided (highest priority) if createResponse != nil && createResponse.Uses != nil { if uses == nil { uses = &context.Uses{} } if createResponse.Uses.Search != "" { uses.Search = createResponse.Uses.Search } if createResponse.Uses.Web != "" { uses.Web = createResponse.Uses.Web } if createResponse.Uses.Keyword != "" { uses.Keyword = createResponse.Uses.Keyword } if createResponse.Uses.QueryDSL != "" { uses.QueryDSL = createResponse.Uses.QueryDSL } if createResponse.Uses.Rerank != "" { uses.Rerank = createResponse.Uses.Rerank } } return uses } // executeAutoSearch executes auto search based on configuration and intent // Returns ReferenceContext with results and formatted context // intent specifies which search types to execute (from needsearch agent) // opts is optional, used to check Skip.Keyword func (ast *Assistant) executeAutoSearch(ctx *context.Context, messages []context.Message, createResponse *context.HookCreateResponse, intent *SearchIntent, opts ...*context.Options) *searchTypes.ReferenceContext { ctx.Logger.Phase("Search") defer ctx.Logger.PhaseComplete("Search") // Get merged uses configuration uses := ast.getMergedSearchUses(createResponse, opts...) // Convert to search.Uses searchUses := &search.Uses{} if uses != nil { searchUses.Search = uses.Search searchUses.Web = uses.Web searchUses.Keyword = uses.Keyword searchUses.QueryDSL = uses.QueryDSL searchUses.Rerank = uses.Rerank } // Get merged search config searchConfig := ast.GetMergedSearchConfig() // Create searcher searcher := search.New(searchConfig, searchUses) // Extract query from messages (save original for storage) originalQuery := extractQueryFromMessages(messages) if originalQuery == "" { ctx.Logger.Info("No query found in messages, skipping auto search") return nil } // Build query with conversation context for better keyword extraction // This helps the keyword extractor understand the full context contextMessages := buildContextMessage(messages) query := originalQuery if len(contextMessages) > 0 { if contextStr, ok := contextMessages[0].Content.(string); ok { query = contextStr } } // Check if keyword extraction should be skipped skipKeyword := false if len(opts) > 0 && opts[0] != nil && opts[0].Skip != nil { skipKeyword = opts[0].Skip.Keyword } // Build search requests based on configuration and intent // Keyword extraction is done inside buildSearchRequests for web search buildOpts := &buildSearchRequestsOptions{ skipKeyword: skipKeyword, usesKeyword: searchUses.Keyword, } requests, extractedKeywords := ast.buildSearchRequests(ctx, query, searchConfig, intent, buildOpts) if len(requests) == 0 { ctx.Logger.Info("No search requests to execute") return nil } // Update query if keywords were extracted (for web search) if len(extractedKeywords) > 0 { query = keywordsToQuery(extractedKeywords) } // === Output: Send loading message === loadingID := ast.sendSearchLoading(ctx) // === Trace: Create search trace node === searchNode := ast.createSearchTrace(ctx, query, requests) // Execute searches in parallel // Build provider info for logging providerInfo := ast.getSearchProviderInfo(searchConfig, searchUses) ctx.Logger.Info("Executing %d search requests via %s for query: %s", len(requests), providerInfo, truncateString(query, 50)) startTime := time.Now() results, err := searcher.All(ctx, requests) duration := time.Since(startTime).Milliseconds() if err != nil { // Log error but don't fail - search errors shouldn't block the main flow ctx.Logger.Error("Auto search failed: %v", err) // === Output: Send failed message === ast.sendSearchDone(ctx, loadingID, 0, true) // === Trace: Mark as failed === ast.completeSearchTrace(searchNode, 0, err) // === Storage: Save failed search === ast.saveSearch(ctx, &SearchExecutionResult{ Query: originalQuery, Keywords: extractedKeywords, Config: ast.configToMap(searchConfig), Duration: duration, Error: err, SearchType: "auto", }) return nil } // Build reference context (includes references, XML, and prompt) var citationConfig *searchTypes.CitationConfig if searchConfig != nil { citationConfig = searchConfig.Citation } refCtx := search.BuildReferenceContext(results, citationConfig) resultCount := len(refCtx.References) // === Output: Send result message, then done === ast.sendSearchResult(ctx, loadingID, resultCount) ast.sendSearchDone(ctx, loadingID, resultCount, false) // === Trace: Mark as completed === ast.completeSearchTrace(searchNode, resultCount, nil) // === Storage: Save successful search === ast.saveSearch(ctx, &SearchExecutionResult{ Query: originalQuery, Keywords: extractedKeywords, Config: ast.configToMap(searchConfig), RefCtx: refCtx, Results: results, Duration: duration, SearchType: "auto", }) if resultCount == 0 { ctx.Logger.Info("No search results found") return nil } ctx.Logger.Info("Auto search completed: %d references", resultCount) return refCtx } // ============================================================================ // Output: Loading Replace Pattern // ============================================================================ // sendSearchLoading sends the initial loading message // Returns the message ID for later replacement func (ast *Assistant) sendSearchLoading(ctx *context.Context) string { loadingMsg := i18n.T(ctx.Locale, "search.loading") msg := &message.Message{ Type: "loading", Props: map[string]any{ "message": loadingMsg, }, } // Send and get message ID msgID, err := ctx.SendStream(msg) if err != nil { ctx.Logger.Warn("Failed to send search loading message: %v", err) return "" } return msgID } // sendKeywordLoading sends the keyword extraction loading message // Returns the message ID for later replacement func (ast *Assistant) sendKeywordLoading(ctx *context.Context) string { loadingMsg := i18n.T(ctx.Locale, "search.keyword.loading") msg := &message.Message{ Type: "loading", Props: map[string]any{ "message": loadingMsg, }, } // Send and get message ID msgID, err := ctx.SendStream(msg) if err != nil { ctx.Logger.Warn("Failed to send keyword loading message: %v", err) return "" } return msgID } // sendKeywordDone replaces keyword loading with done message func (ast *Assistant) sendKeywordDone(ctx *context.Context, loadingID string, success bool) { if loadingID == "" { return } resultMsg := i18n.T(ctx.Locale, "search.keyword.done") msg := &message.Message{ MessageID: loadingID, Delta: true, DeltaAction: message.DeltaReplace, Type: "loading", Props: map[string]any{ "message": resultMsg, "done": true, }, } if err := ctx.Send(msg); err != nil { ctx.Logger.Warn("Failed to send keyword done message: %v", err) } } // sendSearchResult replaces loading with result message (without done flag) func (ast *Assistant) sendSearchResult(ctx *context.Context, loadingID string, count int) { if loadingID == "" { return } var resultMsg string if count == 0 { resultMsg = i18n.T(ctx.Locale, "search.no_results") } else if count == 1 { resultMsg = i18n.T(ctx.Locale, "search.success.one") } else { resultMsg = fmt.Sprintf(i18n.T(ctx.Locale, "search.success"), count) } msg := &message.Message{ MessageID: loadingID, Delta: true, DeltaAction: message.DeltaReplace, Type: "loading", Props: map[string]any{ "message": resultMsg, }, } if err := ctx.Send(msg); err != nil { ctx.Logger.Warn("Failed to send search result message: %v", err) } } // sendSearchDone sends the final done message (removes loading indicator) func (ast *Assistant) sendSearchDone(ctx *context.Context, loadingID string, count int, failed bool) { if loadingID == "" { return } var resultMsg string if failed { resultMsg = i18n.T(ctx.Locale, "search.failed") } else if count == 0 { resultMsg = i18n.T(ctx.Locale, "search.no_results") } else if count == 1 { resultMsg = i18n.T(ctx.Locale, "search.success.one") } else { resultMsg = fmt.Sprintf(i18n.T(ctx.Locale, "search.success"), count) } msg := &message.Message{ MessageID: loadingID, Delta: true, DeltaAction: message.DeltaReplace, Type: "loading", Props: map[string]any{ "message": resultMsg, "done": true, // Frontend will remove loading indicator }, } if err := ctx.Send(msg); err != nil { ctx.Logger.Warn("Failed to send search done message: %v", err) } } // ============================================================================ // Trace: Search Node // ============================================================================ // createSearchTrace creates a trace node for search operation func (ast *Assistant) createSearchTrace(ctx *context.Context, query string, requests []*searchTypes.Request) traceTypes.Node { trace, _ := ctx.Trace() if trace == nil { return nil } // Build search types list var searchTypes []string for _, req := range requests { searchTypes = append(searchTypes, string(req.Type)) } input := map[string]any{ "query": query, "types": searchTypes, } node, err := trace.Add(input, traceTypes.TraceNodeOption{ Label: i18n.T(ctx.Locale, "search.trace.label"), Type: "search", Icon: "search", Description: i18n.T(ctx.Locale, "search.trace.description"), }) if err != nil { ctx.Logger.Warn("Failed to create search trace node: %v", err) return nil } // Log search start node.Info("Starting search", map[string]any{ "query": query, "types": searchTypes, }) return node } // completeSearchTrace marks the search trace node as completed or failed func (ast *Assistant) completeSearchTrace(node traceTypes.Node, resultCount int, err error) { if node == nil { return } if err != nil { node.Warn("Search failed", map[string]any{"error": err.Error()}) node.Fail(err) return } // Log completion node.Info("Search completed", map[string]any{ "result_count": resultCount, }) // Complete with output node.Complete(map[string]any{ "result_count": resultCount, }) } // buildSearchRequestsOptions contains options for building search requests type buildSearchRequestsOptions struct { skipKeyword bool // Skip keyword extraction usesKeyword string // Keyword extractor config: "builtin", "", "mcp:." } // buildSearchRequests builds search requests based on assistant configuration and intent // intent specifies which search types to execute (from needsearch agent) // Returns requests and extracted keywords (if any) func (ast *Assistant) buildSearchRequests(ctx *context.Context, query string, config *searchTypes.Config, intent *SearchIntent, opts *buildSearchRequestsOptions) ([]*searchTypes.Request, []searchTypes.Keyword) { var requests []*searchTypes.Request var extractedKeywords []searchTypes.Keyword // Helper to check if a search type is allowed by intent isTypeAllowed := func(searchType string) bool { if intent == nil || len(intent.SearchTypes) == 0 { return true // No intent or empty types means all types allowed } for _, t := range intent.SearchTypes { if t == searchType { return true } } return false } // Web search - check if web search is configured and allowed by intent if config != nil && config.Web != nil && isTypeAllowed("web") { webQuery := query // Extract keywords for web search if configured if opts != nil && !opts.skipKeyword && opts.usesKeyword != "" { // === Output: Send keyword extraction loading === keywordLoadingID := ast.sendKeywordLoading(ctx) extractor := keyword.NewExtractor(opts.usesKeyword, config.Keyword) keywords, err := extractor.Extract(ctx, query, nil) if err != nil { ctx.Logger.Warn("Keyword extraction failed, using original query: %v", err) ast.sendKeywordDone(ctx, keywordLoadingID, false) } else if len(keywords) > 0 { extractedKeywords = keywords // Use extracted keywords as the search query for web search webQuery = keywordsToQuery(keywords) ctx.Logger.Info("Extracted keywords for web search: %s -> %s", truncateString(query, 30), webQuery) ast.sendKeywordDone(ctx, keywordLoadingID, true) } else { ast.sendKeywordDone(ctx, keywordLoadingID, true) } } requests = append(requests, &searchTypes.Request{ Type: searchTypes.SearchTypeWeb, Query: webQuery, Source: searchTypes.SourceAuto, Limit: config.Web.MaxResults, }) } // KB search - check if KB is configured and allowed by intent if ast.KB != nil && len(ast.KB.Collections) > 0 && isTypeAllowed("kb") { limit := 10 threshold := 0.7 if config != nil && config.KB != nil { if config.KB.Threshold > 0 { threshold = config.KB.Threshold } } // Filter collections by authorization (Collection-level permission check) allowedCollections := FilterKBCollectionsByAuth(ctx, ast.KB.Collections) if len(allowedCollections) == 0 { ctx.Logger.Info("No accessible KB collections after auth filter") } else { // Build KB request kbReq := &searchTypes.Request{ Type: searchTypes.SearchTypeKB, Query: query, // KB uses original query for semantic search Source: searchTypes.SourceAuto, Limit: limit, Collections: allowedCollections, Threshold: threshold, Graph: config != nil && config.KB != nil && config.KB.Graph, } requests = append(requests, kbReq) } } // DB search - check if DB is configured and allowed by intent if ast.DB != nil && len(ast.DB.Models) > 0 && isTypeAllowed("db") { limit := 20 if config != nil && config.DB != nil && config.DB.MaxResults > 0 { limit = config.DB.MaxResults } // Build DB request with auth where clauses dbReq := &searchTypes.Request{ Type: searchTypes.SearchTypeDB, Query: query, // DB uses original query for QueryDSL generation Source: searchTypes.SourceAuto, Limit: limit, Models: ast.DB.Models, } // Apply authorization where clauses if authWheres := BuildDBAuthWheres(ctx); authWheres != nil { dbReq.Wheres = authWheres } requests = append(requests, dbReq) } return requests, extractedKeywords } // injectSearchContext injects search results into messages // Adds search context as a system message after existing system messages func (ast *Assistant) injectSearchContext(messages []context.Message, refCtx *searchTypes.ReferenceContext) []context.Message { if refCtx == nil || len(refCtx.References) == 0 { return messages } // Build the search context message var contentParts []string // Add citation prompt if refCtx.Prompt != "" { contentParts = append(contentParts, refCtx.Prompt) } // Add XML context if refCtx.XML != "" { contentParts = append(contentParts, refCtx.XML) } if len(contentParts) == 0 { return messages } // Create system message with search context searchMessage := context.Message{ Role: "system", Content: strings.Join(contentParts, "\n\n"), } // Find the position to insert the search message // Insert after any existing system messages but before user messages insertIndex := 0 for i, msg := range messages { if msg.Role == "system" { insertIndex = i + 1 } else { break } } // Insert the search message result := make([]context.Message, 0, len(messages)+1) result = append(result, messages[:insertIndex]...) result = append(result, searchMessage) result = append(result, messages[insertIndex:]...) return result } // extractTextContent extracts text-only content from a message // For multimodal messages, concatenates all text parts // Returns empty string if no text content found func extractTextContent(msg context.Message) string { content := msg.Content // Handle string content if str, ok := content.(string); ok { return str } // Handle content parts (array of objects) - extract only text parts if parts, ok := content.([]interface{}); ok { var texts []string for _, part := range parts { if partMap, ok := part.(map[string]interface{}); ok { if partMap["type"] == "text" { if text, ok := partMap["text"].(string); ok { texts = append(texts, text) } } } } if len(texts) > 0 { return strings.Join(texts, "\n") } } // Handle []context.ContentPart if parts, ok := content.([]context.ContentPart); ok { var texts []string for _, part := range parts { if part.Type == context.ContentText && part.Text != "" { texts = append(texts, part.Text) } } if len(texts) > 0 { return strings.Join(texts, "\n") } } return "" } // buildContextMessage builds a single user message with conversation context // Filters out system messages and extracts text-only content // Only takes the last 5 messages for efficiency // Returns a slice with one message containing the full context, or empty slice if no content func buildContextMessage(messages []context.Message) []context.Message { const maxMessages = 5 // Take only the last maxMessages (excluding system messages) var recentMessages []context.Message for i := len(messages) - 1; i >= 0 && len(recentMessages) < maxMessages; i-- { if messages[i].Role != "system" { recentMessages = append(recentMessages, messages[i]) } } // Reverse to maintain chronological order for i, j := 0, len(recentMessages)-1; i < j; i, j = i+1, j-1 { recentMessages[i], recentMessages[j] = recentMessages[j], recentMessages[i] } var contextParts []string var lastUserMessage string for _, msg := range recentMessages { textContent := extractTextContent(msg) if textContent == "" { continue } // Format message with role label switch msg.Role { case "user": contextParts = append(contextParts, "[User]: "+textContent) lastUserMessage = textContent case "assistant": contextParts = append(contextParts, "[Assistant]: "+textContent) default: contextParts = append(contextParts, "["+string(msg.Role)+"]: "+textContent) } } // Build single message with context var result []context.Message if len(contextParts) > 1 { // Multiple messages: include conversation context fullContext := "=== Conversation Context ===\n" + strings.Join(contextParts, "\n\n") + "\n=== End Context ===\n\nCurrent user request: " + lastUserMessage result = append(result, context.Message{ Role: "user", Content: fullContext, }) } else if lastUserMessage != "" { // Single user message: just use it directly result = append(result, context.Message{ Role: "user", Content: lastUserMessage, }) } return result } // extractQueryFromMessages extracts the search query from messages // Uses the last user message as the query func extractQueryFromMessages(messages []context.Message) string { // Find the last user message for i := len(messages) - 1; i >= 0; i-- { if messages[i].Role == "user" { content := messages[i].Content // Handle string content if str, ok := content.(string); ok { return str } // Handle content parts (array of objects) if parts, ok := content.([]interface{}); ok { for _, part := range parts { if partMap, ok := part.(map[string]interface{}); ok { if partMap["type"] == "text" { if text, ok := partMap["text"].(string); ok { return text } } } } } } } return "" } // truncateString truncates a string to maxLen characters func truncateString(s string, maxLen int) string { if len(s) <= maxLen { return s } return s[:maxLen] + "..." } // ============================================================================ // Storage: Save Search Results // ============================================================================ // SearchExecutionResult holds all data from search execution for storage type SearchExecutionResult struct { Query string // Original query (before keyword optimization) Keywords []searchTypes.Keyword // Extracted keywords with weights Config map[string]any // Search config used RefCtx *searchTypes.ReferenceContext // Reference context with results Results []*searchTypes.Result // Raw search results (for extracting DSL, etc.) Duration int64 // Search duration in ms Error error // Error if failed SearchType string // "auto", "web", "kb", "db" } // keywordsToQuery converts keywords with weights to a search query string // Keywords are sorted by weight (descending) and joined with spaces func keywordsToQuery(keywords []searchTypes.Keyword) string { if len(keywords) == 0 { return "" } // Sort by weight descending (higher weight first) sorted := make([]searchTypes.Keyword, len(keywords)) copy(sorted, keywords) for i := 0; i < len(sorted)-1; i++ { for j := i + 1; j < len(sorted); j++ { if sorted[j].W > sorted[i].W { sorted[i], sorted[j] = sorted[j], sorted[i] } } } // Join keywords parts := make([]string, len(sorted)) for i, kw := range sorted { parts[i] = kw.K } return strings.Join(parts, " ") } // keywordsToStrings converts keywords to string slice for storage func keywordsToStrings(keywords []searchTypes.Keyword) []string { if len(keywords) == 0 { return nil } result := make([]string, len(keywords)) for i, kw := range keywords { result[i] = kw.K } return result } // containsSearchType checks if a search type is in the list func containsSearchType(types []string, searchType string) bool { for _, t := range types { if t == searchType { return true } } return false } // saveSearch saves search results to storage // Called after search execution completes (success or failure) func (ast *Assistant) saveSearch(ctx *context.Context, execResult *SearchExecutionResult) { // Get store store := GetStore() if store == nil { ctx.Logger.Debug("Storage not configured, skipping search save") return } // Build search record searchRecord := &storeTypes.Search{ RequestID: ctx.RequestID(), ChatID: ctx.ChatID, Query: execResult.Query, Keywords: keywordsToStrings(execResult.Keywords), Config: execResult.Config, Source: execResult.SearchType, Duration: execResult.Duration, CreatedAt: time.Now(), } // Set error if present if execResult.Error != nil { searchRecord.Error = execResult.Error.Error() } // Convert references if available if execResult.RefCtx != nil { searchRecord.References = convertToStoreReferences(execResult.RefCtx.References) searchRecord.XML = execResult.RefCtx.XML searchRecord.Prompt = execResult.RefCtx.Prompt } // Extract DSL from DB search results if execResult.Results != nil { for _, result := range execResult.Results { if result != nil && result.Type == searchTypes.SearchTypeDB && result.DSL != nil { searchRecord.DSL = result.DSL break // Only store the first DSL (usually there's only one DB search) } } } // Save to store if err := store.SaveSearch(searchRecord); err != nil { ctx.Logger.Warn("Failed to save search record: %v", err) return } ctx.Logger.Debug("Search record saved: request_id=%s, refs=%d", searchRecord.RequestID, len(searchRecord.References)) } // convertToStoreReferences converts search References to store References func convertToStoreReferences(refs []*searchTypes.Reference) []storeTypes.Reference { if len(refs) == 0 { return nil } storeRefs := make([]storeTypes.Reference, len(refs)) for i, ref := range refs { if ref == nil { continue } // Parse citation ID as integer (e.g., "1", "2", "3") index := i + 1 // Default to position-based index if ref.ID != "" { if n, err := fmt.Sscanf(ref.ID, "%d", &index); n != 1 || err != nil { index = i + 1 } } storeRefs[i] = storeTypes.Reference{ Index: index, Type: string(ref.Type), Title: ref.Title, URL: ref.URL, Snippet: truncateString(ref.Content, 200), // Short snippet Content: ref.Content, Metadata: map[string]any{ "weight": ref.Weight, "score": ref.Score, "source": string(ref.Source), }, } } return storeRefs } // configToMap converts search config to map for storage func (ast *Assistant) configToMap(config *searchTypes.Config) map[string]any { if config == nil { return nil } result := make(map[string]any) if config.Web != nil { result["web"] = map[string]any{ "provider": config.Web.Provider, "max_results": config.Web.MaxResults, } } if config.KB != nil { result["kb"] = map[string]any{ "threshold": config.KB.Threshold, "graph": config.KB.Graph, } } if config.DB != nil { result["db"] = map[string]any{ "max_results": config.DB.MaxResults, } } if config.Weights != nil { result["weights"] = map[string]any{ "user": config.Weights.User, "hook": config.Weights.Hook, "auto": config.Weights.Auto, } } return result } // getSearchProviderInfo returns a human-readable string describing the search provider(s) func (ast *Assistant) getSearchProviderInfo(config *searchTypes.Config, uses *search.Uses) string { var parts []string // Web search provider - always show when web search is being executed webMode := "" if uses != nil { webMode = uses.Web } if webMode == "" || webMode == "builtin" { // Builtin mode: show the actual provider (tavily/serper/serpapi) provider := "tavily" // default if config != nil && config.Web != nil && config.Web.Provider != "" { provider = config.Web.Provider } parts = append(parts, "web:"+provider) } else if strings.HasPrefix(webMode, "mcp:") { parts = append(parts, "web:"+webMode) } else { parts = append(parts, "web:agent:"+webMode) } // KB search if config != nil && config.KB != nil && len(config.KB.Collections) > 0 { parts = append(parts, "kb") } // DB search if config != nil && config.DB != nil && len(config.DB.Models) > 0 { parts = append(parts, "db") } return strings.Join(parts, ", ") } ================================================ FILE: agent/assistant/search_auth_db.go ================================================ package assistant import ( "github.com/yaoapp/gou/query/gou" "github.com/yaoapp/yao/agent/context" ) // BuildDBAuthWheres builds where clauses for DB search based on authorization // This applies permission-based filtering to database queries // Returns gou.Where clauses to filter records by authorization scope func BuildDBAuthWheres(ctx *context.Context) []gou.Where { if ctx == nil || ctx.Authorized == nil { return nil } authInfo := ctx.Authorized // No constraints, no filter needed if !authInfo.Constraints.TeamOnly && !authInfo.Constraints.OwnerOnly { return nil } var wheres []gou.Where // Team only - User can access: // 1. Public records (public = true) // 2. Records in their team where: // - They created the record (__yao_created_by matches) // - OR the record is shared with team (share = "team") if authInfo.Constraints.TeamOnly && authInfo.TeamID != "" { wheres = append(wheres, gou.Where{ Wheres: []gou.Where{ // Public records {Condition: gou.Condition{ Field: &gou.Expression{Field: "public"}, Value: true, OP: "=", OR: true, }}, // Team records { Wheres: []gou.Where{ {Condition: gou.Condition{ Field: &gou.Expression{Field: "__yao_team_id"}, Value: authInfo.TeamID, OP: "=", }}, {Wheres: []gou.Where{ {Condition: gou.Condition{ Field: &gou.Expression{Field: "__yao_created_by"}, Value: authInfo.UserID, OP: "=", }}, {Condition: gou.Condition{ Field: &gou.Expression{Field: "share"}, Value: "team", OP: "=", OR: true, }}, }}, }, }, }, }) return wheres } // Owner only - User can access: // 1. Public records (public = true) // 2. Records they created where: // - __yao_team_id is null (not team records) // - __yao_created_by matches their user ID if authInfo.Constraints.OwnerOnly && authInfo.UserID != "" { wheres = append(wheres, gou.Where{ Wheres: []gou.Where{ // Public records {Condition: gou.Condition{ Field: &gou.Expression{Field: "public"}, Value: true, OP: "=", OR: true, }}, // Owner records { Wheres: []gou.Where{ {Condition: gou.Condition{ Field: &gou.Expression{Field: "__yao_team_id"}, OP: "null", }}, {Condition: gou.Condition{ Field: &gou.Expression{Field: "__yao_created_by"}, Value: authInfo.UserID, OP: "=", }}, }, }, }, }) return wheres } return wheres } ================================================ FILE: agent/assistant/search_auth_integration_test.go ================================================ package assistant_test import ( "context" "fmt" "testing" "time" "github.com/stretchr/testify/assert" graphragtypes "github.com/yaoapp/gou/graphrag/types" "github.com/yaoapp/yao/agent/assistant" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search" searchTypes "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/kb" "github.com/yaoapp/yao/kb/api" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) // ========== Test Constants ========== const ( // Test users and teams TestUserA = "user_a" TestUserB = "user_b" TestTeam1 = "team_1" TestTeam2 = "team_2" ) // authTestCollections holds dynamically generated collection IDs for a test run type authTestCollections struct { Team1 string Team2 string Public string } // newAuthTestCollections creates unique collection IDs for a test run func newAuthTestCollections() *authTestCollections { suffix := fmt.Sprintf("%d", time.Now().UnixNano()) return &authTestCollections{ Team1: fmt.Sprintf("auth_test_team1_%s", suffix), Team2: fmt.Sprintf("auth_test_team2_%s", suffix), Public: fmt.Sprintf("auth_test_public_%s", suffix), } } // cleanup removes all test collections func (c *authTestCollections) cleanup(ctx context.Context, t *testing.T) { collections := []string{c.Team1, c.Team2, c.Public} for _, id := range collections { if result, err := kb.API.RemoveCollection(ctx, id); err == nil && result.Removed { t.Logf(" Removed: %s", id) } } } // ========== KB Collection-Level Auth Filter Tests ========== // Note: KB permission filtering works at the Collection level. // The Collection metadata contains __yao_team_id, __yao_created_by, public, share fields. // FilterKBCollectionsByAuth filters collections based on user authorization. func TestKBCollectionAuthFilter(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) if kb.API == nil { t.Fatal("KB API not initialized") } ctx := context.Background() cols := newAuthTestCollections() defer cols.cleanup(ctx, t) // Create test collections t.Log("Creating test collections...") createAuthCollection(ctx, t, cols.Team1, TestUserA, TestTeam1, false, "team") createAuthCollection(ctx, t, cols.Team2, TestUserB, TestTeam2, false, "team") createAuthCollection(ctx, t, cols.Public, TestUserA, TestTeam1, true, "") t.Run("TeamMemberCanAccessTeamCollection", func(t *testing.T) { // UserA from Team1 should access Team1 collection authCtx := createAuthContext(TestUserA, TestTeam1, true, false) collections := []string{cols.Team1, cols.Team2} allowed := assistant.FilterKBCollectionsByAuth(authCtx, collections) assert.Contains(t, allowed, cols.Team1, "Team1 member should access Team1 collection") t.Logf(" Allowed collections: %v", allowed) }) t.Run("TeamMemberCannotAccessOtherTeamCollection", func(t *testing.T) { // UserA from Team1 should NOT access Team2 collection authCtx := createAuthContext(TestUserA, TestTeam1, true, false) collections := []string{cols.Team2} allowed := assistant.FilterKBCollectionsByAuth(authCtx, collections) assert.NotContains(t, allowed, cols.Team2, "Team1 member should NOT access Team2 collection") t.Logf(" Allowed collections: %v (expected empty)", allowed) }) t.Run("OwnerCanAccessOwnCollection", func(t *testing.T) { // UserA with OwnerOnly should access collections they created authCtx := createAuthContext(TestUserA, "", false, true) collections := []string{cols.Team1, cols.Team2} allowed := assistant.FilterKBCollectionsByAuth(authCtx, collections) assert.Contains(t, allowed, cols.Team1, "Owner should access own collection") assert.NotContains(t, allowed, cols.Team2, "Owner should NOT access other's collection") t.Logf(" Allowed collections: %v", allowed) }) t.Run("PublicCollectionAccessibleToAll", func(t *testing.T) { // Note: The 'public' field in Metadata is not automatically saved to the database // by the current KB API. This test documents the expected behavior. // When public=true is properly set in DB, this should pass. // First, check the collection metadata collection, err := kb.API.GetCollection(ctx, cols.Public) assert.NoError(t, err) // Check if public is set correctly publicVal := collection["public"] t.Logf(" Public collection public field: %v (type: %T)", publicVal, publicVal) // If public is not set (0 or false), the test documents current behavior // The collection should be accessible via owner check since UserA created it authCtx := createAuthContext(TestUserA, TestTeam1, false, true) // Owner check collections := []string{cols.Public} allowed := assistant.FilterKBCollectionsByAuth(authCtx, collections) assert.Contains(t, allowed, cols.Public, "Owner should access their collection") t.Logf(" Allowed collections (owner check): %v", allowed) }) t.Run("NoConstraintsMeansFullAccess", func(t *testing.T) { // User with no constraints should access all collections authCtx := createAuthContext(TestUserA, TestTeam1, false, false) collections := []string{cols.Team1, cols.Team2, cols.Public} allowed := assistant.FilterKBCollectionsByAuth(authCtx, collections) assert.Len(t, allowed, 3, "No constraints should allow all collections") t.Logf(" Allowed collections: %v", allowed) }) t.Run("NilContextMeansFullAccess", func(t *testing.T) { collections := []string{cols.Team1, cols.Team2} allowed := assistant.FilterKBCollectionsByAuth(nil, collections) assert.Len(t, allowed, 2, "Nil context should allow all collections") t.Logf(" Allowed collections: %v", allowed) }) } // ========== DB Auth Wheres Tests ========== func TestDBAuthWheresFilter(t *testing.T) { // Note: This test doesn't need KB, just tests the BuildDBAuthWheres function t.Run("TeamOnlyGeneratesCorrectWheres", func(t *testing.T) { ctx := createAuthContext(TestUserA, TestTeam1, true, false) wheres := assistant.BuildDBAuthWheres(ctx) assert.NotNil(t, wheres) assert.Len(t, wheres, 1) // Verify structure: should have 2 top-level conditions (public OR team filter) where := wheres[0] assert.Len(t, where.Wheres, 2, "Should have 2 conditions: public OR team") // First condition: public = true (OR) publicCond := where.Wheres[0] assert.NotNil(t, publicCond.Condition.Field) assert.Equal(t, "public", publicCond.Condition.Field.Field) assert.Equal(t, true, publicCond.Condition.Value) assert.True(t, publicCond.Condition.OR) // Second condition: team filter with nested conditions teamCond := where.Wheres[1] assert.Len(t, teamCond.Wheres, 2, "Team filter should have team_id and (created_by OR share)") // Team ID check teamIDCond := teamCond.Wheres[0] assert.Equal(t, "__yao_team_id", teamIDCond.Condition.Field.Field) assert.Equal(t, TestTeam1, teamIDCond.Condition.Value) // Created by OR share = team ownerOrShareCond := teamCond.Wheres[1] assert.Len(t, ownerOrShareCond.Wheres, 2) assert.Equal(t, "__yao_created_by", ownerOrShareCond.Wheres[0].Condition.Field.Field) assert.Equal(t, TestUserA, ownerOrShareCond.Wheres[0].Condition.Value) assert.Equal(t, "share", ownerOrShareCond.Wheres[1].Condition.Field.Field) assert.Equal(t, "team", ownerOrShareCond.Wheres[1].Condition.Value) assert.True(t, ownerOrShareCond.Wheres[1].Condition.OR) t.Logf(" TeamOnly: Verified team_id=%s, created_by=%s", TestTeam1, TestUserA) }) t.Run("OwnerOnlyGeneratesCorrectWheres", func(t *testing.T) { ctx := createAuthContext(TestUserA, "", false, true) wheres := assistant.BuildDBAuthWheres(ctx) assert.NotNil(t, wheres) assert.Len(t, wheres, 1) // Verify structure: should have 2 top-level conditions (public OR owner filter) where := wheres[0] assert.Len(t, where.Wheres, 2, "Should have 2 conditions: public OR owner") // First condition: public = true (OR) publicCond := where.Wheres[0] assert.NotNil(t, publicCond.Condition.Field) assert.Equal(t, "public", publicCond.Condition.Field.Field) assert.Equal(t, true, publicCond.Condition.Value) assert.True(t, publicCond.Condition.OR) // Second condition: owner filter with nested conditions ownerCond := where.Wheres[1] assert.Len(t, ownerCond.Wheres, 2, "Owner filter should have team_id IS NULL and created_by") // Team ID is null check teamNullCond := ownerCond.Wheres[0] assert.Equal(t, "__yao_team_id", teamNullCond.Condition.Field.Field) assert.Equal(t, "null", teamNullCond.Condition.OP) // Created by check createdByCond := ownerCond.Wheres[1] assert.Equal(t, "__yao_created_by", createdByCond.Condition.Field.Field) assert.Equal(t, TestUserA, createdByCond.Condition.Value) t.Logf(" OwnerOnly: Verified created_by=%s, team_id IS NULL", TestUserA) }) t.Run("NoConstraintsReturnsNil", func(t *testing.T) { ctx := createAuthContext(TestUserA, TestTeam1, false, false) wheres := assistant.BuildDBAuthWheres(ctx) assert.Nil(t, wheres, "No constraints should return nil") t.Log(" No constraints: nil wheres (no filter)") }) t.Run("EmptyTeamIDReturnsNil", func(t *testing.T) { ctx := createAuthContext(TestUserA, "", true, false) wheres := assistant.BuildDBAuthWheres(ctx) assert.Nil(t, wheres, "Empty TeamID with TeamOnly should return nil") t.Log(" Empty TeamID with TeamOnly: nil wheres") }) t.Run("EmptyUserIDReturnsNil", func(t *testing.T) { ctx := createAuthContext("", TestTeam1, false, true) wheres := assistant.BuildDBAuthWheres(ctx) assert.Nil(t, wheres, "Empty UserID with OwnerOnly should return nil") t.Log(" Empty UserID with OwnerOnly: nil wheres") }) t.Run("NilContextReturnsNil", func(t *testing.T) { wheres := assistant.BuildDBAuthWheres(nil) assert.Nil(t, wheres, "Nil context should return nil") t.Log(" Nil context: nil wheres") }) t.Run("NilAuthorizedReturnsNil", func(t *testing.T) { ctx := agentContext.New(context.Background(), nil, "test-chat") wheres := assistant.BuildDBAuthWheres(ctx) assert.Nil(t, wheres, "Nil Authorized should return nil") t.Log(" Nil Authorized: nil wheres") }) } // ========== KB Search Integration Tests ========== func TestKBSearchIntegration(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) if kb.API == nil { t.Fatal("KB API not initialized") } ctx := context.Background() cols := newAuthTestCollections() defer cols.cleanup(ctx, t) // Create test collections with documents t.Log("Creating test collections with documents...") createAuthCollection(ctx, t, cols.Team1, TestUserA, TestTeam1, false, "team") addAuthDocument(ctx, t, cols.Team1, "Team1 Doc1", "Team1 private document about quantum physics and relativity theory.") addAuthDocument(ctx, t, cols.Team1, "Team1 Doc2", "Team1 shared document about machine learning and neural networks.") createAuthCollection(ctx, t, cols.Team2, TestUserB, TestTeam2, false, "team") addAuthDocument(ctx, t, cols.Team2, "Team2 Doc1", "Team2 private document about deep learning algorithms.") addAuthDocument(ctx, t, cols.Team2, "Team2 Doc2", "Team2 shared document about computer vision techniques.") createAuthCollection(ctx, t, cols.Public, TestUserA, TestTeam1, true, "") addAuthDocument(ctx, t, cols.Public, "Public Doc1", "Public document about artificial intelligence and robotics.") addAuthDocument(ctx, t, cols.Public, "Public Doc2", "Public document about natural language processing.") // Wait for indexing t.Log("Waiting for indexing...") time.Sleep(2 * time.Second) t.Run("TeamMemberSearchOnlyFindsTeamData", func(t *testing.T) { // UserA from Team1 searches - should ONLY find Team1 data authCtx := createAuthContext(TestUserA, TestTeam1, true, false) // Filter collections first allCollections := []string{cols.Team1, cols.Team2} allowed := assistant.FilterKBCollectionsByAuth(authCtx, allCollections) // Should only allow Team1 assert.Contains(t, allowed, cols.Team1) assert.NotContains(t, allowed, cols.Team2) assert.Len(t, allowed, 1, "Should only have 1 allowed collection") // Search on allowed collections result := executeKBSearchOnCollections(t, allowed, "quantum physics deep learning") assert.Greater(t, len(result.Items), 0, "Should find Team1 documents") // Verify ALL results are from Team1 collection only for _, item := range result.Items { assert.Equal(t, cols.Team1, item.Collection, "All results should be from Team1 collection, got: %s", item.Collection) } t.Logf(" ✓ Team1 member found %d items, all from Team1 collection", len(result.Items)) }) t.Run("TeamMemberCannotAccessOtherTeamData", func(t *testing.T) { // UserA from Team1 tries to access Team2 - should be blocked authCtx := createAuthContext(TestUserA, TestTeam1, true, false) // Try to filter Team2 collection collections := []string{cols.Team2} allowed := assistant.FilterKBCollectionsByAuth(authCtx, collections) // Should be empty - no access assert.Empty(t, allowed, "Team1 member should NOT have access to Team2 collection") t.Log(" ✓ Team1 member correctly blocked from Team2 collection") }) t.Run("OwnerSearchOnlyFindsOwnData", func(t *testing.T) { // UserA with OwnerOnly - should only find collections they created authCtx := createAuthContext(TestUserA, "", false, true) // Filter all collections allCollections := []string{cols.Team1, cols.Team2, cols.Public} allowed := assistant.FilterKBCollectionsByAuth(authCtx, allCollections) // UserA created Team1 and Public, not Team2 assert.Contains(t, allowed, cols.Team1, "Owner should access Team1 (created by UserA)") assert.Contains(t, allowed, cols.Public, "Owner should access Public (created by UserA)") assert.NotContains(t, allowed, cols.Team2, "Owner should NOT access Team2 (created by UserB)") // Search and verify results result := executeKBSearchOnCollections(t, allowed, "quantum artificial intelligence") assert.Greater(t, len(result.Items), 0, "Should find owner's documents") // Verify NO results from Team2 for _, item := range result.Items { assert.NotEqual(t, cols.Team2, item.Collection, "Should NOT have results from Team2, got: %s", item.Collection) } t.Logf(" ✓ Owner found %d items, none from Team2", len(result.Items)) }) t.Run("NoConstraintsSearchFindsAllData", func(t *testing.T) { // User with no constraints - should find all data authCtx := createAuthContext(TestUserA, TestTeam1, false, false) // Filter all collections allCollections := []string{cols.Team1, cols.Team2, cols.Public} allowed := assistant.FilterKBCollectionsByAuth(authCtx, allCollections) // Should have access to all assert.Len(t, allowed, 3, "No constraints should allow all collections") // Search and verify results from multiple collections result := executeKBSearchOnCollections(t, allowed, "quantum deep learning artificial") // Should find results from multiple collections collectionsFound := make(map[string]bool) for _, item := range result.Items { collectionsFound[item.Collection] = true } assert.Greater(t, len(collectionsFound), 1, "Should find results from multiple collections") t.Logf(" ✓ No constraints: found %d items from %d collections", len(result.Items), len(collectionsFound)) }) t.Run("SearchResultsMatchCollectionFilter", func(t *testing.T) { // Verify that search results ONLY come from allowed collections authCtx := createAuthContext(TestUserB, TestTeam2, true, false) // UserB from Team2 - should only access Team2 allCollections := []string{cols.Team1, cols.Team2, cols.Public} allowed := assistant.FilterKBCollectionsByAuth(authCtx, allCollections) assert.Contains(t, allowed, cols.Team2, "Team2 member should access Team2") assert.NotContains(t, allowed, cols.Team1, "Team2 member should NOT access Team1") // Search result := executeKBSearchOnCollections(t, allowed, "deep learning computer vision") // Verify results if len(result.Items) > 0 { for _, item := range result.Items { // Results should only be from allowed collections assert.Contains(t, allowed, item.Collection, "Result from %s should be in allowed list %v", item.Collection, allowed) } t.Logf(" ✓ Team2 member found %d items, all from allowed collections", len(result.Items)) } else { t.Log(" ✓ Team2 member found 0 items (collection may be empty)") } }) } // ========== Helper Functions ========== func createAuthContext(userID, teamID string, teamOnly, ownerOnly bool) *agentContext.Context { authorized := &oauthtypes.AuthorizedInfo{ UserID: userID, TeamID: teamID, Constraints: oauthtypes.DataConstraints{ TeamOnly: teamOnly, OwnerOnly: ownerOnly, }, } return agentContext.New(context.Background(), authorized, "test-chat") } func createAuthCollection(ctx context.Context, t *testing.T, id, userID, teamID string, public bool, share string) { params := &api.CreateCollectionParams{ ID: id, Metadata: map[string]interface{}{ "name": id, "public": public, "share": share, }, EmbeddingProviderID: "__yao.openai", EmbeddingOptionID: "text-embedding-3-small", Locale: "en", Config: &graphragtypes.CreateCollectionOptions{ Distance: "cosine", IndexType: "hnsw", }, AuthScope: map[string]interface{}{ "__yao_created_by": userID, "__yao_team_id": teamID, }, } _, err := kb.API.CreateCollection(ctx, params) if err != nil { t.Fatalf("Failed to create collection %s: %v", id, err) } t.Logf(" ✓ Created: %s", id) } func addAuthDocument(ctx context.Context, t *testing.T, collectionID, title, content string) { params := &api.AddTextParams{ CollectionID: collectionID, Text: content, DocID: fmt.Sprintf("%s__%s", collectionID, sanitizeForID(title)), Metadata: map[string]interface{}{ "title": title, }, Chunking: &api.ProviderConfigParams{ ProviderID: "__yao.structured", OptionID: "standard", }, Embedding: &api.ProviderConfigParams{ ProviderID: "__yao.openai", OptionID: "text-embedding-3-small", }, } _, err := kb.API.AddText(ctx, params) if err != nil { t.Logf(" Warning: Failed to add document '%s': %v", title, err) return } t.Logf(" ✓ Added: %s", title) } func sanitizeForID(s string) string { result := "" for _, c := range s { if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') { result += string(c) } else if c == ' ' { result += "_" } } return result } func executeKBSearchOnCollections(t *testing.T, collections []string, query string) *searchTypes.Result { if len(collections) == 0 { return &searchTypes.Result{Items: []*searchTypes.ResultItem{}} } cfg := &searchTypes.Config{ KB: &searchTypes.KBConfig{ Collections: collections, Threshold: 0.3, }, } searcher := search.New(cfg, nil) req := &searchTypes.Request{ Type: searchTypes.SearchTypeKB, Query: query, Collections: collections, Threshold: 0.3, Limit: 20, Source: searchTypes.SourceAuto, } result, err := searcher.Search(nil, req) if err != nil { t.Fatalf("Search failed: %v", err) } return result } ================================================ FILE: agent/assistant/search_auth_kb.go ================================================ package assistant import ( "context" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/kb" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) // FilterKBCollectionsByAuth filters collections based on user authorization. // Returns only collections that the user has permission to access. // Permission is determined by Collection's metadata (public, share, __yao_team_id, __yao_created_by). func FilterKBCollectionsByAuth(ctx *agentContext.Context, collections []string) []string { if ctx == nil || ctx.Authorized == nil { return collections // No auth context, return all } authInfo := ctx.Authorized // No constraints, return all collections if !authInfo.Constraints.TeamOnly && !authInfo.Constraints.OwnerOnly { return collections } // Check KB API if kb.API == nil { return collections // KB not initialized, return all } var allowed []string bgCtx := context.Background() for _, collectionID := range collections { // Get collection metadata collection, err := kb.API.GetCollection(bgCtx, collectionID) if err != nil { continue // Skip if can't get collection } if hasCollectionAccess(authInfo, collection) { allowed = append(allowed, collectionID) } } return allowed } // hasCollectionAccess checks if user has access to a collection based on its metadata. func hasCollectionAccess(authInfo *oauthtypes.AuthorizedInfo, collection map[string]interface{}) bool { if authInfo == nil { return true } // No constraints, allow access if !authInfo.Constraints.TeamOnly && !authInfo.Constraints.OwnerOnly { return true } // Check public access (handle different types: bool, int, float64) if isPublicValue(collection["public"]) { return true } // Get metadata for permission fields metadata, _ := collection["metadata"].(map[string]interface{}) if metadata == nil { metadata = collection } // Team only check if authInfo.Constraints.TeamOnly && authInfo.TeamID != "" { teamID, _ := metadata["__yao_team_id"].(string) if teamID == "" { teamID, _ = collection["__yao_team_id"].(string) } if teamID == authInfo.TeamID { createdBy, _ := metadata["__yao_created_by"].(string) if createdBy == "" { createdBy, _ = collection["__yao_created_by"].(string) } share, _ := metadata["share"].(string) if share == "" { share, _ = collection["share"].(string) } if createdBy == authInfo.UserID || share == "team" { return true } } } // Owner only check if authInfo.Constraints.OwnerOnly && authInfo.UserID != "" { createdBy, _ := metadata["__yao_created_by"].(string) if createdBy == "" { createdBy, _ = collection["__yao_created_by"].(string) } if createdBy == authInfo.UserID { return true } } return false } // isPublicValue checks if a value represents "public" access func isPublicValue(v interface{}) bool { switch val := v.(type) { case bool: return val case int: return val == 1 case int64: return val == 1 case float64: return val == 1 case string: return val == "true" || val == "1" } return false } ================================================ FILE: agent/assistant/search_auto_disabled_test.go ================================================ package assistant_test import ( stdContext "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // newSearchAutoDisabledTestContext creates a test context func newSearchAutoDisabledTestContext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.ID = chatID ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Client = context.Client{ Type: "web", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.IDGenerator = message.NewIDGenerator() ctx.Metadata = make(map[string]interface{}) return ctx } func TestSearchAutoDisabled(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.LoadPath("/assistants/tests/search-auto-disabled") require.NoError(t, err) require.NotNil(t, ast) t.Run("ShouldHaveSearchConfig", func(t *testing.T) { // Search config is set but uses.search is disabled assert.NotNil(t, ast.Search, "search config should be set") assert.NotNil(t, ast.Search.Web, "web search config should be set") }) t.Run("ShouldHaveDisabledUses", func(t *testing.T) { assert.NotNil(t, ast.Uses, "uses config should be set") assert.Equal(t, "disabled", ast.Uses.Search, "uses.search should be disabled") }) t.Run("StreamShouldNotExecuteSearch", func(t *testing.T) { // Get agent via assistant.Get (required for Stream) agent, err := assistant.Get("tests.search-auto-disabled") require.NoError(t, err) require.NotNil(t, agent) // Create context ctx := newSearchAutoDisabledTestContext("test-search-auto-disabled", "tests.search-auto-disabled") // Create messages messages := []context.Message{ { Role: "user", Content: "Hello, how are you?", }, } // Execute stream - should NOT trigger search because uses.search is "disabled" response, err := agent.Stream(ctx, messages) require.NoError(t, err) require.NotNil(t, response) assert.NotNil(t, response.Completion, "should have completion") t.Logf("✓ Stream executed without search (disabled)") }) } ================================================ FILE: agent/assistant/search_auto_full_test.go ================================================ package assistant_test import ( stdContext "context" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // newSearchAutoFullTestContext creates a test context func newSearchAutoFullTestContext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.ID = chatID ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Client = context.Client{ Type: "web", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.IDGenerator = message.NewIDGenerator() ctx.Metadata = make(map[string]interface{}) return ctx } func TestSearchAutoFull(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.LoadPath("/assistants/tests/search-auto-full") require.NoError(t, err) require.NotNil(t, ast) t.Run("ShouldHaveWebSearchConfig", func(t *testing.T) { assert.NotNil(t, ast.Search, "search config should be set") assert.NotNil(t, ast.Search.Web, "web search config should be set") assert.Equal(t, "tavily", ast.Search.Web.Provider) assert.Equal(t, 3, ast.Search.Web.MaxResults) }) t.Run("ShouldHaveKBSearchConfig", func(t *testing.T) { assert.NotNil(t, ast.Search.KB, "kb search config should be set") assert.Equal(t, 0.7, ast.Search.KB.Threshold) assert.False(t, ast.Search.KB.Graph) }) t.Run("ShouldHaveDBSearchConfig", func(t *testing.T) { assert.NotNil(t, ast.Search.DB, "db search config should be set") assert.Equal(t, 10, ast.Search.DB.MaxResults) }) t.Run("ShouldHaveKBCollections", func(t *testing.T) { assert.NotNil(t, ast.KB, "kb config should be set") assert.Contains(t, ast.KB.Collections, "test-collection") }) t.Run("ShouldHaveDBModels", func(t *testing.T) { assert.NotNil(t, ast.DB, "db config should be set") assert.Contains(t, ast.DB.Models, "user") assert.Contains(t, ast.DB.Models, "article") }) t.Run("ShouldHaveCitationConfig", func(t *testing.T) { assert.NotNil(t, ast.Search.Citation, "citation config should be set") assert.Equal(t, "xml", ast.Search.Citation.Format) assert.True(t, ast.Search.Citation.AutoInjectPrompt) }) t.Run("ShouldHaveUsesConfig", func(t *testing.T) { assert.NotNil(t, ast.Uses, "uses config should be set") assert.Equal(t, "builtin", ast.Uses.Search) assert.Equal(t, "builtin", ast.Uses.Web) }) t.Run("StreamShouldExecuteMultipleSearchTypes", func(t *testing.T) { // Get agent via assistant.Get (required for Stream) agent, err := assistant.Get("tests.search-auto-full") require.NoError(t, err) require.NotNil(t, agent) // Create context ctx := newSearchAutoFullTestContext("test-search-auto-full", "tests.search-auto-full") // Create messages with a search query messages := []context.Message{ { Role: "user", Content: "Find information about machine learning", }, } // Execute stream - should trigger Web + KB + DB searches response, err := agent.Stream(ctx, messages) // Assert no error (if API key is configured) if err != nil { // If error contains "API key", it's expected in CI without keys if strings.Contains(err.Error(), "API key") || strings.Contains(err.Error(), "api_key") { t.Logf("Expected error without API key: %v", err) return } // Other errors should fail require.NoError(t, err) } require.NotNil(t, response) assert.NotNil(t, response.Completion, "should have completion") t.Logf("✓ Stream executed with full search config (Web + KB + DB)") }) } ================================================ FILE: agent/assistant/search_auto_hook_disable_test.go ================================================ package assistant_test import ( stdContext "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // newSearchAutoHookDisableTestContext creates a test context func newSearchAutoHookDisableTestContext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.ID = chatID ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Client = context.Client{ Type: "web", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.IDGenerator = message.NewIDGenerator() ctx.Metadata = make(map[string]interface{}) return ctx } func TestSearchAutoHookDisable(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.LoadPath("/assistants/tests/search-auto-hook-disable") require.NoError(t, err) require.NotNil(t, ast) t.Run("ShouldHaveSearchConfigEnabled", func(t *testing.T) { // Search config is enabled in package.yao assert.NotNil(t, ast.Search, "search config should be set") assert.NotNil(t, ast.Uses, "uses config should be set") assert.Equal(t, "builtin", ast.Uses.Search, "uses.search should be builtin in config") }) t.Run("ShouldHaveHookScript", func(t *testing.T) { // Hook script should be loaded assert.NotNil(t, ast.HookScript, "hook script should be loaded") }) t.Run("HookShouldDisableSearch", func(t *testing.T) { // Create context ctx := newSearchAutoHookDisableTestContext("test-chat-id", "tests.search-auto-hook-disable") // Create messages messages := []context.Message{ { Role: "user", Content: "Test message", }, } // Call Create hook directly opts := &context.Options{} response, _, err := ast.HookScript.Create(ctx, messages, opts) require.NoError(t, err) require.NotNil(t, response) // Verify hook returns uses.search = "disabled" assert.NotNil(t, response.Uses, "hook should return uses") assert.Equal(t, "disabled", response.Uses.Search, "hook should disable search") }) t.Run("StreamShouldRespectHookDisable", func(t *testing.T) { // Get agent via assistant.Get (required for Stream) agent, err := assistant.Get("tests.search-auto-hook-disable") require.NoError(t, err) require.NotNil(t, agent) // Create context ctx := newSearchAutoHookDisableTestContext("test-search-hook-disable", "tests.search-auto-hook-disable") // Create messages messages := []context.Message{ { Role: "user", Content: "What is AI?", }, } // Execute stream - hook will disable search response, err := agent.Stream(ctx, messages) require.NoError(t, err) require.NotNil(t, response) assert.NotNil(t, response.Completion, "should have completion") t.Logf("✓ Stream executed with hook disabling search") }) } ================================================ FILE: agent/assistant/search_auto_keyword_test.go ================================================ package assistant_test import ( stdContext "context" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // newKeywordTestContext creates a test context for keyword extraction tests func newKeywordTestContext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.ID = chatID ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Client = context.Client{ Type: "web", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.IDGenerator = message.NewIDGenerator() ctx.Metadata = make(map[string]interface{}) return ctx } func TestSearchAutoKeyword(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.LoadPath("/assistants/tests/search-auto-keyword") require.NoError(t, err) require.NotNil(t, ast) t.Run("ShouldHaveKeywordConfig", func(t *testing.T) { assert.NotNil(t, ast.Search, "search config should be set") assert.NotNil(t, ast.Search.Keyword, "keyword config should be set") assert.Equal(t, 5, ast.Search.Keyword.MaxKeywords) assert.Equal(t, "auto", ast.Search.Keyword.Language) }) t.Run("ShouldHaveKeywordInUses", func(t *testing.T) { assert.NotNil(t, ast.Uses, "uses config should be set") assert.Equal(t, "builtin", ast.Uses.Keyword) }) t.Run("StreamWithKeywordExtraction", func(t *testing.T) { // Get agent via assistant.Get (required for Stream) agent, err := assistant.Get("tests.search-auto-keyword") require.NoError(t, err) require.NotNil(t, agent) // Create context ctx := newKeywordTestContext("test-search-keyword", "tests.search-auto-keyword") // Create messages with a verbose query that should benefit from keyword extraction messages := []context.Message{ { Role: "user", Content: "I want to find the best wireless headphones under 100 dollars for programming and music", }, } // Execute stream without Skip.Keyword (keyword extraction should happen) response, err := agent.Stream(ctx, messages) // Assert no error (if API key is configured) if err != nil { if strings.Contains(err.Error(), "API key") || strings.Contains(err.Error(), "api_key") { t.Logf("Expected error without API key: %v", err) return } require.NoError(t, err) } require.NotNil(t, response) assert.NotNil(t, response.Completion, "should have completion") t.Logf("✓ Stream with keyword extraction executed successfully") }) t.Run("StreamWithSkipKeyword", func(t *testing.T) { // Get agent via assistant.Get (required for Stream) agent, err := assistant.Get("tests.search-auto-keyword") require.NoError(t, err) require.NotNil(t, agent) // Create context ctx := newKeywordTestContext("test-search-skip-keyword", "tests.search-auto-keyword") // Create messages messages := []context.Message{ { Role: "user", Content: "I want to find the best wireless headphones under 100 dollars", }, } // Execute stream with Skip.Keyword = true (keyword extraction should be skipped) opts := &context.Options{ Skip: &context.Skip{ Keyword: true, }, } response, err := agent.Stream(ctx, messages, opts) // Assert no error (if API key is configured) if err != nil { if strings.Contains(err.Error(), "API key") || strings.Contains(err.Error(), "api_key") { t.Logf("Expected error without API key: %v", err) return } require.NoError(t, err) } require.NotNil(t, response) assert.NotNil(t, response.Completion, "should have completion") t.Logf("✓ Stream with Skip.Keyword executed successfully") }) } func TestSearchAutoKeywordNotConfigured(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) // Use the search-auto-web assistant which does NOT have uses.keyword configured ast, err := assistant.LoadPath("/assistants/tests/search-auto-web") require.NoError(t, err) require.NotNil(t, ast) t.Run("ShouldNotHaveKeywordInUses", func(t *testing.T) { // uses.keyword should be empty (not configured) if ast.Uses != nil { assert.Empty(t, ast.Uses.Keyword, "uses.keyword should be empty") } }) t.Run("StreamShouldSkipKeywordExtraction", func(t *testing.T) { // Get agent via assistant.Get (required for Stream) agent, err := assistant.Get("tests.search-auto-web") require.NoError(t, err) require.NotNil(t, agent) // Create context ctx := newKeywordTestContext("test-no-keyword", "tests.search-auto-web") // Create messages messages := []context.Message{ { Role: "user", Content: "What is the latest news about AI?", }, } // Execute stream - keyword extraction should NOT happen because uses.keyword is not set response, err := agent.Stream(ctx, messages) // Assert no error (if API key is configured) if err != nil { if strings.Contains(err.Error(), "API key") || strings.Contains(err.Error(), "api_key") { t.Logf("Expected error without API key: %v", err) return } require.NoError(t, err) } require.NotNil(t, response) assert.NotNil(t, response.Completion, "should have completion") t.Logf("✓ Stream without keyword config executed successfully") }) } ================================================ FILE: agent/assistant/search_auto_web_test.go ================================================ package assistant_test import ( stdContext "context" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // newSearchAutoTestContext creates a test context for search auto tests func newSearchAutoTestContext(chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.ID = chatID ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Client = context.Client{ Type: "web", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.IDGenerator = message.NewIDGenerator() ctx.Metadata = make(map[string]interface{}) return ctx } func TestSearchAutoWeb(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ast, err := assistant.LoadPath("/assistants/tests/search-auto-web") require.NoError(t, err) require.NotNil(t, ast) t.Run("ShouldHaveSearchConfig", func(t *testing.T) { assert.NotNil(t, ast.Search, "search config should be set") assert.NotNil(t, ast.Search.Web, "web search config should be set") assert.Equal(t, "tavily", ast.Search.Web.Provider) assert.Equal(t, 3, ast.Search.Web.MaxResults) }) t.Run("ShouldHaveUsesConfig", func(t *testing.T) { assert.NotNil(t, ast.Uses, "uses config should be set") assert.Equal(t, "builtin", ast.Uses.Search) assert.Equal(t, "builtin", ast.Uses.Web) }) t.Run("ShouldHaveCitationConfig", func(t *testing.T) { assert.NotNil(t, ast.Search.Citation, "citation config should be set") assert.Equal(t, "xml", ast.Search.Citation.Format) assert.True(t, ast.Search.Citation.AutoInjectPrompt) }) t.Run("StreamShouldExecuteAutoSearch", func(t *testing.T) { // Get agent via assistant.Get (required for Stream) agent, err := assistant.Get("tests.search-auto-web") require.NoError(t, err) require.NotNil(t, agent) // Create context ctx := newSearchAutoTestContext("test-search-auto-web", "tests.search-auto-web") // Create messages with a search query messages := []context.Message{ { Role: "user", Content: "What is the latest news about artificial intelligence?", }, } // Execute stream response, err := agent.Stream(ctx, messages) // Assert no error (if API key is configured) if err != nil { // If error contains "API key", it's expected in CI without keys if strings.Contains(err.Error(), "API key") || strings.Contains(err.Error(), "api_key") { t.Logf("Expected error without API key: %v", err) return } // Other errors should fail require.NoError(t, err) } require.NotNil(t, response) assert.NotNil(t, response.Completion, "should have completion") t.Logf("✓ Stream executed successfully with auto search") }) } ================================================ FILE: agent/assistant/source.go ================================================ package assistant import ( "fmt" "strings" "time" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/yao/agent/assistant/hook" ) // loadSource loads hook script from source code string // The source field stores TypeScript code directly (but without imports) // Priority: script field > source field (if script exists, source is ignored) // Note: Uses MakeScriptInMemory which supports TypeScript syntax without file resolution. func loadSource(source string, assistantID string) (*hook.Script, error) { if source == "" { return nil, nil } // Use virtual .ts path for TypeScript support // MakeScriptInMemory handles TypeScript transform without file system access virtualFile := fmt.Sprintf("assistants/%s/source.ts", strings.ReplaceAll(assistantID, ".", "/")) script, err := v8.MakeScriptInMemory([]byte(source), virtualFile, 5*time.Second, true) if err != nil { return nil, fmt.Errorf("failed to compile source script: %w", err) } return &hook.Script{Script: script}, nil } // TODO: Future enhancement - support multiple files merged with special comment delimiter // Format: // file: index.ts // This would allow splitting large scripts into multiple logical files while storing as single source // func loadSourceMultiFile(source string, assistantID string) (*hook.Script, error) { // // Parse source by "// file: xxx.ts" delimiter // // Merge and compile // } ================================================ FILE: agent/assistant/trace.go ================================================ package assistant import ( "fmt" "github.com/yaoapp/gou/connector/openai" "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/trace/types" ) // initAgentTraceNode creates and returns the agent trace node func (ast *Assistant) initAgentTraceNode(ctx *context.Context, inputMessages []context.Message) types.Node { trace, _ := ctx.Trace() if trace == nil { return nil } agentNode, _ := trace.Add(inputMessages, types.TraceNodeOption{ Label: i18n.Tr(ast.ID, ctx.Locale, "assistant.agent.stream.label"), // "Assistant {{name}}" Type: "agent", Icon: "assistant", Description: i18n.Tr(ast.ID, ctx.Locale, "assistant.agent.stream.description"), // "Assistant {{name}} is processing the request" }) return agentNode } // traceAgentHistory logs the chat history to the agent trace node func (ast *Assistant) traceAgentHistory(ctx *context.Context, agentNode types.Node, fullMessages []context.Message) { if agentNode == nil { return } agentNode.Info( i18n.Tr(ast.ID, ctx.Locale, "assistant.agent.stream.history"), // "Get Chat History" map[string]any{"messages": fullMessages}, ) } // traceCreateHook logs the create hook response to the agent trace node func (ast *Assistant) traceCreateHook(agentNode types.Node, createResponse *context.HookCreateResponse) { if agentNode == nil { return } agentNode.Debug("Call Create Hook", map[string]any{"response": createResponse}) } // traceConnectorCapabilities logs the connector capabilities to the agent trace node func (ast *Assistant) traceConnectorCapabilities(agentNode types.Node, capabilities *openai.Capabilities) { if agentNode == nil { return } agentNode.Debug("Get Connector Capabilities", map[string]any{"capabilities": capabilities}) } // traceLLMRequest adds a LLM trace node to the trace func (ast *Assistant) traceLLMRequest(ctx *context.Context, connID string, completionMessages []context.Message, completionOptions *context.CompletionOptions) { trace, _ := ctx.Trace() if trace == nil { return } trace.Add( map[string]any{"messages": completionMessages, "options": completionOptions}, types.TraceNodeOption{ Label: fmt.Sprintf(i18n.Tr(ast.ID, ctx.Locale, "llm.openai.stream.label"), connID), // "LLM %s" Type: "llm", Icon: "psychology", Description: fmt.Sprintf(i18n.Tr(ast.ID, ctx.Locale, "llm.openai.stream.description"), connID), // "LLM %s is processing the request" }, ) } // traceLLMComplete marks the LLM request as complete in the trace func (ast *Assistant) traceLLMComplete(ctx *context.Context, completionResponse *context.CompletionResponse) { trace, _ := ctx.Trace() if trace == nil { return } trace.Complete(completionResponse) } // traceLLMFail marks the LLM request as failed in the trace func (ast *Assistant) traceLLMFail(ctx *context.Context, err error) { trace, _ := ctx.Trace() if trace == nil { return } trace.Fail(err) } // traceAgentCompletion creates a completion node to report the final output func (ast *Assistant) traceAgentCompletion(ctx *context.Context, createResponse *context.HookCreateResponse, nextResponse *context.NextHookResponse, completionResponse *context.CompletionResponse, finalResponse interface{}) { trace, _ := ctx.Trace() if trace == nil { return } // Prepare the input data (the raw responses before processing) input := map[string]interface{}{ "create": createResponse, "next": nextResponse, "completion": completionResponse, } // Create a dedicated completion node completionNode, err := trace.Add( input, types.TraceNodeOption{ Label: i18n.Tr(ast.ID, ctx.Locale, "assistant.agent.completion.label"), // "Agent Completion" Type: "agent_completion", Icon: "check_circle", Description: i18n.Tr(ast.ID, ctx.Locale, "assistant.agent.completion.description"), // "Final output from assistant" }, ) if err != nil { log.Trace("[TRACE] Failed to create completion node: %v", err) return } // Immediately mark it as complete with the final response if completionNode != nil { completionNode.Complete(finalResponse) } } // traceAgentOutput sets the output of the agent trace node // Deprecated: Use traceAgentCompletion instead for better trace structure func (ast *Assistant) traceAgentOutput(agentNode types.Node, createResponse *context.HookCreateResponse, nextResponse interface{}, completionResponse *context.CompletionResponse) { if agentNode == nil { return } output := context.Response{ Create: createResponse, Next: nextResponse, Completion: completionResponse, } agentNode.Complete(output) } // traceAgentFail marks the agent trace node as failed func (ast *Assistant) traceAgentFail(agentNode types.Node, err error) { if agentNode == nil { return } agentNode.Fail(err) } // traceLLMRetryRequest adds a LLM retry trace node to the trace func (ast *Assistant) traceLLMRetryRequest(ctx *context.Context, connID string, completionMessages []context.Message, completionOptions *context.CompletionOptions) { trace, _ := ctx.Trace() if trace == nil { return } trace.Add( map[string]any{"messages": completionMessages, "options": completionOptions}, types.TraceNodeOption{ Label: fmt.Sprintf("LLM %s (Tool Retry)", connID), Type: "llm_retry", Icon: "refresh", Description: fmt.Sprintf("LLM %s is retrying with tool call error feedback", connID), }, ) } ================================================ FILE: agent/assistant/types.go ================================================ package assistant import ( jsoniter "github.com/json-iterator/go" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/yao/agent/assistant/hook" chatctx "github.com/yaoapp/yao/agent/context" outputMessage "github.com/yaoapp/yao/agent/output/message" store "github.com/yaoapp/yao/agent/store/types" ) const ( // HookErrorMethodNotFound is the error message for method not found HookErrorMethodNotFound = "method not found" ) // API the assistant API interface type API interface { GetPlaceholder(locale string) *store.Placeholder } // Script the script scripts except hook script type Script struct { *v8.Script } // Assistant the assistant type Assistant struct { store.AssistantModel HookScript *hook.Script `json:"-" yaml:"-"` // Hook Script (index.ts) Scripts map[string]*Script `json:"-" yaml:"-"` // Other scripts // Internal // =============================== vision bool // Whether this assistant supports vision } // MCPTool represents a simplified MCP tool for building LLM requests // This is an internal representation used when collecting tools from MCP servers // and preparing them for the LLM's tool calling interface type MCPTool struct { Name string // Formatted tool name with server prefix (e.g., "server_id__tool_name") Description string // Tool description from MCP server Parameters interface{} // JSON Schema for tool parameters (from MCP InputSchema) } // ToolCallResult represents the result of a tool call execution // Used to track the outcome of MCP tool invocations during agent execution type ToolCallResult struct { ToolCallID string // Tool call ID from the LLM (matches the ID in the LLM's tool_calls response) Name string // Tool name (formatted with server prefix, e.g., "server_id__tool_name") Content string // Result content (JSON string of the tool's output or error message) Error error // Error if the call failed (nil if successful) IsRetryableError bool // Whether the error should be sent to LLM for retry // true: parameter/validation errors that LLM can fix (e.g., "missing required field") // false: MCP internal errors that LLM cannot fix (e.g., "network error", "service unavailable") } // Server extracts the MCP server ID from the formatted tool name // Example: "echo__ping" -> "echo" func (r *ToolCallResult) Server() string { serverID, _, _ := ParseMCPToolName(r.Name) return serverID } // Tool extracts the original tool name without server prefix // Example: "echo__ping" -> "ping" func (r *ToolCallResult) Tool() string { _, toolName, _ := ParseMCPToolName(r.Name) return toolName } // NextProcessContext encapsulates all the context needed to process Next hook responses // This simplifies function signatures and makes it easier to add new fields in the future type NextProcessContext struct { Context *chatctx.Context // Agent context NextResponse *chatctx.NextHookResponse // Response from Next hook (already converted from JS) CompletionResponse *chatctx.CompletionResponse // LLM completion response FullMessages []chatctx.Message // Full conversation history ToolCallResponses []chatctx.ToolCallResponse // Tool call results (if any) StreamHandler outputMessage.StreamFunc // Stream handler for output CreateResponse *chatctx.HookCreateResponse // Create hook response } // SearchIntent is an alias for context.SearchIntent // Used for search intent detection from __yao.needsearch agent type SearchIntent = chatctx.SearchIntent // ParsedContent extracts the actual tool return value from MCP ToolContent array // According to MCP protocol: // - Content is []ToolContent array // - For "text" type, the actual value is in Text field (usually JSON string) // - For "image" type, returns the Data field // - For "resource" type, returns the Resource object // If there are multiple content items, returns an array of parsed values func (r *ToolCallResult) ParsedContent() (interface{}, error) { if r.Content == "" { return nil, nil } // Parse Content as []ToolContent var toolContents []map[string]interface{} if err := jsoniter.UnmarshalFromString(r.Content, &toolContents); err != nil { // If parsing fails, return the string content directly (error message) return r.Content, nil } // Extract actual values from ToolContent items var results []interface{} for _, tc := range toolContents { contentType, _ := tc["type"].(string) switch contentType { case "text": // For text type, parse the Text field (usually JSON) if textStr, ok := tc["text"].(string); ok { // Try to parse as JSON var parsed interface{} if err := jsoniter.UnmarshalFromString(textStr, &parsed); err == nil { results = append(results, parsed) } else { // If not JSON, return as plain string results = append(results, textStr) } } case "image": // For image type, return the data and mimeType results = append(results, map[string]interface{}{ "type": "image", "data": tc["data"], "mimeType": tc["mimeType"], }) case "resource": // For resource type, return the resource object results = append(results, tc["resource"]) default: // Unknown type, return as-is results = append(results, tc) } } // If only one result, return it directly (not as array) if len(results) == 1 { return results[0], nil } return results, nil } ================================================ FILE: agent/assistant/utils.go ================================================ package assistant import ( "crypto/sha256" "encoding/hex" "fmt" "strconv" "time" jsoniter "github.com/json-iterator/go" "github.com/kaptinlin/jsonrepair" ) func getTimestamp(v interface{}) (int64, error) { switch v := v.(type) { case int64: return v, nil case int: return int64(v), nil case string: if ts, err := time.Parse(time.RFC3339, v); err == nil { return ts.UnixNano(), nil } // MySQL format if ts, err := time.Parse("2006-01-02 15:04:05", v); err == nil { return ts.UnixNano(), nil } // UnixNano format if ts, err := strconv.ParseInt(v, 10, 64); err == nil { return ts, nil } case time.Time: return v.UnixNano(), nil case nil: return 0, nil } return 0, fmt.Errorf("invalid timestamp type %T", v) } // getBool gets bool from data map[string]interface{}, key string func getBool(data map[string]interface{}, key string) bool { switch v := data[key].(type) { case bool: return v case int64: return v != 0 case int: return v != 0 case float64: return v != 0 case string: return v == "true" || v == "1" || v == "enabled" || v == "yes" || v == "on" case nil: return false } return false } // stringHash returns the sha256 hash of the string func stringHash(v string) string { h := sha256.New() h.Write([]byte(v)) return hex.EncodeToString(h.Sum(nil)) } // ParseJSON attempts to parse a potentially malformed JSON string func ParseJSON(jsonStr string, v interface{}) error { // Try parsing as-is first err := jsoniter.UnmarshalFromString(jsonStr, v) if err == nil { return nil } originalErr := err // Try adding a closing brace if err := jsoniter.UnmarshalFromString(jsonStr+"}", v); err == nil { return nil } // Try repairing the JSON repaired, err := jsonrepair.JSONRepair(jsonStr) if err != nil { return originalErr } // Try parsing the repaired JSON if err := jsoniter.UnmarshalFromString(repaired, v); err == nil { return nil } // If all attempts fail, return the original error return originalErr } ================================================ FILE: agent/caller/caller.go ================================================ // Package caller provides a shared interface for calling agents // This package is used by both content and search packages to avoid circular dependencies package caller import ( agentContext "github.com/yaoapp/yao/agent/context" ) // AgentCaller interface for calling agents (to avoid circular dependency) // Used by content handlers (vision, audio, etc.) and search handlers (agent mode) type AgentCaller interface { Stream(ctx *agentContext.Context, messages []agentContext.Message, options ...*agentContext.Options) (*agentContext.Response, error) } // AgentGetterFunc is a function type that gets an agent by ID // This should be set by the assistant package during initialization var AgentGetterFunc func(agentID string) (AgentCaller, error) ================================================ FILE: agent/caller/context.go ================================================ package caller import ( "context" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/openapi/oauth/types" ) // NewHeadlessContext creates a headless agent context from a ProcessCallRequest. // This is the Process equivalent of openapi.GetCompletionRequest — constructs // a Context + Options without HTTP dependencies (no Writer, no Interrupt). // // Key behaviors: // - parent context controls timeout/cancellation (caller is responsible) // - skip.output = true (forced): no Writer available, must skip output // - skip.history = true (forced): Process calls don't save chat history // - authorized info is passed in (from authorized.ProcessAuthInfo by caller) // - chatID is auto-generated if not provided func NewHeadlessContext(parent context.Context, authInfo *types.AuthorizedInfo, req *ProcessCallRequest) (*agentContext.Context, *agentContext.Options) { chatID := req.ChatID if chatID == "" { chatID = agentContext.GenChatID() } ctx := agentContext.New(parent, authInfo, chatID) ctx.AssistantID = req.AssistantID ctx.Referer = agentContext.RefererProcess ctx.Locale = req.Locale ctx.Route = req.Route ctx.Metadata = req.Metadata // Force skip for headless context — no Writer, no chat history skip := req.Skip if skip == nil { skip = &agentContext.Skip{} } skip.Output = true // no Writer available skip.History = true // Process calls don't save chat history opts := &agentContext.Options{Skip: skip} if req.Model != "" { opts.Connector = req.Model } return ctx, opts } ================================================ FILE: agent/caller/integration_test.go ================================================ package caller_test import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/caller" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) func TestIntegration_Call_RealAgent(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Load the simple-greeting agent ast, err := assistant.Get("tests.simple-greeting") require.NoError(t, err) require.NotNil(t, ast) // Create authorized info for the context authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } // Create a context with authorization ctx := agentContext.New(context.Background(), authorized, "test-chat-integration") ctx.AssistantID = "tests.agent-caller" // Create JSAPI api := caller.NewJSAPI(ctx) // Call the simple-greeting agent messages := []interface{}{ map[string]interface{}{ "role": "user", "content": "Hello!", }, } opts := map[string]interface{}{ "skip": map[string]interface{}{ "history": true, }, } result := api.Call("tests.simple-greeting", messages, opts) require.NotNil(t, result) r, ok := result.(*caller.Result) require.True(t, ok) assert.Equal(t, "tests.simple-greeting", r.AgentID) // Should either have content or error if r.Error != "" { t.Logf("Agent call error: %s", r.Error) } else { t.Logf("Agent response content: %s", r.Content) assert.NotEmpty(t, r.Content) } } func TestIntegration_All_RealAgents(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Create authorized info for the context authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } // Create a context with authorization ctx := agentContext.New(context.Background(), authorized, "test-chat-all") ctx.AssistantID = "tests.agent-caller" // Create JSAPI api := caller.NewJSAPI(ctx) // Call multiple agents in parallel requests := []interface{}{ map[string]interface{}{ "agent": "tests.simple-greeting", "messages": []interface{}{ map[string]interface{}{ "role": "user", "content": "Hello from test 1!", }, }, "options": map[string]interface{}{ "skip": map[string]interface{}{ "history": true, }, }, }, map[string]interface{}{ "agent": "tests.simple-greeting", "messages": []interface{}{ map[string]interface{}{ "role": "user", "content": "Hello from test 2!", }, }, "options": map[string]interface{}{ "skip": map[string]interface{}{ "history": true, }, }, }, } results := api.All(requests) require.Len(t, results, 2) for i, result := range results { r, ok := result.(*caller.Result) require.True(t, ok, "result %d should be *caller.Result", i) assert.Equal(t, "tests.simple-greeting", r.AgentID) t.Logf("Result[%d]: content=%s, error=%s", i, r.Content, r.Error) } } func TestIntegration_Any_RealAgents(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Create authorized info for the context authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } // Create a context with authorization ctx := agentContext.New(context.Background(), authorized, "test-chat-any") ctx.AssistantID = "tests.agent-caller" // Create JSAPI api := caller.NewJSAPI(ctx) // Call multiple agents - return when any succeeds requests := []interface{}{ map[string]interface{}{ "agent": "tests.simple-greeting", "messages": []interface{}{ map[string]interface{}{ "role": "user", "content": "Hello from any test 1!", }, }, "options": map[string]interface{}{ "skip": map[string]interface{}{ "history": true, }, }, }, map[string]interface{}{ "agent": "tests.simple-greeting", "messages": []interface{}{ map[string]interface{}{ "role": "user", "content": "Hello from any test 2!", }, }, "options": map[string]interface{}{ "skip": map[string]interface{}{ "history": true, }, }, }, } results := api.Any(requests) require.Len(t, results, 2) // At least one should have a result hasResult := false for i, result := range results { if result != nil { r, ok := result.(*caller.Result) if ok && r != nil && r.Error == "" { hasResult = true t.Logf("Any Result[%d]: content=%s", i, r.Content) } } } assert.True(t, hasResult, "At least one result should succeed") } func TestIntegration_Race_RealAgents(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Create authorized info for the context authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } // Create a context with authorization ctx := agentContext.New(context.Background(), authorized, "test-chat-race") ctx.AssistantID = "tests.agent-caller" // Create JSAPI api := caller.NewJSAPI(ctx) // Call multiple agents - return when any completes requests := []interface{}{ map[string]interface{}{ "agent": "tests.simple-greeting", "messages": []interface{}{ map[string]interface{}{ "role": "user", "content": "Hello from race test 1!", }, }, "options": map[string]interface{}{ "skip": map[string]interface{}{ "history": true, }, }, }, map[string]interface{}{ "agent": "tests.simple-greeting", "messages": []interface{}{ map[string]interface{}{ "role": "user", "content": "Hello from race test 2!", }, }, "options": map[string]interface{}{ "skip": map[string]interface{}{ "history": true, }, }, }, } results := api.Race(requests) require.Len(t, results, 2) // At least one should have completed hasResult := false for i, result := range results { if result != nil { r, ok := result.(*caller.Result) if ok && r != nil { hasResult = true t.Logf("Race Result[%d]: content=%s, error=%s", i, r.Content, r.Error) } } } assert.True(t, hasResult, "At least one result should complete") } ================================================ FILE: agent/caller/jsapi.go ================================================ package caller import ( agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" ) // JSAPI implements context.AgentAPI and context.AgentAPIWithCallback interfaces // Provides ctx.agent.Call(), ctx.agent.All(), ctx.agent.Any(), ctx.agent.Race() // and their *WithHandler variants for streaming callback support type JSAPI struct { ctx *agentContext.Context orchestrator *Orchestrator } // Ensure JSAPI implements AgentAPIWithCallback var _ agentContext.AgentAPIWithCallback = (*JSAPI)(nil) // NewJSAPI creates a new agent JSAPI instance func NewJSAPI(ctx *agentContext.Context) *JSAPI { return &JSAPI{ ctx: ctx, orchestrator: NewOrchestrator(ctx), } } // Call executes a single agent call // Usage: ctx.agent.Call("assistant-id", messages, options?) // Returns: { agent_id, response, content, error } // Note: For sub-agent calls, skip.history = true is automatically set // to prevent A2A messages from being saved to chat history. // Sub-agents output normally with ThreadID for SSE stream isolation. func (api *JSAPI) Call(agentID string, messages []interface{}, opts map[string]interface{}) interface{} { req := api.buildRequest(agentID, messages, opts) // Force skip options for sub-agent calls api.forceSkipForSubAgent(req) result := api.orchestrator.callAgent(req) return result } // All executes all agent calls and waits for all to complete (like Promise.all) // Each request should have: // - agent: string - target agent ID // - messages: array - messages to send // - options?: object - call options func (api *JSAPI) All(requests []interface{}) []interface{} { reqs := api.parseRequests(requests) results := api.orchestrator.All(reqs) return api.convertResults(results) } // Any returns as soon as any agent call succeeds (like Promise.any) // Each request should have: // - agent: string - target agent ID // - messages: array - messages to send // - options?: object - call options func (api *JSAPI) Any(requests []interface{}) []interface{} { reqs := api.parseRequests(requests) results := api.orchestrator.Any(reqs) return api.convertResults(results) } // Race returns as soon as any agent call completes (like Promise.race) // Each request should have: // - agent: string - target agent ID // - messages: array - messages to send // - options?: object - call options func (api *JSAPI) Race(requests []interface{}) []interface{} { reqs := api.parseRequests(requests) results := api.orchestrator.Race(reqs) return api.convertResults(results) } // ============================================================================ // AgentAPIWithCallback Implementation // ============================================================================ // CallWithHandler executes a single agent call with an OnMessage handler // Note: For sub-agent calls, skip.history = true is automatically set. // Sub-agents output normally with ThreadID. Use the handler callback // to receive streaming messages. func (api *JSAPI) CallWithHandler(agentID string, messages []interface{}, opts map[string]interface{}, handler agentContext.OnMessageFunc) interface{} { req := api.buildRequest(agentID, messages, opts) req.Handler = handler // Force skip options for sub-agent calls api.forceSkipForSubAgent(req) result := api.orchestrator.callAgent(req) return result } // AllWithHandler executes all agent calls with handlers func (api *JSAPI) AllWithHandler(requests []interface{}, globalHandler agentContext.BatchOnMessageFunc) []interface{} { reqs := api.parseRequestsWithHandlers(requests, globalHandler) results := api.orchestrator.All(reqs) return api.convertResults(results) } // AnyWithHandler executes agent calls and returns on first success, with handlers func (api *JSAPI) AnyWithHandler(requests []interface{}, globalHandler agentContext.BatchOnMessageFunc) []interface{} { reqs := api.parseRequestsWithHandlers(requests, globalHandler) results := api.orchestrator.Any(reqs) return api.convertResults(results) } // RaceWithHandler executes agent calls and returns on first completion, with handlers func (api *JSAPI) RaceWithHandler(requests []interface{}, globalHandler agentContext.BatchOnMessageFunc) []interface{} { reqs := api.parseRequestsWithHandlers(requests, globalHandler) results := api.orchestrator.Race(reqs) return api.convertResults(results) } // forceSkipForSubAgent ensures proper A2A call behavior: // - skip.history = true: always set — A2A messages are not saved to chat history // - skip.output: defaults to false (sub-agents output with ThreadID for SSE stream isolation), // but if the caller explicitly sets skip.output = true, it is respected. // This allows internal worker agents (e.g. classifiers) to run silently. func (api *JSAPI) forceSkipForSubAgent(req *Request) { if req.Options == nil { req.Options = &CallOptions{} } // Preserve caller's explicit skip.output = true before overwriting Skip struct callerSkipOutput := req.Options.Skip != nil && req.Options.Skip.Output if req.Options.Skip == nil { req.Options.Skip = &agentContext.Skip{} } req.Options.Skip.History = true if callerSkipOutput { req.Options.Skip.Output = true } // else: skip.output remains false (default zero value) — sub-agent outputs normally } // parseRequestsWithHandlers parses requests and attaches handlers // It checks for per-request _handler fields and wraps globalHandler with agentID/index // For all calls, this automatically sets: // - skip.history = true: prevents A2A messages from being saved to chat history // - skip.output = false: ensures sub-agents output with ThreadID (overrides user settings) func (api *JSAPI) parseRequestsWithHandlers(requests []interface{}, globalHandler agentContext.BatchOnMessageFunc) []*Request { reqs := make([]*Request, 0, len(requests)) for i, r := range requests { reqMap, ok := r.(map[string]interface{}) if !ok { continue } // Get agent ID agentID, ok := reqMap["agent"].(string) if !ok { continue } // Get messages messages, ok := reqMap["messages"].([]interface{}) if !ok { continue } // Get options (optional) var opts map[string]interface{} if o, ok := reqMap["options"].(map[string]interface{}); ok { opts = o } req := api.buildRequest(agentID, messages, opts) // Force skip.output = true for all sub-agent calls api.forceSkipForSubAgent(req) // Check for per-request handler first (takes precedence) if handler, ok := reqMap["_handler"].(agentContext.OnMessageFunc); ok && handler != nil { req.Handler = handler } else if globalHandler != nil { // Wrap global handler with agentID and index idx := i // Capture index for closure aid := agentID req.Handler = func(msg *message.Message) int { return globalHandler(aid, idx, msg) } } reqs = append(reqs, req) } return reqs } // buildRequest builds a Request from agentID, messages, and options func (api *JSAPI) buildRequest(agentID string, messages []interface{}, opts map[string]interface{}) *Request { req := &Request{ AgentID: agentID, Messages: api.parseMessages(messages), } if opts != nil { req.Options = api.parseCallOptions(opts) } return req } // parseMessages converts []interface{} to []agentContext.Message func (api *JSAPI) parseMessages(messages []interface{}) []agentContext.Message { result := make([]agentContext.Message, 0, len(messages)) for _, m := range messages { msg, ok := m.(map[string]interface{}) if !ok { continue } ctxMsg := agentContext.Message{} // Parse role if role, ok := msg["role"].(string); ok { ctxMsg.Role = agentContext.MessageRole(role) } // Parse content (can be string or array) ctxMsg.Content = msg["content"] // Parse name if name, ok := msg["name"].(string); ok { ctxMsg.Name = &name } // Parse tool_call_id if toolCallID, ok := msg["tool_call_id"].(string); ok { ctxMsg.ToolCallID = &toolCallID } // Parse tool_calls if toolCalls, ok := msg["tool_calls"].([]interface{}); ok { ctxMsg.ToolCalls = api.parseToolCalls(toolCalls) } // Parse refusal if refusal, ok := msg["refusal"].(string); ok { ctxMsg.Refusal = &refusal } result = append(result, ctxMsg) } return result } // parseToolCalls converts []interface{} to []agentContext.ToolCall func (api *JSAPI) parseToolCalls(toolCalls []interface{}) []agentContext.ToolCall { result := make([]agentContext.ToolCall, 0, len(toolCalls)) for _, tc := range toolCalls { tcMap, ok := tc.(map[string]interface{}) if !ok { continue } toolCall := agentContext.ToolCall{} if id, ok := tcMap["id"].(string); ok { toolCall.ID = id } if tcType, ok := tcMap["type"].(string); ok { toolCall.Type = agentContext.ToolCallType(tcType) } if fn, ok := tcMap["function"].(map[string]interface{}); ok { if name, ok := fn["name"].(string); ok { toolCall.Function.Name = name } if args, ok := fn["arguments"].(string); ok { toolCall.Function.Arguments = args } } result = append(result, toolCall) } return result } // parseCallOptions converts map to CallOptions func (api *JSAPI) parseCallOptions(opts map[string]interface{}) *CallOptions { callOpts := &CallOptions{} if connector, ok := opts["connector"].(string); ok { callOpts.Connector = connector } if mode, ok := opts["mode"].(string); ok { callOpts.Mode = mode } if metadata, ok := opts["metadata"].(map[string]interface{}); ok { callOpts.Metadata = metadata } // Parse skip configuration if skip, ok := opts["skip"].(map[string]interface{}); ok { callOpts.Skip = &agentContext.Skip{} if history, ok := skip["history"].(bool); ok { callOpts.Skip.History = history } if trace, ok := skip["trace"].(bool); ok { callOpts.Skip.Trace = trace } if output, ok := skip["output"].(bool); ok { callOpts.Skip.Output = output } if keyword, ok := skip["keyword"].(bool); ok { callOpts.Skip.Keyword = keyword } if search, ok := skip["search"].(bool); ok { callOpts.Skip.Search = search } if contentParsing, ok := skip["content_parsing"].(bool); ok { callOpts.Skip.ContentParsing = contentParsing } } return callOpts } // parseRequests parses an array of request objects into typed Requests func (api *JSAPI) parseRequests(requests []interface{}) []*Request { return api.parseRequestsWithHandlers(requests, nil) } // convertResults converts typed Results to interface slice for JS func (api *JSAPI) convertResults(results []*Result) []interface{} { out := make([]interface{}, len(results)) for i, r := range results { out[i] = r } return out } // SetJSAPIFactory sets the factory function for creating AgentAPI instances // Called by assistant package during initialization func SetJSAPIFactory() { agentContext.AgentAPIFactory = func(ctx *agentContext.Context) agentContext.AgentAPI { return NewJSAPI(ctx) } } ================================================ FILE: agent/caller/jsapi_test.go ================================================ package caller_test import ( stdContext "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/caller" "github.com/yaoapp/yao/agent/context" ) func TestNewJSAPI(t *testing.T) { ctx := context.New(stdContext.Background(), nil, "test-chat") api := caller.NewJSAPI(ctx) require.NotNil(t, api) } func TestJSAPI_Call_NoAgentGetter(t *testing.T) { // Reset AgentGetterFunc originalGetter := caller.AgentGetterFunc caller.AgentGetterFunc = nil defer func() { caller.AgentGetterFunc = originalGetter }() ctx := context.New(stdContext.Background(), nil, "test-chat") api := caller.NewJSAPI(ctx) messages := []interface{}{ map[string]interface{}{ "role": "user", "content": "Hello", }, } result := api.Call("test-agent", messages, nil) require.NotNil(t, result) r, ok := result.(*caller.Result) require.True(t, ok) assert.Equal(t, "test-agent", r.AgentID) assert.Contains(t, r.Error, "agent getter not initialized") } func TestJSAPI_All_Empty(t *testing.T) { ctx := context.New(stdContext.Background(), nil, "test-chat") api := caller.NewJSAPI(ctx) results := api.All([]interface{}{}) assert.Len(t, results, 0) } func TestJSAPI_Any_Empty(t *testing.T) { ctx := context.New(stdContext.Background(), nil, "test-chat") api := caller.NewJSAPI(ctx) results := api.Any([]interface{}{}) assert.Len(t, results, 0) } func TestJSAPI_Race_Empty(t *testing.T) { ctx := context.New(stdContext.Background(), nil, "test-chat") api := caller.NewJSAPI(ctx) results := api.Race([]interface{}{}) assert.Len(t, results, 0) } func TestJSAPI_All_InvalidRequests(t *testing.T) { ctx := context.New(stdContext.Background(), nil, "test-chat") api := caller.NewJSAPI(ctx) // Mix of invalid and valid requests requests := []interface{}{ "invalid", // Not a map map[string]interface{}{ "messages": []interface{}{}, // Missing agent }, map[string]interface{}{ "agent": "test-agent", // Missing messages }, } results := api.All(requests) // None should produce a result (all invalid) assert.Len(t, results, 0) } func TestJSAPI_Call_WithOptions(t *testing.T) { // Reset AgentGetterFunc originalGetter := caller.AgentGetterFunc caller.AgentGetterFunc = nil defer func() { caller.AgentGetterFunc = originalGetter }() ctx := context.New(stdContext.Background(), nil, "test-chat") api := caller.NewJSAPI(ctx) messages := []interface{}{ map[string]interface{}{ "role": "user", "content": "Hello", }, } opts := map[string]interface{}{ "connector": "gpt4", "mode": "chat", "metadata": map[string]interface{}{ "key": "value", }, "skip": map[string]interface{}{ "history": true, "trace": true, }, } result := api.Call("test-agent", messages, opts) require.NotNil(t, result) r, ok := result.(*caller.Result) require.True(t, ok) assert.Equal(t, "test-agent", r.AgentID) // Still errors because AgentGetterFunc is nil assert.Contains(t, r.Error, "agent getter not initialized") } func TestSetJSAPIFactory(t *testing.T) { // Reset factory context.AgentAPIFactory = nil // Set factory caller.SetJSAPIFactory() // Verify factory is set require.NotNil(t, context.AgentAPIFactory) // Create a mock context ctx := context.New(stdContext.Background(), nil, "test-chat") // Get agent API agentAPI := context.AgentAPIFactory(ctx) require.NotNil(t, agentAPI) } func TestJSAPI_ImplementsAgentAPI(t *testing.T) { // Verify JSAPI implements context.AgentAPI interface ctx := context.New(stdContext.Background(), nil, "test-chat") var _ context.AgentAPI = caller.NewJSAPI(ctx) } ================================================ FILE: agent/caller/orchestrator.go ================================================ package caller import ( "fmt" "sync" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/trace/types" ) // Orchestrator handles parallel agent calls with different concurrency patterns // Modeled after JavaScript Promise patterns (all, any, race) type Orchestrator struct { ctx *agentContext.Context } // NewOrchestrator creates a new Orchestrator for parallel agent calls func NewOrchestrator(ctx *agentContext.Context) *Orchestrator { return &Orchestrator{ctx: ctx} } // callResult is used internally to pass results through channels type callResult struct { idx int result *Result } // All executes all agent calls and waits for all to complete (like Promise.all) // Returns results in the same order as requests, regardless of completion order // Each call uses a forked context to avoid race conditions on shared state func (o *Orchestrator) All(reqs []*Request) []*Result { if len(reqs) == 0 { return []*Result{} } results := make([]*Result, len(reqs)) var wg sync.WaitGroup var mu sync.Mutex for i, req := range reqs { wg.Add(1) go func(idx int, r *Request) { defer wg.Done() defer func() { if err := recover(); err != nil { mu.Lock() results[idx] = &Result{ AgentID: r.AgentID, Error: "agent call panic recovered", } mu.Unlock() } }() // Use forked context to avoid race conditions result := o.callAgentWithForkedContext(r) mu.Lock() results[idx] = result mu.Unlock() }(i, req) } wg.Wait() return results } // Any returns as soon as any agent call succeeds (has non-error result) (like Promise.any) // Other calls continue in background but results are discarded after first success // Returns all results received so far when first success is found // Each call uses a forked context to avoid race conditions on shared state func (o *Orchestrator) Any(reqs []*Request) []*Result { if len(reqs) == 0 { return []*Result{} } results := make([]*Result, len(reqs)) resultChan := make(chan callResult, len(reqs)) var wg sync.WaitGroup done := make(chan struct{}) for i, req := range reqs { wg.Add(1) go func(idx int, r *Request) { defer wg.Done() defer func() { if err := recover(); err != nil { // Send panic result through channel select { case <-done: case resultChan <- callResult{idx: idx, result: &Result{ AgentID: r.AgentID, Error: "agent call panic recovered", }}: } } }() // Check if done before starting select { case <-done: return default: } // Use forked context to avoid race conditions result := o.callAgentWithForkedContext(r) // Try to send result select { case <-done: // Already found a successful result case resultChan <- callResult{idx: idx, result: result}: } }(i, req) } // Close channel when all goroutines complete go func() { wg.Wait() close(resultChan) }() // Collect results until we find one with success (no error and has content) var foundSuccess bool for res := range resultChan { results[res.idx] = res.result // Check if this result is successful (no error) if !foundSuccess && res.result != nil && res.result.Error == "" { foundSuccess = true close(done) // Signal other goroutines to stop } } return results } // Race returns as soon as any agent call completes (like Promise.race) // Returns immediately when first result arrives, regardless of success/failure // Note: Still waits for all goroutines to complete before returning to avoid resource leaks // Each call uses a forked context to avoid race conditions on shared state func (o *Orchestrator) Race(reqs []*Request) []*Result { if len(reqs) == 0 { return []*Result{} } results := make([]*Result, len(reqs)) resultChan := make(chan callResult, len(reqs)) var wg sync.WaitGroup done := make(chan struct{}) for i, req := range reqs { wg.Add(1) go func(idx int, r *Request) { defer wg.Done() defer func() { if err := recover(); err != nil { // Send panic result through channel select { case <-done: case resultChan <- callResult{idx: idx, result: &Result{ AgentID: r.AgentID, Error: "agent call panic recovered", }}: } } }() // Check if done before starting select { case <-done: return default: } // Use forked context to avoid race conditions result := o.callAgentWithForkedContext(r) // Try to send result select { case <-done: // Already got first result case resultChan <- callResult{idx: idx, result: result}: } }(i, req) } // Close channel when all goroutines complete go func() { wg.Wait() close(resultChan) }() // Get first result and signal others to stop var gotFirst bool for res := range resultChan { results[res.idx] = res.result if !gotFirst { gotFirst = true close(done) // Signal other goroutines to stop } } return results } // callAgent executes a single agent call using the AgentGetterFunc // This method handles context sharing and result extraction func (o *Orchestrator) callAgent(req *Request) *Result { return o.callAgentWithContext(o.ctx, req) } // callAgentWithForkedContext executes a single agent call with a forked context // This is used by batch operations (All/Any/Race) to avoid race conditions // when multiple goroutines modify shared context state (Stack, Logger, etc.) func (o *Orchestrator) callAgentWithForkedContext(req *Request) *Result { // Fork the context to get independent Stack and Logger forkedCtx := o.ctx.Fork() return o.callAgentWithContext(forkedCtx, req) } // callAgentWithContext executes a single agent call with the given context // This is the core implementation used by both callAgent and callAgentWithForkedContext func (o *Orchestrator) callAgentWithContext(ctx *agentContext.Context, req *Request) *Result { if req == nil { return &Result{Error: "nil request"} } // Get the agent using the getter function if AgentGetterFunc == nil { return NewResult(req.AgentID, nil, fmt.Errorf("agent getter not initialized")) } agent, err := AgentGetterFunc(req.AgentID) if err != nil { return NewResult(req.AgentID, nil, fmt.Errorf("failed to get agent: %w", err)) } // Mark this as an agent-to-agent fork call for proper source tracking // RefererAgentFork distinguishes ctx.agent.Call from delegate calls ctx.Referer = agentContext.RefererAgentFork // Build context options for the call var ctxOpts *agentContext.Options if req.Options != nil { ctxOpts = req.Options.ToContextOptions() } else { ctxOpts = &agentContext.Options{} } // If request has a handler, set OnMessage callback if req.Handler != nil { if ctxOpts == nil { ctxOpts = &agentContext.Options{} } // Set OnMessage to receive SSE messages ctxOpts.OnMessage = req.Handler } // Add trace node for A2A call using the ORIGINAL parent context's trace // (forked contexts have nil trace and nil Stack, so ctx.Trace() would create a new orphan trace) parentTrace, _ := o.ctx.Trace() var a2aNode types.Node if parentTrace != nil { a2aNode, _ = parentTrace.Add( map[string]any{ "agent_id": req.AgentID, "referer": string(ctx.Referer), }, types.TraceNodeOption{ Label: fmt.Sprintf("Agent: %s", req.AgentID), Type: "agent_call", Icon: "smart_toy", Description: fmt.Sprintf("A2A call to '%s'", req.AgentID), }, ) } resp, err := agent.Stream(ctx, req.Messages, ctxOpts) if err != nil { if a2aNode != nil { a2aNode.Fail(err) } return NewResult(req.AgentID, nil, fmt.Errorf("agent call failed: %w", err)) } if a2aNode != nil { a2aNode.Complete(map[string]any{ "agent_id": req.AgentID, "status": "completed", }) } return NewResult(req.AgentID, resp, nil) } // extractContentFromCompletion extracts the text content from a completion response func extractContentFromCompletion(completion *agentContext.CompletionResponse) string { if completion == nil { return "" } // Content can be string or []ContentPart switch content := completion.Content.(type) { case string: return content case []interface{}: // Handle array of content parts - extract text parts var texts []string for _, part := range content { if partMap, ok := part.(map[string]interface{}); ok { if partType, ok := partMap["type"].(string); ok && partType == "text" { if text, ok := partMap["text"].(string); ok { texts = append(texts, text) } } } } if len(texts) > 0 { return texts[0] // Return first text content } } return "" } ================================================ FILE: agent/caller/orchestrator_test.go ================================================ package caller_test import ( stdContext "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/caller" "github.com/yaoapp/yao/agent/context" ) func TestNewOrchestrator(t *testing.T) { ctx := context.New(stdContext.Background(), nil, "test-chat") orch := caller.NewOrchestrator(ctx) require.NotNil(t, orch) } func TestOrchestrator_All_Empty(t *testing.T) { ctx := context.New(stdContext.Background(), nil, "test-chat") orch := caller.NewOrchestrator(ctx) results := orch.All([]*caller.Request{}) assert.Len(t, results, 0) } func TestOrchestrator_Any_Empty(t *testing.T) { ctx := context.New(stdContext.Background(), nil, "test-chat") orch := caller.NewOrchestrator(ctx) results := orch.Any([]*caller.Request{}) assert.Len(t, results, 0) } func TestOrchestrator_Race_Empty(t *testing.T) { ctx := context.New(stdContext.Background(), nil, "test-chat") orch := caller.NewOrchestrator(ctx) results := orch.Race([]*caller.Request{}) assert.Len(t, results, 0) } func TestOrchestrator_All_NoGetter(t *testing.T) { // Reset AgentGetterFunc originalGetter := caller.AgentGetterFunc caller.AgentGetterFunc = nil defer func() { caller.AgentGetterFunc = originalGetter }() ctx := context.New(stdContext.Background(), nil, "test-chat") orch := caller.NewOrchestrator(ctx) reqs := []*caller.Request{ { AgentID: "agent1", Messages: []context.Message{{Role: "user", Content: "Hello"}}, }, { AgentID: "agent2", Messages: []context.Message{{Role: "user", Content: "World"}}, }, } results := orch.All(reqs) require.Len(t, results, 2) // All should have errors because no getter for i, r := range results { require.NotNil(t, r, "result %d should not be nil", i) assert.Contains(t, r.Error, "agent getter not initialized") } } func TestOrchestrator_Any_NoGetter(t *testing.T) { // Reset AgentGetterFunc originalGetter := caller.AgentGetterFunc caller.AgentGetterFunc = nil defer func() { caller.AgentGetterFunc = originalGetter }() ctx := context.New(stdContext.Background(), nil, "test-chat") orch := caller.NewOrchestrator(ctx) reqs := []*caller.Request{ { AgentID: "agent1", Messages: []context.Message{{Role: "user", Content: "Hello"}}, }, { AgentID: "agent2", Messages: []context.Message{{Role: "user", Content: "World"}}, }, } results := orch.Any(reqs) require.Len(t, results, 2) // At least one result should exist hasResult := false for _, r := range results { if r != nil { hasResult = true assert.Contains(t, r.Error, "agent getter not initialized") } } assert.True(t, hasResult) } func TestOrchestrator_Race_NoGetter(t *testing.T) { // Reset AgentGetterFunc originalGetter := caller.AgentGetterFunc caller.AgentGetterFunc = nil defer func() { caller.AgentGetterFunc = originalGetter }() ctx := context.New(stdContext.Background(), nil, "test-chat") orch := caller.NewOrchestrator(ctx) reqs := []*caller.Request{ { AgentID: "agent1", Messages: []context.Message{{Role: "user", Content: "Hello"}}, }, { AgentID: "agent2", Messages: []context.Message{{Role: "user", Content: "World"}}, }, } results := orch.Race(reqs) require.Len(t, results, 2) // At least one result should exist (first to complete) hasResult := false for _, r := range results { if r != nil { hasResult = true } } assert.True(t, hasResult) } func TestOrchestrator_All_NilRequest(t *testing.T) { ctx := context.New(stdContext.Background(), nil, "test-chat") orch := caller.NewOrchestrator(ctx) reqs := []*caller.Request{ nil, { AgentID: "agent1", Messages: []context.Message{{Role: "user", Content: "Hello"}}, }, } // Reset AgentGetterFunc originalGetter := caller.AgentGetterFunc caller.AgentGetterFunc = nil defer func() { caller.AgentGetterFunc = originalGetter }() results := orch.All(reqs) require.Len(t, results, 2) // First result should have "nil request" error assert.Contains(t, results[0].Error, "nil request") } ================================================ FILE: agent/caller/process.go ================================================ package caller import ( "context" "encoding/json" "fmt" "time" "github.com/yaoapp/gou/process" "github.com/yaoapp/kun/exception" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/openapi/oauth/authorized" ) func init() { process.Register("agent.Call", processAgentCall) } // processAgentCall implements the agent.Call Process handler. // Enables agent-to-agent calls from contexts without agent.Context (e.g., YaoJob). // // Usage: Process("agent.Call", { assistant_id, messages, model?, ... }) // Returns: *Result (same structure as ctx.agent.Call in JSAPI) func processAgentCall(p *process.Process) interface{} { // 1. Parse parameters via struct — fail fast on invalid input if len(p.Args) == 0 { exception.New("agent.Call: argument is required", 400).Throw() } var req ProcessCallRequest raw, err := json.Marshal(p.Args[0]) if err != nil { exception.New("agent.Call: invalid argument: %s", 400, err.Error()).Throw() } if err := json.Unmarshal(raw, &req); err != nil { exception.New("agent.Call: failed to parse request: %s", 400, err.Error()).Throw() } if req.AssistantID == "" { exception.New("agent.Call: assistant_id is required", 400).Throw() } if len(req.Messages) == 0 { exception.New("agent.Call: messages is required", 400).Throw() } // 2. Auto-inject authorization info from process context authInfo := authorized.ProcessAuthInfo(p) // 3. Build timeout context — LLM calls can take minutes (tool use, multi-turn) // Default: 10 minutes (DefaultProcessTimeout). Caller can override via `timeout` field. timeoutSec := req.Timeout if timeoutSec <= 0 { timeoutSec = DefaultProcessTimeout } parent := p.Context if parent == nil { parent = context.Background() } timeoutCtx, cancel := context.WithTimeout(parent, time.Duration(timeoutSec)*time.Second) defer cancel() // 4. Build headless context + options (encapsulated in context.go) ctx, opts := NewHeadlessContext(timeoutCtx, authInfo, &req) defer ctx.Release() // 5. Parse messages from []map[string]interface{} to []agentContext.Message messages := ParseMessages(req.Messages) // 6. Get agent and execute if AgentGetterFunc == nil { return NewResult(req.AssistantID, nil, fmt.Errorf("agent getter not initialized")) } agent, err := AgentGetterFunc(req.AssistantID) if err != nil { return NewResult(req.AssistantID, nil, fmt.Errorf("failed to get agent: %w", err)) } resp, err := agent.Stream(ctx, messages, opts) if err != nil { return NewResult(req.AssistantID, nil, fmt.Errorf("agent call failed: %w", err)) } // 7. Return *Result — shared with ctx.agent.Call() via NewResult() return NewResult(req.AssistantID, resp, nil) } // ParseMessages converts []map[string]interface{} to []agentContext.Message. // Extracted as a package-level function so it can be reused by both // processAgentCall and JSAPI.parseMessages. func ParseMessages(raw []map[string]interface{}) []agentContext.Message { result := make([]agentContext.Message, 0, len(raw)) for _, msg := range raw { ctxMsg := agentContext.Message{} // Parse role if role, ok := msg["role"].(string); ok { ctxMsg.Role = agentContext.MessageRole(role) } // Parse content (can be string or array of content parts) ctxMsg.Content = msg["content"] // Parse name if name, ok := msg["name"].(string); ok { ctxMsg.Name = &name } // Parse tool_call_id if toolCallID, ok := msg["tool_call_id"].(string); ok { ctxMsg.ToolCallID = &toolCallID } // Parse tool_calls if toolCalls, ok := msg["tool_calls"].([]interface{}); ok { ctxMsg.ToolCalls = parseToolCalls(toolCalls) } // Parse refusal if refusal, ok := msg["refusal"].(string); ok { ctxMsg.Refusal = &refusal } result = append(result, ctxMsg) } return result } // parseToolCalls converts []interface{} to []agentContext.ToolCall func parseToolCalls(toolCalls []interface{}) []agentContext.ToolCall { result := make([]agentContext.ToolCall, 0, len(toolCalls)) for _, tc := range toolCalls { tcMap, ok := tc.(map[string]interface{}) if !ok { continue } toolCall := agentContext.ToolCall{} if id, ok := tcMap["id"].(string); ok { toolCall.ID = id } if tcType, ok := tcMap["type"].(string); ok { toolCall.Type = agentContext.ToolCallType(tcType) } if fn, ok := tcMap["function"].(map[string]interface{}); ok { if name, ok := fn["name"].(string); ok { toolCall.Function.Name = name } if args, ok := fn["arguments"].(string); ok { toolCall.Function.Arguments = args } } result = append(result, toolCall) } return result } ================================================ FILE: agent/caller/process_e2e_test.go ================================================ package caller_test import ( "context" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/process" "github.com/yaoapp/yao/agent/caller" "github.com/yaoapp/yao/agent/testutils" ) // newLLMProcess creates a process.Process with a 120s outer timeout for LLM calls. // agent.Call has its own internal default timeout (DefaultProcessTimeout = 600s), // but the outer context (120s) takes precedence via context.WithTimeout chaining. func newLLMProcess(t *testing.T, name string, args ...interface{}) *process.Process { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) t.Cleanup(cancel) return process.NewWithContext(ctx, name, args...) } // ============================================================================ // A. Pure LLM scenarios (tests.simple-greeting — no hooks) // ============================================================================ func TestProcessCall_LLM_Basic(t *testing.T) { if testing.Short() { t.Skip("Skipping: requires real LLM (source env.local.sh)") } testutils.Prepare(t) defer testutils.Clean(t) proc := newLLMProcess(t, "agent.call", map[string]interface{}{ "assistant_id": "tests.simple-greeting", "messages": []interface{}{ map[string]interface{}{"role": "user", "content": "Hello!"}, }, }) err := proc.Execute() require.NoError(t, err) val := proc.Value() require.NotNil(t, val, "process should return a value") result, ok := val.(*caller.Result) require.True(t, ok, "value should be *caller.Result, got %T", val) assert.Equal(t, "tests.simple-greeting", result.AgentID) assert.Empty(t, result.Error, "should not have error") assert.NotEmpty(t, result.Content, "should have LLM content") assert.NotNil(t, result.Response, "should have full response") t.Logf("LLM response: %s", result.Content) } func TestProcessCall_LLM_MultipleMessages(t *testing.T) { if testing.Short() { t.Skip("Skipping: requires real LLM") } testutils.Prepare(t) defer testutils.Clean(t) proc := newLLMProcess(t, "agent.call", map[string]interface{}{ "assistant_id": "tests.simple-greeting", "messages": []interface{}{ map[string]interface{}{"role": "system", "content": "Always reply in JSON format."}, map[string]interface{}{"role": "user", "content": "Say hello"}, }, }) err := proc.Execute() require.NoError(t, err) result, ok := proc.Value().(*caller.Result) require.True(t, ok) assert.Empty(t, result.Error) assert.NotEmpty(t, result.Content) t.Logf("Multi-message response: %s", result.Content) } func TestProcessCall_LLM_WithMetadata(t *testing.T) { if testing.Short() { t.Skip("Skipping: requires real LLM") } testutils.Prepare(t) defer testutils.Clean(t) proc := newLLMProcess(t, "agent.call", map[string]interface{}{ "assistant_id": "tests.simple-greeting", "messages": []interface{}{ map[string]interface{}{"role": "user", "content": "Hi there!"}, }, "metadata": map[string]interface{}{"source": "e2e-test", "mode": "task"}, "locale": "zh-CN", "route": "/test/e2e", "chat_id": "e2e-test-chat-001", }) err := proc.Execute() require.NoError(t, err) result, ok := proc.Value().(*caller.Result) require.True(t, ok) assert.Empty(t, result.Error) assert.NotEmpty(t, result.Content) t.Logf("With-metadata response: %s", result.Content) } func TestProcessCall_LLM_SkipOutputForced(t *testing.T) { if testing.Short() { t.Skip("Skipping: requires real LLM") } testutils.Prepare(t) defer testutils.Clean(t) // Explicitly pass skip.output=false — headless context MUST force it to true // If the force logic fails, this would panic (nil Writer). proc := newLLMProcess(t, "agent.call", map[string]interface{}{ "assistant_id": "tests.simple-greeting", "messages": []interface{}{ map[string]interface{}{"role": "user", "content": "Hello!"}, }, "skip": map[string]interface{}{ "output": false, "history": false, }, }) err := proc.Execute() require.NoError(t, err, "should NOT panic even with skip.output=false — headless forces true") result, ok := proc.Value().(*caller.Result) require.True(t, ok) assert.Empty(t, result.Error) assert.NotEmpty(t, result.Content) } // ============================================================================ // B. Create Hook scenarios (tests.create) // ============================================================================ func TestProcessCall_CreateHook_Default(t *testing.T) { if testing.Short() { t.Skip("Skipping: requires real LLM") } testutils.Prepare(t) defer testutils.Clean(t) // Send a generic message — Create Hook routes to scenarioDefault proc := newLLMProcess(t, "agent.call", map[string]interface{}{ "assistant_id": "tests.create", "messages": []interface{}{ map[string]interface{}{"role": "user", "content": "hello world"}, }, }) err := proc.Execute() require.NoError(t, err) result, ok := proc.Value().(*caller.Result) require.True(t, ok) assert.Equal(t, "tests.create", result.AgentID) assert.Empty(t, result.Error) assert.NotEmpty(t, result.Content, "Create Hook should still produce LLM response") t.Logf("CreateHook default response: %s", result.Content) } func TestProcessCall_CreateHook_ReturnFull(t *testing.T) { if testing.Short() { t.Skip("Skipping: requires real LLM") } testutils.Prepare(t) defer testutils.Clean(t) // Send "return_full" — Create Hook returns full HookCreateResponse with custom messages proc := newLLMProcess(t, "agent.call", map[string]interface{}{ "assistant_id": "tests.create", "messages": []interface{}{ map[string]interface{}{"role": "user", "content": "return_full"}, }, }) err := proc.Execute() require.NoError(t, err) result, ok := proc.Value().(*caller.Result) require.True(t, ok) assert.Equal(t, "tests.create", result.AgentID) assert.Empty(t, result.Error) // The Create Hook overrides messages with system + user, then LLM responds assert.NotEmpty(t, result.Content) t.Logf("CreateHook return_full response: %s", result.Content) } // ============================================================================ // C. Next Hook scenarios (tests.next) // ============================================================================ func TestProcessCall_NextHook_Standard(t *testing.T) { if testing.Short() { t.Skip("Skipping: requires real LLM") } testutils.Prepare(t) defer testutils.Clean(t) // Send "standard" — Next Hook returns null, standard LLM response is used proc := newLLMProcess(t, "agent.call", map[string]interface{}{ "assistant_id": "tests.next", "messages": []interface{}{ map[string]interface{}{"role": "user", "content": "standard"}, }, }) err := proc.Execute() require.NoError(t, err) result, ok := proc.Value().(*caller.Result) require.True(t, ok) assert.Equal(t, "tests.next", result.AgentID) assert.Empty(t, result.Error) assert.NotEmpty(t, result.Content, "standard scenario should return LLM content") t.Logf("NextHook standard response: %s", result.Content) } func TestProcessCall_NextHook_CustomData(t *testing.T) { if testing.Short() { t.Skip("Skipping: requires real LLM") } testutils.Prepare(t) defer testutils.Clean(t) // Send "return_custom_data" — Next Hook returns custom data proc := newLLMProcess(t, "agent.call", map[string]interface{}{ "assistant_id": "tests.next", "messages": []interface{}{ map[string]interface{}{"role": "user", "content": "return_custom_data"}, }, }) err := proc.Execute() require.NoError(t, err) result, ok := proc.Value().(*caller.Result) require.True(t, ok) assert.Equal(t, "tests.next", result.AgentID) assert.Empty(t, result.Error) assert.NotNil(t, result.Response, "should have response") // Next Hook custom data is available in response.Next if result.Response != nil && result.Response.Next != nil { t.Logf("NextHook custom data: %+v", result.Response.Next) nextMap, ok := result.Response.Next.(map[string]interface{}) if ok { // The Next Hook returns { data: { message, test, timestamp } } if dataMap, ok := nextMap["data"].(map[string]interface{}); ok { assert.Equal(t, "Custom response from Next Hook", dataMap["message"]) assert.Equal(t, true, dataMap["test"]) } } } } // ============================================================================ // D. Timeout scenarios // ============================================================================ func TestProcessCall_Timeout_Short(t *testing.T) { if testing.Short() { t.Skip("Skipping: requires real LLM (source env.local.sh)") } testutils.Prepare(t) defer testutils.Clean(t) // Set timeout=2 seconds — LLM round-trip will certainly exceed this. // Verifies that the timeout parameter is respected and produces an error. proc := newLLMProcess(t, "agent.call", map[string]interface{}{ "assistant_id": "tests.simple-greeting", "messages": []interface{}{ map[string]interface{}{"role": "user", "content": "Tell me a very long story about the history of computing."}, }, "timeout": 2, }) err := proc.Execute() if err != nil { // Timeout may surface as a process-level error (context deadline exceeded) t.Logf("Process error (expected timeout): %s", err.Error()) assert.Contains(t, err.Error(), "deadline exceeded", "error should indicate context deadline exceeded") return } // Or the agent.Stream may catch the timeout and return it in Result.Error val := proc.Value() require.NotNil(t, val, "process should return a value") result, ok := val.(*caller.Result) require.True(t, ok, "value should be *caller.Result, got %T", val) assert.NotEmpty(t, result.Error, "should have timeout error in result") t.Logf("Timeout error in result: %s", result.Error) } // ============================================================================ // E. Error / validation scenarios // ============================================================================ func TestProcessCall_Error_MissingAssistantID(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) proc := process.New("agent.call", map[string]interface{}{ "messages": []interface{}{ map[string]interface{}{"role": "user", "content": "Hello"}, }, }) err := proc.Execute() require.Error(t, err, "should fail: assistant_id is required") t.Logf("Expected error: %s", err.Error()) assert.Contains(t, err.Error(), "assistant_id") } func TestProcessCall_Error_EmptyMessages(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) proc := process.New("agent.call", map[string]interface{}{ "assistant_id": "tests.simple-greeting", "messages": []interface{}{}, }) err := proc.Execute() require.Error(t, err, "should fail: messages is required") t.Logf("Expected error: %s", err.Error()) assert.Contains(t, err.Error(), "messages") } func TestProcessCall_Error_InvalidArgument(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) // Pass a string instead of map — json.Marshal will succeed but Unmarshal will fail proc := process.New("agent.call", "not-a-map") err := proc.Execute() require.Error(t, err, "should fail: argument must be a map") t.Logf("Expected error: %s", err.Error()) } func TestProcessCall_Error_NoArgument(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) proc := process.New("agent.call") err := proc.Execute() require.Error(t, err, "should fail: argument is required") t.Logf("Expected error: %s", err.Error()) } func TestProcessCall_Error_NonexistentAgent(t *testing.T) { if testing.Short() { t.Skip("Skipping: requires environment") } testutils.Prepare(t) defer testutils.Clean(t) proc := process.New("agent.call", map[string]interface{}{ "assistant_id": "does.not.exist.agent", "messages": []interface{}{ map[string]interface{}{"role": "user", "content": "Hello"}, }, }) err := proc.Execute() require.NoError(t, err, "process should not error — error is in Result") result, ok := proc.Value().(*caller.Result) require.True(t, ok) assert.Equal(t, "does.not.exist.agent", result.AgentID) assert.NotEmpty(t, result.Error, "should have error for nonexistent agent") assert.Contains(t, result.Error, "failed to get agent") t.Logf("Expected error in result: %s", result.Error) } ================================================ FILE: agent/caller/process_test.go ================================================ package caller_test import ( "context" "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/caller" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/openapi/oauth/types" ) // --- NewHeadlessContext tests --- func TestNewHeadlessContext_Basic(t *testing.T) { authInfo := &types.AuthorizedInfo{ TeamID: "team-123", UserID: "user-456", } req := &caller.ProcessCallRequest{ AssistantID: "yao.keeper.classify", Messages: []map[string]interface{}{ {"role": "user", "content": "hello"}, }, Locale: "zh-CN", } ctx, opts := caller.NewHeadlessContext(context.Background(), authInfo, req) defer ctx.Release() assert.Equal(t, "yao.keeper.classify", ctx.AssistantID) assert.Equal(t, agentContext.RefererProcess, ctx.Referer) assert.Equal(t, "zh-CN", ctx.Locale) assert.NotEmpty(t, ctx.ChatID) // auto-generated require.NotNil(t, opts) require.NotNil(t, opts.Skip) assert.True(t, opts.Skip.Output, "skip.output must be forced true for headless context") assert.True(t, opts.Skip.History, "skip.history must be forced true for headless context") assert.Empty(t, opts.Connector) } func TestNewHeadlessContext_WithModel(t *testing.T) { req := &caller.ProcessCallRequest{ AssistantID: "test.agent", Messages: []map[string]interface{}{{"role": "user", "content": "hi"}}, Model: "deepseek.v3", } ctx, opts := caller.NewHeadlessContext(context.Background(), nil, req) defer ctx.Release() assert.Equal(t, "deepseek.v3", opts.Connector) } func TestNewHeadlessContext_WithChatID(t *testing.T) { req := &caller.ProcessCallRequest{ AssistantID: "test.agent", Messages: []map[string]interface{}{{"role": "user", "content": "hi"}}, ChatID: "custom-chat-id", } ctx, _ := caller.NewHeadlessContext(context.Background(), nil, req) defer ctx.Release() assert.Equal(t, "custom-chat-id", ctx.ChatID) } func TestNewHeadlessContext_ForceSkipOverridesUserSkip(t *testing.T) { req := &caller.ProcessCallRequest{ AssistantID: "test.agent", Messages: []map[string]interface{}{{"role": "user", "content": "hi"}}, Skip: &agentContext.Skip{Output: false, History: false, Trace: true}, } _, opts := caller.NewHeadlessContext(context.Background(), nil, req) // Output and History must be forced true regardless of user input assert.True(t, opts.Skip.Output, "skip.output must be forced true") assert.True(t, opts.Skip.History, "skip.history must be forced true") // User-specified skip.trace should be preserved assert.True(t, opts.Skip.Trace, "skip.trace should be preserved from user input") } func TestNewHeadlessContext_WithMetadata(t *testing.T) { req := &caller.ProcessCallRequest{ AssistantID: "test.agent", Messages: []map[string]interface{}{{"role": "user", "content": "hi"}}, Metadata: map[string]interface{}{"key": "value"}, Route: "/test", } ctx, _ := caller.NewHeadlessContext(context.Background(), nil, req) defer ctx.Release() assert.Equal(t, "value", ctx.Metadata["key"]) assert.Equal(t, "/test", ctx.Route) } func TestNewHeadlessContext_WithTimeout(t *testing.T) { req := &caller.ProcessCallRequest{ AssistantID: "test.agent", Messages: []map[string]interface{}{{"role": "user", "content": "hi"}}, Timeout: 30, } // Pass a context with timeout to verify it propagates ctx, _ := caller.NewHeadlessContext(context.Background(), nil, req) defer ctx.Release() // Timeout field is consumed by processAgentCall, not NewHeadlessContext. // Here we just verify the field is correctly set in the struct. assert.Equal(t, 30, req.Timeout) } func TestProcessCallRequest_DefaultTimeout(t *testing.T) { req := &caller.ProcessCallRequest{ AssistantID: "test.agent", Messages: []map[string]interface{}{{"role": "user", "content": "hi"}}, } // When Timeout is 0 (zero value), the default should be used assert.Equal(t, 0, req.Timeout, "zero value means use default") assert.Equal(t, 600, caller.DefaultProcessTimeout, "default timeout should be 600 seconds") } // --- ParseMessages tests --- func TestParseMessages_Basic(t *testing.T) { raw := []map[string]interface{}{ {"role": "user", "content": "hello"}, {"role": "assistant", "content": "hi there"}, } messages := caller.ParseMessages(raw) require.Len(t, messages, 2) assert.Equal(t, agentContext.MessageRole("user"), messages[0].Role) assert.Equal(t, "hello", messages[0].Content) assert.Equal(t, agentContext.MessageRole("assistant"), messages[1].Role) assert.Equal(t, "hi there", messages[1].Content) } func TestParseMessages_WithOptionalFields(t *testing.T) { name := "test-name" raw := []map[string]interface{}{ { "role": "tool", "content": "result", "name": name, "tool_call_id": "tc-1", }, } messages := caller.ParseMessages(raw) require.Len(t, messages, 1) msg := messages[0] assert.Equal(t, agentContext.MessageRole("tool"), msg.Role) require.NotNil(t, msg.Name) assert.Equal(t, name, *msg.Name) require.NotNil(t, msg.ToolCallID) assert.Equal(t, "tc-1", *msg.ToolCallID) } func TestParseMessages_Empty(t *testing.T) { messages := caller.ParseMessages(nil) assert.Empty(t, messages) } // --- NewResult tests --- func TestNewResult_Success(t *testing.T) { resp := &agentContext.Response{ Completion: &agentContext.CompletionResponse{ Content: "answer text", }, } result := caller.NewResult("test.agent", resp, nil) assert.Equal(t, "test.agent", result.AgentID) assert.Equal(t, "answer text", result.Content) assert.Empty(t, result.Error) assert.NotNil(t, result.Response) } func TestNewResult_WithError(t *testing.T) { result := caller.NewResult("test.agent", nil, errors.New("something failed")) assert.Equal(t, "test.agent", result.AgentID) assert.Equal(t, "something failed", result.Error) assert.Empty(t, result.Content) assert.Nil(t, result.Response) } func TestNewResult_NilResponse(t *testing.T) { result := caller.NewResult("test.agent", nil, nil) assert.Equal(t, "test.agent", result.AgentID) assert.Empty(t, result.Content) assert.Empty(t, result.Error) assert.Nil(t, result.Response) } func TestNewResult_NilCompletion(t *testing.T) { resp := &agentContext.Response{Completion: nil} result := caller.NewResult("test.agent", resp, nil) assert.Empty(t, result.Content, "content should be empty when completion is nil") assert.NotNil(t, result.Response) } ================================================ FILE: agent/caller/sandbox_integration_test.go ================================================ package caller_test import ( "context" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/caller" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // TestSandboxE2E_ClaudeCLIExecution tests the full sandbox + claude-proxy integration // This test verifies: // 1. Assistant loads with sandbox and prompts configured // 2. Claude CLI is invoked (not skipped) because prompts exist // 3. claude-proxy correctly translates requests to OpenAI backend // 4. Response is received with actual content func TestSandboxE2E_ClaudeCLIExecution(t *testing.T) { if testing.Short() { t.Skip("Skipping sandbox E2E test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Load the e2e-test assistant ast, err := assistant.Get("tests.sandbox.e2e-test") if err != nil { t.Skipf("Skipping test: e2e-test assistant not available: %v", err) } // Verify configuration require.NotNil(t, ast.Sandbox, "Sandbox should be configured") require.NotEmpty(t, ast.Prompts, "Prompts should be configured (required for Claude CLI)") t.Logf("✓ Assistant loaded: sandbox=%s, prompts=%d", ast.Sandbox.Command, len(ast.Prompts)) // Create authorized info authorized := &types.AuthorizedInfo{ Subject: "sandbox-e2e-test", UserID: "e2e-user-123", TenantID: "e2e-tenant", } // Create context with unique chat ID chatID := "sandbox-e2e-" + time.Now().Format("20060102-150405") ctx := agentContext.New(context.Background(), authorized, chatID) ctx.AssistantID = "tests.sandbox.e2e-test" // Create JSAPI api := caller.NewJSAPI(ctx) // Test 1: Simple echo command t.Run("EchoCommand", func(t *testing.T) { messages := []interface{}{ map[string]interface{}{ "role": "user", "content": "Run this command: echo 'SANDBOX_E2E_SUCCESS_12345'", }, } opts := map[string]interface{}{ "skip": map[string]interface{}{ "history": true, }, } startTime := time.Now() result := api.Call("tests.sandbox.e2e-test", messages, opts) duration := time.Since(startTime) t.Logf("Execution time: %v", duration) require.NotNil(t, result, "Result should not be nil") r, ok := result.(*caller.Result) require.True(t, ok, "Result should be *caller.Result") // Check for errors if r.Error != "" { // Check if it's a Docker/sandbox availability issue if strings.Contains(r.Error, "Docker") || strings.Contains(r.Error, "sandbox") || strings.Contains(r.Error, "container") { t.Skipf("Skipping: Docker/sandbox not available: %s", r.Error) } t.Fatalf("Agent call failed: %s", r.Error) } // Verify response t.Logf("Response content: %s", truncateStr(r.Content, 500)) assert.NotEmpty(t, r.Content, "Response content should not be empty") // Check if Claude executed the command if strings.Contains(r.Content, "SANDBOX_E2E_SUCCESS_12345") { t.Log("✓ Echo command executed successfully - found verification string") } else if strings.Contains(strings.ToLower(r.Content), "echo") || strings.Contains(r.Content, "SANDBOX") { t.Log("✓ Response mentions the command or partial output") } else { t.Log("⚠ Response does not contain expected output") } }) } // TestSandboxE2E_FileCreation tests that Claude can create files in the sandbox func TestSandboxE2E_FileCreation(t *testing.T) { if testing.Short() { t.Skip("Skipping sandbox E2E test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Load the e2e-test assistant ast, err := assistant.Get("tests.sandbox.e2e-test") if err != nil { t.Skipf("Skipping test: e2e-test assistant not available: %v", err) } require.NotNil(t, ast.Sandbox) require.NotEmpty(t, ast.Prompts) // Create context authorized := &types.AuthorizedInfo{ Subject: "sandbox-e2e-test", UserID: "e2e-user-456", TenantID: "e2e-tenant", } chatID := "sandbox-file-" + time.Now().Format("20060102-150405") ctx := agentContext.New(context.Background(), authorized, chatID) ctx.AssistantID = "tests.sandbox.e2e-test" api := caller.NewJSAPI(ctx) messages := []interface{}{ map[string]interface{}{ "role": "user", "content": "Create a file named 'test-output.txt' with the content 'FILE_CREATION_VERIFIED_67890', then read it back and show me the content.", }, } opts := map[string]interface{}{ "skip": map[string]interface{}{ "history": true, }, } startTime := time.Now() result := api.Call("tests.sandbox.e2e-test", messages, opts) duration := time.Since(startTime) t.Logf("Execution time: %v", duration) require.NotNil(t, result) r, ok := result.(*caller.Result) require.True(t, ok) if r.Error != "" { if strings.Contains(r.Error, "Docker") || strings.Contains(r.Error, "sandbox") { t.Skipf("Skipping: Docker/sandbox not available: %s", r.Error) } t.Fatalf("Agent call failed: %s", r.Error) } t.Logf("Response: %s", truncateStr(r.Content, 800)) // Verify file was created and read back if strings.Contains(r.Content, "FILE_CREATION_VERIFIED_67890") { t.Log("✓ File creation and read verified") } else if strings.Contains(strings.ToLower(r.Content), "created") || strings.Contains(strings.ToLower(r.Content), "wrote") || strings.Contains(r.Content, "test-output.txt") { t.Log("✓ File operation appears successful") } else { t.Log("⚠ Could not verify file creation") } } // TestSandboxE2E_HookOnlyMode tests that hooks can work without Claude CLI func TestSandboxE2E_HookOnlyMode(t *testing.T) { if testing.Short() { t.Skip("Skipping sandbox E2E test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Load the hook-only assistant (no prompts) ast, err := assistant.Get("tests.sandbox.hook-only") if err != nil { t.Skipf("Skipping test: hook-only assistant not available: %v", err) } // Verify configuration - no prompts means Claude CLI should be skipped require.NotNil(t, ast.Sandbox) require.Empty(t, ast.Prompts, "Hook-only mode should have no prompts") t.Logf("✓ Hook-only assistant loaded: sandbox=%s, prompts=%d (should be 0)", ast.Sandbox.Command, len(ast.Prompts)) // Create context authorized := &types.AuthorizedInfo{ Subject: "sandbox-hook-test", UserID: "hook-user-789", TenantID: "hook-tenant", } chatID := "sandbox-hook-" + time.Now().Format("20060102-150405") ctx := agentContext.New(context.Background(), authorized, chatID) ctx.AssistantID = "tests.sandbox.hook-only" api := caller.NewJSAPI(ctx) messages := []interface{}{ map[string]interface{}{ "role": "user", "content": "test hook-only mode", }, } opts := map[string]interface{}{ "skip": map[string]interface{}{ "history": true, }, } startTime := time.Now() result := api.Call("tests.sandbox.hook-only", messages, opts) duration := time.Since(startTime) t.Logf("Execution time: %v", duration) require.NotNil(t, result) r, ok := result.(*caller.Result) require.True(t, ok) if r.Error != "" { if strings.Contains(r.Error, "Docker") || strings.Contains(r.Error, "sandbox") { t.Skipf("Skipping: Docker/sandbox not available: %s", r.Error) } t.Fatalf("Agent call failed: %s", r.Error) } t.Logf("Response: %s", r.Content) t.Log("✓ Hook-only mode executed successfully") } // TestSandboxE2E_StreamingResponse verifies streaming works correctly func TestSandboxE2E_StreamingResponse(t *testing.T) { if testing.Short() { t.Skip("Skipping sandbox E2E test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Load the e2e-test assistant ast, err := assistant.Get("tests.sandbox.e2e-test") if err != nil { t.Skipf("Skipping test: e2e-test assistant not available: %v", err) } require.NotNil(t, ast.Sandbox) require.NotEmpty(t, ast.Prompts) // Create context authorized := &types.AuthorizedInfo{ Subject: "sandbox-stream-test", UserID: "stream-user", TenantID: "stream-tenant", } chatID := "sandbox-stream-" + time.Now().Format("20060102-150405") ctx := agentContext.New(context.Background(), authorized, chatID) ctx.AssistantID = "tests.sandbox.e2e-test" api := caller.NewJSAPI(ctx) // Ask for a slightly longer response to verify streaming messages := []interface{}{ map[string]interface{}{ "role": "user", "content": "Say 'Hello World' and nothing else.", }, } opts := map[string]interface{}{ "skip": map[string]interface{}{ "history": true, }, } startTime := time.Now() result := api.Call("tests.sandbox.e2e-test", messages, opts) duration := time.Since(startTime) t.Logf("Execution time: %v", duration) require.NotNil(t, result) r, ok := result.(*caller.Result) require.True(t, ok) if r.Error != "" { if strings.Contains(r.Error, "Docker") || strings.Contains(r.Error, "sandbox") { t.Skipf("Skipping: Docker/sandbox not available: %s", r.Error) } t.Fatalf("Agent call failed: %s", r.Error) } t.Logf("Response: %s", r.Content) // Verify we got a response assert.NotEmpty(t, r.Content, "Should have response content") if strings.Contains(strings.ToLower(r.Content), "hello") { t.Log("✓ Streaming response received with expected content") } else { t.Log("✓ Streaming response received") } } func truncateStr(s string, maxLen int) string { s = strings.ReplaceAll(s, "\n", " ") if len(s) <= maxLen { return s } return s[:maxLen] + "..." } ================================================ FILE: agent/caller/types.go ================================================ // Package caller provides types and utilities for agent-to-agent calls package caller import ( agentContext "github.com/yaoapp/yao/agent/context" ) // DefaultProcessTimeout is the default timeout (in seconds) for agent.Call Process. // LLM calls with tool use can take minutes; 10 minutes provides safe headroom. const DefaultProcessTimeout = 600 // Request represents a request to call an agent type Request struct { AgentID string `json:"agent"` // Target agent ID Messages []agentContext.Message `json:"messages"` // Messages to send Options *CallOptions `json:"options,omitempty"` // Call options Handler agentContext.OnMessageFunc `json:"-"` // OnMessage handler for this request (not serialized) } // CallOptions represents options for an agent call type CallOptions struct { Connector string `json:"connector,omitempty"` // Override connector Mode string `json:"mode,omitempty"` // Agent mode (chat, etc.) Metadata map[string]interface{} `json:"metadata,omitempty"` // Custom metadata passed to hooks Skip *agentContext.Skip `json:"skip,omitempty"` // Skip configuration (history, trace, output, etc.) } // Result represents the result of an agent call type Result struct { AgentID string `json:"agent_id"` // Agent ID that was called Response *agentContext.Response `json:"response,omitempty"` // Full response from agent Content string `json:"content,omitempty"` // Final text content (extracted from completion) Error string `json:"error,omitempty"` // Error message if call failed } // ProcessCallRequest is the parameter structure for the agent.Call Process. // Fields mirror CompletionRequest + HTTP header semantics, enabling headless // agent calls from contexts without agent.Context (e.g., YaoJob async tasks). type ProcessCallRequest struct { AssistantID string `json:"assistant_id"` // Required: target assistant ID (maps to X-Yao-Assistant header) Messages []map[string]interface{} `json:"messages"` // Required: message list (maps to CompletionRequest.Messages) Model string `json:"model,omitempty"` // Optional: connector ID override (maps to CompletionRequest.Model) Skip *agentContext.Skip `json:"skip,omitempty"` // Optional: skip config (maps to CompletionRequest.Skip) Metadata map[string]interface{} `json:"metadata,omitempty"` // Optional: passed to hooks (maps to CompletionRequest.Metadata) Locale string `json:"locale,omitempty"` // Optional (maps to locale query param) Route string `json:"route,omitempty"` // Optional (maps to CompletionRequest.Route) ChatID string `json:"chat_id,omitempty"` // Optional: auto-generated if empty (maps to chat_id query/header) Timeout int `json:"timeout,omitempty"` // Optional: timeout in seconds (default: DefaultProcessTimeout = 600) } // NewResult builds a Result from an agent call response. // Used by both ctx.agent.Call (orchestrator) and Process("agent.Call") to // ensure consistent result construction. func NewResult(agentID string, resp *agentContext.Response, err error) *Result { result := &Result{AgentID: agentID} if err != nil { result.Error = err.Error() return result } result.Response = resp if resp != nil && resp.Completion != nil { result.Content = extractContentFromCompletion(resp.Completion) } return result } // ToContextOptions converts CallOptions to context.Options for the agent call func (o *CallOptions) ToContextOptions() *agentContext.Options { if o == nil { return nil } return &agentContext.Options{ Connector: o.Connector, Mode: o.Mode, Metadata: o.Metadata, Skip: o.Skip, } } ================================================ FILE: agent/caller/types_test.go ================================================ package caller_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/caller" "github.com/yaoapp/yao/agent/context" ) func TestCallOptions_ToContextOptions_Nil(t *testing.T) { var opts *caller.CallOptions ctxOpts := opts.ToContextOptions() assert.Nil(t, ctxOpts) } func TestCallOptions_ToContextOptions_Empty(t *testing.T) { opts := &caller.CallOptions{} ctxOpts := opts.ToContextOptions() require.NotNil(t, ctxOpts) assert.Empty(t, ctxOpts.Connector) assert.Empty(t, ctxOpts.Mode) assert.Nil(t, ctxOpts.Metadata) assert.Nil(t, ctxOpts.Skip) } func TestCallOptions_ToContextOptions_Full(t *testing.T) { opts := &caller.CallOptions{ Connector: "gpt4", Mode: "chat", Metadata: map[string]interface{}{ "key": "value", }, Skip: &context.Skip{ History: true, Trace: true, Output: false, }, } ctxOpts := opts.ToContextOptions() require.NotNil(t, ctxOpts) assert.Equal(t, "gpt4", ctxOpts.Connector) assert.Equal(t, "chat", ctxOpts.Mode) assert.Equal(t, "value", ctxOpts.Metadata["key"]) require.NotNil(t, ctxOpts.Skip) assert.True(t, ctxOpts.Skip.History) assert.True(t, ctxOpts.Skip.Trace) assert.False(t, ctxOpts.Skip.Output) } func TestRequest_Basic(t *testing.T) { req := &caller.Request{ AgentID: "test-agent", Messages: []context.Message{ {Role: "user", Content: "Hello"}, }, } assert.Equal(t, "test-agent", req.AgentID) assert.Len(t, req.Messages, 1) assert.Equal(t, context.MessageRole("user"), req.Messages[0].Role) } func TestResult_Basic(t *testing.T) { result := &caller.Result{ AgentID: "test-agent", Content: "Hello response", } assert.Equal(t, "test-agent", result.AgentID) assert.Equal(t, "Hello response", result.Content) assert.Empty(t, result.Error) } func TestResult_WithError(t *testing.T) { result := &caller.Result{ AgentID: "test-agent", Error: "something went wrong", } assert.Equal(t, "test-agent", result.AgentID) assert.Equal(t, "something went wrong", result.Error) assert.Empty(t, result.Content) } ================================================ FILE: agent/content/content.go ================================================ package content import ( "fmt" "strings" "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/agent/content/docx" "github.com/yaoapp/yao/agent/content/image" "github.com/yaoapp/yao/agent/content/pdf" "github.com/yaoapp/yao/agent/content/pptx" "github.com/yaoapp/yao/agent/content/text" "github.com/yaoapp/yao/agent/content/types" agentContext "github.com/yaoapp/yao/agent/context" searchTypes "github.com/yaoapp/yao/agent/search/types" ) // ParseUserInput ParseUserInput func ParseUserInput(ctx *agentContext.Context, messages []agentContext.Message, options *types.Options) ([]agentContext.Message, *searchTypes.ReferenceContext, error) { var referenceContext *searchTypes.ReferenceContext = nil var parsedMessages []agentContext.Message = make([]agentContext.Message, 0) for _, message := range messages { // Only process user messages (current or from history) if message.Role != agentContext.RoleUser { parsedMessages = append(parsedMessages, message) continue } // Parse user input message (Ignore errors) parsedMessage, refs, err := parseUserInputMessage(ctx, message, options) if err != nil { parsedMessages = append(parsedMessages, message) log.Error("Failed to parse user input message: %v, %v", message.Content, err) continue } parsedMessages = append(parsedMessages, parsedMessage) // Add reference to reference context if refs != nil { if referenceContext == nil { referenceContext = &searchTypes.ReferenceContext{} } referenceContext.References = append(referenceContext.References, refs...) } } return parsedMessages, referenceContext, nil } // parseUserInputMessage parse a user input message func parseUserInputMessage(ctx *agentContext.Context, message agentContext.Message, options *types.Options) (agentContext.Message, []*searchTypes.Reference, error) { // Context content type switch content := message.Content.(type) { case string: return message, nil, nil case []agentContext.ContentPart: return parseContentParts(ctx, message, content, options) case []interface{}: // Handle content loaded from history/JSON ([]interface{} instead of []ContentPart) parts, ok := convertToContentParts(content) if !ok { return message, nil, nil } return parseContentParts(ctx, message, parts, options) } return message, nil, fmt.Errorf("unsupported content type: %T", message.Content) } // parseContentParts parses content parts and returns the parsed message func parseContentParts(ctx *agentContext.Context, message agentContext.Message, content []agentContext.ContentPart, options *types.Options) (agentContext.Message, []*searchTypes.Reference, error) { allRefs := []*searchTypes.Reference{} parts := make([]agentContext.ContentPart, 0, len(content)) for _, part := range content { parsedPart, refs, err := parseContentPart(ctx, part, options) if err != nil { parts = append(parts, part) continue } parts = append(parts, parsedPart) if refs != nil { allRefs = append(allRefs, refs...) } } parsedMessage := message parsedMessage.Content = parts return parsedMessage, allRefs, nil } // parseContentPart parse a content part func parseContentPart(ctx *agentContext.Context, content agentContext.ContentPart, options *types.Options) (agentContext.ContentPart, []*searchTypes.Reference, error) { switch content.Type { case agentContext.ContentText: return content, nil, nil case agentContext.ContentImageURL: return image.New(options).Parse(ctx, content) case agentContext.ContentInputAudio: return content, nil, nil case agentContext.ContentFile: return parseFileContent(ctx, content, options) case agentContext.ContentData: return content, nil, nil default: return content, nil, fmt.Errorf("unsupported content part type: %s", content.Type) } } // parseFileContent parses file content based on file type func parseFileContent(ctx *agentContext.Context, content agentContext.ContentPart, options *types.Options) (agentContext.ContentPart, []*searchTypes.Reference, error) { if content.File == nil || content.File.URL == "" { return content, nil, nil } // Determine file type from filename filename := strings.ToLower(content.File.Filename) // Check file type and route to appropriate handler switch { case strings.HasSuffix(filename, ".pdf"): return pdf.New(options).Parse(ctx, content) case strings.HasSuffix(filename, ".docx"): return docx.New(options).Parse(ctx, content) case strings.HasSuffix(filename, ".pptx"): return pptx.New(options).Parse(ctx, content) case text.IsSupportedExtension(filename): return text.New(options).Parse(ctx, content) } // For unsupported file types, try to read as text // This allows any file to be converted to text content return text.New(options).ParseRaw(ctx, content) } // convertToContentParts converts []interface{} to []ContentPart // This is needed when content is loaded from JSON/history and is []interface{} instead of []ContentPart func convertToContentParts(content []interface{}) ([]agentContext.ContentPart, bool) { parts := make([]agentContext.ContentPart, 0, len(content)) for _, item := range content { // Each item should be a map m, ok := item.(map[string]interface{}) if !ok { continue } // Get type field typeStr, _ := m["type"].(string) if typeStr == "" { continue } part := agentContext.ContentPart{ Type: agentContext.ContentPartType(typeStr), } switch typeStr { case "text": if text, ok := m["text"].(string); ok { part.Text = text } case "image_url": if imgData, ok := m["image_url"].(map[string]interface{}); ok { part.ImageURL = &agentContext.ImageURL{} if url, ok := imgData["url"].(string); ok { part.ImageURL.URL = url } if detail, ok := imgData["detail"].(string); ok { part.ImageURL.Detail = agentContext.ImageDetailLevel(detail) } } case "file": if fileData, ok := m["file"].(map[string]interface{}); ok { part.File = &agentContext.FileAttachment{} if url, ok := fileData["url"].(string); ok { part.File.URL = url } if filename, ok := fileData["filename"].(string); ok { part.File.Filename = filename } } case "input_audio": if audioData, ok := m["input_audio"].(map[string]interface{}); ok { part.InputAudio = &agentContext.InputAudio{} if data, ok := audioData["data"].(string); ok { part.InputAudio.Data = data } if format, ok := audioData["format"].(string); ok { part.InputAudio.Format = format } } case "data": if dataContent, ok := m["data"].(map[string]interface{}); ok { part.Data = &agentContext.DataContent{} if sources, ok := dataContent["sources"].([]interface{}); ok { part.Data.Sources = make([]agentContext.DataSource, 0, len(sources)) for _, src := range sources { if srcMap, ok := src.(map[string]interface{}); ok { source := agentContext.DataSource{} if t, ok := srcMap["type"].(string); ok { source.Type = agentContext.DataSourceType(t) } if name, ok := srcMap["name"].(string); ok { source.Name = name } if id, ok := srcMap["id"].(string); ok { source.ID = id } if filters, ok := srcMap["filters"].(map[string]interface{}); ok { source.Filters = filters } if metadata, ok := srcMap["metadata"].(map[string]interface{}); ok { source.Metadata = metadata } part.Data.Sources = append(part.Data.Sources, source) } } } } } parts = append(parts, part) } if len(parts) == 0 { return nil, false } return parts, true } ================================================ FILE: agent/content/docx/docx.go ================================================ package docx import ( "fmt" "os" "strings" "github.com/yaoapp/gou/office" "github.com/yaoapp/yao/agent/content/types" agentContext "github.com/yaoapp/yao/agent/context" searchTypes "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/attachment" ) // Docx handles DOCX content type Docx struct { options *types.Options } // New creates a new DOCX handler func New(options *types.Options) *Docx { return &Docx{options: options} } // Parse parses DOCX content and returns text func (h *Docx) Parse(ctx *agentContext.Context, content agentContext.ContentPart) (agentContext.ContentPart, []*searchTypes.Reference, error) { if content.File == nil || content.File.URL == "" { return content, nil, fmt.Errorf("file content missing URL") } url := content.File.URL // Check cache first cachedText, found, err := h.readFromCache(ctx, url) if err == nil && found { return agentContext.ContentPart{ Type: agentContext.ContentText, Text: cachedText, }, nil, nil } // Read DOCX file data, err := h.readFile(ctx, url) if err != nil { return content, nil, fmt.Errorf("failed to read DOCX: %w", err) } // Parse DOCX using gou/office parser := office.NewParser() result, err := parser.Parse(data) if err != nil { return content, nil, fmt.Errorf("failed to parse DOCX: %w", err) } text := result.Markdown if text == "" { return content, nil, fmt.Errorf("no text content extracted from DOCX") } // Cache the result if err := h.saveToCache(ctx, url, text); err != nil { // Log warning but don't fail fmt.Printf("Warning: failed to cache DOCX text: %v\n", err) } return agentContext.ContentPart{ Type: agentContext.ContentText, Text: text, }, nil, nil } // readFile reads DOCX content from various sources func (h *Docx) readFile(ctx *agentContext.Context, url string) ([]byte, error) { if strings.HasPrefix(url, "__") { return h.readFromUploader(ctx, url) } if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") { return nil, fmt.Errorf("HTTP URL fetch not implemented yet: %s", url) } // Try to read as local file path if _, err := os.Stat(url); err == nil { return os.ReadFile(url) } return nil, fmt.Errorf("unsupported DOCX source: %s", url) } // readFromUploader reads DOCX content from file uploader func (h *Docx) readFromUploader(ctx *agentContext.Context, wrapper string) ([]byte, error) { uploaderName, fileID, ok := attachment.Parse(wrapper) if !ok { return nil, fmt.Errorf("invalid uploader wrapper format: %s", wrapper) } manager, exists := attachment.Managers[uploaderName] if !exists { return nil, fmt.Errorf("uploader '%s' not found", uploaderName) } data, err := manager.Read(ctx.Context, fileID) if err != nil { return nil, fmt.Errorf("failed to read file: %w", err) } return data, nil } // readFromCache reads cached text content for a DOCX func (h *Docx) readFromCache(ctx *agentContext.Context, url string) (string, bool, error) { uploaderName, fileID, isWrapper := attachment.Parse(url) if !isWrapper { return "", false, nil } manager, exists := attachment.Managers[uploaderName] if !exists { return "", false, nil } text, err := manager.GetText(ctx.Context, fileID, false) if err == nil && text != "" { return text, true, nil } return "", false, nil } // saveToCache saves processed text to cache func (h *Docx) saveToCache(ctx *agentContext.Context, url string, text string) error { uploaderName, fileID, isWrapper := attachment.Parse(url) if !isWrapper { return nil } manager, exists := attachment.Managers[uploaderName] if !exists { return nil } return manager.SaveText(ctx.Context, fileID, text) } ================================================ FILE: agent/content/docx/docx_test.go ================================================ package docx_test import ( stdContext "context" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/connector/openai" "github.com/yaoapp/yao/agent/content/docx" contentTypes "github.com/yaoapp/yao/agent/content/types" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/testutils" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" ) const testFilesDir = "assistants/tests/vision-helper/tests" func newTestContext() *agentContext.Context { authorized := &oauthTypes.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client-id", UserID: "test-user-123", } ctx := agentContext.New(stdContext.Background(), authorized, "test-chat") ctx.AssistantID = "test-assistant" ctx.Locale = "en-us" ctx.IDGenerator = message.NewIDGenerator() return ctx } func newTestOptions() *contentTypes.Options { return &contentTypes.Options{ Capabilities: &openai.Capabilities{}, } } func getTestFilePath(filename string) string { yaoRoot := os.Getenv("YAO_TEST_APPLICATION") if yaoRoot == "" { yaoRoot = os.Getenv("YAO_ROOT") } return filepath.Join(yaoRoot, testFilesDir, filename) } // TestParseWithMissingURL tests parsing DOCX with missing URL func TestParseWithMissingURL(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) options := newTestOptions() ctx := newTestContext() content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: nil, } handler := docx.New(options) _, _, err := handler.Parse(ctx, content) assert.Error(t, err) assert.Contains(t, err.Error(), "missing URL") } // TestParseWithLocalDocx tests parsing a local DOCX file func TestParseWithLocalDocx(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) docxPath := getTestFilePath("docx.docx") if _, err := os.Stat(docxPath); os.IsNotExist(err) { t.Skipf("Test DOCX file not found: %s", docxPath) } options := newTestOptions() ctx := newTestContext() content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: docxPath, Filename: "docx.docx", }, } handler := docx.New(options) result, refs, err := handler.Parse(ctx, content) assert.NoError(t, err) assert.Nil(t, refs) assert.Equal(t, agentContext.ContentText, result.Type) assert.NotEmpty(t, result.Text) t.Logf("DOCX parse result (first 500 chars): %.500s...", result.Text) } // TestParseWithNonExistentFile tests parsing DOCX with non-existent file func TestParseWithNonExistentFile(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) options := newTestOptions() ctx := newTestContext() content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: "/non/existent/path/test.docx", Filename: "test.docx", }, } handler := docx.New(options) _, _, err := handler.Parse(ctx, content) assert.Error(t, err) assert.Contains(t, err.Error(), "unsupported DOCX source") } ================================================ FILE: agent/content/image/image.go ================================================ package image import ( "encoding/base64" "fmt" "strings" "github.com/yaoapp/yao/agent/content/tools" "github.com/yaoapp/yao/agent/content/types" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/agent/output/message" searchTypes "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/attachment" ) // Image handles image content type Image struct { options *types.Options } // New creates a new image handler func New(options *types.Options) *Image { return &Image{options: options} } // Parse parses image content // Logic: // 1. Check model capabilities first // 2. If forceUses is true and uses.Vision is specified -> use vision tool regardless of model capability // 3. If model supports vision -> pass through or convert to base64 format // 4. If model doesn't support vision -> use vision agent/MCP to extract text func (h *Image) Parse(ctx *agentContext.Context, content agentContext.ContentPart) (agentContext.ContentPart, []*searchTypes.Reference, error) { if content.ImageURL == nil || content.ImageURL.URL == "" { return content, nil, fmt.Errorf("image_url content missing URL") } // Check model capabilities first supportsVision, visionFormat := agentContext.GetVisionSupport(h.options.Capabilities) // Check if we should force using Uses tools forceUses := h.options.CompletionOptions != nil && h.options.CompletionOptions.ForceUses // If forceUses is true and uses.Vision is specified, use vision tool regardless of model capability if forceUses && h.options.CompletionOptions != nil && h.options.CompletionOptions.Uses != nil && h.options.CompletionOptions.Uses.Vision != "" { // Check cache first before calling agent cachedText, found, err := h.readFromCache(ctx, content.ImageURL.URL) if err == nil && found { return agentContext.ContentPart{ Type: agentContext.ContentText, Text: cachedText, }, nil, nil } return h.agent(ctx, content) } // If model supports vision if supportsVision { url := content.ImageURL.URL // If it's already a data URI (base64), pass through directly if strings.HasPrefix(url, "data:") { return content, nil, nil } // Convert to base64 format return h.base64(ctx, content, visionFormat) } // Model doesn't support vision - check cache first, then use vision agent/MCP // Try to get cached text (from attachment's content_preview) cachedText, found, err := h.readFromCache(ctx, content.ImageURL.URL) if err == nil && found { // Cache hit! Return as text content return agentContext.ContentPart{ Type: agentContext.ContentText, Text: cachedText, }, nil, nil } // No cache, try to use vision agent/MCP if h.options.CompletionOptions != nil && h.options.CompletionOptions.Uses != nil && h.options.CompletionOptions.Uses.Vision != "" { return h.agent(ctx, content) } // No vision support and no vision tool specified, return error return content, nil, fmt.Errorf("model doesn't support vision and no vision tool specified in uses.Vision") } // base64 encodes image content to base64 (for vision support) func (h *Image) base64(ctx *agentContext.Context, content agentContext.ContentPart, format agentContext.VisionFormat) (agentContext.ContentPart, []*searchTypes.Reference, error) { if content.ImageURL == nil || content.ImageURL.URL == "" { return content, nil, fmt.Errorf("image_url content missing URL") } url := content.ImageURL.URL // Read image data from source data, contentType, err := h.read(ctx, url) if err != nil { return content, nil, fmt.Errorf("failed to read image: %w", err) } // Encode to base64 data URI base64Data := EncodeToBase64DataURI(data, contentType) // Return as image_url ContentPart return agentContext.ContentPart{ Type: agentContext.ContentImageURL, ImageURL: &agentContext.ImageURL{ URL: base64Data, Detail: content.ImageURL.Detail, }, }, nil, nil } // read reads image content from various sources func (h *Image) read(ctx *agentContext.Context, url string) ([]byte, string, error) { // Determine source type and read accordingly if strings.HasPrefix(url, "data:") { // Data URI format: data:image/png;base64,xxxxx return h.readFromDataURI(url) } if strings.HasPrefix(url, "__") { // Uploader wrapper format: __uploader://fileid return h.readFromUploader(ctx, url) } if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") { // HTTP URL - for now return error, can be implemented later return nil, "", fmt.Errorf("HTTP URL fetch not implemented yet: %s", url) } // Unknown source return nil, "", fmt.Errorf("unsupported image source: %s", url) } // readFromDataURI reads image content from a data URI func (h *Image) readFromDataURI(dataURI string) ([]byte, string, error) { // Parse data URI: data:image/png;base64,xxxxx if !strings.HasPrefix(dataURI, "data:") { return nil, "", fmt.Errorf("invalid data URI format") } // Find the comma separator commaIndex := strings.Index(dataURI, ",") if commaIndex == -1 { return nil, "", fmt.Errorf("invalid data URI: missing comma separator") } // Extract metadata part (e.g., "image/png;base64") metadata := dataURI[5:commaIndex] // Skip "data:" base64Data := dataURI[commaIndex+1:] // Parse content type contentType := "image/png" // default if strings.Contains(metadata, ";") { parts := strings.Split(metadata, ";") if len(parts) > 0 && parts[0] != "" { contentType = parts[0] } } else if metadata != "" && metadata != "base64" { contentType = metadata } // Decode base64 data data, err := base64.StdEncoding.DecodeString(base64Data) if err != nil { return nil, "", fmt.Errorf("failed to decode base64 data: %w", err) } return data, contentType, nil } // readFromUploader reads image content from file uploader __uploader://fileid func (h *Image) readFromUploader(ctx *agentContext.Context, wrapper string) ([]byte, string, error) { // Parse wrapper to get uploader name and file ID uploaderName, fileID, ok := attachment.Parse(wrapper) if !ok { return nil, "", fmt.Errorf("invalid uploader wrapper format: %s", wrapper) } // Get attachment manager manager, exists := attachment.Managers[uploaderName] if !exists { return nil, "", fmt.Errorf("uploader '%s' not found", uploaderName) } // Get file info file, err := manager.Info(ctx.Context, fileID) if err != nil { return nil, "", fmt.Errorf("failed to get file info: %w", err) } // Read file content data, err := manager.Read(ctx.Context, fileID) if err != nil { return nil, "", fmt.Errorf("failed to read file: %w", err) } return data, file.ContentType, nil } // readFromCache reads cached text content for an image func (h *Image) readFromCache(ctx *agentContext.Context, url string) (string, bool, error) { // Parse URL to check if it's an uploader wrapper uploaderName, fileID, isWrapper := attachment.Parse(url) if !isWrapper { return "", false, nil // Not an uploader wrapper, no cache } // Try attachment manager's content_preview (cross-call cache) manager, exists := attachment.Managers[uploaderName] if !exists { return "", false, nil } // GetText with fullContent=false to get preview (default) text, err := manager.GetText(ctx.Context, fileID, false) if err == nil && text != "" { return text, true, nil } // No cache found return "", false, nil } // saveToCache saves processed text to cache func (h *Image) saveToCache(ctx *agentContext.Context, url string, text string) error { // Parse URL to get uploader name and file ID uploaderName, fileID, isWrapper := attachment.Parse(url) if !isWrapper { return nil // Not an uploader wrapper, nothing to cache } // Save to attachment manager for future calls manager, exists := attachment.Managers[uploaderName] if !exists { return nil } return manager.SaveText(ctx.Context, fileID, text) } // agent calls image agent to parse image content // Note: Cache check is done in Parse() before calling this method func (h *Image) agent(ctx *agentContext.Context, content agentContext.ContentPart) (agentContext.ContentPart, []*searchTypes.Reference, error) { if content.ImageURL == nil || content.ImageURL.URL == "" { return content, nil, fmt.Errorf("image_url content missing URL") } url := content.ImageURL.URL // Get vision tool from options visionTool := "" if h.options.CompletionOptions != nil && h.options.CompletionOptions.Uses != nil { visionTool = h.options.CompletionOptions.Uses.Vision } if visionTool == "" { return content, nil, fmt.Errorf("no vision tool specified in uses.Vision") } // Parse vision tool format // Format can be: // - "agent_id" (call agent) // - "mcp:server_id" (call MCP tool) var text string var err error if strings.HasPrefix(visionTool, "mcp:") { // MCP tool serverID := strings.TrimPrefix(visionTool, "mcp:") text, err = h.callMCPVisionTool(ctx, serverID, content) } else { // Agent call text, err = h.callVisionAgent(ctx, visionTool, content) } if err != nil { return content, nil, fmt.Errorf("failed to process image with vision tool: %w", err) } // Cache the result if cacheErr := h.saveToCache(ctx, url, text); cacheErr != nil { // Log error but don't fail the request fmt.Printf("Warning: failed to cache processed text: %v\n", cacheErr) } // Return as text content return agentContext.ContentPart{ Type: agentContext.ContentText, Text: text, }, nil, nil } // callVisionAgent calls a vision agent to describe the image func (h *Image) callVisionAgent(ctx *agentContext.Context, agentID string, content agentContext.ContentPart) (string, error) { // Read image data and convert to base64 data, contentType, err := h.read(ctx, content.ImageURL.URL) if err != nil { return "", fmt.Errorf("failed to read image: %w", err) } base64Data := EncodeToBase64DataURI(data, contentType) // Prepare message with image message := agentContext.Message{ Role: agentContext.RoleUser, Content: []agentContext.ContentPart{ { Type: agentContext.ContentText, Text: "Please analyze this image.", }, { Type: agentContext.ContentImageURL, ImageURL: &agentContext.ImageURL{ URL: base64Data, Detail: agentContext.DetailAuto, }, }, }, } // Send loading message loadingID := h.sendLoading(ctx, i18n.T(ctx.Locale, "content.image.analyzing")) // Call agent using the tools package result, err := tools.CallAgent(ctx, agentID, message) // Send done message h.sendLoadingDone(ctx, loadingID) return result, err } // callMCPVisionTool calls an MCP vision tool to describe the image func (h *Image) callMCPVisionTool(ctx *agentContext.Context, serverID string, content agentContext.ContentPart) (string, error) { // Read image data and convert to base64 data, contentType, err := h.read(ctx, content.ImageURL.URL) if err != nil { return "", fmt.Errorf("failed to read image: %w", err) } base64Data := EncodeToBase64DataURI(data, contentType) // Prepare arguments for MCP tool arguments := map[string]interface{}{ "image": base64Data, "content_type": contentType, } // Send loading message loadingID := h.sendLoading(ctx, i18n.T(ctx.Locale, "content.image.analyzing")) // Call MCP tool (typically "describe_image" or similar) result, err := tools.CallMCPTool(ctx, serverID, "describe_image", arguments) // Send done message h.sendLoadingDone(ctx, loadingID) return result, err } // sendLoading sends a loading message and returns the message ID // Returns empty string if SilentLoading is enabled func (h *Image) sendLoading(ctx *agentContext.Context, msg string) string { // Skip loading message if SilentLoading is enabled (called from parent handler like PDF) if h.options != nil && h.options.SilentLoading { return "" } loadingMsg := &message.Message{ Type: message.TypeLoading, Props: map[string]interface{}{ "message": msg, }, } msgID, err := ctx.SendStream(loadingMsg) if err != nil { return "" } return msgID } // sendLoadingDone marks the loading message as done func (h *Image) sendLoadingDone(ctx *agentContext.Context, loadingID string) { if loadingID == "" { return } doneMsg := &message.Message{ MessageID: loadingID, Delta: true, DeltaAction: message.DeltaReplace, Type: message.TypeLoading, Props: map[string]interface{}{ "done": true, }, } ctx.Send(doneMsg) } // EncodeToBase64DataURI encodes data to base64 with data URI prefix func EncodeToBase64DataURI(data []byte, contentType string) string { if contentType == "" { contentType = "image/png" // default for images } encoded := base64.StdEncoding.EncodeToString(data) return fmt.Sprintf("data:%s;base64,%s", contentType, encoded) } ================================================ FILE: agent/content/image/image_test.go ================================================ package image_test import ( stdContext "context" "encoding/base64" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/connector/openai" "github.com/yaoapp/yao/agent/content/image" contentTypes "github.com/yaoapp/yao/agent/content/types" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/testutils" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" ) // newTestContext creates a Context for testing with commonly used fields pre-populated func newTestContext(capabilities *openai.Capabilities) *agentContext.Context { authorized := &oauthTypes.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client-id", UserID: "test-user-123", TeamID: "test-team-456", TenantID: "test-tenant-789", } ctx := agentContext.New(stdContext.Background(), authorized, "test-chat") ctx.AssistantID = "test-assistant" ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = agentContext.Client{ Type: "web", UserAgent: "TestAgent/1.0", IP: "127.0.0.1", } ctx.Referer = agentContext.RefererAPI ctx.Accept = agentContext.AcceptWebCUI ctx.Route = "" ctx.Metadata = make(map[string]interface{}) ctx.Capabilities = capabilities ctx.IDGenerator = message.NewIDGenerator() return ctx } // newTestOptions creates test options with the given capabilities func newTestOptions(capabilities *openai.Capabilities, completionOptions *agentContext.CompletionOptions) *contentTypes.Options { return &contentTypes.Options{ Capabilities: capabilities, CompletionOptions: completionOptions, } } // TestParseWithVisionSupport tests parsing image when model supports vision func TestParseWithVisionSupport(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Create capabilities with vision support capabilities := &openai.Capabilities{ Vision: "openai", } options := newTestOptions(capabilities, nil) ctx := newTestContext(capabilities) // Create test image content with data URI base64Data := "data:image/png;base64," + base64.StdEncoding.EncodeToString(createTestPNG()) content := agentContext.ContentPart{ Type: agentContext.ContentImageURL, ImageURL: &agentContext.ImageURL{ URL: base64Data, Detail: agentContext.DetailAuto, }, } handler := image.New(options) result, refs, err := handler.Parse(ctx, content) assert.NoError(t, err) assert.Nil(t, refs) assert.Equal(t, agentContext.ContentImageURL, result.Type) assert.NotNil(t, result.ImageURL) assert.Equal(t, base64Data, result.ImageURL.URL) // Should pass through unchanged } // TestParseWithoutVisionSupport tests parsing image when model doesn't support vision func TestParseWithoutVisionSupport(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Create capabilities WITHOUT vision support capabilities := &openai.Capabilities{ Vision: nil, } options := newTestOptions(capabilities, nil) ctx := newTestContext(capabilities) // Create test image content base64Data := "data:image/png;base64," + base64.StdEncoding.EncodeToString(createTestPNG()) content := agentContext.ContentPart{ Type: agentContext.ContentImageURL, ImageURL: &agentContext.ImageURL{ URL: base64Data, Detail: agentContext.DetailAuto, }, } handler := image.New(options) _, _, err := handler.Parse(ctx, content) // Should return error because no vision support and no vision tool specified assert.Error(t, err) assert.Contains(t, err.Error(), "no vision tool specified") } // TestParseWithEmptyURL tests parsing image with empty URL func TestParseWithEmptyURL(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) capabilities := &openai.Capabilities{ Vision: "openai", } options := newTestOptions(capabilities, nil) ctx := newTestContext(capabilities) // Create content with empty URL content := agentContext.ContentPart{ Type: agentContext.ContentImageURL, ImageURL: &agentContext.ImageURL{ URL: "", }, } handler := image.New(options) _, _, err := handler.Parse(ctx, content) assert.Error(t, err) assert.Contains(t, err.Error(), "missing URL") } // TestParseWithNilImageURL tests parsing image with nil ImageURL func TestParseWithNilImageURL(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) capabilities := &openai.Capabilities{ Vision: "openai", } options := newTestOptions(capabilities, nil) ctx := newTestContext(capabilities) // Create content with nil ImageURL content := agentContext.ContentPart{ Type: agentContext.ContentImageURL, ImageURL: nil, } handler := image.New(options) _, _, err := handler.Parse(ctx, content) assert.Error(t, err) assert.Contains(t, err.Error(), "missing URL") } // TestEncodeToBase64DataURI tests base64 encoding func TestEncodeToBase64DataURI(t *testing.T) { tests := []struct { name string data []byte contentType string wantPrefix string }{ { name: "PNG image", data: []byte{0x89, 0x50, 0x4E, 0x47}, contentType: "image/png", wantPrefix: "data:image/png;base64,", }, { name: "JPEG image", data: []byte{0xFF, 0xD8, 0xFF}, contentType: "image/jpeg", wantPrefix: "data:image/jpeg;base64,", }, { name: "Empty content type defaults to PNG", data: []byte{0x01, 0x02, 0x03}, contentType: "", wantPrefix: "data:image/png;base64,", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := image.EncodeToBase64DataURI(tt.data, tt.contentType) // Check prefix assert.True(t, strings.HasPrefix(result, tt.wantPrefix)) // Verify base64 encoding by decoding base64Part := result[len(tt.wantPrefix):] decoded, err := base64.StdEncoding.DecodeString(base64Part) assert.NoError(t, err) // Verify decoded data matches original assert.Equal(t, tt.data, decoded) }) } } // TestParseDataURIPassthrough tests that data URI images pass through unchanged when vision is supported func TestParseDataURIPassthrough(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) capabilities := &openai.Capabilities{ Vision: "openai", } options := newTestOptions(capabilities, nil) ctx := newTestContext(capabilities) // Create test image content with data URI originalURL := "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" content := agentContext.ContentPart{ Type: agentContext.ContentImageURL, ImageURL: &agentContext.ImageURL{ URL: originalURL, Detail: agentContext.DetailHigh, }, } handler := image.New(options) result, refs, err := handler.Parse(ctx, content) assert.NoError(t, err) assert.Nil(t, refs) assert.Equal(t, agentContext.ContentImageURL, result.Type) assert.NotNil(t, result.ImageURL) assert.Equal(t, originalURL, result.ImageURL.URL) assert.Equal(t, agentContext.DetailHigh, result.ImageURL.Detail) } // TestParseWithVisionAgent tests parsing image using a vision agent when model doesn't support vision func TestParseWithVisionAgent(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Create capabilities WITHOUT vision support capabilities := &openai.Capabilities{ Vision: nil, // Model doesn't support vision } // Configure to use vision agent completionOptions := &agentContext.CompletionOptions{ Uses: &agentContext.Uses{ Vision: "tests.vision-test", // Use our test vision agent }, } options := newTestOptions(capabilities, completionOptions) ctx := newTestContext(capabilities) // Create test image content with data URI (1x1 red PNG) base64Data := "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" content := agentContext.ContentPart{ Type: agentContext.ContentImageURL, ImageURL: &agentContext.ImageURL{ URL: base64Data, Detail: agentContext.DetailAuto, }, } handler := image.New(options) result, refs, err := handler.Parse(ctx, content) // Should succeed and return text content (image description from agent) assert.NoError(t, err) assert.Nil(t, refs) assert.Equal(t, agentContext.ContentText, result.Type) assert.NotEmpty(t, result.Text) // Agent should return some description t.Logf("Vision agent response: %s", result.Text) } // TestParseWithForceUsesVisionAgent tests forceUses flag with vision agent func TestParseWithForceUsesVisionAgent(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Create capabilities WITH vision support capabilities := &openai.Capabilities{ Vision: "openai", // Model supports vision } // Configure to FORCE use vision agent (even though model supports vision) completionOptions := &agentContext.CompletionOptions{ ForceUses: true, // Force using the vision tool Uses: &agentContext.Uses{ Vision: "tests.vision-test", // Use our test vision agent }, } options := newTestOptions(capabilities, completionOptions) ctx := newTestContext(capabilities) // Create test image content with data URI base64Data := "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" content := agentContext.ContentPart{ Type: agentContext.ContentImageURL, ImageURL: &agentContext.ImageURL{ URL: base64Data, Detail: agentContext.DetailAuto, }, } handler := image.New(options) result, refs, err := handler.Parse(ctx, content) // Should succeed and return text content (forced to use agent even though model supports vision) assert.NoError(t, err) assert.Nil(t, refs) assert.Equal(t, agentContext.ContentText, result.Type) assert.NotEmpty(t, result.Text) // Agent should return some description t.Logf("Vision agent (forced) response: %s", result.Text) } // createTestPNG creates a minimal valid PNG image (1x1 red pixel) func createTestPNG() []byte { return []byte{ 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, // IHDR chunk 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, // 1x1 dimensions 0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, // IDAT chunk 0x54, 0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x03, 0x01, 0x01, 0x00, 0x18, 0xDD, 0x8D, 0xB4, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, // IEND chunk 0x44, 0xAE, 0x42, 0x60, 0x82, } } ================================================ FILE: agent/content/link/link.go ================================================ package link ================================================ FILE: agent/content/pdf/pdf.go ================================================ package pdf import ( "fmt" "os" "path/filepath" "strings" "time" goupdf "github.com/yaoapp/gou/pdf" "github.com/yaoapp/yao/agent/content/image" "github.com/yaoapp/yao/agent/content/types" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/agent/output/message" searchTypes "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/attachment" kbTypes "github.com/yaoapp/yao/kb/types" ) // PDF handles PDF content type PDF struct { options *types.Options } // New creates a new PDF handler func New(options *types.Options) *PDF { return &PDF{options: options} } // Parse parses PDF content by converting to images and processing each page // Returns multiple ContentPart (one text part per page) combined into a single text part func (h *PDF) Parse(ctx *agentContext.Context, content agentContext.ContentPart) (agentContext.ContentPart, []*searchTypes.Reference, error) { if content.File == nil || content.File.URL == "" { return content, nil, fmt.Errorf("file content missing URL") } url := content.File.URL // Check cache first cachedText, found, err := h.readFromCache(ctx, url) if err == nil && found { return agentContext.ContentPart{ Type: agentContext.ContentText, Text: cachedText, }, nil, nil } // Convert PDF to images and process each page return h.asImages(ctx, content) } // ParseMulti parses PDF content and returns multiple ContentParts (one per page) // This is useful when you need separate parts for each page func (h *PDF) ParseMulti(ctx *agentContext.Context, content agentContext.ContentPart) ([]agentContext.ContentPart, []*searchTypes.Reference, error) { if content.File == nil || content.File.URL == "" { return nil, nil, fmt.Errorf("file content missing URL") } url := content.File.URL // Check cache first - if cached, return as single text part cachedText, found, err := h.readFromCache(ctx, url) if err == nil && found { return []agentContext.ContentPart{ { Type: agentContext.ContentText, Text: cachedText, }, }, nil, nil } // Convert PDF to images and process each page return h.asImagesMulti(ctx, content) } // asImages converts PDF to images and processes each page, returning combined result func (h *PDF) asImages(ctx *agentContext.Context, content agentContext.ContentPart) (agentContext.ContentPart, []*searchTypes.Reference, error) { parts, refs, err := h.asImagesMulti(ctx, content) if err != nil { return content, nil, err } if len(parts) == 0 { return content, nil, fmt.Errorf("no pages extracted from PDF") } // Check if any parts are text (vision agent was used) or image_url (model supports vision) hasTextParts := false hasImageParts := false for _, part := range parts { if part.Type == agentContext.ContentText { hasTextParts = true } else if part.Type == agentContext.ContentImageURL { hasImageParts = true } } // If all parts are image_url (model supports vision), return the first image // The caller should use ParseMulti to get all images if hasImageParts && !hasTextParts { return parts[0], refs, nil } // Combine all text parts into one var combinedText strings.Builder pageNum := 0 for _, part := range parts { if part.Type == agentContext.ContentText && part.Text != "" { pageNum++ if pageNum > 1 { combinedText.WriteString("\n\n---\n\n") // Page separator } combinedText.WriteString(fmt.Sprintf("## Page %d\n\n", pageNum)) combinedText.WriteString(part.Text) } } result := agentContext.ContentPart{ Type: agentContext.ContentText, Text: combinedText.String(), } // Cache the combined result if content.File != nil && content.File.URL != "" && combinedText.Len() > 0 { h.saveToCache(ctx, content.File.URL, combinedText.String()) } return result, refs, nil } // asImagesMulti converts PDF to images and processes each page separately func (h *PDF) asImagesMulti(ctx *agentContext.Context, content agentContext.ContentPart) ([]agentContext.ContentPart, []*searchTypes.Reference, error) { if content.File == nil || content.File.URL == "" { return nil, nil, fmt.Errorf("file content missing URL") } url := content.File.URL // Read PDF file pdfData, err := h.readPDF(ctx, url) if err != nil { return nil, nil, fmt.Errorf("failed to read PDF: %w", err) } // Create temporary file for PDF tempDir := os.TempDir() pdfPath := filepath.Join(tempDir, fmt.Sprintf("pdf_%d.pdf", time.Now().UnixNano())) if err := os.WriteFile(pdfPath, pdfData, 0644); err != nil { return nil, nil, fmt.Errorf("failed to write temp PDF: %w", err) } defer os.Remove(pdfPath) // Get PDF processor with global config processor, err := h.getPDFProcessor() if err != nil { return nil, nil, fmt.Errorf("failed to create PDF processor: %w", err) } // Create output directory for images imagesDir := filepath.Join(tempDir, fmt.Sprintf("pdf_images_%d", time.Now().UnixNano())) if err := os.MkdirAll(imagesDir, 0755); err != nil { return nil, nil, fmt.Errorf("failed to create images directory: %w", err) } defer os.RemoveAll(imagesDir) // Convert PDF to images convertConfig := goupdf.ConvertConfig{ OutputDir: imagesDir, OutputPrefix: "page", Format: "png", DPI: 150, Quality: 90, PageRange: "all", } imageFiles, err := processor.Convert(ctx.Context, pdfPath, convertConfig) if err != nil { return nil, nil, fmt.Errorf("failed to convert PDF to images: %w", err) } if len(imageFiles) == 0 { return nil, nil, fmt.Errorf("no pages extracted from PDF") } // Process each image using the image handler (with SilentLoading to suppress image loading messages) imageOptions := *h.options // Copy options imageOptions.SilentLoading = true imageHandler := image.New(&imageOptions) var parts []agentContext.ContentPart var allRefs []*searchTypes.Reference for i, imageFile := range imageFiles { // Send loading message for this page loadingMsg := fmt.Sprintf(i18n.T(ctx.Locale, "content.pdf.analyzing_page"), i+1, len(imageFiles)) loadingID := h.sendLoading(ctx, loadingMsg) // Read image file imageData, err := os.ReadFile(imageFile) if err != nil { h.sendLoadingDone(ctx, loadingID) continue } // Convert to base64 data URI base64Data := image.EncodeToBase64DataURI(imageData, "image/png") // Create image content part imagePart := agentContext.ContentPart{ Type: agentContext.ContentImageURL, ImageURL: &agentContext.ImageURL{ URL: base64Data, Detail: agentContext.DetailAuto, }, } // Parse image using image handler parsedPart, refs, err := imageHandler.Parse(ctx, imagePart) // Mark loading as done h.sendLoadingDone(ctx, loadingID) if err != nil { // If parsing fails, skip this page continue } parts = append(parts, parsedPart) if refs != nil { allRefs = append(allRefs, refs...) } } if len(parts) == 0 { return nil, nil, fmt.Errorf("failed to process any PDF pages") } return parts, allRefs, nil } // readPDF reads PDF content from various sources func (h *PDF) readPDF(ctx *agentContext.Context, url string) ([]byte, error) { if strings.HasPrefix(url, "__") { // Uploader wrapper format: __uploader://fileid return h.readFromUploader(ctx, url) } if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") { return nil, fmt.Errorf("HTTP URL fetch not implemented yet: %s", url) } // Try to read as local file path if _, err := os.Stat(url); err == nil { return os.ReadFile(url) } return nil, fmt.Errorf("unsupported PDF source: %s", url) } // readFromUploader reads PDF content from file uploader func (h *PDF) readFromUploader(ctx *agentContext.Context, wrapper string) ([]byte, error) { uploaderName, fileID, ok := attachment.Parse(wrapper) if !ok { return nil, fmt.Errorf("invalid uploader wrapper format: %s", wrapper) } manager, exists := attachment.Managers[uploaderName] if !exists { return nil, fmt.Errorf("uploader '%s' not found", uploaderName) } data, err := manager.Read(ctx.Context, fileID) if err != nil { return nil, fmt.Errorf("failed to read file: %w", err) } return data, nil } // readFromCache reads cached text content for a PDF func (h *PDF) readFromCache(ctx *agentContext.Context, url string) (string, bool, error) { uploaderName, fileID, isWrapper := attachment.Parse(url) if !isWrapper { return "", false, nil } manager, exists := attachment.Managers[uploaderName] if !exists { return "", false, nil } text, err := manager.GetText(ctx.Context, fileID, false) if err == nil && text != "" { return text, true, nil } return "", false, nil } // saveToCache saves processed text to cache func (h *PDF) saveToCache(ctx *agentContext.Context, url string, text string) error { uploaderName, fileID, isWrapper := attachment.Parse(url) if !isWrapper { return nil } manager, exists := attachment.Managers[uploaderName] if !exists { return nil } return manager.SaveText(ctx.Context, fileID, text) } // getPDFProcessor creates a PDF processor using global KB config func (h *PDF) getPDFProcessor() (*goupdf.PDF, error) { globalPDF := kbTypes.GetGlobalPDF() opts := goupdf.Options{ ConvertTool: goupdf.ToolPdftoppm, // default ToolPath: "", } if globalPDF != nil { if globalPDF.ConvertTool != "" { switch globalPDF.ConvertTool { case "pdftoppm": opts.ConvertTool = goupdf.ToolPdftoppm case "mutool": opts.ConvertTool = goupdf.ToolMutool case "imagemagick", "convert": opts.ConvertTool = goupdf.ToolImageMagick } } if globalPDF.ToolPath != "" { opts.ToolPath = globalPDF.ToolPath } } return goupdf.New(opts), nil } // sendLoading sends a loading message and returns the message ID func (h *PDF) sendLoading(ctx *agentContext.Context, msg string) string { loadingMsg := &message.Message{ Type: message.TypeLoading, Props: map[string]interface{}{ "message": msg, }, } msgID, err := ctx.SendStream(loadingMsg) if err != nil { return "" } return msgID } // sendLoadingDone marks the loading message as done func (h *PDF) sendLoadingDone(ctx *agentContext.Context, loadingID string) { if loadingID == "" { return } doneMsg := &message.Message{ MessageID: loadingID, Delta: true, DeltaAction: message.DeltaReplace, Type: message.TypeLoading, Props: map[string]interface{}{ "done": true, }, } ctx.Send(doneMsg) } ================================================ FILE: agent/content/pdf/pdf_test.go ================================================ package pdf_test import ( stdContext "context" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/connector/openai" "github.com/yaoapp/yao/agent/content/pdf" contentTypes "github.com/yaoapp/yao/agent/content/types" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/testutils" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" ) // Test files directory (relative to yao-dev-app) const testFilesDir = "assistants/tests/vision-helper/tests" // newTestContext creates a Context for testing with commonly used fields pre-populated func newTestContext(capabilities *openai.Capabilities) *agentContext.Context { authorized := &oauthTypes.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client-id", UserID: "test-user-123", TeamID: "test-team-456", TenantID: "test-tenant-789", } ctx := agentContext.New(stdContext.Background(), authorized, "test-chat") ctx.AssistantID = "test-assistant" ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = agentContext.Client{ Type: "web", UserAgent: "TestAgent/1.0", IP: "127.0.0.1", } ctx.Referer = agentContext.RefererAPI ctx.Accept = agentContext.AcceptWebCUI ctx.Route = "" ctx.Metadata = make(map[string]interface{}) ctx.Capabilities = capabilities ctx.IDGenerator = message.NewIDGenerator() return ctx } // newTestOptions creates test options with the given capabilities func newTestOptions(capabilities *openai.Capabilities, completionOptions *agentContext.CompletionOptions) *contentTypes.Options { return &contentTypes.Options{ Capabilities: capabilities, CompletionOptions: completionOptions, } } // getTestFilePath returns the full path to a test file func getTestFilePath(filename string) string { yaoRoot := os.Getenv("YAO_TEST_APPLICATION") if yaoRoot == "" { yaoRoot = os.Getenv("YAO_ROOT") } return filepath.Join(yaoRoot, testFilesDir, filename) } // TestParseWithMissingURL tests parsing PDF with missing URL func TestParseWithMissingURL(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) capabilities := &openai.Capabilities{ Vision: "openai", } options := newTestOptions(capabilities, nil) ctx := newTestContext(capabilities) // Create content with nil File content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: nil, } handler := pdf.New(options) _, _, err := handler.Parse(ctx, content) assert.Error(t, err) assert.Contains(t, err.Error(), "missing URL") } // TestParseWithEmptyURL tests parsing PDF with empty URL func TestParseWithEmptyURL(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) capabilities := &openai.Capabilities{ Vision: "openai", } options := newTestOptions(capabilities, nil) ctx := newTestContext(capabilities) // Create content with empty URL content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: "", Filename: "test.pdf", }, } handler := pdf.New(options) _, _, err := handler.Parse(ctx, content) assert.Error(t, err) assert.Contains(t, err.Error(), "missing URL") } // TestParseWithLocalPDFAndVisionSupport tests parsing a local PDF file when model supports vision func TestParseWithLocalPDFAndVisionSupport(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Check if test file exists pdfPath := getTestFilePath("test.pdf") if _, err := os.Stat(pdfPath); os.IsNotExist(err) { t.Skipf("Test PDF file not found: %s", pdfPath) } // Create capabilities with vision support capabilities := &openai.Capabilities{ Vision: "openai", } options := newTestOptions(capabilities, nil) ctx := newTestContext(capabilities) // Create content with local file path content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: pdfPath, Filename: "test.pdf", }, } handler := pdf.New(options) result, refs, err := handler.Parse(ctx, content) // Should succeed - PDF converted to images assert.NoError(t, err) assert.Nil(t, refs) // When model supports vision, Parse returns the first image_url part // Use ParseMulti to get all pages as separate image_url parts assert.Equal(t, agentContext.ContentImageURL, result.Type) assert.NotNil(t, result.ImageURL) assert.NotEmpty(t, result.ImageURL.URL) t.Logf("PDF parse result type: %s, URL prefix: %s...", result.Type, result.ImageURL.URL[:50]) } // TestParseWithLocalPDFAndVisionAgent tests parsing a local PDF file using vision agent func TestParseWithLocalPDFAndVisionAgent(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Check if test file exists pdfPath := getTestFilePath("test.pdf") if _, err := os.Stat(pdfPath); os.IsNotExist(err) { t.Skipf("Test PDF file not found: %s", pdfPath) } // Create capabilities WITHOUT vision support capabilities := &openai.Capabilities{ Vision: nil, } // Configure to use vision agent completionOptions := &agentContext.CompletionOptions{ Uses: &agentContext.Uses{ Vision: "tests.vision-test", // Use our test vision agent }, } options := newTestOptions(capabilities, completionOptions) ctx := newTestContext(capabilities) // Create content with local file path content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: pdfPath, Filename: "test.pdf", }, } handler := pdf.New(options) result, refs, err := handler.Parse(ctx, content) // Should succeed - PDF converted to images and processed by vision agent assert.NoError(t, err) assert.Nil(t, refs) assert.Equal(t, agentContext.ContentText, result.Type) assert.NotEmpty(t, result.Text) t.Logf("PDF parse result (via vision agent): %s", result.Text) } // TestParseMultiWithLocalPDF tests ParseMulti which returns separate parts for each page func TestParseMultiWithLocalPDF(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Check if test file exists pdfPath := getTestFilePath("test.pdf") if _, err := os.Stat(pdfPath); os.IsNotExist(err) { t.Skipf("Test PDF file not found: %s", pdfPath) } // Create capabilities with vision support capabilities := &openai.Capabilities{ Vision: "openai", } options := newTestOptions(capabilities, nil) ctx := newTestContext(capabilities) // Create content with local file path content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: pdfPath, Filename: "test.pdf", }, } handler := pdf.New(options) parts, refs, err := handler.ParseMulti(ctx, content) // Should succeed and return at least one part (one per page) assert.NoError(t, err) assert.Nil(t, refs) assert.NotEmpty(t, parts) t.Logf("PDF ParseMulti returned %d parts", len(parts)) // When model supports vision, each part should be image_url type for i, part := range parts { assert.Equal(t, agentContext.ContentImageURL, part.Type) assert.NotNil(t, part.ImageURL) t.Logf(" Part %d: type=%s, has URL=%v", i+1, part.Type, part.ImageURL != nil && part.ImageURL.URL != "") } } // TestParseWithUnsupportedSource tests parsing PDF with unsupported source func TestParseWithUnsupportedSource(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) capabilities := &openai.Capabilities{ Vision: "openai", } options := newTestOptions(capabilities, nil) ctx := newTestContext(capabilities) // Create content with HTTP URL (not implemented) content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: "https://example.com/test.pdf", Filename: "test.pdf", }, } handler := pdf.New(options) _, _, err := handler.Parse(ctx, content) assert.Error(t, err) assert.Contains(t, err.Error(), "HTTP URL fetch not implemented") } // TestParseWithNonExistentFile tests parsing PDF with non-existent file func TestParseWithNonExistentFile(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) capabilities := &openai.Capabilities{ Vision: "openai", } options := newTestOptions(capabilities, nil) ctx := newTestContext(capabilities) // Create content with non-existent file content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: "/non/existent/path/test.pdf", Filename: "test.pdf", }, } handler := pdf.New(options) _, _, err := handler.Parse(ctx, content) assert.Error(t, err) assert.Contains(t, err.Error(), "unsupported PDF source") } // TestSilentLoadingOption tests that SilentLoading option is respected func TestSilentLoadingOption(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) capabilities := &openai.Capabilities{ Vision: "openai", } // Create options with SilentLoading enabled options := &contentTypes.Options{ Capabilities: capabilities, SilentLoading: true, } // This test just verifies the option can be set // The actual behavior is tested in the image handler tests handler := pdf.New(options) assert.NotNil(t, handler) assert.True(t, options.SilentLoading) } ================================================ FILE: agent/content/pptx/pptx.go ================================================ package pptx import ( "fmt" "os" "strings" "github.com/yaoapp/gou/office" "github.com/yaoapp/yao/agent/content/types" agentContext "github.com/yaoapp/yao/agent/context" searchTypes "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/attachment" ) // Pptx handles PPTX content type Pptx struct { options *types.Options } // New creates a new PPTX handler func New(options *types.Options) *Pptx { return &Pptx{options: options} } // Parse parses PPTX content and returns text func (h *Pptx) Parse(ctx *agentContext.Context, content agentContext.ContentPart) (agentContext.ContentPart, []*searchTypes.Reference, error) { if content.File == nil || content.File.URL == "" { return content, nil, fmt.Errorf("file content missing URL") } url := content.File.URL // Check cache first cachedText, found, err := h.readFromCache(ctx, url) if err == nil && found { return agentContext.ContentPart{ Type: agentContext.ContentText, Text: cachedText, }, nil, nil } // Read PPTX file data, err := h.readFile(ctx, url) if err != nil { return content, nil, fmt.Errorf("failed to read PPTX: %w", err) } // Parse PPTX using gou/office parser := office.NewParser() result, err := parser.Parse(data) if err != nil { return content, nil, fmt.Errorf("failed to parse PPTX: %w", err) } text := result.Markdown if text == "" { return content, nil, fmt.Errorf("no text content extracted from PPTX") } // Cache the result if err := h.saveToCache(ctx, url, text); err != nil { // Log warning but don't fail fmt.Printf("Warning: failed to cache PPTX text: %v\n", err) } return agentContext.ContentPart{ Type: agentContext.ContentText, Text: text, }, nil, nil } // readFile reads PPTX content from various sources func (h *Pptx) readFile(ctx *agentContext.Context, url string) ([]byte, error) { if strings.HasPrefix(url, "__") { return h.readFromUploader(ctx, url) } if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") { return nil, fmt.Errorf("HTTP URL fetch not implemented yet: %s", url) } // Try to read as local file path if _, err := os.Stat(url); err == nil { return os.ReadFile(url) } return nil, fmt.Errorf("unsupported PPTX source: %s", url) } // readFromUploader reads PPTX content from file uploader func (h *Pptx) readFromUploader(ctx *agentContext.Context, wrapper string) ([]byte, error) { uploaderName, fileID, ok := attachment.Parse(wrapper) if !ok { return nil, fmt.Errorf("invalid uploader wrapper format: %s", wrapper) } manager, exists := attachment.Managers[uploaderName] if !exists { return nil, fmt.Errorf("uploader '%s' not found", uploaderName) } data, err := manager.Read(ctx.Context, fileID) if err != nil { return nil, fmt.Errorf("failed to read file: %w", err) } return data, nil } // readFromCache reads cached text content for a PPTX func (h *Pptx) readFromCache(ctx *agentContext.Context, url string) (string, bool, error) { uploaderName, fileID, isWrapper := attachment.Parse(url) if !isWrapper { return "", false, nil } manager, exists := attachment.Managers[uploaderName] if !exists { return "", false, nil } text, err := manager.GetText(ctx.Context, fileID, false) if err == nil && text != "" { return text, true, nil } return "", false, nil } // saveToCache saves processed text to cache func (h *Pptx) saveToCache(ctx *agentContext.Context, url string, text string) error { uploaderName, fileID, isWrapper := attachment.Parse(url) if !isWrapper { return nil } manager, exists := attachment.Managers[uploaderName] if !exists { return nil } return manager.SaveText(ctx.Context, fileID, text) } ================================================ FILE: agent/content/pptx/pptx_test.go ================================================ package pptx_test import ( stdContext "context" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/connector/openai" "github.com/yaoapp/yao/agent/content/pptx" contentTypes "github.com/yaoapp/yao/agent/content/types" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/testutils" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" ) const testFilesDir = "assistants/tests/vision-helper/tests" func newTestContext() *agentContext.Context { authorized := &oauthTypes.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client-id", UserID: "test-user-123", } ctx := agentContext.New(stdContext.Background(), authorized, "test-chat") ctx.AssistantID = "test-assistant" ctx.Locale = "en-us" ctx.IDGenerator = message.NewIDGenerator() return ctx } func newTestOptions() *contentTypes.Options { return &contentTypes.Options{ Capabilities: &openai.Capabilities{}, } } func getTestFilePath(filename string) string { yaoRoot := os.Getenv("YAO_TEST_APPLICATION") if yaoRoot == "" { yaoRoot = os.Getenv("YAO_ROOT") } return filepath.Join(yaoRoot, testFilesDir, filename) } // TestParseWithMissingURL tests parsing PPTX with missing URL func TestParseWithMissingURL(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) options := newTestOptions() ctx := newTestContext() content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: nil, } handler := pptx.New(options) _, _, err := handler.Parse(ctx, content) assert.Error(t, err) assert.Contains(t, err.Error(), "missing URL") } // TestParseWithLocalPptx tests parsing a local PPTX file func TestParseWithLocalPptx(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) pptxPath := getTestFilePath("pptx.pptx") if _, err := os.Stat(pptxPath); os.IsNotExist(err) { t.Skipf("Test PPTX file not found: %s", pptxPath) } options := newTestOptions() ctx := newTestContext() content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: pptxPath, Filename: "pptx.pptx", }, } handler := pptx.New(options) result, refs, err := handler.Parse(ctx, content) assert.NoError(t, err) assert.Nil(t, refs) assert.Equal(t, agentContext.ContentText, result.Type) assert.NotEmpty(t, result.Text) t.Logf("PPTX parse result (first 500 chars): %.500s...", result.Text) } // TestParseWithNonExistentFile tests parsing PPTX with non-existent file func TestParseWithNonExistentFile(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) options := newTestOptions() ctx := newTestContext() content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: "/non/existent/path/test.pptx", Filename: "test.pptx", }, } handler := pptx.New(options) _, _, err := handler.Parse(ctx, content) assert.Error(t, err) assert.Contains(t, err.Error(), "unsupported PPTX source") } ================================================ FILE: agent/content/text/text.go ================================================ package text import ( "fmt" "os" "path/filepath" "strings" "github.com/yaoapp/yao/agent/content/types" agentContext "github.com/yaoapp/yao/agent/context" searchTypes "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/attachment" ) // SupportedExtensions text file extensions var SupportedExtensions = map[string]bool{ // Markdown ".md": true, ".markdown": true, // Plain text ".txt": true, // Code files ".go": true, ".ts": true, ".tsx": true, ".js": true, ".jsx": true, ".py": true, ".java": true, ".c": true, ".cpp": true, ".h": true, ".hpp": true, ".rs": true, ".rb": true, ".php": true, ".swift": true, ".kt": true, ".scala": true, ".sh": true, ".bash": true, ".zsh": true, ".fish": true, ".ps1": true, ".bat": true, ".cmd": true, ".sql": true, ".r": true, ".lua": true, ".perl": true, ".pl": true, ".groovy": true, ".dart": true, ".elm": true, ".ex": true, ".exs": true, ".erl": true, ".hs": true, ".clj": true, ".lisp": true, ".vim": true, // Config files ".json": true, ".jsonc": true, ".yaml": true, ".yml": true, ".toml": true, ".ini": true, ".conf": true, ".cfg": true, ".env": true, ".yao": true, // Web files ".html": true, ".htm": true, ".css": true, ".scss": true, ".sass": true, ".less": true, ".xml": true, ".svg": true, // Documentation ".rst": true, ".tex": true, ".latex": true, ".org": true, ".adoc": true, // Data files ".csv": true, ".tsv": true, // Log files ".log": true, } // Text handles text file content type Text struct { options *types.Options } // New creates a new text handler func New(options *types.Options) *Text { return &Text{options: options} } // IsSupportedExtension checks if a file extension is supported func IsSupportedExtension(filename string) bool { ext := strings.ToLower(filepath.Ext(filename)) return SupportedExtensions[ext] } // Parse parses text file content and returns text func (h *Text) Parse(ctx *agentContext.Context, content agentContext.ContentPart) (agentContext.ContentPart, []*searchTypes.Reference, error) { if content.File == nil || content.File.URL == "" { return content, nil, fmt.Errorf("file content missing URL") } url := content.File.URL filename := content.File.Filename // Check cache first cachedText, found, err := h.readFromCache(ctx, url) if err == nil && found { return agentContext.ContentPart{ Type: agentContext.ContentText, Text: cachedText, }, nil, nil } // Read text file data, err := h.readFile(ctx, url) if err != nil { return content, nil, fmt.Errorf("failed to read text file: %w", err) } // Convert to string text := string(data) // Add file type context if it's a code file ext := strings.ToLower(filepath.Ext(filename)) if isCodeFile(ext) { // Wrap in markdown code block with language hint lang := getLanguageFromExt(ext) text = fmt.Sprintf("```%s\n%s\n```", lang, text) } // Cache the result if err := h.saveToCache(ctx, url, text); err != nil { // Log warning but don't fail fmt.Printf("Warning: failed to cache text: %v\n", err) } return agentContext.ContentPart{ Type: agentContext.ContentText, Text: text, }, nil, nil } // ParseRaw parses any file as raw text content without code block wrapping // This is used as a fallback for unsupported file types func (h *Text) ParseRaw(ctx *agentContext.Context, content agentContext.ContentPart) (agentContext.ContentPart, []*searchTypes.Reference, error) { if content.File == nil || content.File.URL == "" { return content, nil, fmt.Errorf("file content missing URL") } url := content.File.URL filename := content.File.Filename // Check cache first cachedText, found, err := h.readFromCache(ctx, url) if err == nil && found { return agentContext.ContentPart{ Type: agentContext.ContentText, Text: cachedText, }, nil, nil } // Read file data, err := h.readFile(ctx, url) if err != nil { return content, nil, fmt.Errorf("failed to read file: %w", err) } // Convert to string directly (no code block wrapping) text := string(data) // Add filename as context if filename != "" { text = fmt.Sprintf("File: %s\n\n%s", filename, text) } // Cache the result if err := h.saveToCache(ctx, url, text); err != nil { // Log warning but don't fail fmt.Printf("Warning: failed to cache text: %v\n", err) } return agentContext.ContentPart{ Type: agentContext.ContentText, Text: text, }, nil, nil } // readFile reads text content from various sources func (h *Text) readFile(ctx *agentContext.Context, url string) ([]byte, error) { if strings.HasPrefix(url, "__") { return h.readFromUploader(ctx, url) } if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") { return nil, fmt.Errorf("HTTP URL fetch not implemented yet: %s", url) } // Try to read as local file path if _, err := os.Stat(url); err == nil { return os.ReadFile(url) } return nil, fmt.Errorf("unsupported text file source: %s", url) } // readFromUploader reads text content from file uploader func (h *Text) readFromUploader(ctx *agentContext.Context, wrapper string) ([]byte, error) { uploaderName, fileID, ok := attachment.Parse(wrapper) if !ok { return nil, fmt.Errorf("invalid uploader wrapper format: %s", wrapper) } manager, exists := attachment.Managers[uploaderName] if !exists { return nil, fmt.Errorf("uploader '%s' not found", uploaderName) } data, err := manager.Read(ctx.Context, fileID) if err != nil { return nil, fmt.Errorf("failed to read file: %w", err) } return data, nil } // readFromCache reads cached text content func (h *Text) readFromCache(ctx *agentContext.Context, url string) (string, bool, error) { uploaderName, fileID, isWrapper := attachment.Parse(url) if !isWrapper { return "", false, nil } manager, exists := attachment.Managers[uploaderName] if !exists { return "", false, nil } text, err := manager.GetText(ctx.Context, fileID, false) if err == nil && text != "" { return text, true, nil } return "", false, nil } // saveToCache saves processed text to cache func (h *Text) saveToCache(ctx *agentContext.Context, url string, text string) error { uploaderName, fileID, isWrapper := attachment.Parse(url) if !isWrapper { return nil } manager, exists := attachment.Managers[uploaderName] if !exists { return nil } return manager.SaveText(ctx.Context, fileID, text) } // isCodeFile checks if the extension represents a code file func isCodeFile(ext string) bool { codeExts := map[string]bool{ ".go": true, ".ts": true, ".tsx": true, ".js": true, ".jsx": true, ".py": true, ".java": true, ".c": true, ".cpp": true, ".h": true, ".hpp": true, ".rs": true, ".rb": true, ".php": true, ".swift": true, ".kt": true, ".scala": true, ".sh": true, ".bash": true, ".zsh": true, ".sql": true, ".r": true, ".lua": true, ".perl": true, ".pl": true, ".groovy": true, ".dart": true, ".elm": true, ".ex": true, ".exs": true, ".erl": true, ".hs": true, ".clj": true, ".lisp": true, ".vim": true, } return codeExts[ext] } // getLanguageFromExt returns the language name for markdown code block func getLanguageFromExt(ext string) string { langMap := map[string]string{ ".go": "go", ".ts": "typescript", ".tsx": "tsx", ".js": "javascript", ".jsx": "jsx", ".py": "python", ".java": "java", ".c": "c", ".cpp": "cpp", ".h": "c", ".hpp": "cpp", ".rs": "rust", ".rb": "ruby", ".php": "php", ".swift": "swift", ".kt": "kotlin", ".scala": "scala", ".sh": "bash", ".bash": "bash", ".zsh": "zsh", ".fish": "fish", ".ps1": "powershell", ".bat": "batch", ".cmd": "batch", ".sql": "sql", ".r": "r", ".lua": "lua", ".perl": "perl", ".pl": "perl", ".groovy": "groovy", ".dart": "dart", ".elm": "elm", ".ex": "elixir", ".exs": "elixir", ".erl": "erlang", ".hs": "haskell", ".clj": "clojure", ".lisp": "lisp", ".vim": "vim", ".json": "json", ".jsonc": "jsonc", ".yaml": "yaml", ".yml": "yaml", ".toml": "toml", ".xml": "xml", ".html": "html", ".htm": "html", ".css": "css", ".scss": "scss", ".sass": "sass", ".less": "less", ".svg": "svg", ".yao": "json", } if lang, ok := langMap[ext]; ok { return lang } return "" } ================================================ FILE: agent/content/text/text_test.go ================================================ package text_test import ( stdContext "context" "os" "path/filepath" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/connector/openai" "github.com/yaoapp/yao/agent/content/text" contentTypes "github.com/yaoapp/yao/agent/content/types" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/testutils" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" ) const testFilesDir = "assistants/tests/vision-helper/tests" func newTestContext() *agentContext.Context { authorized := &oauthTypes.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client-id", UserID: "test-user-123", } ctx := agentContext.New(stdContext.Background(), authorized, "test-chat") ctx.AssistantID = "test-assistant" ctx.Locale = "en-us" ctx.IDGenerator = message.NewIDGenerator() return ctx } func newTestOptions() *contentTypes.Options { return &contentTypes.Options{ Capabilities: &openai.Capabilities{}, } } func getTestFilePath(filename string) string { yaoRoot := os.Getenv("YAO_TEST_APPLICATION") if yaoRoot == "" { yaoRoot = os.Getenv("YAO_ROOT") } return filepath.Join(yaoRoot, testFilesDir, filename) } // TestIsSupportedExtension tests the IsSupportedExtension function func TestIsSupportedExtension(t *testing.T) { // Supported extensions assert.True(t, text.IsSupportedExtension("test.md")) assert.True(t, text.IsSupportedExtension("test.txt")) assert.True(t, text.IsSupportedExtension("test.go")) assert.True(t, text.IsSupportedExtension("test.ts")) assert.True(t, text.IsSupportedExtension("test.json")) assert.True(t, text.IsSupportedExtension("test.jsonc")) assert.True(t, text.IsSupportedExtension("test.yao")) assert.True(t, text.IsSupportedExtension("test.yaml")) assert.True(t, text.IsSupportedExtension("test.yml")) assert.True(t, text.IsSupportedExtension("test.py")) assert.True(t, text.IsSupportedExtension("test.js")) assert.True(t, text.IsSupportedExtension("test.css")) assert.True(t, text.IsSupportedExtension("test.html")) // Unsupported extensions assert.False(t, text.IsSupportedExtension("test.docx")) assert.False(t, text.IsSupportedExtension("test.pptx")) assert.False(t, text.IsSupportedExtension("test.pdf")) assert.False(t, text.IsSupportedExtension("test.png")) assert.False(t, text.IsSupportedExtension("test.jpg")) assert.False(t, text.IsSupportedExtension("test.exe")) assert.False(t, text.IsSupportedExtension("test.zip")) } // TestParseWithMissingURL tests parsing text with missing URL func TestParseWithMissingURL(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) options := newTestOptions() ctx := newTestContext() content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: nil, } handler := text.New(options) _, _, err := handler.Parse(ctx, content) assert.Error(t, err) assert.Contains(t, err.Error(), "missing URL") } // TestParseWithLocalTextFile tests parsing a local text file func TestParseWithLocalTextFile(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) txtPath := getTestFilePath("text.txt") if _, err := os.Stat(txtPath); os.IsNotExist(err) { t.Skipf("Test text file not found: %s", txtPath) } options := newTestOptions() ctx := newTestContext() content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: txtPath, Filename: "text.txt", }, } handler := text.New(options) result, refs, err := handler.Parse(ctx, content) assert.NoError(t, err) assert.Nil(t, refs) assert.Equal(t, agentContext.ContentText, result.Type) assert.NotEmpty(t, result.Text) t.Logf("Text parse result: %s", result.Text) } // TestParseWithLocalMarkdownFile tests parsing a local markdown file func TestParseWithLocalMarkdownFile(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) mdPath := getTestFilePath("test.md") if _, err := os.Stat(mdPath); os.IsNotExist(err) { t.Skipf("Test markdown file not found: %s", mdPath) } options := newTestOptions() ctx := newTestContext() content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: mdPath, Filename: "test.md", }, } handler := text.New(options) result, refs, err := handler.Parse(ctx, content) assert.NoError(t, err) assert.Nil(t, refs) assert.Equal(t, agentContext.ContentText, result.Type) assert.NotEmpty(t, result.Text) t.Logf("Markdown parse result: %s", result.Text) } // TestParseWithLocalCodeFile tests parsing a local code file (TypeScript) func TestParseWithLocalCodeFile(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) tsPath := getTestFilePath("code.ts") if _, err := os.Stat(tsPath); os.IsNotExist(err) { t.Skipf("Test TypeScript file not found: %s", tsPath) } options := newTestOptions() ctx := newTestContext() content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: tsPath, Filename: "code.ts", }, } handler := text.New(options) result, refs, err := handler.Parse(ctx, content) assert.NoError(t, err) assert.Nil(t, refs) assert.Equal(t, agentContext.ContentText, result.Type) assert.NotEmpty(t, result.Text) // Code files should be wrapped in markdown code blocks assert.True(t, strings.HasPrefix(result.Text, "```typescript")) assert.True(t, strings.HasSuffix(strings.TrimSpace(result.Text), "```")) t.Logf("Code parse result (first 500 chars): %.500s...", result.Text) } // TestParseWithLocalYaoFile tests parsing a local .yao file func TestParseWithLocalYaoFile(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) yaoPath := getTestFilePath("hero.mod.yao") if _, err := os.Stat(yaoPath); os.IsNotExist(err) { t.Skipf("Test .yao file not found: %s", yaoPath) } options := newTestOptions() ctx := newTestContext() content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: yaoPath, Filename: "hero.mod.yao", }, } handler := text.New(options) result, refs, err := handler.Parse(ctx, content) assert.NoError(t, err) assert.Nil(t, refs) assert.Equal(t, agentContext.ContentText, result.Type) assert.NotEmpty(t, result.Text) t.Logf("Yao file parse result: %s", result.Text) } // TestParseWithLocalJsonFile tests parsing a local JSON file func TestParseWithLocalJsonFile(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) jsonPath := getTestFilePath("test.json") if _, err := os.Stat(jsonPath); os.IsNotExist(err) { t.Skipf("Test JSON file not found: %s", jsonPath) } options := newTestOptions() ctx := newTestContext() content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: jsonPath, Filename: "test.json", }, } handler := text.New(options) result, refs, err := handler.Parse(ctx, content) assert.NoError(t, err) assert.Nil(t, refs) assert.Equal(t, agentContext.ContentText, result.Type) assert.NotEmpty(t, result.Text) t.Logf("JSON parse result: %s", result.Text) } // TestParseWithNonExistentFile tests parsing text with non-existent file func TestParseWithNonExistentFile(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) options := newTestOptions() ctx := newTestContext() content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: "/non/existent/path/test.txt", Filename: "test.txt", }, } handler := text.New(options) _, _, err := handler.Parse(ctx, content) assert.Error(t, err) assert.Contains(t, err.Error(), "unsupported text file source") } // TestParseRawWithLocalFile tests ParseRaw with a local file func TestParseRawWithLocalFile(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) txtPath := getTestFilePath("text.txt") if _, err := os.Stat(txtPath); os.IsNotExist(err) { t.Skipf("Test text file not found: %s", txtPath) } options := newTestOptions() ctx := newTestContext() content := agentContext.ContentPart{ Type: agentContext.ContentFile, File: &agentContext.FileAttachment{ URL: txtPath, Filename: "text.txt", }, } handler := text.New(options) result, refs, err := handler.ParseRaw(ctx, content) assert.NoError(t, err) assert.Nil(t, refs) assert.Equal(t, agentContext.ContentText, result.Type) assert.NotEmpty(t, result.Text) // ParseRaw should include filename as context assert.True(t, strings.HasPrefix(result.Text, "File: text.txt")) t.Logf("ParseRaw result: %s", result.Text) } ================================================ FILE: agent/content/tools/tools.go ================================================ package tools import ( "context" "fmt" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/mcp" "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/agent/caller" agentContext "github.com/yaoapp/yao/agent/context" ) // CallAgent calls an agent to process content (vision, audio, etc.) func CallAgent(ctx *agentContext.Context, agentID string, message agentContext.Message) (string, error) { if caller.AgentGetterFunc == nil { return "", fmt.Errorf("AgentGetterFunc not initialized") } // Load the agent by ID using the injected function agent, err := caller.AgentGetterFunc(agentID) if err != nil { return "", fmt.Errorf("failed to load agent %s: %w", agentID, err) } // Call the agent with the message messages := []agentContext.Message{message} // For A2A calls, skip history and output (we only need the response data) opts := &agentContext.Options{Skip: &agentContext.Skip{History: true, Output: true}} response, err := agent.Stream(ctx, messages, opts) if err != nil { return "", fmt.Errorf("failed to call agent %s: %w", agentID, err) } // Extract text from agent response return ExtractTextFromResponse(response) } // ExtractTextFromResponse extracts text from agent response func ExtractTextFromResponse(response *agentContext.Response) (string, error) { if response == nil { return "", fmt.Errorf("agent returned nil response") } // Priority 1: Check Next field (custom hook data) if response.Next != nil { if nextStr, ok := response.Next.(string); ok { return nextStr, nil } // Otherwise, JSON stringify to preserve complete structure jsonBytes, err := jsoniter.Marshal(response.Next) if err != nil { return "", fmt.Errorf("failed to serialize next hook data: %w", err) } return string(jsonBytes), nil } // Priority 2: Check Completion field (standard LLM response) if response.Completion != nil { switch v := response.Completion.Content.(type) { case string: return v, nil case []interface{}: // Multimodal content array - extract all text parts var text string for _, part := range v { if partMap, ok := part.(map[string]interface{}); ok { if partType, _ := partMap["type"].(string); partType == "text" { if textContent, ok := partMap["text"].(string); ok { text += textContent } } } } if text != "" { return text, nil } return "", fmt.Errorf("no text content found in completion content parts") } } return "", fmt.Errorf("no content found in agent response") } // CallMCPTool calls an MCP tool to process content func CallMCPTool(ctx *agentContext.Context, serverID string, toolName string, arguments map[string]interface{}) (string, error) { // Get MCP context for cancellation/timeout control mcpCtx := ctx.Context if mcpCtx == nil { mcpCtx = context.Background() } // Get MCP client client, err := mcp.Select(serverID) if err != nil { return "", fmt.Errorf("failed to select MCP client '%s': %w", serverID, err) } // Call the tool log.Trace("[Content] Calling MCP tool: %s (server: %s)", toolName, serverID) callResult, err := client.CallTool(mcpCtx, toolName, arguments) if err != nil { return "", fmt.Errorf("MCP tool call failed: %w", err) } // Check if result is an error if callResult.IsError { return "", fmt.Errorf("MCP tool returned error: %v", callResult.Content) } // Extract text content from result var text string for _, content := range callResult.Content { if content.Type == "text" { text += content.Text } } if text == "" { return "", fmt.Errorf("MCP tool returned no text content") } return text, nil } ================================================ FILE: agent/content/types/types.go ================================================ package types import ( "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/connector/openai" agentContext "github.com/yaoapp/yao/agent/context" ) // Options represents the options for the content type Options struct { // Connector, Current connector instance Connector connector.Connector // Capabilities, Current capabilities instance Capabilities *openai.Capabilities // CompletionOptions, Current completion options instance CompletionOptions *agentContext.CompletionOptions // StreamOptions, Current stream options instance StreamOptions *agentContext.StreamOptions // SilentLoading, if true, suppress loading messages (used when called from parent handler) SilentLoading bool } ================================================ FILE: agent/context/JSAPI.md ================================================ # Context JavaScript API Documentation ## Overview The Context JavaScript API provides a comprehensive interface for interacting with the Yao Agent system from JavaScript/TypeScript hooks (Create, Next). The Context object exposes agent state, configuration, messaging capabilities, trace operations, and MCP (Model Context Protocol) integrations. ## Context Object The Context object is automatically passed to hook functions and provides access to the agent's execution environment. ### Basic Properties ```typescript interface Context { // Identifiers chat_id: string; // Current chat session ID assistant_id: string; // Assistant identifier // Configuration locale: string; // User locale (e.g., "en", "zh-cn") theme: string; // UI theme preference accept: string; // Output format ("standard", "cui-web", "cui-native", etc.) route: string; // Request route path referer: string; // Request referer // Client Information client: { type: string; // Client type user_agent: string; // User agent string ip: string; // Client IP address }; // Dynamic Data metadata: Record; // Custom metadata (empty object if not set) authorized: Record; // Authorization data (empty object if not set) // Objects memory: Memory; // Agent memory with four namespaces: user, team, chat, context trace: Trace; // Trace object for debugging and monitoring mcp: MCP; // MCP object for external tool/resource access agent: Agent; // Agent-to-Agent calls (A2A) llm: LLM; // Direct LLM connector calls sandbox?: Sandbox; // Sandbox operations (only when sandbox configured) } ``` ## Methods ### Send Messages The Context provides several methods for sending messages to the client: | Method | Description | Auto `message_end` | Updatable | | ------------------------------------ | --------------------------- | ------------------ | --------- | | `Send(message, block_id?)` | Send a complete message | ✅ Yes | ❌ No | | `SendStream(message, block_id?)` | Start a streaming message | ❌ No | ✅ Yes | | `Append(message_id, content, path?)` | Append content to a message | - | - | | `Replace(message_id, message)` | Replace message content | - | - | | `Merge(message_id, data, path?)` | Merge data into message | - | - | | `Set(message_id, data, path)` | Set a field in message | - | - | | `End(message_id, final_content?)` | Finalize streaming message | ✅ Yes | - | | `EndBlock(block_id)` | End a message block | - | - | | `MessageID()` | Generate unique message ID | - | - | | `BlockID()` | Generate unique block ID | - | - | | `ThreadID()` | Generate unique thread ID | - | - | > **Note:** `Append`, `Replace`, `Merge`, and `Set` only work with messages started via `SendStream()`. Messages sent via `Send()` are immediately finalized and cannot be updated. #### `ctx.Send(message, block_id?): string` Sends a message to the client and automatically flushes the output. **Parameters:** - `message`: Message object or string - `block_id`: String (optional) - Block ID to send this message in. If omitted, no block ID is assigned. **Returns:** - `string`: The message ID (auto-generated if not provided in the message object) **Message Object Structure:** ```typescript interface Message { // Required type: string; // Message type: "text", "tool", "image", etc. // Common fields props?: Record; // Message properties (passed to frontend component) message_id?: string; // Message ID (auto-generated if omitted) block_id?: string; // Block ID (NOT auto-generated, has priority over block_id parameter) thread_id?: string; // Thread ID (auto-set from Stack for nested agents) // Metadata (optional) metadata?: Record; // Custom metadata } ``` **Examples:** ```javascript // Send text message (object format) and capture message ID const message_id = ctx.Send({ type: "text", props: { content: "Hello, World!" }, }); console.log("Sent message:", message_id); // Send text message (shorthand) - no block ID by default const text_id = ctx.Send("Hello, World!"); // Send multiple messages in the same block (same bubble/card in UI) const block_id = ctx.BlockID(); // Generate block ID first const msg1 = ctx.Send("Step 1: Analyzing...", block_id); const msg2 = ctx.Send("Step 2: Processing...", block_id); const msg3 = ctx.Send("Step 3: Complete!", block_id); // Specify block_id in message object (highest priority) const msg4 = ctx.Send({ type: "text", props: { content: "In specific block" }, block_id: "B2", // This takes priority over second parameter }); // Send tool message with custom IDs const tool_id = ctx.Send({ type: "tool", message_id: "custom-tool-msg-1", block_id: "B_tools", props: { name: "calculator", result: { sum: 42 }, }, }); // Send image message const image_id = ctx.Send({ type: "image", props: { url: "https://example.com/image.png", alt: "Example Image", }, }); ``` **Block Management:** ```javascript // Scenario 1: Simple message (most common) function Next(ctx, payload) { const { completion } = payload; // Send a complete message ctx.Send({ type: "text", props: { content: completion.content }, }); } // Scenario 2: Loading indicator before slow operation function Next(ctx, payload) { // Start a streaming message for loading const loading_id = ctx.SendStream({ type: "loading", props: { message: "Fetching data..." }, }); // Do slow operation (e.g., external API call) const result = fetchExternalData(); // Replace loading with result ctx.Replace(loading_id, { type: "text", props: { content: result }, }); ctx.End(loading_id); } // Scenario 3: Grouping messages in one block (special case) function Create(ctx, messages) { // Generate a block ID for grouping const block_id = ctx.BlockID(); // "B1" ctx.Send("# Analysis Results", block_id); ctx.Send("- Finding 1: ...", block_id); ctx.Send("- Finding 2: ...", block_id); ctx.Send("- Finding 3: ...", block_id); // All messages appear in the same card/bubble in the UI } // Scenario 4: LLM response + follow-up card in same block function Next(ctx, payload) { const { completion } = payload; const block_id = ctx.BlockID(); // LLM response ctx.Send({ type: "text", props: { content: completion.content }, block_id: block_id, }); // Action card (grouped with LLM response) ctx.Send({ type: "card", props: { title: "Related Actions", actions: ["action1", "action2"], }, block_id: block_id, }); } ``` **Notes:** - **Message ID** is automatically generated if not provided - **Block ID** is NOT auto-generated by default (remains empty unless manually specified) - Most messages don't need a Block ID (each message is independent) - Only specify Block ID in special cases (e.g., grouping LLM output with a follow-up card) - **Block ID priority**: message.block_id > block_id parameter > empty - **Thread ID** is automatically set from Stack for non-root calls (nested agents) - Returns the message ID for reference in subsequent operations - Output is automatically flushed after sending - Throws exception on failure - `Send()` automatically sends `message_end` event - the message is complete and cannot be updated - **For updatable messages**, use `ctx.SendStream()` instead (see below) #### `ctx.SendStream(message, block_id?): string` Sends a streaming message that can be appended to later. Unlike `Send()`, this does NOT automatically send `message_end` event. Use `ctx.Append()` to add content, then `ctx.End()` to finalize. **Parameters:** - `message`: Message object or string - `block_id`: String (optional) - Block ID to send this message in **Returns:** - `string`: The message ID (for use with `Append` and `End`) **Examples:** ```javascript // Start a streaming message const msg_id = ctx.SendStream({ type: "text", props: { content: "# Title\n\n" }, }); // Append content in chunks (simulating streaming) ctx.Append(msg_id, "First paragraph. "); ctx.Append(msg_id, "Second sentence. "); ctx.Append(msg_id, "Third sentence.\n\n"); // Finalize the message (sends message_end event) ctx.End(msg_id); ``` **String Shorthand:** ```javascript // SendStream with string shorthand const msg_id = ctx.SendStream("Starting analysis..."); ctx.Append(msg_id, " processing..."); ctx.Append(msg_id, " done!"); ctx.End(msg_id); // Final content: "Starting analysis... processing... done!" ``` **With Block ID:** ```javascript const block_id = ctx.BlockID(); const msg_id = ctx.SendStream("Step 1: ", block_id); ctx.Append(msg_id, "Analyzing data..."); ctx.End(msg_id); ``` **Notes:** - Returns the message ID immediately for use with `Append` and `End` - Sends `message_start` event but NOT `message_end` (unlike `Send`) - Must call `ctx.End(msg_id)` to finalize the message - Content appended via `ctx.Append()` is accumulated for storage - Ideal for streaming text output where you control the timing #### `ctx.End(message_id, final_content?): string` Finalizes a streaming message started with `SendStream()`. Sends `message_end` event with the complete accumulated content. **Parameters:** - `message_id`: String - The message ID returned by `SendStream()` - `final_content`: String (optional) - Final content to append before ending **Returns:** - `string`: The message ID **Examples:** ```javascript // Basic usage const msg_id = ctx.SendStream("Hello"); ctx.Append(msg_id, " World"); ctx.End(msg_id); // Final: "Hello World" // End with final content const msg_id2 = ctx.SendStream("Processing"); ctx.Append(msg_id2, "..."); ctx.End(msg_id2, " Complete!"); // Final: "Processing... Complete!" ``` **Notes:** - Must be called after `SendStream()` to send `message_end` event - Optional `final_content` is appended before sending `message_end` - The complete accumulated content is included in `message_end.extra.content` - Throws exception if `message_id` is not a string **Send vs SendStream Comparison:** | Feature | `Send()` | `SendStream()` | | --------------------- | ----------------- | ------------------- | | `message_start` event | ✅ Auto | ✅ Auto | | `message_end` event | ✅ Auto | ❌ Manual (`End()`) | | Use case | Complete messages | Streaming output | | Content accumulation | N/A | Via `Append()` | | Storage | Immediate | On `End()` | **Streaming Workflow Example:** ```javascript function Create(ctx, messages) { // Start streaming output const msg_id = ctx.SendStream({ type: "text", props: { content: "# Analysis Report\n\n" }, }); // Simulate streaming chunks ctx.Append(msg_id, "## Section 1\n"); ctx.Append(msg_id, "Processing data...\n\n"); // Do some work const result = analyzeData(); ctx.Append(msg_id, "## Section 2\n"); ctx.Append(msg_id, `Found ${result.count} items.\n\n`); // Finalize with conclusion ctx.End(msg_id, "## Conclusion\nAnalysis complete."); return { messages }; } ``` #### `ctx.Replace(message_id, message): string` Replaces the content of a streaming message. **Only works with messages started via `SendStream()`**. **Parameters:** - `message_id`: String - The ID of the streaming message (returned by `SendStream()`) - `message`: Message object or string - The new message content **Returns:** - `string`: The message ID (same as the provided message_id) **Examples:** ```javascript // Start a streaming message const msg_id = ctx.SendStream({ type: "loading", props: { message: "Loading..." }, }); // Replace with new content ctx.Replace(msg_id, { type: "text", props: { content: "Data loaded successfully!" }, }); // Finalize the message ctx.End(msg_id); ``` **Use Cases:** ```javascript // Progress updates with replacement function Next(ctx, payload) { const msg_id = ctx.SendStream("Step 1/3: Starting..."); // ... do work ... ctx.Replace(msg_id, "Step 2/3: Processing..."); // ... do more work ... ctx.Replace(msg_id, "Step 3/3: Finalizing..."); // ... finish ... ctx.Replace(msg_id, "Complete! ✓"); ctx.End(msg_id); } // Loading to result transition function Next(ctx, payload) { const msg_id = ctx.SendStream({ type: "loading", props: { message: "Fetching results..." }, }); const results = fetchData(); ctx.Replace(msg_id, { type: "text", props: { content: `Found ${results.length} results` }, }); ctx.End(msg_id); } ``` **Notes:** - **Only works with `SendStream()` messages** - `Send()` messages cannot be replaced - Replaces the entire message content, not just specific fields - Must call `ctx.End(msg_id)` after all updates to finalize the message - Output is automatically flushed after replacing - Throws exception on failure #### `ctx.Append(message_id, content, path?): string` Appends content to a streaming message. **Only works with messages started via `SendStream()`**. **Parameters:** - `message_id`: String - The ID of the streaming message (returned by `SendStream()`) - `content`: Message object or string - The content to append - `path`: String (optional) - The delta path to append to (e.g., "props.content", "props.data") **Returns:** - `string`: The message ID (same as the provided message_id) **Examples:** ```javascript // Start a streaming message const msg_id = ctx.SendStream("Starting"); // Append more text (default path) ctx.Append(msg_id, "... processing"); ctx.Append(msg_id, "... done!"); // Finalize the message ctx.End(msg_id); // Final content: "Starting... processing... done!" // Append to specific path const data_id = ctx.SendStream({ type: "data", props: { content: "Item 1\n", status: "loading", }, }); ctx.Append(data_id, "Item 2\n", "props.content"); ctx.Append(data_id, "Item 3\n", "props.content"); ctx.End(data_id); // Final: props.content = "Item 1\nItem 2\nItem 3\n" ``` **Use Cases:** ```javascript // Streaming text output (simulating LLM-like output) function Create(ctx, messages) { const msg_id = ctx.SendStream(""); ctx.Append(msg_id, "The"); ctx.Append(msg_id, " quick"); ctx.Append(msg_id, " brown"); ctx.Append(msg_id, " fox"); ctx.End(msg_id); // Final: "The quick brown fox" return { messages }; } // Progress logs function Next(ctx, payload) { const log_id = ctx.SendStream({ type: "log", props: { content: "Starting process\n" }, }); // Step 1 doStep1(); ctx.Append(log_id, "Step 1 complete\n", "props.content"); // Step 2 doStep2(); ctx.Append(log_id, "Step 2 complete\n", "props.content"); // Finish ctx.Append(log_id, "All done!\n", "props.content"); ctx.End(log_id); } ``` **Notes:** - **Only works with `SendStream()` messages** - `Send()` messages cannot be appended to - Uses delta append operation (adds to existing content, doesn't replace) - If `path` is omitted, appends to the default content location (`props.content`) - Must call `ctx.End(msg_id)` after all appends to finalize the message - Output is automatically flushed after appending - Throws exception on failure - block_id and ThreadID are inherited from the original message #### `ctx.Merge(message_id, data, path?): string` Merges data into a streaming message object. **Only works with messages started via `SendStream()`**. **Parameters:** - `message_id`: String - The ID of the streaming message (returned by `SendStream()`) - `data`: Object - The data to merge (should be an object) - `path`: String (optional) - The delta path to merge into (e.g., "props", "props.metadata") **Returns:** - `string`: The message ID (same as the provided message_id) **Examples:** ```javascript // Start a streaming message with object data const msg_id = ctx.SendStream({ type: "status", props: { status: "running", progress: 0, started: true, }, }); // Merge updates into props (adds/updates fields, keeps others unchanged) ctx.Merge(msg_id, { progress: 50 }, "props"); // Result: props = { status: "running", progress: 50, started: true } ctx.Merge(msg_id, { progress: 100, status: "completed" }, "props"); // Result: props = { status: "completed", progress: 100, started: true } // Finalize the message ctx.End(msg_id); ``` **Use Cases:** ```javascript // Updating task progress function Next(ctx, payload) { const task_id = ctx.SendStream({ type: "task", props: { name: "Data Processing", status: "pending", progress: 0, }, }); ctx.Merge(task_id, { status: "running" }, "props"); doStep1(); ctx.Merge(task_id, { progress: 25 }, "props"); doStep2(); ctx.Merge(task_id, { progress: 50 }, "props"); doStep3(); ctx.Merge(task_id, { progress: 100, status: "completed" }, "props"); ctx.End(task_id); } // Building metadata incrementally function Create(ctx, messages) { const data_id = ctx.SendStream({ type: "data", props: { content: "Result data" }, }); ctx.Merge(data_id, { metadata: { source: "api" } }, "props"); ctx.Merge(data_id, { metadata: { timestamp: Date.now() } }, "props"); // metadata fields are merged together ctx.End(data_id); return { messages }; } ``` **Notes:** - **Only works with `SendStream()` messages** - `Send()` messages cannot be merged into - Uses delta merge operation (merges objects, doesn't replace) - Only works with object data (for merging key-value pairs) - Existing fields not in the merge data remain unchanged - If `path` is omitted, merges into the default object location - Must call `ctx.End(msg_id)` after all merges to finalize the message - Output is automatically flushed after merging - Throws exception on failure - block_id and ThreadID are inherited from the original message #### `ctx.Set(message_id, data, path): string` Sets a new field or value in a streaming message. **Only works with messages started via `SendStream()`**. **Parameters:** - `message_id`: String - The ID of the streaming message (returned by `SendStream()`) - `data`: Any - The value to set - `path`: String (required) - The delta path where to set the value (e.g., "props.newField", "props.metadata.key") **Returns:** - `string`: The message ID (same as the provided message_id) **Examples:** ```javascript // Start a streaming message const msg_id = ctx.SendStream({ type: "result", props: { content: "Initial content", }, }); // Set a new field ctx.Set(msg_id, "success", "props.status"); // Result: props.status = "success" // Set a nested object ctx.Set(msg_id, { duration: 1500, cached: true }, "props.metadata"); // Result: props.metadata = { duration: 1500, cached: true } // Finalize the message ctx.End(msg_id); ``` **Use Cases:** ```javascript // Adding computed metadata after initial send function Next(ctx, payload) { const result_id = ctx.SendStream({ type: "search_result", props: { results: search_results }, }); ctx.Set(result_id, search_results.length, "props.count"); ctx.Set(result_id, Date.now(), "props.timestamp"); ctx.Set(result_id, "relevance", "props.sort_by"); ctx.End(result_id); } // Conditionally adding fields function Create(ctx, messages) { const msg_id = ctx.SendStream({ type: "operation", props: { name: "Process Data" }, }); try { const result = processData(); ctx.Set(msg_id, "success", "props.status"); ctx.Set(msg_id, result, "props.data"); } catch (e) { ctx.Set(msg_id, e.message, "props.error"); ctx.Set(msg_id, "error", "props.status"); } ctx.End(msg_id); return { messages }; } ``` **Notes:** - **Only works with `SendStream()` messages** - `Send()` messages cannot be modified - Uses delta set operation (creates/sets new fields) - The `path` parameter is **required** (must specify where to set the value) - Creates the path if it doesn't exist - Use for adding new fields or completely replacing a field's value - For updating existing object fields, consider using `Merge` instead - Must call `ctx.End(msg_id)` after all sets to finalize the message - Output is automatically flushed after setting - Throws exception on failure - block_id and ThreadID are inherited from the original message ### ID Generators These methods generate unique IDs for manual message management. Useful when you need to specify IDs before sending messages or for advanced Block/Thread management. #### `ctx.MessageID(): string` Generates a unique message ID. **Returns:** - `string`: Message ID in format "M1", "M2", "M3"... **Example:** ```javascript // Generate IDs manually const id_1 = ctx.MessageID(); // "M1" const id_2 = ctx.MessageID(); // "M2" // Use custom ID ctx.Send({ type: "text", message_id: id_1, props: { content: "Hello" }, }); ``` #### `ctx.BlockID(): string` Generates a unique block ID for grouping messages. **Returns:** - `string`: Block ID in format "B1", "B2", "B3"... **Example:** ```javascript // Generate block ID for grouping messages const block_id = ctx.BlockID(); // "B1" // Send multiple messages in the same block ctx.Send("Step 1: Analyzing...", block_id); ctx.Send("Step 2: Processing...", block_id); ctx.Send("Step 3: Complete!", block_id); // All three messages appear in the same card/bubble in UI ``` **Use Cases:** ```javascript // Scenario: LLM output + follow-up card in same block const block_id = ctx.BlockID(); // LLM response const llm_result = Process("llms.chat", {...}); ctx.Send({ type: "text", props: { content: llm_result.content }, block_id: block_id, }); // Follow-up action card (grouped with LLM output) ctx.Send({ type: "card", props: { title: "Related Actions", actions: [...] }, block_id: block_id, }); ``` #### `ctx.ThreadID(): string` Generates a unique thread ID for concurrent operations. **Returns:** - `string`: Thread ID in format "T1", "T2", "T3"... **Example:** ```javascript // For advanced parallel processing scenarios const thread_id = ctx.ThreadID(); // "T1" // Send messages in a specific thread ctx.Send({ type: "text", props: { content: "Parallel task 1" }, thread_id: thread_id, }); ``` **Notes:** - IDs are generated sequentially within each context - Each context has its own ID counter (starts from 1) - IDs are guaranteed to be unique within the same request/stream - ThreadID is usually auto-managed by Stack, manual generation is for advanced use cases ### Lifecycle Management #### `ctx.EndBlock(block_id): void` Manually sends a `block_end` event for the specified block. Use this to explicitly mark the end of a block. **Parameters:** - `block_id`: String - The block ID to end **Returns:** - `void` **Example:** ```javascript // Create a block for grouped messages const block_id = ctx.BlockID(); // "B1" // Send messages in the block ctx.Send("Analyzing data...", block_id); ctx.Send("Processing results...", block_id); ctx.Send("Complete!", block_id); // Manually end the block ctx.EndBlock(block_id); ``` **Block Lifecycle Events:** When you send messages with a `block_id`: 1. **First message**: Automatically sends `block_start` event 2. **Subsequent messages**: No additional block events 3. **Manual end**: Call `ctx.EndBlock(block_id)` to send `block_end` event **block_end Event Format:** ```json { "type": "event", "props": { "event": "block_end", "message": "Block ended", "data": { "block_id": "B1", "timestamp": 1764483531624, "duration_ms": 1523, "message_count": 3, "status": "completed" } } } ``` **Notes:** - `block_start` is sent automatically when the first message with a new `block_id` is sent - `block_end` must be called manually via `ctx.EndBlock()` - You can track multiple blocks simultaneously (each has independent lifecycle) - Automatically flushes output after sending the event **Use Cases:** ```javascript // Use case 1: Progress reporting in a block function Create(ctx, messages) { const block_id = ctx.BlockID(); ctx.Send("Step 1: Analyzing data...", block_id); // ... analysis logic ... ctx.Send("Step 2: Processing results...", block_id); // ... processing logic ... ctx.Send("Step 3: Complete!", block_id); // Mark the block as complete ctx.EndBlock(block_id); return { messages }; } // Use case 2: Multiple parallel blocks function Create(ctx, messages) { const llm_block = ctx.BlockID(); // "B1" const mcp_block = ctx.BlockID(); // "B2" // LLM output block ctx.Send("Thinking...", llm_block); const response = callLLM(); ctx.Send(response, llm_block); ctx.EndBlock(llm_block); // MCP tool call block ctx.Send("Fetching data...", mcp_block); const data = ctx.mcp.CallTool("tool", "method", {}); ctx.Send(`Found ${data.length} results`, mcp_block); ctx.EndBlock(mcp_block); return { messages }; } ``` ### Resource Cleanup #### `ctx.Release()` Manually releases Context resources. > **Note:** In Hook functions (`Create`, `Next`), you do **NOT** need to call `Release()` - the system handles cleanup automatically. Only call `Release()` when you create a new Context manually (e.g., via `new Context()`). **Example (only for manually created Context):** ```javascript // Only needed when creating Context manually, NOT in hooks const ctx = new Context(options); try { ctx.Send("Processing..."); } finally { ctx.Release(); // Required for manually created Context } ``` ## Trace API The `ctx.trace` object provides tracing capabilities for: 1. **User Transparency** - Expose the agent's working and thinking process to users. The frontend will render these trace nodes to show users what the agent is doing. 2. **Developer Debugging** - Help developers debug agent execution by recording detailed steps and data. > **Note:** Trace is primarily designed for developers to expose the agent's process to users. The frontend has corresponding UI components to render these trace nodes. ### Properties - `ctx.trace.id`: String - The unique identifier of the trace ### Methods Summary | Method | Description | | ------------------------- | ------------------------------- | | `Add(input, option)` | Create a sequential trace node | | `Parallel(inputs)` | Create parallel trace nodes | | `Info(message)` | Add info log to current node | | `Debug(message)` | Add debug log to current node | | `Warn(message)` | Add warning log to current node | | `Error(message)` | Add error log to current node | | `SetOutput(output)` | Set output for current node | | `SetMetadata(key, value)` | Set metadata for current node | | `Complete(output?)` | Mark current node as completed | | `Fail(error)` | Mark current node as failed | | `MarkComplete()` | Mark entire trace as complete | | `IsComplete()` | Check if trace is complete | | `CreateSpace(option)` | Create a visual space container | | `GetSpace(id)` | Get a trace space by ID | | `Release()` | Release trace resources | ### Node Operations #### `ctx.trace.Add(input, options)` Creates a new trace node (sequential step). **Parameters:** - `input`: Input data for the node - `options`: Node configuration object **Options Structure:** ```typescript interface TraceNodeOption { label: string; // Display label in UI type?: string; // Node type identifier icon?: string; // Icon identifier description?: string; // Node description metadata?: Record; // Additional metadata autoCompleteParent?: boolean; // Auto-complete parent node(s) when this node is created (default: true) } ``` **Example:** ```javascript const search_node = ctx.trace.Add( { query: "What is AI?" }, { label: "Search Query", type: "search", icon: "search", description: "Searching for AI information", } ); ``` #### `ctx.trace.Parallel(inputs)` Creates multiple parallel trace nodes for concurrent operations. **Parameters:** - `inputs`: Array of parallel input objects **Input Structure:** ```typescript interface ParallelInput { input: any; // Input data option: TraceNodeOption; // Node configuration } ``` **Example:** ```javascript const parallel_nodes = ctx.trace.Parallel([ { input: { url: "https://api1.com" }, option: { label: "API Call 1", type: "api", icon: "cloud", description: "Fetching from API 1", }, }, { input: { url: "https://api2.com" }, option: { label: "API Call 2", type: "api", icon: "cloud", description: "Fetching from API 2", }, }, ]); ``` ### Logging Methods Add log entries to the current trace node. Each method takes a single string message and returns the trace object for chaining. ```javascript // Information logs ctx.trace.Info("Processing started"); // Debug logs ctx.trace.Debug("Variable value: 42"); // Warning logs ctx.trace.Warn("Deprecated feature used"); // Error logs ctx.trace.Error("Operation failed: timeout"); ``` ### Trace-Level Operations These methods operate on the current trace node (managed by the trace manager). #### `ctx.trace.SetOutput(output)` Sets the output data for the current trace node. ```javascript ctx.trace.SetOutput({ result: "success", data: [...] }); ``` #### `ctx.trace.SetMetadata(key, value)` Sets metadata for the current trace node. ```javascript ctx.trace.SetMetadata("duration", 1500); ctx.trace.SetMetadata("source", "cache"); ``` #### `ctx.trace.Complete(output?)` Marks the current trace node as completed (optionally with output). ```javascript ctx.trace.Complete({ status: "done" }); ``` #### `ctx.trace.Fail(error)` Marks the current trace node as failed with an error message. ```javascript ctx.trace.Fail("Connection timeout"); ``` ### Node Object The `ctx.trace.Add()` and `ctx.trace.Parallel()` methods return Node objects. Each node has the following properties and methods: #### Properties - `id`: String - The unique identifier of the node #### `node.Add(input, option)` Creates a child node under this node. ```javascript const parent_node = ctx.trace.Add({ step: "process" }, { label: "Process" }); const child_node = parent_node.Add( { action: "validate" }, { label: "Validate Input", type: "validation" } ); ``` #### `node.Parallel(inputs)` Creates multiple parallel child nodes under this node. ```javascript const parent_node = ctx.trace.Add({ step: "fetch" }, { label: "Fetch Data" }); const child_nodes = parent_node.Parallel([ { input: { source: "db" }, option: { label: "Database Query" } }, { input: { source: "api" }, option: { label: "API Call" } }, ]); ``` #### `node.Info(message)`, `node.Debug(message)`, `node.Warn(message)`, `node.Error(message)` Add log entries to the node. All methods return the node for chaining. ```javascript const search_node = ctx.trace.Add({ query: "search" }, { label: "Search" }); search_node .Info("Starting search") .Debug("Query parameters validated") .Warn("Cache miss, fetching from source"); ``` #### `node.SetOutput(output)` Sets the output data for a node. Returns the node for chaining. ```javascript const search_node = ctx.trace.Add({ query: "search" }, { label: "Search" }); search_node.SetOutput({ results: [...], count: 10 }); ``` #### `node.SetMetadata(key, value)` Sets metadata for a node. Returns the node for chaining. ```javascript search_node.SetMetadata("duration", 1500).SetMetadata("cache_hit", true); ``` #### `node.Complete(output?)` Marks a node as completed (optionally with output). Returns the node for chaining. ```javascript search_node.Complete({ status: "success", data: [...] }); ``` #### `node.Fail(error)` Marks a node as failed with an error message. Returns the node for chaining. ```javascript try { // Operation } catch (error) { search_node.Fail(error.message); } ``` ### Trace Lifecycle #### `ctx.trace.IsComplete()` Checks if the trace is complete. ```javascript if (ctx.trace.IsComplete()) { console.log("Trace completed"); } ``` #### `ctx.trace.MarkComplete()` Marks the entire trace as complete. ```javascript ctx.trace.MarkComplete(); ``` #### `ctx.trace.Release()` Releases trace resources. > **Note:** In Hook functions, you do **NOT** need to call `Release()` - the system handles cleanup automatically. Only call this when you create a Trace manually (e.g., via `new Trace()`). ### Trace Space Operations Trace spaces are visual containers for organizing trace nodes in the frontend UI. They help group related operations together for better presentation to users. > **Note:** Trace spaces are purely for visual organization and presentation. They do not store data - use `ctx.memory` for data storage between hooks. #### `ctx.trace.CreateSpace(option)` Creates a visual space container for grouping trace nodes. **Option Structure:** ```typescript interface TraceSpaceOption { label: string; // Display label in UI type?: string; // Space type identifier icon?: string; // Icon identifier description?: string; // Space description ttl?: number; // Time to live in seconds (for display only) metadata?: Record; // Additional metadata } ``` **Example:** ```javascript const visual_space = ctx.trace.CreateSpace({ label: "Search Results", type: "search", icon: "search", description: "Knowledge base search operations", }); ``` #### `ctx.trace.GetSpace(id)` Retrieves a trace space by ID. ```javascript const search_space = ctx.trace.GetSpace("search-space-id"); ``` ## Memory API The `ctx.memory` object provides a four-level hierarchical memory system for agent state management. Each level has different persistence and scope characteristics. ### Memory Namespaces | Namespace | Scope | Persistence | Use Case | | -------------------- | ------------------- | ----------- | ------------------------------------------- | | `ctx.memory.user` | Per user | Persistent | User preferences, settings, long-term state | | `ctx.memory.team` | Per team | Persistent | Team-wide settings, shared configurations | | `ctx.memory.chat` | Per chat session | Persistent | Chat-specific context, conversation state | | `ctx.memory.context` | Per request context | Temporary | Request-scoped data, cleared on release | ### Namespace Interface Each namespace (`user`, `team`, `chat`, `context`) provides the same interface: ```typescript interface MemoryNamespace { // Basic KV operations Get(key: string): any; // Get a value Set(key: string, value: any, ttl?: number): void; // Set a value with optional TTL (seconds) Del(key: string): void; // Delete a key (supports wildcards: "prefix:*") Has(key: string): boolean; // Check if key exists GetDel(key: string): any; // Get and delete atomically // Collection operations Keys(): string[]; // Get all keys Len(): number; // Get number of keys Clear(): void; // Delete all keys // Atomic counter operations Incr(key: string, delta?: number): number; // Increment (default delta=1) Decr(key: string, delta?: number): number; // Decrement (default delta=1) // List operations Push(key: string, values: any[]): number; // Append to list, returns new length Pop(key: string): any; // Remove and return last element Pull(key: string, count: number): any[]; // Remove and return last N elements PullAll(key: string): any[]; // Remove and return all elements AddToSet(key: string, values: any[]): number; // Add unique values to set // Array access operations ArrayLen(key: string): number; // Get array length ArrayGet(key: string, index: number): any; // Get element at index ArraySet(key: string, index: number, value: any): void; // Set element at index ArraySlice(key: string, start: number, end: number): any[]; // Get slice ArrayPage(key: string, page: number, size: number): any[]; // Paginated access ArrayAll(key: string): any[]; // Get all elements // Metadata id: string; // Namespace ID space: string; // Space type: "user", "team", "chat", or "context" } ``` ### Basic KV Operations #### `Get(key): any` Gets a value from the namespace. ```javascript // User preferences const theme = ctx.memory.user.Get("theme"); if (theme) { console.log("User prefers:", theme); } // Chat context const topic = ctx.memory.chat.Get("current_topic"); ``` #### `Set(key, value, ttl?): void` Sets a value with optional TTL (time-to-live in seconds). ```javascript // Persistent user setting ctx.memory.user.Set("language", "en"); // Team configuration ctx.memory.team.Set("api_key", "sk-xxx"); // Chat state ctx.memory.chat.Set("last_query", "What is AI?"); // Temporary context data with 5 minute TTL ctx.memory.context.Set("temp_result", { data: "..." }, 300); ``` #### `Del(key): void` Deletes a key. Supports wildcard patterns with `*`. ```javascript // Delete single key ctx.memory.user.Del("old_setting"); // Delete with wildcard pattern ctx.memory.chat.Del("cache:*"); // Deletes all keys starting with "cache:" ``` #### `Has(key): boolean` Checks if a key exists. ```javascript if (ctx.memory.user.Has("onboarding_complete")) { // Skip onboarding } ``` #### `GetDel(key): any` Atomically gets and deletes a value. Useful for one-time tokens. ```javascript const token = ctx.memory.context.GetDel("one_time_token"); if (token) { // Use token (it's now deleted) } ``` ### Collection Operations #### `Keys(): string[]` Returns all keys in the namespace. ```javascript const userKeys = ctx.memory.user.Keys(); console.log("User has", userKeys.length, "stored values"); ``` #### `Len(): number` Returns the number of keys. ```javascript const count = ctx.memory.chat.Len(); console.log("Chat has", count, "stored values"); ``` #### `Clear(): void` Deletes all keys in the namespace. ```javascript // Clear temporary context data ctx.memory.context.Clear(); ``` ### Atomic Counter Operations #### `Incr(key, delta?): number` Atomically increments a counter. Returns the new value. ```javascript // Simple counter const views = ctx.memory.user.Incr("page_views"); console.log("Total views:", views); // Increment by custom amount const points = ctx.memory.user.Incr("points", 10); ``` #### `Decr(key, delta?): number` Atomically decrements a counter. Returns the new value. ```javascript const remaining = ctx.memory.user.Decr("credits"); if (remaining < 0) { throw new Error("Insufficient credits"); } ``` ### List Operations #### `Push(key, values): number` Appends values to a list. Returns new length. ```javascript const len = ctx.memory.chat.Push("history", [ { role: "user", content: "Hello" }, { role: "assistant", content: "Hi there!" }, ]); ``` #### `Pop(key): any` Removes and returns the last element. ```javascript const lastItem = ctx.memory.chat.Pop("pending_tasks"); ``` #### `Pull(key, count): any[]` Removes and returns the last N elements. ```javascript const recentItems = ctx.memory.chat.Pull("notifications", 5); ``` #### `PullAll(key): any[]` Removes and returns all elements. ```javascript const allTasks = ctx.memory.context.PullAll("batch_queue"); // Process all tasks, queue is now empty ``` #### `AddToSet(key, values): number` Adds unique values to a set (no duplicates). Returns new size. ```javascript ctx.memory.user.AddToSet("visited_pages", ["/home", "/about"]); ctx.memory.user.AddToSet("visited_pages", ["/home", "/contact"]); // "/home" not added again ``` ### Array Access Operations #### `ArrayLen(key): number` Gets the length of an array. ```javascript const historyLen = ctx.memory.chat.ArrayLen("messages"); ``` #### `ArrayGet(key, index): any` Gets an element at a specific index. ```javascript const firstMessage = ctx.memory.chat.ArrayGet("messages", 0); const lastMessage = ctx.memory.chat.ArrayGet("messages", -1); // Negative index ``` #### `ArraySet(key, index, value): void` Sets an element at a specific index. ```javascript ctx.memory.chat.ArraySet("messages", 0, { role: "system", content: "Updated" }); ``` #### `ArraySlice(key, start, end): any[]` Gets a slice of the array. ```javascript const recent = ctx.memory.chat.ArraySlice("messages", -10, -1); // Last 10 messages ``` #### `ArrayPage(key, page, size): any[]` Gets a page of elements (1-indexed pages). ```javascript const page1 = ctx.memory.chat.ArrayPage("messages", 1, 20); // First 20 messages const page2 = ctx.memory.chat.ArrayPage("messages", 2, 20); // Next 20 messages ``` #### `ArrayAll(key): any[]` Gets all elements of the array. ```javascript const allMessages = ctx.memory.chat.ArrayAll("messages"); ``` ### Use Cases ```javascript // Use case 1: User preferences (persistent across sessions) function Create(ctx, messages) { // Load user preferences const locale = ctx.memory.user.Get("preferred_locale") || "en"; const style = ctx.memory.user.Get("response_style") || "concise"; return { messages, locale: locale, metadata: { style: style }, }; } // Use case 2: Chat context (persistent within chat session) function Next(ctx, payload) { // Track conversation topics const topics = ctx.memory.chat.Get("discussed_topics") || []; const newTopic = extractTopic(payload.completion.content); if (newTopic && !topics.includes(newTopic)) { topics.push(newTopic); ctx.memory.chat.Set("discussed_topics", topics); } } // Use case 3: Request-scoped data (cleared on context release) function Create(ctx, messages) { // Store temporary processing data ctx.memory.context.Set("request_start", Date.now()); ctx.memory.context.Set("original_query", messages[0]?.content); return { messages }; } function Next(ctx, payload) { // Retrieve temporary data const startTime = ctx.memory.context.Get("request_start"); const duration = Date.now() - startTime; console.log("Request took", duration, "ms"); // context memory is automatically cleared when ctx.Release() is called } // Use case 4: Team-wide settings function Create(ctx, messages) { // Check team quota const used = ctx.memory.team.Incr("monthly_requests"); const limit = ctx.memory.team.Get("monthly_limit") || 10000; if (used > limit) { throw new Error("Team quota exceeded"); } return { messages }; } // Use case 5: Rate limiting with counters function Create(ctx, messages) { const key = `rate:${new Date().toISOString().slice(0, 13)}`; // Hourly bucket const count = ctx.memory.user.Incr(key); if (count > 100) { throw new Error("Rate limit exceeded"); } return { messages }; } ``` ### Memory Lifecycle | Namespace | Created When | Cleared When | | --------- | ---------------- | --------------- | | `user` | First access | Manual only | | `team` | First access | Manual only | | `chat` | First access | Manual only | | `context` | Context creation | `ctx.Release()` | **Notes:** - `user`, `team`, `chat` namespaces are persistent (backed by database) - `context` namespace is temporary and cleared when the request context is released - All namespaces support TTL for automatic expiration - Wildcard deletion (`Del("prefix:*")`) works on all namespaces - Counter operations (`Incr`, `Decr`) are atomic ## MCP API The `ctx.mcp` object provides access to Model Context Protocol operations for interacting with external tools, resources, and prompts. ### Methods Summary | Method | Description | | ------------------------------------ | -------------------------------- | | `ListResources(client, cursor?)` | List available resources | | `ReadResource(client, uri)` | Read a specific resource | | `ListTools(client, cursor?)` | List available tools | | `CallTool(client, name, args?)` | Call a single tool | | `CallTools(client, tools)` | Call multiple tools sequentially | | `CallToolsParallel(client, tools)` | Call multiple tools in parallel | | `All(requests)` | Call tools across servers, wait for all | | `Any(requests)` | Call tools across servers, first success wins | | `Race(requests)` | Call tools across servers, first complete wins | | `ListPrompts(client, cursor?)` | List available prompts | | `GetPrompt(client, name, args?)` | Get a specific prompt | | `ListSamples(client, type, name)` | List samples for a tool/resource | | `GetSample(client, type, name, idx)` | Get a specific sample by index | ### Resource Operations #### `ctx.mcp.ListResources(client, cursor?)` Lists available resources from an MCP client. **Parameters:** - `client`: String - MCP client ID - `cursor`: String (optional) - Pagination cursor ```javascript const resources = ctx.mcp.ListResources("echo", ""); console.log(resources.resources); // Array of resources ``` #### `ctx.mcp.ReadResource(client, uri)` Reads a specific resource. **Parameters:** - `client`: String - MCP client ID - `uri`: String - Resource URI ```javascript const info = ctx.mcp.ReadResource("echo", "echo://info"); console.log(info.contents); // Array of content items ``` ### Tool Operations #### `ctx.mcp.ListTools(client, cursor?)` Lists available tools from an MCP client. **Parameters:** - `client`: String - MCP client ID - `cursor`: String (optional) - Pagination cursor ```javascript const tools = ctx.mcp.ListTools("echo", ""); console.log(tools.tools); // Array of tools ``` #### `ctx.mcp.CallTool(client, name, arguments?)` Calls a single tool and returns the parsed result directly. **Parameters:** - `client`: String - MCP client ID - `name`: String - Tool name - `arguments`: Object (optional) - Tool arguments **Returns:** Parsed result directly (automatically extracts and parses JSON from tool response) ```javascript // Result is returned directly - no wrapper object needed const result = ctx.mcp.CallTool("echo", "echo", { message: "hello" }); console.log(result.echo); // "hello" - directly access parsed data! // Another example const status = ctx.mcp.CallTool("echo", "status", { verbose: true }); console.log(status.status); // "online" console.log(status.uptime); // 3600 ``` #### `ctx.mcp.CallTools(client, tools)` Calls multiple tools sequentially and returns array of parsed results. **Parameters:** - `client`: String - MCP client ID - `tools`: Array - Array of tool call objects **Returns:** Array of parsed results (same order as input tools) ```javascript const results = ctx.mcp.CallTools("echo", [ { name: "ping", arguments: { count: 1 } }, { name: "echo", arguments: { message: "hello" } }, ]); // Results are directly accessible console.log(results[0].message); // "pong" console.log(results[1].echo); // "hello" ``` #### `ctx.mcp.CallToolsParallel(client, tools)` Calls multiple tools in parallel and returns array of parsed results. **Parameters:** - `client`: String - MCP client ID - `tools`: Array - Array of tool call objects **Returns:** Array of parsed results (same order as input tools) ```javascript const results = ctx.mcp.CallToolsParallel("echo", [ { name: "ping", arguments: { count: 1 } }, { name: "echo", arguments: { message: "hello" } }, ]); // Results are directly accessible (order matches input order) console.log(results[0].message); // "pong" (ping result) console.log(results[1].echo); // "hello" (echo result) ``` ### Prompt Operations #### `ctx.mcp.ListPrompts(client, cursor?)` Lists available prompts from an MCP client. **Parameters:** - `client`: String - MCP client ID - `cursor`: String (optional) - Pagination cursor ```javascript const prompts = ctx.mcp.ListPrompts("echo", ""); console.log(prompts.prompts); // Array of prompts ``` #### `ctx.mcp.GetPrompt(client, name, arguments?)` Retrieves a specific prompt with optional arguments. **Parameters:** - `client`: String - MCP client ID - `name`: String - Prompt name - `arguments`: Object (optional) - Prompt arguments ```javascript const prompt = ctx.mcp.GetPrompt("echo", "test_connection", { detailed: "true", }); console.log(prompt.messages); // Array of prompt messages ``` ### Sample Operations #### `ctx.mcp.ListSamples(client, type, name)` Lists available samples for a tool or resource. **Parameters:** - `client`: String - MCP client ID - `type`: String - Sample type ("tool" or "resource") - `name`: String - Tool or resource name ```javascript const samples = ctx.mcp.ListSamples("echo", "tool", "ping"); console.log(samples.samples); // Array of samples ``` #### `ctx.mcp.GetSample(client, type, name, index)` Gets a specific sample by index. **Parameters:** - `client`: String - MCP client ID - `type`: String - Sample type ("tool" or "resource") - `name`: String - Tool or resource name - `index`: Number - Sample index (0-based) ```javascript const sample = ctx.mcp.GetSample("echo", "tool", "ping", 0); console.log(sample.name, sample.input); // Sample name and input data ``` ### Cross-Server Tool Operations These methods enable calling tools across multiple MCP servers concurrently, similar to JavaScript Promise patterns. This is useful for: - **Parallel data fetching**: Query multiple data sources simultaneously - **Redundancy/Fallback**: Try multiple servers, use first successful result - **Load balancing**: Distribute load across servers #### `ctx.mcp.All(requests)` Calls tools on multiple MCP servers concurrently and waits for all to complete (like `Promise.all`). **Parameters:** - `requests`: Array of request objects with `mcp`, `tool`, and optional `arguments` **Returns:** Array of `MCPToolResult` objects in the same order as requests ```javascript const results = ctx.mcp.All([ { mcp: "server1", tool: "search", arguments: { query: "topic" } }, { mcp: "server2", tool: "fetch", arguments: { id: 123 } }, { mcp: "server3", tool: "analyze", arguments: { data: "input" } } ]); // Process all results results.forEach((r, i) => { if (r.error) { console.log(`Request ${i} failed: ${r.error}`); } else { console.log(`Request ${i} result:`, r.result); } }); ``` #### `ctx.mcp.Any(requests)` Calls tools on multiple MCP servers concurrently and returns when any succeeds (like `Promise.any`). Useful for redundancy/fallback scenarios. **Parameters:** - `requests`: Array of request objects **Returns:** Array of `MCPToolResult` objects (only contains results received before first success) ```javascript // Try multiple search providers, use first successful result const results = ctx.mcp.Any([ { mcp: "search-primary", tool: "search", arguments: { q: "query" } }, { mcp: "search-backup", tool: "search", arguments: { q: "query" } } ]); const success = results.find(r => r && !r.error); if (success) { console.log("Search result:", success.result); } ``` #### `ctx.mcp.Race(requests)` Calls tools on multiple MCP servers concurrently and returns when any completes (like `Promise.race`). Returns immediately with first completion, regardless of success or failure. **Parameters:** - `requests`: Array of request objects **Returns:** Array of `MCPToolResult` objects (only first completed result is populated) ```javascript // Get fastest response const results = ctx.mcp.Race([ { mcp: "region-us", tool: "ping", arguments: {} }, { mcp: "region-eu", tool: "ping", arguments: {} }, { mcp: "region-asia", tool: "ping", arguments: {} } ]); const first = results.find(r => r !== undefined && r !== null); console.log(`Fastest server: ${first.mcp}`); ``` #### MCPToolRequest Structure ```typescript interface MCPToolRequest { mcp: string; // MCP server ID (required) tool: string; // Tool name (required) arguments?: any; // Tool arguments (optional) } ``` #### MCPToolResult Structure ```typescript interface MCPToolResult { mcp: string; // MCP server ID tool: string; // Tool name result?: any; // Parsed result content (directly usable) error?: string; // Error message (on failure) } ``` The `result` field contains the automatically parsed content from the MCP response: - For text content: JSON parsed if valid JSON, otherwise plain string - For image content: `{ type: "image", data: "...", mimeType: "..." }` - For resource content: The resource object directly - If only one content item exists, returns it directly (not as array) **Example using parsed result:** ```javascript // Single server - direct result const result = ctx.mcp.CallTool("echo", "echo", { message: "hello" }); console.log(result.echo); // Directly access parsed data // Cross-server - results array with MCPToolResult objects const results = ctx.mcp.All([ { mcp: "echo", tool: "echo", arguments: { message: "hello" } } ]); console.log(results[0].result.echo); // Access via .result field ``` ## Agent API The `ctx.agent` object provides methods to call other agents from within hooks, enabling agent-to-agent communication (A2A). This allows building complex multi-agent workflows where agents can delegate tasks, consult specialists, or orchestrate parallel operations. ### Methods Summary | Method | Description | | ------------------------------- | ---------------------------------------- | | `Call(agentID, messages, opts)` | Call a single agent | | `All(requests, opts?)` | Call multiple agents, wait for all | | `Any(requests, opts?)` | Call multiple agents, first success wins | | `Race(requests, opts?)` | Call multiple agents, first complete wins| ### Single Agent Call #### `ctx.agent.Call(agentID, messages, options?)` Calls a single agent and streams the response to the current context's output. **Parameters:** - `agentID`: String - The target agent/assistant ID - `messages`: Array - Messages to send to the agent - `options`: Object (optional) - Call options including callback **Options:** ```typescript interface AgentCallOptions { connector?: string; // Override LLM connector mode?: string; // Agent mode ("chat", "task", etc.) metadata?: Record; // Custom metadata passed to hooks skip?: { history?: boolean; // Skip loading chat history trace?: boolean; // Skip trace recording output?: boolean; // Skip output to client keyword?: boolean; // Skip keyword extraction search?: boolean; // Skip search content_parsing?: boolean; // Skip content parsing }; onChunk?: (msg: Message) => number; // Callback for each message chunk } ``` **Example:** ```javascript // Basic call const result = ctx.agent.Call("specialist.agent", [ { role: "user", content: "Analyze this data" } ]); // With callback const result = ctx.agent.Call("specialist.agent", messages, { connector: "gpt-4o", onChunk: (msg) => { console.log("Received:", msg.type, msg.props?.content); return 0; // 0 = continue, non-zero = stop } }); ``` **Returns:** ```typescript interface AgentResult { agent_id: string; // Agent ID that was called response?: Response; // Full agent response content?: string; // Extracted text content error?: string; // Error message if failed } ``` **Message Object (received in onChunk callback):** The `onChunk` callback receives a `Message` object with the following structure: ```typescript interface Message { type: string; // Message type: "text", "thinking", "tool_call", "error", etc. props?: Record; // Message properties (e.g., { content: "Hello" }) // Streaming identifiers chunk_id?: string; // Unique chunk ID (C1, C2, ...) message_id?: string; // Logical message ID (M1, M2, ...) block_id?: string; // Output block ID (B1, B2, ...) thread_id?: string; // Thread ID for concurrent calls (T1, T2, ...) // Delta control delta?: boolean; // Whether this is an incremental update delta_path?: string; // Update path (e.g., "content") delta_action?: string; // Update action: "append", "replace", "merge", "set" } ``` Common message types: - `"text"` - Text content (`props.content` contains the text) - `"thinking"` - Reasoning/thinking content (o1, DeepSeek R1 models) - `"tool_call"` - Tool/function call - `"error"` - Error message (`props.error` contains error details) ### Parallel Agent Calls The parallel methods allow calling multiple agents concurrently, similar to JavaScript Promise patterns. > **Important: SSE Output is Automatically Disabled** > > For all batch calls (`All`, `Any`, `Race`), SSE output is **automatically disabled** (`skip.output = true`). > This prevents multiple agents from writing to the same SSE stream simultaneously, which would cause > client disconnection and message corruption. Use the `onChunk` callback to receive streaming messages > if needed. #### `ctx.agent.All(requests, options?)` Executes all agent calls and waits for all to complete (like `Promise.all`). **Parameters:** - `requests`: Array of request objects - `options`: Object (optional) - Global options including callback **Request Structure:** ```typescript interface AgentRequest { agent: string; // Target agent ID messages: Message[]; // Messages to send options?: AgentCallOptions; // Per-request options (excluding onChunk) } // Note: Per-request onChunk is NOT supported in batch calls. // Use the global onChunk callback in the second argument instead. // Note: skip.output is automatically set to true for all batch calls. ``` **Example:** ```javascript // Call multiple agents in parallel const results = ctx.agent.All([ { agent: "analyzer", messages: [{ role: "user", content: "Analyze X" }] }, { agent: "summarizer", messages: [{ role: "user", content: "Summarize Y" }] } ]); // Results array matches request order results.forEach((r, i) => { if (r.error) { console.log(`Agent ${r.agent_id} failed:`, r.error); } else { console.log(`Agent ${r.agent_id} response:`, r.content); } }); // With global callback for all responses const results = ctx.agent.All([ { agent: "agent-1", messages: [...] }, { agent: "agent-2", messages: [...] } ], { onChunk: (agentId, index, msg) => { console.log(`Agent ${agentId} [${index}]:`, msg.type, msg.props?.content); return 0; } }); ``` #### `ctx.agent.Any(requests, options?)` Returns as soon as any agent call succeeds (like `Promise.any`). Other calls continue in background. **Example:** ```javascript // Try multiple agents, use first successful response const results = ctx.agent.Any([ { agent: "primary.agent", messages: [...] }, { agent: "fallback.agent", messages: [...] } ]); // First successful result is returned const success = results.find(r => !r.error); if (success) { console.log("Got response from:", success.agent_id); } ``` #### `ctx.agent.Race(requests, options?)` Returns as soon as any agent call completes, regardless of success/failure (like `Promise.race`). **Example:** ```javascript // Race multiple agents for fastest response const results = ctx.agent.Race([ { agent: "fast.agent", messages: [...] }, { agent: "slow.agent", messages: [...] } ]); // First completed result (may be error or success) const first = results.find(r => r !== null); console.log("Fastest agent:", first.agent_id); ``` ### Use Cases ```javascript // Use case 1: Specialist consultation function Next(ctx, payload) { const { completion } = payload; if (completion?.content?.includes("complex analysis")) { // Delegate to specialist const result = ctx.agent.Call("specialist.analyzer", [ { role: "user", content: completion.content } ]); return { data: { status: "delegated", specialist_response: result.content } }; } return null; } // Use case 2: Parallel processing function Create(ctx, messages) { const userQuery = messages[messages.length - 1]?.content; // Query multiple knowledge sources in parallel const results = ctx.agent.All([ { agent: "kb.technical", messages: [{ role: "user", content: userQuery }] }, { agent: "kb.business", messages: [{ role: "user", content: userQuery }] }, { agent: "kb.legal", messages: [{ role: "user", content: userQuery }] } ]); // Combine results const combinedKnowledge = results .filter(r => !r.error) .map(r => r.content) .join("\n\n"); // Add to messages return { messages: [ ...messages, { role: "system", content: `Relevant knowledge:\n${combinedKnowledge}` } ] }; } // Use case 3: Fallback strategy function Next(ctx, payload) { if (payload.error) { // Try backup agents const results = ctx.agent.Any([ { agent: "backup.gpt4", messages: payload.messages }, { agent: "backup.claude", messages: payload.messages } ]); const success = results.find(r => !r.error); if (success) { return { data: { recovered: true, content: success.content } }; } } return null; } ``` ## Sandbox API The `ctx.sandbox` object provides access to sandbox operations when the assistant is configured with a sandbox executor (e.g., Claude CLI, Cursor CLI). The sandbox allows hooks to interact with an isolated Docker container environment for file operations and command execution. > **Note:** `ctx.sandbox` is only available when the assistant has `sandbox` configuration in `package.yao`. If no sandbox is configured, `ctx.sandbox` will be `null`. ### Properties - `ctx.sandbox.workdir`: String - The workspace directory path inside the container (e.g., `/workspace`) ### Methods Summary | Method | Description | | ----------------------------- | ---------------------------------------- | | `ReadFile(path)` | Read a file from the container | | `WriteFile(path, content)` | Write content to a file in the container | | `ListDir(path)` | List directory contents | | `Exec(command)` | Execute a command in the container | ### File Operations #### `ctx.sandbox.ReadFile(path): string` Reads a file from the sandbox container. **Parameters:** - `path`: String - File path (relative to workdir or absolute) **Returns:** - `string`: File contents as string **Example:** ```javascript // Read a file from workspace const content = ctx.sandbox.ReadFile("config.json"); console.log(content); // Read with absolute path const readme = ctx.sandbox.ReadFile("/workspace/README.md"); ``` #### `ctx.sandbox.WriteFile(path, content): void` Writes content to a file in the sandbox container. **Parameters:** - `path`: String - File path (relative to workdir or absolute) - `content`: String - Content to write **Example:** ```javascript // Write a configuration file ctx.sandbox.WriteFile("config.json", JSON.stringify({ debug: true })); // Write a script ctx.sandbox.WriteFile("script.sh", "#!/bin/bash\necho 'Hello'"); ``` #### `ctx.sandbox.ListDir(path): FileInfo[]` Lists the contents of a directory in the sandbox container. **Parameters:** - `path`: String - Directory path (relative to workdir or absolute) **Returns:** - `FileInfo[]`: Array of file information objects **FileInfo Structure:** ```typescript interface FileInfo { name: string; // File or directory name size: number; // Size in bytes is_dir: boolean; // True if directory } ``` **Example:** ```javascript // List workspace contents const files = ctx.sandbox.ListDir("."); files.forEach(f => { console.log(`${f.is_dir ? "DIR" : "FILE"} ${f.name} (${f.size} bytes)`); }); // List specific directory const srcFiles = ctx.sandbox.ListDir("src"); ``` ### Command Execution #### `ctx.sandbox.Exec(command): string` Executes a command in the sandbox container and returns the output. **Parameters:** - `command`: String[] - Command and arguments as an array **Returns:** - `string`: Command stdout output **Throws:** - Error if command exits with non-zero code (includes stderr in error message) **Example:** ```javascript // Run a simple command const output = ctx.sandbox.Exec(["echo", "Hello, World!"]); console.log(output); // "Hello, World!\n" // Run git commands const status = ctx.sandbox.Exec(["git", "status"]); console.log(status); // Run npm install try { const result = ctx.sandbox.Exec(["npm", "install"]); console.log("Install complete:", result); } catch (e) { console.error("Install failed:", e.message); } // Run shell script ctx.sandbox.WriteFile("test.sh", "#!/bin/bash\necho 'Running script'\nls -la"); ctx.sandbox.Exec(["chmod", "+x", "test.sh"]); const scriptOutput = ctx.sandbox.Exec(["./test.sh"]); ``` ### Use Cases ```javascript // Use case 1: Prepare workspace before Claude CLI execution function Create(ctx, messages) { if (ctx.sandbox) { // Create project structure ctx.sandbox.WriteFile("package.json", JSON.stringify({ name: "project", version: "1.0.0" }, null, 2)); // Write initial code ctx.sandbox.WriteFile("src/index.ts", "console.log('Hello');"); ctx.trace.Info("Workspace prepared"); } return { messages }; } // Use case 2: Post-process sandbox results function Next(ctx, payload) { if (ctx.sandbox && !payload.error) { // Read generated files try { const files = ctx.sandbox.ListDir("output"); const results = files.map(f => ({ name: f.name, content: ctx.sandbox.ReadFile(`output/${f.name}`) })); return { data: { status: "success", generated_files: results } }; } catch (e) { ctx.trace.Warn("No output directory found"); } } return null; } // Use case 3: Run tests after code generation function Next(ctx, payload) { if (ctx.sandbox && payload.completion) { try { // Run tests const testOutput = ctx.sandbox.Exec(["npm", "test"]); ctx.trace.Info("Tests passed"); return { data: { status: "success", test_output: testOutput } }; } catch (e) { ctx.trace.Error("Tests failed: " + e.message); return { data: { status: "test_failed", error: e.message } }; } } return null; } ``` ### Sandbox Configuration The sandbox is configured in the assistant's `package.yao`: ```jsonc { "name": "Coder Assistant", "connector": "deepseek.v3", "sandbox": { "command": "claude", // claude | cursor (future) "image": "yaoapp/sandbox-claude:latest", // Optional, auto-selected by command "max_memory": "4g", // Memory limit (optional) "max_cpu": 2.0, // CPU limit (optional) "timeout": "10m", // Execution timeout "arguments": { // Command-specific arguments "max_turns": 20, "permission_mode": "acceptEdits" } } } ``` ### Notes - Sandbox operations are **synchronous** - they block until complete - File paths can be relative (to workdir) or absolute - Relative paths are resolved against the `workdir` directory - The sandbox container is created at the start of the request and removed when the request completes - Commands are executed with the sandbox user's permissions - Errors throw JavaScript exceptions - use try/catch for error handling - Large file operations may timeout - use appropriate timeout settings ## LLM API The `ctx.llm` object provides direct access to LLM connectors for streaming completions. This allows calling LLM models directly without going through the full agent pipeline, useful for quick completions, model comparisons, or building custom workflows. ### Methods Summary | Method | Description | | --------------------------------- | -------------------------------------- | | `Stream(connector, messages, opts)` | Stream LLM completion | | `All(requests, opts?)` | Call multiple LLMs, wait for all | | `Any(requests, opts?)` | Call multiple LLMs, first success wins | | `Race(requests, opts?)` | Call multiple LLMs, first complete wins| ### Single LLM Call #### `ctx.llm.Stream(connector, messages, options?)` Calls an LLM connector with streaming output to the current context's writer. **Parameters:** - `connector`: String - The LLM connector ID (e.g., "gpt-4o", "claude-3") - `messages`: Array - Messages to send to the LLM - `options`: Object (optional) - LLM options including callback **Options:** ```typescript interface LlmOptions { temperature?: number; // Sampling temperature (0-2) max_tokens?: number; // Max tokens (legacy, use max_completion_tokens) max_completion_tokens?: number; // Max completion tokens top_p?: number; // Nucleus sampling presence_penalty?: number; // Presence penalty (-2 to 2) frequency_penalty?: number; // Frequency penalty (-2 to 2) stop?: string | string[]; // Stop sequences user?: string; // User identifier for tracking seed?: number; // Random seed for reproducibility tools?: object[]; // Function/tool definitions tool_choice?: string | object; // Tool choice strategy response_format?: { // Response format type: string; // "text" | "json_object" | "json_schema" json_schema?: { name: string; description?: string; schema: object; strict?: boolean; }; }; reasoning_effort?: string; // For reasoning models (e.g., "low", "medium", "high") onChunk?: (msg: Message) => number; // Callback for each chunk } ``` **Example:** ```javascript // Basic streaming call const result = ctx.llm.Stream("gpt-4o", [ { role: "system", content: "You are a helpful assistant." }, { role: "user", content: "Explain quantum computing" } ]); // With options and callback const result = ctx.llm.Stream("gpt-4o", messages, { temperature: 0.7, max_tokens: 2000, onChunk: (msg) => { console.log("Chunk:", msg.type, msg.props?.content); return 0; // 0 = continue, non-zero = stop } }); console.log("Full response:", result.content); ``` **Returns:** ```typescript interface LlmResult { connector: string; // Connector ID used response?: CompletionResponse; // Full completion response content?: string; // Extracted text content error?: string; // Error message if failed } ``` ### Parallel LLM Calls The parallel methods allow calling multiple LLM connectors concurrently, useful for model comparison, ensemble methods, or fallback strategies. #### `ctx.llm.All(requests, options?)` Executes all LLM calls and waits for all to complete (like `Promise.all`). **Request Structure:** ```typescript interface LlmRequest { connector: string; // LLM connector ID messages: Message[]; // Messages to send options?: LlmOptions; // Per-request options (excluding onChunk) } ``` **Example:** ```javascript // Compare responses from multiple models const results = ctx.llm.All([ { connector: "gpt-4o", messages: [...], options: { temperature: 0.7 } }, { connector: "claude-3", messages: [...], options: { temperature: 0.7 } }, { connector: "gemini-pro", messages: [...] } ]); results.forEach((r) => { console.log(`${r.connector}: ${r.content?.substring(0, 100)}...`); }); // With global callback const results = ctx.llm.All([ { connector: "gpt-4o", messages: [...] }, { connector: "claude-3", messages: [...] } ], { onChunk: (connectorId, index, msg) => { console.log(`LLM ${connectorId} [${index}]:`, msg.props?.content); return 0; } }); ``` #### `ctx.llm.Any(requests, options?)` Returns as soon as any LLM call succeeds (like `Promise.any`). **Example:** ```javascript // Use first successful response from any model const results = ctx.llm.Any([ { connector: "gpt-4o", messages: [...] }, { connector: "gpt-4o-mini", messages: [...] } ]); const success = results.find(r => !r.error); if (success) { ctx.Send(success.content); } ``` #### `ctx.llm.Race(requests, options?)` Returns as soon as any LLM call completes (like `Promise.race`). **Example:** ```javascript // Get fastest response const results = ctx.llm.Race([ { connector: "gpt-4o-mini", messages: [...] }, // Usually faster { connector: "gpt-4o", messages: [...] } // Usually slower ]); const first = results.find(r => r !== null); console.log("Fastest model:", first.connector); ``` ### Use Cases ```javascript // Use case 1: Quick classification without full agent pipeline function Create(ctx, messages) { const userMessage = messages[messages.length - 1]?.content; // Quick intent classification const result = ctx.llm.Stream("gpt-4o-mini", [ { role: "system", content: "Classify intent as: question, command, or chat" }, { role: "user", content: userMessage } ], { temperature: 0, max_tokens: 10 }); const intent = result.content?.toLowerCase(); ctx.memory.context.Set("intent", intent); return { messages }; } // Use case 2: Model comparison for quality assurance function Next(ctx, payload) { const { completion } = payload; // Get second opinion from different model const results = ctx.llm.All([ { connector: "gpt-4o", messages: payload.messages }, { connector: "claude-3-opus", messages: payload.messages } ]); // Compare responses const gptResponse = results[0].content; const claudeResponse = results[1].content; return { data: { primary: completion.content, comparisons: { gpt4o: gptResponse, claude: claudeResponse } } }; } // Use case 3: Ensemble with voting function Create(ctx, messages) { // Get multiple model opinions for important decisions const results = ctx.llm.All([ { connector: "gpt-4o", messages: [...] }, { connector: "claude-3", messages: [...] }, { connector: "gemini-pro", messages: [...] } ]); // Simple majority voting (in real use, implement proper consensus) const responses = results.filter(r => !r.error).map(r => r.content); return { messages: [ ...messages, { role: "system", content: `Multiple model opinions:\n${responses.map((r, i) => `Model ${i+1}: ${r}`).join('\n')}` } ] }; } // Use case 4: Fallback with latency optimization function Next(ctx, payload) { if (payload.error) { // Race multiple fallback models const results = ctx.llm.Race([ { connector: "gpt-4o-mini", messages: payload.messages }, { connector: "claude-3-haiku", messages: payload.messages } ]); const fastest = results.find(r => r !== null); if (fastest && !fastest.error) { ctx.Send(fastest.content); return { data: { recovered: true, model: fastest.connector } }; } } return null; } ``` ## Hooks The Agent system supports two hooks that can be defined in the assistant's `index.ts` file: `Create` and `Next`. ### Agent Execution Lifecycle ```mermaid flowchart TD A[User Input] --> B[Load History] B --> C{Create Hook?} C -->|Yes| D[Execute Create Hook] C -->|No| E{Has Prompts/MCP?} D --> E E -->|Yes| F[Build LLM Request] E -->|No| K F --> G[LLM Stream Call] G --> H{Tool Calls?} H -->|Yes| I[Execute Tools] I --> J{Tool Errors?} J -->|Yes, Retry| G J -->|No| K H -->|No| K K{Next Hook?} K -->|Yes| L[Execute Next Hook] K -->|No| M[Return Response] L --> N{Delegate?} N -->|Yes| O[Call Target Agent] O --> M N -->|No| M M --> P[End] style D fill:#e1f5fe style L fill:#e1f5fe style G fill:#fff3e0 style I fill:#f3e5f5 ``` > **Note:** LLM call is optional. If the assistant has no prompts and no MCP servers configured, the LLM call is skipped. Hooks can be used independently to implement custom logic without LLM involvement. ### Create Hook Called at the beginning of agent execution, before any LLM call. Use this to preprocess messages, add context, configure the request, or implement custom logic. **Signature:** ```typescript function Create( ctx: Context, messages: Message[], options?: Record ): HookCreateResponse | null; ``` **Parameters:** - `ctx`: Context object - `messages`: Array of input messages (including chat history if enabled) - `options`: Optional call-level options (see below) **Options Structure:** ```typescript interface Options { skip?: { history?: boolean; // Skip loading/saving chat history trace?: boolean; // Skip trace recording output?: boolean; // Skip output to client (for internal A2A calls that only need response data) }; connector?: string; // Override LLM connector ID disable_global_prompts?: boolean; // Disable global prompts for this request search?: boolean; // Enable/disable search mode mode?: string; // Agent mode (default: "chat") } ``` **Return Value (`HookCreateResponse`):** ```typescript interface HookCreateResponse { // Messages to be sent to the assistant (can modify/replace input messages) messages?: Message[]; // Audio configuration (for models that support audio output) audio?: AudioConfig; // Generation parameters (override assistant defaults) temperature?: number; max_tokens?: number; max_completion_tokens?: number; // MCP configuration - add/override MCP servers for this request mcp_servers?: MCPServerConfig[]; // Prompt configuration prompt_preset?: string; // Select prompt preset (e.g., "chat.friendly", "task.analysis") disable_global_prompts?: boolean; // Temporarily disable global prompts for this request // Context adjustments - allow hook to modify context fields connector?: string; // Override connector (call-level) locale?: string; // Override locale (session-level) theme?: string; // Override theme (session-level) route?: string; // Override route (session-level) metadata?: Record; // Override or merge metadata (session-level) // Uses configuration - allow hook to override wrapper configurations uses?: UsesConfig; // Override wrapper configurations for vision, audio, search, and fetch force_uses?: boolean; // Force using Uses tools regardless of model capabilities } // Audio output configuration interface AudioConfig { voice: string; // Voice to use (e.g., "alloy", "echo", "fable", "onyx", "nova", "shimmer") format: string; // Audio format (e.g., "wav", "mp3", "flac", "opus", "pcm16") } // MCP server configuration interface MCPServerConfig { server_id: string; // MCP server ID (required) tools?: string[]; // Tool name filter (empty = all tools) resources?: string[]; // Resource URI filter (empty = all resources) } // Uses wrapper configuration interface UsesConfig { vision?: string; // Vision processing tool. Format: "agent" or "mcp:server_id" audio?: string; // Audio processing tool. Format: "agent" or "mcp:server_id" search?: string; // Search tool. Format: "agent" or "mcp:server_id" fetch?: string; // Fetch/retrieval tool. Format: "agent" or "mcp:server_id" } ``` **Example:** ```javascript function Create(ctx, messages) { // Store data for Next hook ctx.memory.context.Set("user_query", messages[0]?.content); // Modify messages const enhanced_messages = messages.map((msg) => ({ ...msg, content: msg.content + "\n\nPlease be concise.", })); // Return configuration return { messages: enhanced_messages, temperature: 0.7, max_tokens: 2000, }; } ``` ### Next Hook Called after the LLM response and tool calls (if any), or directly after Create Hook if no LLM call is configured. Use this to post-process the response, send custom messages, delegate to another agent, or implement custom response logic. **Signature:** ```typescript function Next( ctx: Context, payload: NextHookPayload, options?: Record ): NextHookResponse | null; ``` **Parameters:** - `ctx`: Context object - `payload`: Object containing: - `options`: Optional call-level options (same structure as Create Hook options) ```typescript interface NextHookPayload { messages: Message[]; // Messages sent to the assistant completion?: CompletionResponse; // LLM response tools?: ToolCallResponse[]; // Tool call results (if any) error?: string; // Error message if LLM call failed } interface CompletionResponse { content: string; // LLM text response tool_calls?: ToolCall[]; // Tool calls requested by LLM usage?: UsageInfo; // Token usage statistics } interface ToolCallResponse { toolcall_id: string; server: string; // MCP server name tool: string; // Tool name arguments?: any; // Arguments passed to tool result?: any; // Tool execution result error?: string; // Error if tool failed } ``` **Return Value (`NextHookResponse`):** ```typescript interface NextHookResponse { // Delegate to another agent (recursive call) // If provided, the current agent will call the target agent delegate?: { agent_id: string; // Required: target agent ID messages: Message[]; // Messages to send to target agent options?: Record; // Optional: call-level options for delegation }; // Custom response data // Will be placed in Response.next field and returned to user // If both delegate and data are null/undefined, standard Response is returned data?: any; // Metadata for debugging and logging metadata?: Record; } ``` **Agent Response Structure:** The agent's `Stream()` method returns a `Response` object: ```typescript interface Response { request_id: string; // Request ID context_id: string; // Context ID trace_id: string; // Trace ID chat_id: string; // Chat ID assistant_id: string; // Assistant ID create?: HookCreateResponse; // Create hook response next?: any; // See below for what this contains completion?: CompletionResponse; // LLM completion response } ``` **Response.next field logic:** - If `NextHookResponse.data` is provided → `Response.next` = custom data - If `NextHookResponse.data` is null/undefined → `Response.next` = entire `NextHookResponse` object - If no Next hook defined → `Response.next` = null **Example:** ```javascript /** * Next Hook - Process LLM response * @param {Context} ctx - Agent context * @param {NextHookPayload} payload - Contains messages, completion, tools, error * @returns {NextHookResponse | null} - Return null for standard response */ function Next(ctx, payload) { const { messages, completion, tools, error } = payload; // Handle errors gracefully if (error) { return { data: { status: "error", message: error, recovery: "Please try again", }, metadata: { error_handled: true }, }; } // Process tool results if any if (tools && tools.length > 0) { const successful = tools.filter((t) => !t.error); const failed = tools.filter((t) => t.error); return { data: { status: "tools_processed", total: tools.length, successful: successful.length, failed: failed.length, results: successful.map((t) => t.result), }, metadata: { has_failures: failed.length > 0 }, }; } // Return custom data based on completion if (completion && completion.content) { return { data: { status: "success", response: completion.content, processed: true, }, metadata: { source: "next_hook" }, }; } // Return null to use standard response return null; } ``` ### Hook Execution Flow See the [Agent Execution Lifecycle](#agent-execution-lifecycle) diagram above for a visual representation. **Key Points:** - **Hooks are optional** - if not defined, the agent uses default behavior - **LLM call is optional** - only executed if the assistant has prompts or MCP servers configured - **Return `null` or `undefined`** from hooks to use default behavior - **Hooks can send messages directly** via `ctx.Send()`, `ctx.SendStream()`, etc. - **Create Hook** runs before LLM call (if any), can modify messages and configure the request - **Next Hook** runs after LLM call and tool execution (if any), can post-process or delegate - Use `ctx.memory.context` to pass data between Create and Next hooks within a request ## Complete Example Here's a comprehensive example demonstrating Create and Next hooks with various Context API features: ```javascript /** * Create Hook - Preprocess messages and configure the request * * @param {Context} ctx - Agent context object * @param {Message[]} messages - Input messages (including history if enabled) * @returns {HookCreateResponse | null} - Configuration for LLM call, or null for defaults */ function Create(ctx, messages) { // Extract user query from the last message const user_query = messages[messages.length - 1]?.content || ""; // Store data in context memory for use in Next hook ctx.memory.context.Set("original_query", user_query); ctx.memory.context.Set("request_time", Date.now()); // Add trace node to show processing in UI const create_node = ctx.trace.Add( { query: user_query }, { label: "Create Hook", type: "preprocessing", icon: "play", description: "Analyzing user request", } ); // Check if user needs search functionality const needs_search = user_query.toLowerCase().includes("search") || user_query.toLowerCase().includes("find"); if (needs_search) { create_node.Info("Search mode enabled"); // Configure MCP servers for search return { messages: messages, mcp_servers: [{ server_id: "search_engine" }], prompt_preset: "search.assistant", metadata: { mode: "search" }, }; } create_node.Complete({ mode: "standard" }); // Return modified messages or configuration return { messages: messages, temperature: 0.7, max_tokens: 2000, }; } /** * Next Hook - Process LLM response and optionally customize output * * @param {Context} ctx - Agent context object * @param {NextHookPayload} payload - Contains messages, completion, tools, error * @returns {NextHookResponse | null} - Custom response, delegation, or null for standard */ function Next(ctx, payload) { const { messages, completion, tools, error } = payload; // Retrieve data from Create hook via context memory const original_query = ctx.memory.context.Get("original_query"); const request_time = ctx.memory.context.Get("request_time"); const duration = Date.now() - request_time; // Create trace node for Next hook processing const next_node = ctx.trace.Add( { completion_length: completion?.content?.length || 0 }, { label: "Next Hook", type: "postprocessing", icon: "check", description: "Processing LLM response", } ); // Handle errors if (error) { next_node.Fail(error); return { data: { status: "error", message: "An error occurred while processing your request", error: error, }, }; } // Process tool call results if (tools && tools.length > 0) { next_node.Info(`Processing ${tools.length} tool results`); const successful = tools.filter((t) => !t.error); const results = successful.map((t) => ({ tool: t.tool, server: t.server, result: t.result, })); // Send streaming message with results const msg_id = ctx.SendStream("## Tool Results\n\n"); results.forEach((r, i) => { ctx.Append(msg_id, `**${i + 1}. ${r.tool}**\n`); ctx.Append(msg_id, `${JSON.stringify(r.result, null, 2)}\n\n`); }); ctx.End(msg_id); next_node.SetMetadata("tools_processed", tools.length); next_node.Complete({ status: "tools_processed" }); return { data: { status: "success", tool_results: results, duration_ms: duration, }, metadata: { processed_by: "next_hook" }, }; } // Check if delegation is needed based on completion content if (completion?.content?.toLowerCase().includes("delegate to specialist")) { next_node.Info("Delegating to specialist agent"); return { delegate: { agent_id: "specialist.agent", messages: [ { role: "system", content: "Handle this specialized request" }, { role: "user", content: original_query }, ], options: { priority: "high" }, }, metadata: { reason: "specialist_needed" }, }; } // Standard processing - add metadata and return next_node.SetMetadata("duration_ms", duration); next_node.Complete({ status: "success" }); // Return null to use standard LLM response // Or return custom data to override return null; } ``` ## Best Practices 1. **Error Handling**: Always wrap Context operations in try-catch blocks 2. **Resource Cleanup**: Only call `ctx.Release()` for manually created Context, not in hooks 3. **Trace Organization**: Create meaningful trace nodes with descriptive labels 4. **Logging Levels**: Use appropriate log levels (Debug for development, Info for progress, Error for failures) 5. **Message IDs**: Let the system auto-generate message IDs unless you need specific tracking 6. **Parallel Operations**: Use `Trace.Parallel()` for concurrent operations to maintain trace clarity 7. **Memory Usage**: Use `ctx.memory.context` for request-scoped data, `ctx.memory.chat` for chat state, `ctx.memory.user` for user preferences 8. **Streaming Messages**: Use `SendStream()` + `Append()` + `End()` for streaming output; use `Send()` for complete messages 9. **Block Grouping**: Only use Block IDs when you need to group multiple messages together (e.g., LLM output + follow-up card) ## Error Handling All Context methods throw exceptions on failure. Always handle errors appropriately: ```javascript try { ctx.Send(message); } catch (error) { ctx.trace.Error("Failed to send message", { error: error.message }); throw error; } ``` ## TypeScript Support For TypeScript projects, the Context types are automatically inferred. You can also import explicit types: ```typescript import { Context, Message, TraceNodeOption } from "@yao/runtime"; interface NextPayload { messages: Message[]; completion: any; tools: any[]; error?: string; } function Next( ctx: Context, payload: NextHookPayload, options?: Record ): NextHookResponse | null { // Your code with full type checking const { messages, completion, tools, error } = payload; // ... } ``` ## See Also - [Agent Hooks Documentation](../hooks/README.md) - [MCP Protocol Specification](../mcp/README.md) - [Trace System Documentation](../../trace/README.md) - [Message Format Specification](../message/README.md) ================================================ FILE: agent/context/RESOURCE_MANAGEMENT.md ================================================ # Context Resource Management This document explains the resource management strategy for Context and Trace objects in JavaScript. ## Overview Both `Context` and `Trace` objects provide two cleanup methods: - **`__release()`** - Internal method called automatically by: - V8 garbage collector (when object is collected) - `Use()` function (immediate cleanup after callback) - **`Release()`** - Public method for explicit manual cleanup: - Called in `try-finally` blocks - Provides immediate resource cleanup - Same implementation as `__release()` - they do the same thing ## Resource Hierarchy When `Context.Release()` is called, it automatically releases: 1. **Trace object** - If present, calls `Trace.__release()` to cleanup: - Go bridge registry entries - Trace manager resources - Background goroutines 2. **Context object** - Releases: - Go bridge registry entry for the Context itself This ensures proper cleanup of the entire resource tree. ## Usage Patterns ### Pattern 1: Automatic Cleanup with `Use()` (Recommended) **Best for**: Most cases, clean code, automatic resource management ```javascript // Context is released automatically after callback Use(Context, contextData, (ctx) => { // Access Trace (released automatically with context) const trace = ctx.Trace const node = trace.Add({ type: "step" }, { label: "Processing" }) trace.Info("Doing work") node.Complete({ result: "done" }) return result }) // ctx.Release() called automatically, which also releases Trace ``` ### Pattern 2: Manual Cleanup with `try-finally` **Best for**: Explicit control, critical memory scenarios ```javascript const ctx = getContext() // or passed as parameter const trace = ctx.Trace try { const node = trace.Add({ type: "step" }, { label: "Processing" }) trace.Info("Doing work") node.Complete({ result: "done" }) return result } finally { // Explicit cleanup (also releases Trace) ctx.Release() } ``` ### Pattern 3: Separate Trace Cleanup **Best for**: When you want to release Trace independently ```javascript const ctx = getContext() const trace = ctx.Trace try { const node = trace.Add({ type: "step" }, { label: "Processing" }) trace.Info("Doing work") node.Complete({ result: "done" }) // Release trace early if needed trace.Release() // Continue using ctx... return result } finally { // Release context (Trace already released, safe to call again) ctx.Release() } ``` ### Pattern 4: No Explicit Cleanup (Not Recommended) **Avoid in production**: Relies on GC, unpredictable timing ```javascript function processData(ctx) { const trace = ctx.Trace const node = trace.Add({ type: "step" }, { label: "Processing" }) trace.Info("Doing work") node.Complete({ result: "done" }) return result // Waits for V8 GC to call __release() - SLOW! } ``` ## No-op Trace Handling When Trace is not initialized, `ctx.Trace` returns a no-op object: - All methods are no-ops (do nothing) - `Release()` is safe to call (no-op) - No errors are thrown - Provides consistent API regardless of trace initialization ```javascript // Works even if Trace is not initialized const ctx = getContext() const trace = ctx.Trace // might be no-op trace.Info("Message") // safe even if no-op trace.Release() // safe even if no-op ctx.Release() // always safe ``` ## Error Handling Cleanup happens even when errors occur: ```javascript const ctx = getContext() try { const trace = ctx.Trace const node = trace.Add({ type: "step" }, { label: "Processing" }) throw new Error("Something went wrong") } finally { // Cleanup still happens ctx.Release() // also releases Trace } ``` With `Use()`: ```javascript try { Use(Context, contextData, (ctx) => { throw new Error("Something went wrong") }) } catch (error) { // Error is caught // ctx.Release() was already called automatically } ``` ## Memory Management ### ✅ Good: Immediate Cleanup ```javascript // Loop with immediate cleanup for (let i = 0; i < 10000; i++) { Use(Context, data, (ctx) => { const trace = ctx.Trace trace.Info(`Processing item ${i}`) // Released immediately after each iteration }) } ``` ### ❌ Bad: Waiting for GC ```javascript // Memory accumulates until GC runs for (let i = 0; i < 10000; i++) { const ctx = getContext() const trace = ctx.Trace trace.Info(`Processing item ${i}`) // No cleanup - may run out of memory! } ``` ## Implementation Details ### Context.Release() / Context.__release() 1. Checks if `ctx.Trace` exists 2. If yes, calls `trace.__release()` to cleanup Trace resources 3. Releases Context from bridge registry 4. Safe to call multiple times (idempotent) 5. Errors in cleanup are silently ignored ### Trace.Release() / Trace.__release() 1. Releases Go manager object from bridge registry 2. Calls `trace.Release(traceID)` to cleanup: - Remove from global trace registry - Stop background goroutines - Free associated resources 3. Safe to call multiple times (idempotent) ### No-op Objects Both no-op Trace and no-op Node provide: - All methods as no-ops - `Release()` and `__release()` methods - Consistent API for error-free operation - Zero memory overhead ## Best Practices 1. **✅ Use `Use()` for automatic cleanup** in most cases 2. **✅ Use `try-finally` with `Release()`** when you need explicit control 3. **✅ Release Context** (which also releases Trace) rather than releasing each separately 4. **✅ Release resources in loops** to prevent memory accumulation 5. **❌ Don't rely on GC** for resource cleanup in production code 6. **❌ Don't worry about calling `Release()` twice** - it's idempotent ## Testing See `jsapi_release_test.go` for comprehensive tests of: - Context Release - Trace Release - Cascading cleanup (Context → Trace) - try-finally pattern - No-op object Release - Error handling with cleanup Run tests: ```bash cd yao go test -v ./agent/context -run Release ``` ================================================ FILE: agent/context/authorized_test.go ================================================ package context_test import ( stdContext "context" "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/openapi/oauth/types" "github.com/yaoapp/yao/test" "github.com/yaoapp/yao/trace" ) func TestContextNew_PreservesAuthorizedInfo(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create authorized info authInfo := &types.AuthorizedInfo{ UserID: "716942074991", TeamID: "565955042879", TenantID: "tenant-001", } // Create context using New() ctx := context.New(stdContext.Background(), authInfo, "test-chat-123") defer ctx.Release() // Verify authorized info is preserved assert.NotNil(t, ctx) assert.NotNil(t, ctx.Authorized) assert.Equal(t, "716942074991", ctx.Authorized.UserID) assert.Equal(t, "565955042879", ctx.Authorized.TeamID) assert.Equal(t, "tenant-001", ctx.Authorized.TenantID) assert.Equal(t, "test-chat-123", ctx.ChatID) } func TestContextTrace_SavesAuthorizedInfo(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create authorized info authInfo := &types.AuthorizedInfo{ UserID: "716942074991", TeamID: "565955042879", TenantID: "tenant-001", } // Create context using New ctx := context.New(stdContext.Background(), authInfo, "test-chat-456") ctx.AssistantID = "test-assistant" ctx.Referer = context.RefererAPI // Initialize stack (required for trace) stack, _, done := context.EnterStack(ctx, "test-assistant", &context.Options{}) ctx.Stack = stack defer done() // Initialize trace manager, err := ctx.Trace() assert.NoError(t, err) assert.NotNil(t, manager) // Get trace info info, err := manager.GetTraceInfo() assert.NoError(t, err) assert.NotNil(t, info) // Verify auth info is saved in trace assert.Equal(t, "716942074991", info.CreatedBy) assert.Equal(t, "565955042879", info.TeamID) assert.Equal(t, "tenant-001", info.TenantID) // Clean up if ctx.Stack != nil && ctx.Stack.TraceID != "" { trace.Release(ctx.Stack.TraceID) trace.Remove(stdContext.Background(), trace.Local, ctx.Stack.TraceID) } } func TestContextNew_NilAuthorized(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create context with nil authorized info (should not panic) ctx := context.New(stdContext.Background(), nil, "test-chat-789") defer ctx.Release() assert.NotNil(t, ctx) assert.Nil(t, ctx.Authorized) assert.Equal(t, "test-chat-789", ctx.ChatID) } ================================================ FILE: agent/context/buffer.go ================================================ package context import ( "sync" "time" "github.com/google/uuid" ) // ============================================================================= // Chat Buffer - Buffers messages and steps during execution for batch saving // ============================================================================= // ChatBuffer buffers messages and resume steps during agent execution // All data is held in memory and batch-written at the end of Stream() type ChatBuffer struct { // Identity chatID string requestID string assistantID string connector string // Current connector ID (for data analysis) mode string // Current chat mode (chat or task) // Message buffer messages []*BufferedMessage msgSequence int // Step buffer (for Resume) steps []*BufferedStep currentStep *BufferedStep stepSequence int // Space snapshot (captured when step starts, for recovery) spaceSnapshot map[string]interface{} mu sync.Mutex } // BufferedMessage represents a message waiting to be saved type BufferedMessage struct { MessageID string `json:"message_id"` ChatID string `json:"chat_id"` RequestID string `json:"request_id,omitempty"` Role string `json:"role"` // "user" or "assistant" Type string `json:"type"` // "text", "image", "loading", "tool_call", "retrieval", etc. Props map[string]interface{} `json:"props"` BlockID string `json:"block_id,omitempty"` ThreadID string `json:"thread_id,omitempty"` AssistantID string `json:"assistant_id,omitempty"` Connector string `json:"connector,omitempty"` // Connector ID used for this message Mode string `json:"mode,omitempty"` // Chat mode used for this message (chat or task) Sequence int `json:"sequence"` Metadata map[string]interface{} `json:"metadata,omitempty"` CreatedAt time.Time `json:"created_at"` IsStreaming bool `json:"-"` // Internal flag: true if message is still streaming (not saved until End) } // BufferedStep represents an execution step waiting to be saved (for Resume) // Only saved when request is interrupted or failed type BufferedStep struct { ResumeID string `json:"resume_id"` ChatID string `json:"chat_id"` RequestID string `json:"request_id"` AssistantID string `json:"assistant_id"` StackID string `json:"stack_id"` StackParentID string `json:"stack_parent_id,omitempty"` StackDepth int `json:"stack_depth"` Type string `json:"type"` // "input", "hook_create", "llm", "tool", "hook_next", "delegate" Status string `json:"status"` // "running", "completed", "failed", "interrupted" Input map[string]interface{} `json:"input,omitempty"` Output map[string]interface{} `json:"output,omitempty"` SpaceSnapshot map[string]interface{} `json:"space_snapshot,omitempty"` Error string `json:"error,omitempty"` Sequence int `json:"sequence"` Metadata map[string]interface{} `json:"metadata,omitempty"` CreatedAt time.Time `json:"created_at"` } // Step status constants (internal use only, not stored in database) const ( StepStatusRunning = "running" StepStatusCompleted = "completed" ) // Step type constants const ( StepTypeInput = "input" StepTypeHookCreate = "hook_create" StepTypeLLM = "llm" StepTypeTool = "tool" StepTypeHookNext = "hook_next" StepTypeDelegate = "delegate" ) // Resume status constants (for database storage) const ( ResumeStatusFailed = "failed" ResumeStatusInterrupted = "interrupted" ) // NewChatBuffer creates a new chat buffer func NewChatBuffer(chatID, requestID, assistantID, connector, mode string) *ChatBuffer { return &ChatBuffer{ chatID: chatID, requestID: requestID, assistantID: assistantID, connector: connector, mode: mode, messages: make([]*BufferedMessage, 0), steps: make([]*BufferedStep, 0), } } // ============================================================================= // Message Buffer Methods // ============================================================================= // AddMessage adds a message to the buffer func (b *ChatBuffer) AddMessage(msg *BufferedMessage) { if msg == nil { return } b.mu.Lock() defer b.mu.Unlock() // Auto-generate IDs if not provided if msg.MessageID == "" { msg.MessageID = uuid.New().String() } if msg.ChatID == "" { msg.ChatID = b.chatID } if msg.RequestID == "" { msg.RequestID = b.requestID } if msg.CreatedAt.IsZero() { msg.CreatedAt = time.Now() } // Set mode from buffer if not provided if msg.Mode == "" && b.mode != "" { msg.Mode = b.mode } // Auto-increment sequence b.msgSequence++ msg.Sequence = b.msgSequence b.messages = append(b.messages, msg) } // AddUserInput adds user input message to the buffer func (b *ChatBuffer) AddUserInput(content interface{}, name string) { props := map[string]interface{}{ "content": content, "role": "user", } if name != "" { props["name"] = name } b.AddMessage(&BufferedMessage{ Role: "user", Type: "user_input", Props: props, }) } // AddAssistantMessage adds an assistant message to the buffer // This is called by ctx.Send() to buffer messages for batch saving func (b *ChatBuffer) AddAssistantMessage(messageID, msgType string, props map[string]interface{}, blockID, threadID, assistantID string, metadata map[string]interface{}) { // Skip event type messages (transient, not stored) if msgType == "event" { return } b.AddMessage(&BufferedMessage{ MessageID: messageID, // Use the same MessageID as sent to client Role: "assistant", Type: msgType, Props: props, BlockID: blockID, ThreadID: threadID, AssistantID: assistantID, Connector: b.connector, // Use current connector Metadata: metadata, }) } // AddStreamingMessage adds a streaming message to the buffer // Streaming messages are not saved until CompleteStreamingMessage is called // This is called by ctx.SendStream() to start a streaming message func (b *ChatBuffer) AddStreamingMessage(messageID, msgType string, props map[string]interface{}, blockID, threadID, assistantID string, metadata map[string]interface{}) { // Skip event type messages (transient, not stored) if msgType == "event" { return } // Deep copy props to avoid mutation issues propsCopy := make(map[string]interface{}) for k, v := range props { propsCopy[k] = v } b.AddMessage(&BufferedMessage{ MessageID: messageID, // Use provided message ID Role: "assistant", Type: msgType, Props: propsCopy, BlockID: blockID, ThreadID: threadID, AssistantID: assistantID, Connector: b.connector, Metadata: metadata, IsStreaming: true, // Mark as streaming }) } // AppendMessageContent appends content to a streaming message // This is called by ctx.Append() to accumulate content func (b *ChatBuffer) AppendMessageContent(messageID string, content string) bool { b.mu.Lock() defer b.mu.Unlock() // Find the message by ID for _, msg := range b.messages { if msg.MessageID == messageID && msg.IsStreaming { // Append to existing content if msg.Props == nil { msg.Props = make(map[string]interface{}) } if existing, ok := msg.Props["content"].(string); ok { msg.Props["content"] = existing + content } else { msg.Props["content"] = content } return true } } return false } // CompleteStreamingMessage marks a streaming message as complete // This is called by ctx.End() to finalize the message // Returns the complete content for the message_end event func (b *ChatBuffer) CompleteStreamingMessage(messageID string) (string, bool) { b.mu.Lock() defer b.mu.Unlock() // Find the message by ID for _, msg := range b.messages { if msg.MessageID == messageID && msg.IsStreaming { msg.IsStreaming = false // Return the accumulated content if content, ok := msg.Props["content"].(string); ok { return content, true } return "", true } } return "", false } // GetStreamingMessage returns a streaming message by ID func (b *ChatBuffer) GetStreamingMessage(messageID string) *BufferedMessage { b.mu.Lock() defer b.mu.Unlock() for _, msg := range b.messages { if msg.MessageID == messageID && msg.IsStreaming { return msg } } return nil } // GetMessages returns all buffered messages func (b *ChatBuffer) GetMessages() []*BufferedMessage { b.mu.Lock() defer b.mu.Unlock() result := make([]*BufferedMessage, len(b.messages)) copy(result, b.messages) return result } // GetMessageCount returns the number of buffered messages func (b *ChatBuffer) GetMessageCount() int { b.mu.Lock() defer b.mu.Unlock() return len(b.messages) } // ============================================================================= // Step Buffer Methods (for Resume) // ============================================================================= // BeginStep starts tracking a new execution step // Returns the step for further updates func (b *ChatBuffer) BeginStep(stepType string, input map[string]interface{}, stack *Stack) *BufferedStep { b.mu.Lock() defer b.mu.Unlock() b.stepSequence++ step := &BufferedStep{ ResumeID: uuid.New().String(), ChatID: b.chatID, RequestID: b.requestID, AssistantID: b.assistantID, Type: stepType, Status: StepStatusRunning, Input: input, Sequence: b.stepSequence, CreatedAt: time.Now(), } // Set stack information if available if stack != nil { step.StackID = stack.ID step.StackParentID = stack.ParentID step.StackDepth = stack.Depth } // Capture current space snapshot if b.spaceSnapshot != nil { step.SpaceSnapshot = copyMap(b.spaceSnapshot) } b.steps = append(b.steps, step) b.currentStep = step return step } // CompleteStep marks the current step as completed func (b *ChatBuffer) CompleteStep(output map[string]interface{}) { b.mu.Lock() defer b.mu.Unlock() if b.currentStep != nil { b.currentStep.Output = output b.currentStep.Status = StepStatusCompleted b.currentStep = nil } } // FailCurrentStep marks the current step as failed or interrupted func (b *ChatBuffer) FailCurrentStep(status string, err error) { b.mu.Lock() defer b.mu.Unlock() if b.currentStep != nil && b.currentStep.Status == StepStatusRunning { b.currentStep.Status = status if err != nil { b.currentStep.Error = err.Error() } } } // GetCurrentStep returns the current running step func (b *ChatBuffer) GetCurrentStep() *BufferedStep { b.mu.Lock() defer b.mu.Unlock() return b.currentStep } // GetStepsForResume returns steps that need to be saved for resume // Only returns steps with failed or interrupted status func (b *ChatBuffer) GetStepsForResume(finalStatus string) []*BufferedStep { b.mu.Lock() defer b.mu.Unlock() // If completed successfully, no steps need to be saved if finalStatus == StepStatusCompleted { return nil } // Mark current running step with final status if b.currentStep != nil && b.currentStep.Status == StepStatusRunning { b.currentStep.Status = finalStatus } // Return all steps (they will all have the context for recovery) result := make([]*BufferedStep, len(b.steps)) copy(result, b.steps) return result } // GetAllSteps returns all buffered steps (for debugging/testing) func (b *ChatBuffer) GetAllSteps() []*BufferedStep { b.mu.Lock() defer b.mu.Unlock() result := make([]*BufferedStep, len(b.steps)) copy(result, b.steps) return result } // ============================================================================= // Space Snapshot Methods // ============================================================================= // SetSpaceSnapshot sets the space snapshot for recovery // Should be called when space data changes func (b *ChatBuffer) SetSpaceSnapshot(snapshot map[string]interface{}) { b.mu.Lock() defer b.mu.Unlock() b.spaceSnapshot = copyMap(snapshot) } // GetSpaceSnapshot returns the current space snapshot func (b *ChatBuffer) GetSpaceSnapshot() map[string]interface{} { b.mu.Lock() defer b.mu.Unlock() return copyMap(b.spaceSnapshot) } // ============================================================================= // Identity Methods // ============================================================================= // ChatID returns the chat ID func (b *ChatBuffer) ChatID() string { return b.chatID } // RequestID returns the request ID func (b *ChatBuffer) RequestID() string { return b.requestID } // AssistantID returns the assistant ID func (b *ChatBuffer) AssistantID() string { return b.assistantID } // SetAssistantID updates the assistant ID (for A2A calls) func (b *ChatBuffer) SetAssistantID(assistantID string) { b.mu.Lock() defer b.mu.Unlock() b.assistantID = assistantID } // Connector returns the current connector ID func (b *ChatBuffer) Connector() string { return b.connector } // SetConnector updates the connector ID (when user switches connector) func (b *ChatBuffer) SetConnector(connector string) { b.mu.Lock() defer b.mu.Unlock() b.connector = connector } // Mode returns the current chat mode func (b *ChatBuffer) Mode() string { return b.mode } // SetMode updates the chat mode (when user switches mode) func (b *ChatBuffer) SetMode(mode string) { b.mu.Lock() defer b.mu.Unlock() b.mode = mode } // ============================================================================= // Helper Functions // ============================================================================= // copyMap creates a shallow copy of a map func copyMap(src map[string]interface{}) map[string]interface{} { if src == nil { return nil } dst := make(map[string]interface{}, len(src)) for k, v := range src { dst[k] = v } return dst } ================================================ FILE: agent/context/buffer_test.go ================================================ package context_test import ( "fmt" "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/context" ) // ============================================================================= // ChatBuffer Creation Tests // ============================================================================= func TestBufferNewChatBuffer(t *testing.T) { t.Run("CreateWithAllFields", func(t *testing.T) { buffer := context.NewChatBuffer("chat-123", "req-456", "assistant-789", "", "") assert.NotNil(t, buffer) assert.Equal(t, "chat-123", buffer.ChatID()) assert.Equal(t, "req-456", buffer.RequestID()) assert.Equal(t, "assistant-789", buffer.AssistantID()) assert.Empty(t, buffer.GetMessages()) assert.Empty(t, buffer.GetAllSteps()) assert.Equal(t, 0, buffer.GetMessageCount()) }) t.Run("CreateWithEmptyFields", func(t *testing.T) { buffer := context.NewChatBuffer("", "", "", "", "") assert.NotNil(t, buffer) assert.Empty(t, buffer.ChatID()) assert.Empty(t, buffer.RequestID()) assert.Empty(t, buffer.AssistantID()) }) } // ============================================================================= // Message Buffer Tests // ============================================================================= func TestBufferAddMessage(t *testing.T) { t.Run("AddSingleMessage", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "", "") msg := &context.BufferedMessage{ Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "Hello"}, } buffer.AddMessage(msg) messages := buffer.GetMessages() require.Len(t, messages, 1) assert.Equal(t, "assistant", messages[0].Role) assert.Equal(t, "text", messages[0].Type) assert.Equal(t, 1, messages[0].Sequence) assert.NotEmpty(t, messages[0].MessageID) // Auto-generated assert.Equal(t, "chat-1", messages[0].ChatID) assert.Equal(t, "req-1", messages[0].RequestID) assert.False(t, messages[0].CreatedAt.IsZero()) }) t.Run("AddMultipleMessages", func(t *testing.T) { buffer := context.NewChatBuffer("chat-2", "req-2", "assistant-2", "", "") for i := 0; i < 5; i++ { buffer.AddMessage(&context.BufferedMessage{ Role: "assistant", Type: "text", Props: map[string]interface{}{"content": fmt.Sprintf("Message %d", i+1)}, }) } messages := buffer.GetMessages() require.Len(t, messages, 5) // Verify sequence numbers for i, msg := range messages { assert.Equal(t, i+1, msg.Sequence) } }) t.Run("AddNilMessage", func(t *testing.T) { buffer := context.NewChatBuffer("chat-3", "req-3", "assistant-3", "", "") buffer.AddMessage(nil) assert.Equal(t, 0, buffer.GetMessageCount()) }) t.Run("AddMessageWithExistingID", func(t *testing.T) { buffer := context.NewChatBuffer("chat-4", "req-4", "assistant-4", "", "") msg := &context.BufferedMessage{ MessageID: "custom-id-123", Role: "assistant", Type: "text", } buffer.AddMessage(msg) messages := buffer.GetMessages() require.Len(t, messages, 1) assert.Equal(t, "custom-id-123", messages[0].MessageID) // Preserved }) t.Run("AddMessageWithExistingTimestamp", func(t *testing.T) { buffer := context.NewChatBuffer("chat-5", "req-5", "assistant-5", "", "") customTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) msg := &context.BufferedMessage{ Role: "assistant", Type: "text", CreatedAt: customTime, } buffer.AddMessage(msg) messages := buffer.GetMessages() require.Len(t, messages, 1) assert.Equal(t, customTime, messages[0].CreatedAt) // Preserved }) } func TestBufferAddUserInput(t *testing.T) { t.Run("AddStringContent", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "", "") buffer.AddUserInput("What is the weather?", "") messages := buffer.GetMessages() require.Len(t, messages, 1) assert.Equal(t, "user", messages[0].Role) assert.Equal(t, "user_input", messages[0].Type) assert.Equal(t, "What is the weather?", messages[0].Props["content"]) assert.Equal(t, "user", messages[0].Props["role"]) }) t.Run("AddUserInputWithName", func(t *testing.T) { buffer := context.NewChatBuffer("chat-2", "req-2", "assistant-2", "", "") buffer.AddUserInput("Hello", "John") messages := buffer.GetMessages() require.Len(t, messages, 1) assert.Equal(t, "John", messages[0].Props["name"]) }) t.Run("AddComplexContent", func(t *testing.T) { buffer := context.NewChatBuffer("chat-3", "req-3", "assistant-3", "", "") complexContent := []map[string]interface{}{ {"type": "text", "text": "Look at this image"}, {"type": "image_url", "image_url": map[string]string{"url": "https://example.com/image.jpg"}}, } buffer.AddUserInput(complexContent, "") messages := buffer.GetMessages() require.Len(t, messages, 1) content, ok := messages[0].Props["content"].([]map[string]interface{}) require.True(t, ok) assert.Len(t, content, 2) }) } func TestBufferAddAssistantMessage(t *testing.T) { t.Run("AddTextMessage", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "", "") buffer.AddAssistantMessage( "M1", "text", map[string]interface{}{"content": "Hello, how can I help?"}, "block-1", "thread-1", "assistant-1", map[string]interface{}{"model": "gpt-4"}, ) messages := buffer.GetMessages() require.Len(t, messages, 1) assert.Equal(t, "M1", messages[0].MessageID) assert.Equal(t, "assistant", messages[0].Role) assert.Equal(t, "text", messages[0].Type) assert.Equal(t, "block-1", messages[0].BlockID) assert.Equal(t, "thread-1", messages[0].ThreadID) assert.Equal(t, "assistant-1", messages[0].AssistantID) assert.Equal(t, "gpt-4", messages[0].Metadata["model"]) }) t.Run("SkipEventMessage", func(t *testing.T) { buffer := context.NewChatBuffer("chat-2", "req-2", "assistant-2", "", "") buffer.AddAssistantMessage( "E1", "event", map[string]interface{}{"event": "message_start"}, "", "", "", nil, ) // Event messages should be skipped assert.Equal(t, 0, buffer.GetMessageCount()) }) t.Run("AddRetrievalMessage", func(t *testing.T) { buffer := context.NewChatBuffer("chat-3", "req-3", "assistant-3", "", "") buffer.AddAssistantMessage( "M2", "retrieval", map[string]interface{}{ "sources": []map[string]interface{}{ {"title": "Doc 1", "score": 0.95}, {"title": "Doc 2", "score": 0.87}, }, }, "block-1", "", "assistant-3", nil, ) messages := buffer.GetMessages() require.Len(t, messages, 1) assert.Equal(t, "retrieval", messages[0].Type) }) t.Run("AddToolCallMessage", func(t *testing.T) { buffer := context.NewChatBuffer("chat-4", "req-4", "assistant-4", "", "") buffer.AddAssistantMessage( "M3", "tool_call", map[string]interface{}{ "name": "get_weather", "arguments": `{"location": "San Francisco"}`, }, "block-1", "", "assistant-4", nil, ) messages := buffer.GetMessages() require.Len(t, messages, 1) assert.Equal(t, "tool_call", messages[0].Type) assert.Equal(t, "get_weather", messages[0].Props["name"]) }) t.Run("AddCustomTypeMessage", func(t *testing.T) { buffer := context.NewChatBuffer("chat-5", "req-5", "assistant-5", "", "") buffer.AddAssistantMessage( "M4", "custom_chart", map[string]interface{}{ "chart_type": "bar", "data": []int{1, 2, 3, 4, 5}, }, "block-1", "", "assistant-5", nil, ) messages := buffer.GetMessages() require.Len(t, messages, 1) assert.Equal(t, "custom_chart", messages[0].Type) }) } func TestBufferGetMessages(t *testing.T) { t.Run("GetMessagesReturnsSliceCopy", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "", "") buffer.AddUserInput("Hello", "") messages1 := buffer.GetMessages() messages2 := buffer.GetMessages() // Slices should be different (copy of slice) // But pointers point to same underlying objects (shallow copy) assert.Len(t, messages1, 1) assert.Len(t, messages2, 1) }) t.Run("GetEmptyMessages", func(t *testing.T) { buffer := context.NewChatBuffer("chat-2", "req-2", "assistant-2", "", "") messages := buffer.GetMessages() assert.NotNil(t, messages) assert.Empty(t, messages) }) } func TestBufferGetMessageCount(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "", "") assert.Equal(t, 0, buffer.GetMessageCount()) buffer.AddUserInput("Message 1", "") assert.Equal(t, 1, buffer.GetMessageCount()) buffer.AddAssistantMessage("M1", "text", map[string]interface{}{"content": "Reply"}, "", "", "", nil) assert.Equal(t, 2, buffer.GetMessageCount()) } // ============================================================================= // Step Buffer Tests (for Resume) // ============================================================================= func TestBufferBeginStep(t *testing.T) { t.Run("BeginStepWithStack", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "", "") stack := &context.Stack{ ID: "stack-123", ParentID: "stack-parent-456", Depth: 2, } step := buffer.BeginStep(context.StepTypeLLM, map[string]interface{}{"prompt": "Hello"}, stack) require.NotNil(t, step) assert.NotEmpty(t, step.ResumeID) assert.Equal(t, "chat-1", step.ChatID) assert.Equal(t, "req-1", step.RequestID) assert.Equal(t, "assistant-1", step.AssistantID) assert.Equal(t, "stack-123", step.StackID) assert.Equal(t, "stack-parent-456", step.StackParentID) assert.Equal(t, 2, step.StackDepth) assert.Equal(t, context.StepTypeLLM, step.Type) assert.Equal(t, context.StepStatusRunning, step.Status) assert.Equal(t, 1, step.Sequence) assert.Equal(t, "Hello", step.Input["prompt"]) assert.False(t, step.CreatedAt.IsZero()) }) t.Run("BeginStepWithNilStack", func(t *testing.T) { buffer := context.NewChatBuffer("chat-2", "req-2", "assistant-2", "", "") step := buffer.BeginStep(context.StepTypeInput, nil, nil) require.NotNil(t, step) assert.Empty(t, step.StackID) assert.Empty(t, step.StackParentID) assert.Equal(t, 0, step.StackDepth) }) t.Run("BeginMultipleSteps", func(t *testing.T) { buffer := context.NewChatBuffer("chat-3", "req-3", "assistant-3", "", "") step1 := buffer.BeginStep(context.StepTypeInput, nil, nil) step2 := buffer.BeginStep(context.StepTypeHookCreate, nil, nil) step3 := buffer.BeginStep(context.StepTypeLLM, nil, nil) assert.Equal(t, 1, step1.Sequence) assert.Equal(t, 2, step2.Sequence) assert.Equal(t, 3, step3.Sequence) steps := buffer.GetAllSteps() require.Len(t, steps, 3) }) t.Run("BeginStepWithSpaceSnapshot", func(t *testing.T) { buffer := context.NewChatBuffer("chat-4", "req-4", "assistant-4", "", "") // Set space snapshot before beginning step buffer.SetSpaceSnapshot(map[string]interface{}{ "key1": "value1", "key2": 42, }) step := buffer.BeginStep(context.StepTypeLLM, nil, nil) require.NotNil(t, step.SpaceSnapshot) assert.Equal(t, "value1", step.SpaceSnapshot["key1"]) assert.Equal(t, 42, step.SpaceSnapshot["key2"]) }) } func TestBufferCompleteStep(t *testing.T) { t.Run("CompleteCurrentStep", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "", "") buffer.BeginStep(context.StepTypeLLM, map[string]interface{}{"prompt": "Hello"}, nil) buffer.CompleteStep(map[string]interface{}{"response": "Hi there!"}) steps := buffer.GetAllSteps() require.Len(t, steps, 1) assert.Equal(t, context.StepStatusCompleted, steps[0].Status) assert.Equal(t, "Hi there!", steps[0].Output["response"]) assert.Nil(t, buffer.GetCurrentStep()) // Current step cleared }) t.Run("CompleteWithNoCurrentStep", func(t *testing.T) { buffer := context.NewChatBuffer("chat-2", "req-2", "assistant-2", "", "") // Should not panic buffer.CompleteStep(map[string]interface{}{"response": "test"}) assert.Nil(t, buffer.GetCurrentStep()) }) t.Run("CompleteMultipleStepsSequentially", func(t *testing.T) { buffer := context.NewChatBuffer("chat-3", "req-3", "assistant-3", "", "") buffer.BeginStep(context.StepTypeInput, nil, nil) buffer.CompleteStep(map[string]interface{}{"done": true}) buffer.BeginStep(context.StepTypeHookCreate, nil, nil) buffer.CompleteStep(map[string]interface{}{"hook_result": "ok"}) buffer.BeginStep(context.StepTypeLLM, nil, nil) buffer.CompleteStep(map[string]interface{}{"llm_response": "hello"}) steps := buffer.GetAllSteps() require.Len(t, steps, 3) for _, step := range steps { assert.Equal(t, context.StepStatusCompleted, step.Status) } }) } func TestBufferFailCurrentStep(t *testing.T) { t.Run("FailWithError", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "", "") buffer.BeginStep(context.StepTypeLLM, nil, nil) buffer.FailCurrentStep(context.ResumeStatusFailed, fmt.Errorf("API error: rate limit exceeded")) steps := buffer.GetAllSteps() require.Len(t, steps, 1) assert.Equal(t, context.ResumeStatusFailed, steps[0].Status) assert.Equal(t, "API error: rate limit exceeded", steps[0].Error) }) t.Run("FailWithInterrupted", func(t *testing.T) { buffer := context.NewChatBuffer("chat-2", "req-2", "assistant-2", "", "") buffer.BeginStep(context.StepTypeLLM, nil, nil) buffer.FailCurrentStep(context.ResumeStatusInterrupted, nil) steps := buffer.GetAllSteps() require.Len(t, steps, 1) assert.Equal(t, context.ResumeStatusInterrupted, steps[0].Status) assert.Empty(t, steps[0].Error) }) t.Run("FailAlreadyCompletedStep", func(t *testing.T) { buffer := context.NewChatBuffer("chat-3", "req-3", "assistant-3", "", "") buffer.BeginStep(context.StepTypeLLM, nil, nil) buffer.CompleteStep(map[string]interface{}{"done": true}) // Try to fail completed step (should be no-op since currentStep is nil) buffer.FailCurrentStep(context.ResumeStatusFailed, fmt.Errorf("late error")) steps := buffer.GetAllSteps() require.Len(t, steps, 1) assert.Equal(t, context.StepStatusCompleted, steps[0].Status) // Still completed }) t.Run("FailWithNoCurrentStep", func(t *testing.T) { buffer := context.NewChatBuffer("chat-4", "req-4", "assistant-4", "", "") // Should not panic buffer.FailCurrentStep(context.ResumeStatusFailed, fmt.Errorf("error")) }) } func TestBufferGetCurrentStep(t *testing.T) { t.Run("NoCurrentStep", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "", "") assert.Nil(t, buffer.GetCurrentStep()) }) t.Run("HasCurrentStep", func(t *testing.T) { buffer := context.NewChatBuffer("chat-2", "req-2", "assistant-2", "", "") buffer.BeginStep(context.StepTypeLLM, nil, nil) current := buffer.GetCurrentStep() require.NotNil(t, current) assert.Equal(t, context.StepTypeLLM, current.Type) }) t.Run("CurrentStepClearedAfterComplete", func(t *testing.T) { buffer := context.NewChatBuffer("chat-3", "req-3", "assistant-3", "", "") buffer.BeginStep(context.StepTypeLLM, nil, nil) buffer.CompleteStep(nil) assert.Nil(t, buffer.GetCurrentStep()) }) } func TestBufferGetStepsForResume(t *testing.T) { t.Run("CompletedSuccessfully", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "", "") buffer.BeginStep(context.StepTypeInput, nil, nil) buffer.CompleteStep(nil) buffer.BeginStep(context.StepTypeLLM, nil, nil) buffer.CompleteStep(nil) // Completed successfully - no steps need to be saved steps := buffer.GetStepsForResume(context.StepStatusCompleted) assert.Nil(t, steps) }) t.Run("FailedRequest", func(t *testing.T) { buffer := context.NewChatBuffer("chat-2", "req-2", "assistant-2", "", "") buffer.BeginStep(context.StepTypeInput, nil, nil) buffer.CompleteStep(nil) buffer.BeginStep(context.StepTypeLLM, nil, nil) // Step still running when failure occurs steps := buffer.GetStepsForResume(context.ResumeStatusFailed) require.NotNil(t, steps) assert.Len(t, steps, 2) // Current step should be marked as failed assert.Equal(t, context.ResumeStatusFailed, steps[1].Status) }) t.Run("InterruptedRequest", func(t *testing.T) { buffer := context.NewChatBuffer("chat-3", "req-3", "assistant-3", "", "") buffer.BeginStep(context.StepTypeInput, nil, nil) buffer.CompleteStep(nil) buffer.BeginStep(context.StepTypeHookCreate, nil, nil) buffer.CompleteStep(nil) buffer.BeginStep(context.StepTypeLLM, nil, nil) // Interrupted during LLM steps := buffer.GetStepsForResume(context.ResumeStatusInterrupted) require.NotNil(t, steps) assert.Len(t, steps, 3) assert.Equal(t, context.ResumeStatusInterrupted, steps[2].Status) }) } func TestBufferGetAllSteps(t *testing.T) { t.Run("GetStepsReturnsSliceCopy", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "", "") buffer.BeginStep(context.StepTypeLLM, nil, nil) steps1 := buffer.GetAllSteps() steps2 := buffer.GetAllSteps() // Slices should be different (copy of slice) assert.Len(t, steps1, 1) assert.Len(t, steps2, 1) }) t.Run("GetEmptySteps", func(t *testing.T) { buffer := context.NewChatBuffer("chat-2", "req-2", "assistant-2", "", "") steps := buffer.GetAllSteps() assert.NotNil(t, steps) assert.Empty(t, steps) }) } // ============================================================================= // Space Snapshot Tests // ============================================================================= func TestBufferSpaceSnapshot(t *testing.T) { t.Run("SetAndGetSnapshot", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "", "") snapshot := map[string]interface{}{ "user_id": "user-123", "session": map[string]interface{}{"token": "abc"}, "counter": 42, "is_active": true, } buffer.SetSpaceSnapshot(snapshot) retrieved := buffer.GetSpaceSnapshot() assert.Equal(t, "user-123", retrieved["user_id"]) assert.Equal(t, 42, retrieved["counter"]) assert.Equal(t, true, retrieved["is_active"]) }) t.Run("SnapshotIsCopy", func(t *testing.T) { buffer := context.NewChatBuffer("chat-2", "req-2", "assistant-2", "", "") original := map[string]interface{}{"key": "original"} buffer.SetSpaceSnapshot(original) // Modify original original["key"] = "modified" // Buffer should have original value retrieved := buffer.GetSpaceSnapshot() assert.Equal(t, "original", retrieved["key"]) }) t.Run("GetSnapshotReturnsCopy", func(t *testing.T) { buffer := context.NewChatBuffer("chat-3", "req-3", "assistant-3", "", "") buffer.SetSpaceSnapshot(map[string]interface{}{"key": "value"}) retrieved1 := buffer.GetSpaceSnapshot() retrieved1["key"] = "modified" retrieved2 := buffer.GetSpaceSnapshot() assert.Equal(t, "value", retrieved2["key"]) // Original unchanged }) t.Run("GetNilSnapshot", func(t *testing.T) { buffer := context.NewChatBuffer("chat-4", "req-4", "assistant-4", "", "") snapshot := buffer.GetSpaceSnapshot() assert.Nil(t, snapshot) }) t.Run("SetNilSnapshot", func(t *testing.T) { buffer := context.NewChatBuffer("chat-5", "req-5", "assistant-5", "", "") buffer.SetSpaceSnapshot(map[string]interface{}{"key": "value"}) buffer.SetSpaceSnapshot(nil) snapshot := buffer.GetSpaceSnapshot() assert.Nil(t, snapshot) }) } // ============================================================================= // Identity Methods Tests // ============================================================================= func TestBufferIdentityMethods(t *testing.T) { t.Run("SetAssistantID", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-original", "", "") assert.Equal(t, "assistant-original", buffer.AssistantID()) buffer.SetAssistantID("assistant-new") assert.Equal(t, "assistant-new", buffer.AssistantID()) }) t.Run("ChatID", func(t *testing.T) { buffer := context.NewChatBuffer("chat-test", "req-test", "assistant-test", "", "") assert.Equal(t, "chat-test", buffer.ChatID()) }) t.Run("RequestID", func(t *testing.T) { buffer := context.NewChatBuffer("chat-test", "req-test", "assistant-test", "", "") assert.Equal(t, "req-test", buffer.RequestID()) }) t.Run("Connector", func(t *testing.T) { buffer := context.NewChatBuffer("chat-test", "req-test", "assistant-test", "openai", "") assert.Equal(t, "openai", buffer.Connector()) }) t.Run("SetConnector", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "openai", "") assert.Equal(t, "openai", buffer.Connector()) // Simulate user switching connector mid-conversation buffer.SetConnector("anthropic") assert.Equal(t, "anthropic", buffer.Connector()) }) t.Run("EmptyConnector", func(t *testing.T) { buffer := context.NewChatBuffer("chat-test", "req-test", "assistant-test", "", "") assert.Equal(t, "", buffer.Connector()) }) } func TestBufferConnectorInMessages(t *testing.T) { t.Run("MessageInheritsConnector", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "openai", "") // Add assistant message - should inherit connector from buffer buffer.AddAssistantMessage( "M1", "text", map[string]interface{}{"content": "Hello"}, "block-1", "thread-1", "assistant-1", nil, ) messages := buffer.GetMessages() require.Len(t, messages, 1) assert.Equal(t, "openai", messages[0].Connector) }) t.Run("MessageConnectorUpdatesWithBuffer", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "openai", "") // First message with openai buffer.AddAssistantMessage( "M1", "text", map[string]interface{}{"content": "Using OpenAI"}, "", "", "assistant-1", nil, ) // User switches connector buffer.SetConnector("anthropic") // Second message with anthropic buffer.AddAssistantMessage( "M2", "text", map[string]interface{}{"content": "Now using Claude"}, "", "", "assistant-1", nil, ) messages := buffer.GetMessages() require.Len(t, messages, 2) assert.Equal(t, "openai", messages[0].Connector, "First message should use openai") assert.Equal(t, "anthropic", messages[1].Connector, "Second message should use anthropic") }) t.Run("UserInputDoesNotSetConnector", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "openai", "") // User input doesn't have connector (it's set by the system based on which model processes it) buffer.AddUserInput("Hello", "") messages := buffer.GetMessages() require.Len(t, messages, 1) // User input messages don't have connector field set by AddUserInput // Connector is only set for assistant messages assert.Equal(t, "", messages[0].Connector) }) t.Run("MultipleConnectorSwitches", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "openai", "") // Simulate a conversation with multiple connector switches connectors := []string{"openai", "anthropic", "openai", "google"} for i, conn := range connectors { buffer.SetConnector(conn) buffer.AddAssistantMessage( fmt.Sprintf("M%d", i+1), "text", map[string]interface{}{"content": fmt.Sprintf("Message %d", i+1)}, "", "", "assistant-1", nil, ) } messages := buffer.GetMessages() require.Len(t, messages, 4) for i, msg := range messages { assert.Equal(t, connectors[i], msg.Connector, "Message %d should have connector %s", i+1, connectors[i]) } }) } // ============================================================================= // Concurrency Tests // ============================================================================= func TestBufferConcurrentMessageOperations(t *testing.T) { buffer := context.NewChatBuffer("chat-concurrent", "req-concurrent", "assistant-concurrent", "", "") var wg sync.WaitGroup numGoroutines := 100 // Concurrent writes for i := 0; i < numGoroutines; i++ { wg.Add(1) go func(idx int) { defer wg.Done() buffer.AddMessage(&context.BufferedMessage{ Role: "assistant", Type: "text", Props: map[string]interface{}{"content": fmt.Sprintf("Message %d", idx)}, }) }(i) } wg.Wait() // Verify all messages were added messages := buffer.GetMessages() assert.Len(t, messages, numGoroutines) // Verify sequences are unique sequences := make(map[int]bool) for _, msg := range messages { assert.False(t, sequences[msg.Sequence], "Duplicate sequence found: %d", msg.Sequence) sequences[msg.Sequence] = true } } func TestBufferConcurrentStepOperations(t *testing.T) { buffer := context.NewChatBuffer("chat-concurrent", "req-concurrent", "assistant-concurrent", "", "") var wg sync.WaitGroup numGoroutines := 50 // Concurrent step operations for i := 0; i < numGoroutines; i++ { wg.Add(1) go func(idx int) { defer wg.Done() buffer.BeginStep(context.StepTypeLLM, map[string]interface{}{"idx": idx}, nil) time.Sleep(time.Millisecond) // Simulate some work buffer.CompleteStep(map[string]interface{}{"result": idx}) }(i) } wg.Wait() // Verify all steps were recorded steps := buffer.GetAllSteps() assert.Len(t, steps, numGoroutines) } func TestBufferConcurrentReadWrite(t *testing.T) { buffer := context.NewChatBuffer("chat-rw", "req-rw", "assistant-rw", "", "") var wg sync.WaitGroup done := make(chan bool) // Writer goroutine wg.Add(1) go func() { defer wg.Done() for i := 0; i < 100; i++ { buffer.AddMessage(&context.BufferedMessage{ Role: "assistant", Type: "text", Props: map[string]interface{}{"content": fmt.Sprintf("Message %d", i)}, }) time.Sleep(time.Microsecond) } }() // Reader goroutine wg.Add(1) go func() { defer wg.Done() for { select { case <-done: return default: _ = buffer.GetMessages() _ = buffer.GetMessageCount() time.Sleep(time.Microsecond) } } }() // Let it run for a bit time.Sleep(50 * time.Millisecond) close(done) wg.Wait() // Should complete without race conditions assert.Equal(t, 100, buffer.GetMessageCount()) } // ============================================================================= // Step Type Constants Tests // ============================================================================= func TestBufferStepTypeConstants(t *testing.T) { // Verify all step types are defined assert.Equal(t, "input", context.StepTypeInput) assert.Equal(t, "hook_create", context.StepTypeHookCreate) assert.Equal(t, "llm", context.StepTypeLLM) assert.Equal(t, "tool", context.StepTypeTool) assert.Equal(t, "hook_next", context.StepTypeHookNext) assert.Equal(t, "delegate", context.StepTypeDelegate) } func TestBufferResumeStatusConstants(t *testing.T) { assert.Equal(t, "failed", context.ResumeStatusFailed) assert.Equal(t, "interrupted", context.ResumeStatusInterrupted) } func TestBufferStepStatusConstants(t *testing.T) { assert.Equal(t, "running", context.StepStatusRunning) assert.Equal(t, "completed", context.StepStatusCompleted) } // ============================================================================= // Edge Cases and Error Handling Tests // ============================================================================= func TestBufferEdgeCases(t *testing.T) { t.Run("LargeNumberOfMessages", func(t *testing.T) { buffer := context.NewChatBuffer("chat-large", "req-large", "assistant-large", "", "") // Add 10000 messages for i := 0; i < 10000; i++ { buffer.AddMessage(&context.BufferedMessage{ Role: "assistant", Type: "text", Props: map[string]interface{}{"content": fmt.Sprintf("Message %d", i)}, }) } assert.Equal(t, 10000, buffer.GetMessageCount()) messages := buffer.GetMessages() assert.Len(t, messages, 10000) }) t.Run("MessageWithEmptyProps", func(t *testing.T) { buffer := context.NewChatBuffer("chat-empty", "req-empty", "assistant-empty", "", "") buffer.AddMessage(&context.BufferedMessage{ Role: "assistant", Type: "text", Props: nil, }) messages := buffer.GetMessages() require.Len(t, messages, 1) assert.Nil(t, messages[0].Props) }) t.Run("StepWithEmptyInput", func(t *testing.T) { buffer := context.NewChatBuffer("chat-step", "req-step", "assistant-step", "", "") step := buffer.BeginStep(context.StepTypeLLM, nil, nil) assert.Nil(t, step.Input) buffer.CompleteStep(nil) steps := buffer.GetAllSteps() assert.Nil(t, steps[0].Output) }) t.Run("AllMessageTypes", func(t *testing.T) { buffer := context.NewChatBuffer("chat-types", "req-types", "assistant-types", "", "") messageTypes := []string{ "text", "image", "loading", "tool_call", "tool_result", "retrieval", "thinking", "action", "chart", "table", "custom_type_1", "custom_type_2", } for i, msgType := range messageTypes { buffer.AddAssistantMessage(fmt.Sprintf("M%d", i+1), msgType, map[string]interface{}{"type": msgType}, "", "", "", nil) } assert.Equal(t, len(messageTypes), buffer.GetMessageCount()) }) t.Run("AllStepTypes", func(t *testing.T) { buffer := context.NewChatBuffer("chat-step-types", "req-step-types", "assistant-step-types", "", "") stepTypes := []string{ context.StepTypeInput, context.StepTypeHookCreate, context.StepTypeLLM, context.StepTypeTool, context.StepTypeHookNext, context.StepTypeDelegate, } for _, stepType := range stepTypes { buffer.BeginStep(stepType, nil, nil) buffer.CompleteStep(nil) } steps := buffer.GetAllSteps() assert.Len(t, steps, len(stepTypes)) }) } // ============================================================================= // Integration-like Tests (Simulating Real Workflow) // ============================================================================= func TestBufferCompleteWorkflow(t *testing.T) { t.Run("SuccessfulChatFlow", func(t *testing.T) { buffer := context.NewChatBuffer("chat-workflow", "req-workflow", "assistant-main", "", "") // 1. User input buffer.AddUserInput("What's the weather in San Francisco?", "John") buffer.BeginStep(context.StepTypeInput, map[string]interface{}{"content": "What's the weather in San Francisco?"}, nil) buffer.CompleteStep(nil) // 2. Create hook buffer.BeginStep(context.StepTypeHookCreate, nil, nil) buffer.AddAssistantMessage("M1", "thinking", map[string]interface{}{"content": "Processing your request..."}, "block-1", "", "assistant-main", nil) buffer.CompleteStep(nil) // 3. LLM call with tool buffer.BeginStep(context.StepTypeLLM, map[string]interface{}{"model": "gpt-4"}, nil) buffer.AddAssistantMessage("M2", "tool_call", map[string]interface{}{ "name": "get_weather", "arguments": `{"location":"San Francisco"}`, }, "block-2", "", "assistant-main", nil) buffer.CompleteStep(map[string]interface{}{"tool_calls": 1}) // 4. Tool execution buffer.BeginStep(context.StepTypeTool, map[string]interface{}{"tool": "get_weather"}, nil) buffer.AddAssistantMessage("M3", "tool_result", map[string]interface{}{ "result": "72°F, Sunny", }, "block-2", "", "assistant-main", nil) buffer.CompleteStep(map[string]interface{}{"result": "72°F, Sunny"}) // 5. Final LLM response buffer.BeginStep(context.StepTypeLLM, nil, nil) buffer.AddAssistantMessage("M4", "text", map[string]interface{}{ "content": "The weather in San Francisco is currently 72°F and sunny.", }, "block-3", "", "assistant-main", nil) buffer.CompleteStep(nil) // Verify: 1 user_input + 4 assistant messages (thinking, tool_call, tool_result, text) assert.Equal(t, 5, buffer.GetMessageCount()) assert.Len(t, buffer.GetAllSteps(), 5) // 5 steps (no hook_next in this flow) // All steps should be completed steps := buffer.GetStepsForResume(context.StepStatusCompleted) assert.Nil(t, steps) }) t.Run("InterruptedChatFlow", func(t *testing.T) { buffer := context.NewChatBuffer("chat-interrupted", "req-interrupted", "assistant-main", "", "") // Set space snapshot buffer.SetSpaceSnapshot(map[string]interface{}{ "user_context": "previous conversation", "preferences": map[string]interface{}{"language": "en"}, }) // 1. User input buffer.AddUserInput("Generate a long story", "") buffer.BeginStep(context.StepTypeInput, nil, nil) buffer.CompleteStep(nil) // 2. LLM starts generating buffer.BeginStep(context.StepTypeLLM, map[string]interface{}{"model": "gpt-4"}, nil) buffer.AddAssistantMessage("M1", "text", map[string]interface{}{"content": "Once upon a time..."}, "block-1", "", "assistant-main", nil) // User interrupts here! // Get steps for resume steps := buffer.GetStepsForResume(context.ResumeStatusInterrupted) require.NotNil(t, steps) assert.Len(t, steps, 2) // Last step should be interrupted with space snapshot lastStep := steps[len(steps)-1] assert.Equal(t, context.ResumeStatusInterrupted, lastStep.Status) assert.NotNil(t, lastStep.SpaceSnapshot) assert.Equal(t, "previous conversation", lastStep.SpaceSnapshot["user_context"]) }) t.Run("A2ACallWithDelegation", func(t *testing.T) { buffer := context.NewChatBuffer("chat-a2a", "req-a2a", "assistant-main", "", "") mainStack := &context.Stack{ID: "stack-main", Depth: 0} childStack := &context.Stack{ID: "stack-child", ParentID: "stack-main", Depth: 1} // Main assistant starts buffer.BeginStep(context.StepTypeInput, nil, mainStack) buffer.CompleteStep(nil) // Delegate to child assistant buffer.SetAssistantID("assistant-child") buffer.BeginStep(context.StepTypeDelegate, map[string]interface{}{"delegate_to": "assistant-child"}, childStack) // Child assistant messages buffer.AddAssistantMessage("M1", "text", map[string]interface{}{"content": "Child assistant responding"}, "block-child", "", "assistant-child", nil) buffer.CompleteStep(map[string]interface{}{"delegate_result": "success"}) // Return to main assistant buffer.SetAssistantID("assistant-main") buffer.BeginStep(context.StepTypeLLM, nil, mainStack) buffer.AddAssistantMessage("M2", "text", map[string]interface{}{"content": "Main assistant continuing"}, "block-main", "", "assistant-main", nil) buffer.CompleteStep(nil) // Verify messages := buffer.GetMessages() assert.Len(t, messages, 2) assert.Equal(t, "assistant-child", messages[0].AssistantID) assert.Equal(t, "assistant-main", messages[1].AssistantID) steps := buffer.GetAllSteps() assert.Len(t, steps, 3) assert.Equal(t, "stack-child", steps[1].StackID) assert.Equal(t, "stack-main", steps[1].StackParentID) }) t.Run("ConcurrentAgentCalls", func(t *testing.T) { buffer := context.NewChatBuffer("chat-concurrent-a2a", "req-concurrent-a2a", "assistant-main", "", "") // Main assistant spawns multiple concurrent calls buffer.BeginStep(context.StepTypeInput, nil, nil) buffer.CompleteStep(nil) // Simulate concurrent responses with thread IDs var wg sync.WaitGroup for i := 0; i < 3; i++ { wg.Add(1) go func(idx int) { defer wg.Done() threadID := fmt.Sprintf("thread-%d", idx) buffer.AddAssistantMessage( fmt.Sprintf("M%d", idx), "text", map[string]interface{}{"content": fmt.Sprintf("Response from thread %d", idx)}, "block-concurrent", threadID, fmt.Sprintf("assistant-%d", idx), nil, ) }(i) } wg.Wait() messages := buffer.GetMessages() assert.Len(t, messages, 3) // Verify all have same block ID but different thread IDs threadIDs := make(map[string]bool) for _, msg := range messages { assert.Equal(t, "block-concurrent", msg.BlockID) assert.False(t, threadIDs[msg.ThreadID], "Duplicate thread ID") threadIDs[msg.ThreadID] = true } }) } // ============================================================================= // Message Sequence Tests // ============================================================================= func TestBufferMessageSequence(t *testing.T) { t.Run("SequenceAutoIncrement", func(t *testing.T) { buffer := context.NewChatBuffer("chat-seq", "req-seq", "assistant-seq", "", "") for i := 0; i < 10; i++ { buffer.AddMessage(&context.BufferedMessage{ Role: "assistant", Type: "text", }) } messages := buffer.GetMessages() for i, msg := range messages { assert.Equal(t, i+1, msg.Sequence) } }) t.Run("MixedMessageTypes", func(t *testing.T) { buffer := context.NewChatBuffer("chat-mixed", "req-mixed", "assistant-mixed", "", "") buffer.AddUserInput("Hello", "") buffer.AddAssistantMessage("M1", "text", nil, "", "", "", nil) buffer.AddUserInput("Follow up", "") buffer.AddAssistantMessage("M2", "tool_call", nil, "", "", "", nil) messages := buffer.GetMessages() assert.Len(t, messages, 4) for i, msg := range messages { assert.Equal(t, i+1, msg.Sequence) } }) } // ============================================================================= // Step Sequence Tests // ============================================================================= func TestBufferStepSequence(t *testing.T) { t.Run("SequenceAutoIncrement", func(t *testing.T) { buffer := context.NewChatBuffer("chat-step-seq", "req-step-seq", "assistant-step-seq", "", "") for i := 0; i < 5; i++ { buffer.BeginStep(context.StepTypeLLM, nil, nil) buffer.CompleteStep(nil) } steps := buffer.GetAllSteps() for i, step := range steps { assert.Equal(t, i+1, step.Sequence) } }) } // ============================================================================= // Buffer Reset/Clear Tests (if needed in future) // ============================================================================= func TestBufferMultipleRequests(t *testing.T) { t.Run("NewBufferPerRequest", func(t *testing.T) { // Simulate multiple requests with separate buffers buffer1 := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "", "") buffer1.AddUserInput("Request 1", "") buffer2 := context.NewChatBuffer("chat-1", "req-2", "assistant-1", "", "") buffer2.AddUserInput("Request 2", "") // Buffers should be independent assert.Equal(t, 1, buffer1.GetMessageCount()) assert.Equal(t, 1, buffer2.GetMessageCount()) msg1 := buffer1.GetMessages()[0] msg2 := buffer2.GetMessages()[0] assert.Equal(t, "req-1", msg1.RequestID) assert.Equal(t, "req-2", msg2.RequestID) }) } // ============================================================================= // Streaming Message Tests // ============================================================================= func TestBufferStreamingMessage(t *testing.T) { t.Run("AddStreamingMessage", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "openai", "") buffer.AddStreamingMessage( "msg-stream-1", "text", map[string]interface{}{"content": "# Title\n\n"}, "block-1", "thread-1", "assistant-1", nil, ) assert.Equal(t, 1, buffer.GetMessageCount()) // Verify streaming message is added msg := buffer.GetStreamingMessage("msg-stream-1") assert.NotNil(t, msg) assert.Equal(t, "msg-stream-1", msg.MessageID) assert.Equal(t, "text", msg.Type) assert.Equal(t, "# Title\n\n", msg.Props["content"]) assert.True(t, msg.IsStreaming) }) t.Run("AppendMessageContent", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "openai", "") // Add streaming message buffer.AddStreamingMessage( "msg-stream-2", "text", map[string]interface{}{"content": "Initial "}, "", "", "", nil, ) // Append content ok := buffer.AppendMessageContent("msg-stream-2", "Line 1\n") assert.True(t, ok) ok = buffer.AppendMessageContent("msg-stream-2", "Line 2\n") assert.True(t, ok) // Verify accumulated content msg := buffer.GetStreamingMessage("msg-stream-2") assert.NotNil(t, msg) assert.Equal(t, "Initial Line 1\nLine 2\n", msg.Props["content"]) }) t.Run("AppendToNonExistentMessage", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "openai", "") // Try to append to non-existent message ok := buffer.AppendMessageContent("non-existent", "content") assert.False(t, ok) }) t.Run("AppendToCompletedMessage", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "openai", "") // Add and complete streaming message buffer.AddStreamingMessage( "msg-stream-3", "text", map[string]interface{}{"content": "Initial"}, "", "", "", nil, ) buffer.CompleteStreamingMessage("msg-stream-3") // Try to append to completed message (should fail) ok := buffer.AppendMessageContent("msg-stream-3", " more") assert.False(t, ok) }) t.Run("CompleteStreamingMessage", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "openai", "") // Add streaming message buffer.AddStreamingMessage( "msg-stream-4", "text", map[string]interface{}{"content": "Hello "}, "", "", "", nil, ) // Append content buffer.AppendMessageContent("msg-stream-4", "World!") // Complete the message content, ok := buffer.CompleteStreamingMessage("msg-stream-4") assert.True(t, ok) assert.Equal(t, "Hello World!", content) // Message should no longer be streaming msg := buffer.GetStreamingMessage("msg-stream-4") assert.Nil(t, msg) // But should still exist in messages messages := buffer.GetMessages() assert.Equal(t, 1, len(messages)) assert.False(t, messages[0].IsStreaming) }) t.Run("CompleteNonExistentMessage", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "openai", "") content, ok := buffer.CompleteStreamingMessage("non-existent") assert.False(t, ok) assert.Empty(t, content) }) t.Run("StreamingMessageWorkflow", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "deepseek", "") // Simulate a typical streaming workflow: // 1. SendStream sends initial content buffer.AddStreamingMessage( "msg-workflow", "text", map[string]interface{}{"content": "# Available Tests\n\n"}, "block-main", "", "assistant-1", nil, ) // 2. Multiple Append calls add content buffer.AppendMessageContent("msg-workflow", "Send one of these keywords:\n\n") buffer.AppendMessageContent("msg-workflow", "- **basic** - Basic tests\n") buffer.AppendMessageContent("msg-workflow", "- **advanced** - Advanced tests\n") // 3. End completes the message finalContent, ok := buffer.CompleteStreamingMessage("msg-workflow") assert.True(t, ok) expectedContent := "# Available Tests\n\nSend one of these keywords:\n\n- **basic** - Basic tests\n- **advanced** - Advanced tests\n" assert.Equal(t, expectedContent, finalContent) // Verify final message state messages := buffer.GetMessages() assert.Equal(t, 1, len(messages)) assert.Equal(t, "msg-workflow", messages[0].MessageID) assert.Equal(t, "deepseek", messages[0].Connector) // Connector should be set assert.False(t, messages[0].IsStreaming) }) t.Run("MixedStreamingAndRegularMessages", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "openai", "") // Add user input (regular) buffer.AddUserInput("Hello", "user1") // Add streaming assistant message buffer.AddStreamingMessage( "msg-stream", "text", map[string]interface{}{"content": "Hi "}, "", "", "", nil, ) buffer.AppendMessageContent("msg-stream", "there!") buffer.CompleteStreamingMessage("msg-stream") // Add regular assistant message buffer.AddAssistantMessage("M3", "text", map[string]interface{}{"content": "How can I help?"}, "", "", "", nil) // Verify all messages messages := buffer.GetMessages() assert.Equal(t, 3, len(messages)) // Check sequence assert.Equal(t, 1, messages[0].Sequence) assert.Equal(t, 2, messages[1].Sequence) assert.Equal(t, 3, messages[2].Sequence) // Check content assert.Equal(t, "user", messages[0].Role) assert.Equal(t, "Hi there!", messages[1].Props["content"]) assert.Equal(t, "How can I help?", messages[2].Props["content"]) }) t.Run("StreamingMessageWithEmptyInitialContent", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "openai", "") // Add streaming message with nil props buffer.AddStreamingMessage( "msg-empty", "text", nil, "", "", "", nil, ) // Append content buffer.AppendMessageContent("msg-empty", "Content") // Complete content, ok := buffer.CompleteStreamingMessage("msg-empty") assert.True(t, ok) assert.Equal(t, "Content", content) }) t.Run("ConcurrentStreamingOperations", func(t *testing.T) { buffer := context.NewChatBuffer("chat-1", "req-1", "assistant-1", "openai", "") // Add streaming message buffer.AddStreamingMessage( "msg-concurrent", "text", map[string]interface{}{"content": ""}, "", "", "", nil, ) // Concurrent appends with fixed-length content var wg sync.WaitGroup for i := 0; i < 100; i++ { wg.Add(1) go func() { defer wg.Done() buffer.AppendMessageContent("msg-concurrent", "x") }() } wg.Wait() // Complete content, ok := buffer.CompleteStreamingMessage("msg-concurrent") assert.True(t, ok) // Content should have 100 'x' characters assert.Equal(t, 100, len(content)) }) } ================================================ FILE: agent/context/chat.go ================================================ package context import ( "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "time" gonanoid "github.com/matoous/go-nanoid/v2" "github.com/yaoapp/gou/store" ) const ( chatCachePrefix = "chat:messages:" chatCacheTTL = time.Hour * 24 * 7 // 7 days ) // filterNonAssistantMessages returns messages excluding assistant messages func filterNonAssistantMessages(messages []Message) []Message { var filtered []Message for _, msg := range messages { if msg.Role != RoleAssistant { filtered = append(filtered, msg) } } return filtered } // countUserMessages returns the number of user role messages func countUserMessages(messages []Message) int { count := 0 for _, msg := range messages { if msg.Role == RoleUser { count++ } } return count } // GetChatIDByMessages gets or generates a chat ID based on message content // Matching strategy: // - Only non-assistant messages (system, developer, user, tool) are used for matching // - User adds a new message at the end each time // - To detect continuation, we match messages BEFORE the last non-assistant message // - If previous non-assistant messages match cached conversation → same chat // - For first user message (even with system/developer messages): always generate new chat ID func GetChatIDByMessages(cache store.Store, messages []Message) (string, error) { if len(messages) == 0 { return "", fmt.Errorf("messages cannot be empty") } // Filter out assistant messages for matching nonAssistantMessages := filterNonAssistantMessages(messages) // Count user messages to determine matching strategy userMessageCount := countUserMessages(nonAssistantMessages) var chatID string var matched bool // Matching strategy based on user message count: // - 1 user message: generate new chat ID (cannot determine continuation) // - 2+ user messages: match all except last (which is the new user input) if userMessageCount >= 2 { // Match previous messages (all except last non-assistant message) matchMessages := nonAssistantMessages[:len(nonAssistantMessages)-1] hash, err := hashMessages(matchMessages) if err == nil { key := getKey(hash) if cachedID, ok := cache.Get(key); ok { if chatIDStr, ok := cachedID.(string); ok && chatIDStr != "" { chatID = chatIDStr matched = true } } } } // If no match, generate new chat ID if !matched { chatID = GenChatID() } // Cache the current messages for future matching // CacheChatID will handle filtering assistant messages // Next request will have one more message and will try to match current messages _ = CacheChatID(cache, messages, chatID) return chatID, nil } // CacheChatID cache the chat ID with all message prefixes for future matching // It caches ALL prefixes of the message array to enable conversation continuation detection // Assistant messages are automatically filtered out before caching // Example: For messages [A,B,C], it caches hashes for [A], [A,B], and [A,B,C] func CacheChatID(cache store.Store, messages []Message, chatID string) error { if len(messages) == 0 { return fmt.Errorf("messages cannot be empty") } if chatID == "" { return fmt.Errorf("chatID cannot be empty") } // Filter out assistant messages nonAssistantMessages := filterNonAssistantMessages(messages) if len(nonAssistantMessages) == 0 { return fmt.Errorf("no non-assistant messages to cache") } // Cache all prefixes of the non-assistant messages array // This allows detecting conversation continuation when new messages are added for length := 1; length <= len(nonAssistantMessages); length++ { prefix := nonAssistantMessages[:length] hash, err := hashMessages(prefix) if err != nil { continue // Skip this prefix if hashing fails } key := getKey(hash) // Ignore errors for individual cache sets _ = cache.Set(key, chatID, chatCacheTTL) } return nil } // GenChatID generate a new chat ID using NanoID algorithm // safe: optional parameter, reserved for future safe mode implementation (collision detection) func GenChatID(safe ...bool) string { // TODO: Implement safe mode with collision detection when needed // For now, NanoID provides sufficient uniqueness without collision checking // URL-safe alphabet (no ambiguous characters like 0/O, 1/l/I) const alphabet = "23456789ABCDEFGHJKMNPQRSTUVWXYZabcdefghijkmnpqrstuvwxyz" const length = 16 // 16 characters provides good balance of uniqueness and readability id, err := gonanoid.Generate(alphabet, length) if err != nil { // Fallback to timestamp-based ID if NanoID generation fails return fmt.Sprintf("%d", time.Now().UnixNano()) } return id } // getKey generates a cache key for messages func getKey(messageHash string) string { return chatCachePrefix + messageHash } // hashMessage generates a hash for a single message func hashMessage(msg Message) (string, error) { data, err := json.Marshal(msg) if err != nil { return "", err } hash := sha256.Sum256(data) return hex.EncodeToString(hash[:]), nil } // hashMessages generates a combined hash for a slice of messages // Note: Caller is responsible for filtering messages (e.g., removing assistant messages) func hashMessages(messages []Message) (string, error) { if len(messages) == 0 { return "", fmt.Errorf("messages cannot be empty") } var hashes string for _, msg := range messages { hash, err := hashMessage(msg) if err != nil { return "", err } hashes += hash } // Generate final hash from combined hashes finalHash := sha256.Sum256([]byte(hashes)) return hex.EncodeToString(finalHash[:]), nil } ================================================ FILE: agent/context/chat_test.go ================================================ package context_test import ( "testing" "github.com/yaoapp/gou/store" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func getTestCache(t *testing.T) store.Store { cache, err := store.Get("__yao.agent.cache") if err != nil { t.Fatalf("Failed to get cache store: %v", err) } cache.Clear() // Clean before test return cache } func TestGetChatIDByMessages_NewConversation(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cache := getTestCache(t) messages := []context.Message{ { Role: context.RoleUser, Content: "Hello, how are you?", }, } // First request - should generate new chat ID chatID1, err := context.GetChatIDByMessages(cache, messages) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } if chatID1 == "" { t.Fatal("Expected non-empty chat ID") } // Second request with same single user message - should generate DIFFERENT chat ID // (single user message always generates new chat ID to avoid false matches) chatID2, err := context.GetChatIDByMessages(cache, messages) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } if chatID2 == "" { t.Fatal("Expected non-empty chat ID") } // Both should be valid but different (single user message = new conversation each time) if chatID1 == chatID2 { t.Errorf("Expected different chat IDs for single user message, got same ID: %s", chatID1) } } func TestGetChatIDByMessages_ContinuousConversation(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cache := getTestCache(t) // Scenario: User conversation with incrementally added messages // Request 1: [user1] messages1 := []context.Message{ {Role: context.RoleUser, Content: "First message"}, } chatID1, err := context.GetChatIDByMessages(cache, messages1) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } // Request 2: [user1, user2] // For 2 messages, matches last 1 message // Should match chatID1 because last message is cached messages2 := []context.Message{ {Role: context.RoleUser, Content: "First message"}, {Role: context.RoleUser, Content: "Second message"}, } chatID2, err := context.GetChatIDByMessages(cache, messages2) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } if chatID1 != chatID2 { t.Errorf("Expected chatID2 to match chatID1, got %s and %s", chatID2, chatID1) } // Request 3: [user1, user2, user3] // For 3+ messages, matches last 2 messages // Should match chatID2 because last 2 messages are cached messages3 := []context.Message{ {Role: context.RoleUser, Content: "First message"}, {Role: context.RoleUser, Content: "Second message"}, {Role: context.RoleUser, Content: "Third message"}, } chatID3, err := context.GetChatIDByMessages(cache, messages3) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } if chatID2 != chatID3 { t.Errorf("Expected chatID3 to match chatID2, got %s and %s", chatID3, chatID2) } // Request 4: [user1, user2, user3, user4] // Should match chatID3 because last 2 messages are cached messages4 := []context.Message{ {Role: context.RoleUser, Content: "First message"}, {Role: context.RoleUser, Content: "Second message"}, {Role: context.RoleUser, Content: "Third message"}, {Role: context.RoleUser, Content: "Fourth message"}, } chatID4, err := context.GetChatIDByMessages(cache, messages4) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } if chatID3 != chatID4 { t.Errorf("Expected chatID4 to match chatID3, got %s and %s", chatID4, chatID3) } // All should be the same conversation if chatID1 != chatID4 { t.Errorf("Expected all chat IDs to be the same, got %s and %s", chatID1, chatID4) } } func TestGetChatIDByMessages_DifferentConversations(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cache := getTestCache(t) // First conversation messages1 := []context.Message{ { Role: context.RoleUser, Content: "Hello", }, } chatID1, err := context.GetChatIDByMessages(cache, messages1) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } err = context.CacheChatID(cache, messages1, chatID1) if err != nil { t.Fatalf("Failed to cache chat ID: %v", err) } // Different conversation messages2 := []context.Message{ { Role: context.RoleUser, Content: "Goodbye", }, } chatID2, err := context.GetChatIDByMessages(cache, messages2) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } if chatID1 == chatID2 { t.Errorf("Expected different chat IDs for different conversations, got %s", chatID1) } } func TestGetChatIDByMessages_MultiModalContent(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cache := getTestCache(t) // First request with multimodal content messages1 := []context.Message{ { Role: context.RoleUser, Content: []context.ContentPart{ { Type: context.ContentText, Text: "What's in this image?", }, { Type: context.ContentImageURL, ImageURL: &context.ImageURL{ URL: "https://example.com/image.jpg", Detail: context.DetailHigh, }, }, }, }, } chatID1, err := context.GetChatIDByMessages(cache, messages1) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } // Second request - add another message to continue conversation messages2 := append(messages1, context.Message{ Role: context.RoleUser, Content: "Tell me more details", }) chatID2, err := context.GetChatIDByMessages(cache, messages2) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } // Should get same chat ID (continuation) if chatID1 != chatID2 { t.Errorf("Expected same chat ID for multimodal continuation, got %s and %s", chatID1, chatID2) } } func TestGetChatIDByMessages_WithToolCalls(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cache := getTestCache(t) // First request with user message messages1 := []context.Message{ { Role: context.RoleUser, Content: "What's the weather in Tokyo?", }, } chatID1, err := context.GetChatIDByMessages(cache, messages1) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } // Second request - add assistant response and another user message messages2 := []context.Message{ { Role: context.RoleUser, Content: "What's the weather in Tokyo?", }, { Role: context.RoleAssistant, Content: nil, ToolCalls: []context.ToolCall{ { ID: "call_123", Type: context.ToolTypeFunction, Function: context.Function{ Name: "get_weather", Arguments: `{"location":"Tokyo"}`, }, }, }, }, { Role: context.RoleUser, Content: "How about tomorrow?", }, } chatID2, err := context.GetChatIDByMessages(cache, messages2) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } // Should get same chat ID (assistant messages are ignored, so it matches the first user message) if chatID1 != chatID2 { t.Errorf("Expected same chat ID for messages with tool calls, got %s and %s", chatID1, chatID2) } } func TestCacheChatID_EmptyMessages(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cache := getTestCache(t) err := context.CacheChatID(cache, []context.Message{}, "chat_123") if err == nil { t.Error("Expected error for empty messages") } } func TestCacheChatID_EmptyChatID(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cache := getTestCache(t) messages := []context.Message{ { Role: context.RoleUser, Content: "Hello", }, } err := context.CacheChatID(cache, messages, "") if err == nil { t.Error("Expected error for empty chat ID") } } func TestGetChatIDByMessages_EmptyMessages(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cache := getTestCache(t) _, err := context.GetChatIDByMessages(cache, []context.Message{}) if err == nil { t.Error("Expected error for empty messages") } } func TestGenChatID(t *testing.T) { id1 := context.GenChatID() if id1 == "" { t.Error("Expected non-empty chat ID") } // Check length - NanoID with length 16 should produce 16 character strings if len(id1) < 10 { t.Errorf("Expected chat ID to have reasonable length, got %d characters: %s", len(id1), id1) } // Note: We don't test uniqueness here because nano timestamp-based IDs // can occasionally be the same when generated in rapid succession. // The uniqueness is good enough for production use. } ================================================ FILE: agent/context/context.go ================================================ package context import ( "context" "fmt" "sync" "github.com/google/uuid" "github.com/yaoapp/yao/agent/memory" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/openapi/oauth/types" "github.com/yaoapp/yao/trace" traceTypes "github.com/yaoapp/yao/trace/types" ) // Global context registry for interrupt management var ( contextRegistry = &sync.Map{} // map[contextID]*Context ) // New create a new context with basic initialization func New(parent context.Context, authorized *types.AuthorizedInfo, chatID string) *Context { if parent == nil { parent = context.Background() } contextID := generateContextID() // Extract user and team IDs from authorized info var userID, teamID string if authorized != nil { userID = authorized.UserID teamID = authorized.TeamID } // Create memory instance using global manager mem, _ := memory.GetMemory(userID, teamID, chatID, contextID) ctx := &Context{ Context: parent, ID: contextID, // Generate unique ID for the context Authorized: authorized, // Set authorized info Memory: mem, ChatID: chatID, IDGenerator: message.NewIDGenerator(), // Initialize ID generator for this context messageMetadata: newMessageMetadataStore(), // Initialize message metadata store Logger: NewRequestLogger("", chatID, contextID), // Initialize logger (assistantID set later) } return ctx } // Release the context and clean up all resources including stacks and trace func (ctx *Context) Release() { if ctx.Logger != nil { ctx.Logger.Release() } // Unregister from global registry if ctx.ID != "" { Unregister(ctx.ID) } // Stop interrupt controller if ctx.Interrupt != nil { if ctx.Logger != nil { ctx.Logger.Cleanup("Interrupt controller") } ctx.Interrupt.Stop() ctx.Interrupt = nil } // Complete and release trace if exists. // Only the root context (non-forked) owns the trace lifecycle. // Forked contexts share the same trace manager but must not release it. if ctx.trace != nil && ctx.Stack != nil && ctx.Stack.TraceID != "" { if ctx.ForkParent == nil { if ctx.Logger != nil { ctx.Logger.Cleanup("Trace: " + ctx.Stack.TraceID) } if ctx.Context != nil && ctx.Context.Err() != nil { trace.MarkCancelled(ctx.Stack.TraceID, ctx.Context.Err().Error()) trace.Release(ctx.Stack.TraceID) } else { ctx.trace.MarkComplete() trace.Release(ctx.Stack.TraceID) } } ctx.trace = nil } // Clear context-level memory only (request-scoped temporary data) // User, Team, Chat level memory is persistent and should NOT be cleared if ctx.Memory != nil && ctx.Memory.Context != nil { if ctx.Logger != nil { ctx.Logger.Cleanup("Memory.Context") } ctx.Memory.Context.Clear() } ctx.Memory = nil // Clear stacks if ctx.Stacks != nil { if ctx.Logger != nil { ctx.Logger.Cleanup(fmt.Sprintf("Stacks (%d)", len(ctx.Stacks))) } for k := range ctx.Stacks { delete(ctx.Stacks, k) } ctx.Stacks = nil } // Clear current stack reference ctx.Stack = nil // Close SafeWriter if exists (must be before setting Writer to nil) // This ensures the background goroutine is properly stopped ctx.CloseSafeWriter() // Clear writer reference ctx.Writer = nil // Close logger (MUST be last) if ctx.Logger != nil { ctx.Logger.Close() ctx.Logger = nil } } // GetAuthorizedMap returns the authorized information as a map // This implements the AuthorizedProvider interface for MCP process calls // Allows MCP tools to receive authorization context when called via Process transport func (ctx *Context) GetAuthorizedMap() map[string]interface{} { if ctx.Authorized == nil { return nil } return ctx.Authorized.AuthorizedToMap() } // Fork creates a child context for concurrent agent/LLM calls // The forked context shares read-only resources (Authorized, Cache, Writer) // but has its own independent Stack, Logger, and Memory.Context namespace // to avoid race conditions and state sharing issues. // // This is essential for batch operations (All/Any/Race) where multiple goroutines // need to execute concurrently without interfering with each other's state. // // Key behavior: // - Memory.User, Memory.Team, Memory.Chat are shared (cross-request state) // - Memory.Context is INDEPENDENT (request-scoped state, isolated per fork) // // The forked context does NOT need to be released separately - the parent context // manages shared resources. However, the child's Stack will be collected in parent's Stacks map. func (ctx *Context) Fork() *Context { childID := generateContextID() // Fork memory with independent Context namespace // This prevents parallel sub-agents from sharing ctx.memory.context state var forkedMemory *memory.Memory if ctx.Memory != nil { var err error forkedMemory, err = ctx.Memory.Fork(childID) if err != nil { // Fallback to shared memory if fork fails (log warning) forkedMemory = ctx.Memory } } child := &Context{ // Inherit parent's standard context Context: ctx.Context, // New unique ID for this forked context ID: childID, // Memory with independent Context namespace (see above) Memory: forkedMemory, // Share read-only/thread-safe resources with parent Cache: ctx.Cache, // Cache store is thread-safe Writer: ctx.Writer, // Output writer is thread-safe (output module handles concurrency) Authorized: ctx.Authorized, // Read-only auth info Capabilities: ctx.Capabilities, // Read-only model capabilities // Share reference to parent's Stacks map for trace collection // Child stacks will be added here by EnterStack Stacks: ctx.Stacks, // Stack is nil for forked contexts - will be set by EnterStack // ForkParent stores parent stack info so EnterStack can create child stack Stack: nil, // Create independent resources to avoid race conditions IDGenerator: message.NewIDGenerator(), Logger: NewRequestLogger(ctx.AssistantID, ctx.ChatID, childID, WithParentID(ctx.ID)), messageMetadata: newMessageMetadataStore(), // Inherit context metadata ChatID: ctx.ChatID, AssistantID: ctx.AssistantID, Locale: ctx.Locale, Theme: ctx.Theme, Client: ctx.Client, Referer: ctx.Referer, Accept: ctx.Accept, Route: ctx.Route, Metadata: ctx.Metadata, // Don't inherit these - they are request-specific Buffer: nil, // Buffer belongs to root context Interrupt: nil, // Interrupt controller belongs to root context trace: nil, // Trace will be inherited via TraceID in Stack } // Set ForkParent info if parent has a Stack // This allows EnterStack to create a child stack instead of root stack if ctx.Stack != nil { child.ForkParent = &ForkParentInfo{ StackID: ctx.Stack.ID, TraceID: ctx.Stack.TraceID, Depth: ctx.Stack.Depth, Path: append([]string{}, ctx.Stack.Path...), // Copy path slice } } return child } // Send sends data to the context's writer // This is used by the output module to send messages to the client // func (ctx *Context) Send(data []byte) error { // if ctx.Writer == nil { // return nil // No writer, silently ignore // } // _, err := ctx.Writer.Write(data) // return err // } // Trace returns the trace manager for this context, lazily initialized on first call // Uses the TraceID from ctx.Stack if available, or generates a new one func (ctx *Context) Trace() (traceTypes.Manager, error) { // Return trace if already initialized if ctx.trace != nil { return ctx.trace, nil } // Get TraceID from Stack or generate new one var traceID string if ctx.Stack != nil && ctx.Stack.TraceID != "" { traceID = ctx.Stack.TraceID // Try to load existing trace first manager, err := trace.Load(traceID) if err == nil { // Found in registry, reuse it ctx.trace = manager return manager, nil } } // Get trace configuration from global config cfg := config.Conf // Prepare driver options var driverOptions []any var driverType string switch cfg.Trace.Driver { case "store": driverType = trace.Store if cfg.Trace.Store == "" { return nil, fmt.Errorf("trace store ID not configured") } driverOptions = []any{cfg.Trace.Store, cfg.Trace.Prefix} case "local", "": driverType = trace.Local driverOptions = []any{cfg.Trace.Path} default: return nil, fmt.Errorf("unsupported trace driver: %s", cfg.Trace.Driver) } // Prepare trace options traceOption := &traceTypes.TraceOption{ID: traceID, AutoArchive: config.Conf.Mode == "production"} // Set trace options from authorized information if ctx.Authorized != nil { traceOption.CreatedBy = ctx.Authorized.UserID traceOption.TeamID = ctx.Authorized.TeamID traceOption.TenantID = ctx.Authorized.TenantID } // Create trace using trace.New (handles registry) createdTraceID, manager, err := trace.New(ctx.Context, driverType, traceOption, driverOptions...) if err != nil { return nil, fmt.Errorf("failed to create trace: %w", err) } // Update Stack with the created TraceID if needed if ctx.Stack != nil && ctx.Stack.TraceID == "" { ctx.Stack.TraceID = createdTraceID } // Store for future calls ctx.trace = manager return manager, nil } // Map the context to a map func (ctx *Context) Map() map[string]interface{} { data := map[string]interface{}{} // Authorized information if ctx.Authorized != nil { data["authorized"] = ctx.Authorized } if ctx.ChatID != "" { data["chat_id"] = ctx.ChatID } if ctx.AssistantID != "" { data["assistant_id"] = ctx.AssistantID } // Locale information if ctx.Locale != "" { data["locale"] = ctx.Locale } if ctx.Theme != "" { data["theme"] = ctx.Theme } // Request information if ctx.Client.Type != "" || ctx.Client.UserAgent != "" || ctx.Client.IP != "" { data["client"] = map[string]interface{}{ "type": ctx.Client.Type, "user_agent": ctx.Client.UserAgent, "ip": ctx.Client.IP, } } if ctx.Referer != "" { data["referer"] = ctx.Referer } if ctx.Accept != "" { data["accept"] = ctx.Accept } // CUI Context information if ctx.Route != "" { data["route"] = ctx.Route } if len(ctx.Metadata) > 0 { data["metadata"] = ctx.Metadata } return data } // Global Registry Functions // =================================== // Register registers a context to the global registry func Register(ctx *Context) error { if ctx == nil { return fmt.Errorf("context is nil") } if ctx.ID == "" { return fmt.Errorf("context ID is empty") } contextRegistry.Store(ctx.ID, ctx) return nil } // Unregister removes a context from the global registry func Unregister(contextID string) { contextRegistry.Delete(contextID) } // Get retrieves a context from the global registry by ID func Get(contextID string) (*Context, error) { value, ok := contextRegistry.Load(contextID) if !ok { return nil, fmt.Errorf("context not found: %s", contextID) } ctx, ok := value.(*Context) if !ok { return nil, fmt.Errorf("invalid context type") } return ctx, nil } // SendInterrupt sends an interrupt signal to a context by ID // This is the main entry point for external interrupt requests func SendInterrupt(contextID string, signal *InterruptSignal) error { ctx, err := Get(contextID) if err != nil { return err } if ctx.Interrupt == nil { return fmt.Errorf("interrupt controller not initialized for context: %s", contextID) } return ctx.Interrupt.SendSignal(signal) } // generateContextID generates a unique context ID func generateContextID() string { return uuid.New().String() } // RequestID returns a unique request ID using NanoID func (ctx *Context) RequestID() string { return ctx.ID } // TraceID returns the trace ID for the context func (ctx *Context) TraceID() string { if ctx.Stack != nil { return ctx.Stack.TraceID } return "" } // getMessageMetadata retrieves metadata for a message by ID // Returns nil if message metadata is not found func (ctx *Context) getMessageMetadata(messageID string) *MessageMetadata { if ctx.messageMetadata == nil { return nil } return ctx.messageMetadata.getMessage(messageID) } // GetMessageMetadata returns metadata for a message (public version) func (ctx *Context) GetMessageMetadata(messageID string) *MessageMetadata { return ctx.getMessageMetadata(messageID) } // ============================================================================= // Chat Buffer Methods // ============================================================================= // InitBuffer initializes the chat buffer for this context // Should be called at the start of Stream() to begin buffering messages and steps func (ctx *Context) InitBuffer(assistantID, connector, mode string) *ChatBuffer { ctx.Buffer = NewChatBuffer(ctx.ChatID, ctx.RequestID(), assistantID, connector, mode) return ctx.Buffer } // HasBuffer returns true if the buffer is initialized func (ctx *Context) HasBuffer() bool { return ctx.Buffer != nil } // BufferUserInput adds user input to the buffer // Should be called at the start of Stream() to buffer the user's input message func (ctx *Context) BufferUserInput(messages []Message) { if ctx.Buffer == nil { return } for _, msg := range messages { if msg.Role == RoleUser { // Get name if available var name string if msg.Name != nil { name = *msg.Name } ctx.Buffer.AddUserInput(msg.Content, name) } } } // BufferAssistantMessage adds an assistant message to the buffer // Called by ctx.Send() to buffer messages for batch saving func (ctx *Context) BufferAssistantMessage(messageID, msgType string, props map[string]interface{}, blockID, threadID string, metadata map[string]interface{}) { if ctx.Buffer == nil { return } ctx.Buffer.AddAssistantMessage(messageID, msgType, props, blockID, threadID, ctx.AssistantID, metadata) } // BeginStep starts tracking a new execution step // Returns the step for further updates func (ctx *Context) BeginStep(stepType string, input map[string]interface{}) *BufferedStep { if ctx.Buffer == nil { return nil } // Update context memory snapshot before starting step (for recovery) if ctx.Memory != nil && ctx.Memory.Context != nil { ctx.Buffer.SetSpaceSnapshot(ctx.Memory.Context.Snapshot()) } return ctx.Buffer.BeginStep(stepType, input, ctx.Stack) } // CompleteStep marks the current step as completed func (ctx *Context) CompleteStep(output map[string]interface{}) { if ctx.Buffer == nil { return } ctx.Buffer.CompleteStep(output) } // FailCurrentStep marks the current step as failed or interrupted func (ctx *Context) FailCurrentStep(status string, err error) { if ctx.Buffer == nil { return } ctx.Buffer.FailCurrentStep(status, err) } // shouldSkipHistory checks if history saving should be skipped // Returns true if Skip.History is set in the current stack options func (ctx *Context) shouldSkipHistory() bool { if ctx.Stack == nil || ctx.Stack.Options == nil || ctx.Stack.Options.Skip == nil { return false } return ctx.Stack.Options.Skip.History } // IsA2ACall returns true if this is any Agent-to-Agent call (delegate or fork) // A2A calls are identified by Referer being "agent" or "agent_fork": // - ctx.agent.Call/All/Any/Race uses RefererAgentFork (forked context, skips history) // - delegate uses RefererAgent (same context flow, saves history) func (ctx *Context) IsA2ACall() bool { return ctx.Referer == RefererAgent || ctx.Referer == RefererAgentFork } // IsForkedA2ACall returns true if this is a forked A2A call (ctx.agent.Call/All/Any/Race) // Forked calls use RefererAgentFork, while delegate calls use RefererAgent. // This is used to skip history saving for forked sub-agent calls, // while allowing delegate calls to save history normally. func (ctx *Context) IsForkedA2ACall() bool { return ctx.Referer == RefererAgentFork } // MergeMetadata merges the given metadata into ctx.Metadata. // Existing keys are overwritten by incoming values. // This enables A2A callers to pass custom metadata (e.g. oneshot, async) // to sub-agent hooks via ctx.metadata in JavaScript. func (ctx *Context) MergeMetadata(metadata map[string]interface{}) { if len(metadata) == 0 { return } if ctx.Metadata == nil { ctx.Metadata = make(map[string]interface{}, len(metadata)) } for k, v := range metadata { ctx.Metadata[k] = v } } ================================================ FILE: agent/context/context_test.go ================================================ package context_test import ( "bytes" stdContext "context" "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/store" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func TestGetCompletionRequest(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) cache, err := store.Get("__yao.agent.cache") if err != nil { t.Fatalf("Failed to get cache: %v", err) } tests := []struct { name string requestBody map[string]interface{} queryParams map[string]string headers map[string]string expectedModel string expectedMsgCount int expectedTemp *float64 expectedStream *bool expectedLocale string expectedTheme string expectedReferer string expectedAccept context.Accept expectedAssistantID string expectError bool }{ { name: "Complete request from body with metadata", requestBody: map[string]interface{}{ "model": "gpt-4-yao_assistant123", "messages": []map[string]interface{}{ {"role": "user", "content": "Hello"}, }, "temperature": 0.7, "stream": true, "metadata": map[string]string{ "locale": "zh-cn", "theme": "dark", "referer": "process", "accept": "cui-web", "chat_id": "chat-from-metadata", }, }, expectedModel: "gpt-4-yao_assistant123", expectedMsgCount: 1, expectedTemp: floatPtr(0.7), expectedStream: boolPtr(true), expectedLocale: "zh-cn", expectedTheme: "dark", expectedReferer: context.RefererProcess, expectedAccept: context.AcceptWebCUI, expectedAssistantID: "assistant123", expectError: false, }, { name: "Query params override payload metadata", requestBody: map[string]interface{}{ "model": "gpt-4-yao_test456", "messages": []map[string]interface{}{ {"role": "user", "content": "Test"}, }, "metadata": map[string]string{ "locale": "en-us", "theme": "light", }, }, queryParams: map[string]string{ "locale": "fr-FR", "theme": "auto", }, expectedModel: "gpt-4-yao_test456", expectedMsgCount: 1, expectedLocale: "fr-fr", expectedTheme: "auto", expectedReferer: context.RefererAPI, expectedAccept: context.AcceptStandard, expectedAssistantID: "test456", expectError: false, }, { name: "Headers override payload metadata", requestBody: map[string]interface{}{ "model": "gpt-3.5-turbo-yao_header789", "messages": []map[string]interface{}{ {"role": "user", "content": "Test"}, }, "metadata": map[string]string{ "referer": "process", "accept": "cui-web", }, }, headers: map[string]string{ "X-Yao-Referer": "mcp", "X-Yao-Accept": "cui-desktop", }, expectedModel: "gpt-3.5-turbo-yao_header789", expectedMsgCount: 1, expectedLocale: "", expectedTheme: "", expectedReferer: context.RefererMCP, expectedAccept: context.AcceptDesktopCUI, expectedAssistantID: "header789", expectError: false, }, { name: "Minimal request without metadata", requestBody: map[string]interface{}{ "model": "gpt-4o-yao_minimal", "messages": []map[string]interface{}{ {"role": "user", "content": "Hello"}, }, }, expectedModel: "gpt-4o-yao_minimal", expectedMsgCount: 1, expectedLocale: "", expectedTheme: "", expectedReferer: context.RefererAPI, expectedAccept: context.AcceptStandard, expectedAssistantID: "minimal", expectError: false, }, { name: "Missing model", requestBody: map[string]interface{}{ "messages": []map[string]interface{}{ {"role": "user", "content": "Hello"}, }, }, expectError: true, }, { name: "Missing messages", requestBody: map[string]interface{}{ "model": "gpt-4", }, expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) // Build request bodyBytes, _ := json.Marshal(tt.requestBody) req, _ := http.NewRequest("POST", "http://example.com/chat/completions", bytes.NewBuffer(bodyBytes)) req.Header.Set("Content-Type", "application/json") // Add query params q := req.URL.Query() for key, value := range tt.queryParams { q.Add(key, value) } req.URL.RawQuery = q.Encode() // Add headers for key, value := range tt.headers { req.Header.Set(key, value) } c.Request = req // Call GetCompletionRequest completionReq, ctx, opts, err := context.GetCompletionRequest(c, cache) if tt.expectError { assert.Error(t, err) return } assert.NoError(t, err) assert.NotNil(t, completionReq) assert.NotNil(t, ctx) assert.NotNil(t, opts) // Verify CompletionRequest assert.Equal(t, tt.expectedModel, completionReq.Model) assert.Equal(t, tt.expectedMsgCount, len(completionReq.Messages)) if tt.expectedTemp != nil { assert.NotNil(t, completionReq.Temperature) assert.Equal(t, *tt.expectedTemp, *completionReq.Temperature) } if tt.expectedStream != nil { assert.NotNil(t, completionReq.Stream) assert.Equal(t, *tt.expectedStream, *completionReq.Stream) } // Verify Context assert.Equal(t, tt.expectedLocale, ctx.Locale) assert.Equal(t, tt.expectedTheme, ctx.Theme) assert.Equal(t, tt.expectedReferer, ctx.Referer) assert.Equal(t, tt.expectedAccept, ctx.Accept) assert.Equal(t, tt.expectedAssistantID, ctx.AssistantID) assert.NotNil(t, ctx.Memory) assert.NotNil(t, ctx.Cache) }) } } func TestContextNew_WithAuthorized(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create context using New() ctx := context.New(stdContext.Background(), nil, "test-chat-id") defer ctx.Release() assert.NotNil(t, ctx) assert.Equal(t, "test-chat-id", ctx.ChatID) assert.NotNil(t, ctx.Memory) assert.NotNil(t, ctx.IDGenerator) } // Helper functions for context_test package func floatPtr(f float64) *float64 { return &f } func boolPtr(b bool) *bool { return &b } ================================================ FILE: agent/context/grpc.go ================================================ package context import ( "context" "encoding/json" "fmt" "net/http" "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/store" "github.com/yaoapp/yao/openapi/oauth/types" ) // GRPCAgentInput holds the raw inputs from a gRPC AgentStream request. type GRPCAgentInput struct { AssistantID string Messages []byte Options []byte AuthInfo *types.AuthorizedInfo Cache store.Store Writer http.ResponseWriter } // GetGRPCAgentRequest parses a gRPC agent request and creates a Context + Options, // mirroring openapi.go GetCompletionRequest. // // Flow: validate → parse messages → parse options → build Context → build Options → register interrupt func GetGRPCAgentRequest(parent context.Context, input GRPCAgentInput) ([]Message, *Context, *Options, error) { if input.AssistantID == "" { return nil, nil, nil, fmt.Errorf("assistant_id is required") } messages, err := parseGRPCMessages(input.Messages) if err != nil { return nil, nil, nil, err } var rawOpts map[string]interface{} if len(input.Options) > 0 { if err := json.Unmarshal(input.Options, &rawOpts); err != nil { return nil, nil, nil, fmt.Errorf("invalid options JSON: %w", err) } } chatID := getChatIDFromOpts(rawOpts) ctx := New(parent, input.AuthInfo, chatID) ctx.Cache = input.Cache ctx.Writer = input.Writer ctx.AssistantID = input.AssistantID ctx.Locale = getStringOpt(rawOpts, "locale") ctx.Theme = getStringOpt(rawOpts, "theme") ctx.Referer = getRefererOpt(rawOpts) ctx.Accept = getAcceptOpt(rawOpts) ctx.Route = getStringOpt(rawOpts, "route") ctx.Metadata = getMapOpt(rawOpts, "metadata") ctx.Client = Client{Type: "grpc"} opts := &Options{ Context: parent, Skip: getSkipOpt(rawOpts), Mode: getStringOpt(rawOpts, "mode"), } if connectorID := getStringOpt(rawOpts, "connector"); connectorID != "" { if _, err := connector.Select(connectorID); err == nil { opts.Connector = connectorID } } ctx.Interrupt = NewInterruptController() if err := Register(ctx); err != nil { return nil, nil, nil, fmt.Errorf("failed to register context: %w", err) } ctx.Interrupt.Start(ctx.ID) return messages, ctx, opts, nil } func parseGRPCMessages(raw []byte) ([]Message, error) { if len(raw) == 0 { return nil, fmt.Errorf("messages are required") } var messages []Message if err := json.Unmarshal(raw, &messages); err != nil { return nil, fmt.Errorf("invalid messages JSON: %w", err) } if len(messages) == 0 { return nil, fmt.Errorf("messages must not be empty") } return messages, nil } func getChatIDFromOpts(opts map[string]interface{}) string { if opts != nil { if v, ok := opts["chat_id"].(string); ok && v != "" { return v } } return GenChatID() } func getStringOpt(opts map[string]interface{}, key string) string { if opts == nil { return "" } v, _ := opts[key].(string) return v } func getRefererOpt(opts map[string]interface{}) string { r := getStringOpt(opts, "referer") if r != "" { return validateReferer(r) } return RefererAPI } func getAcceptOpt(opts map[string]interface{}) Accept { a := getStringOpt(opts, "accept") if a != "" { return validateAccept(a) } return AcceptStandard } func getMapOpt(opts map[string]interface{}, key string) map[string]interface{} { if opts == nil { return nil } v, _ := opts[key].(map[string]interface{}) return v } func getSkipOpt(opts map[string]interface{}) *Skip { if opts == nil { return nil } raw, ok := opts["skip"] if !ok { return nil } data, err := json.Marshal(raw) if err != nil { return nil } var skip Skip if err := json.Unmarshal(data, &skip); err != nil { return nil } return &skip } ================================================ FILE: agent/context/interfaces.go ================================================ package context import ( "net/http" "github.com/yaoapp/yao/agent/output/message" ) // StreamChunkType represents the type of content in a streaming chunk type StreamChunkType string // Stream chunk type constants - indicates what type of content is in the current chunk const ( // Content chunk types - actual data from the LLM ChunkText StreamChunkType = "text" // Regular text content ChunkThinking StreamChunkType = "thinking" // Reasoning/thinking content (o1, DeepSeek R1) ChunkToolCall StreamChunkType = "tool_call" // Tool/function call ChunkRefusal StreamChunkType = "refusal" // Model refusal ChunkMetadata StreamChunkType = "metadata" // Metadata (usage, finish_reason, etc.) ChunkError StreamChunkType = "error" // Error chunk ChunkUnknown StreamChunkType = "unknown" // Unknown/unrecognized chunk type // Lifecycle event types - stream and message boundaries ChunkStreamStart StreamChunkType = "stream_start" // Stream begins (entire request starts) ChunkStreamEnd StreamChunkType = "stream_end" // Stream ends (entire request completes) ChunkMessageStart StreamChunkType = "message_start" // Message begins (text/tool_call/thinking message starts) ChunkMessageEnd StreamChunkType = "message_end" // Message ends (text/tool_call/thinking message completes) ) // Writer is an alias for http.ResponseWriter interface used by an agent to construct a response. // A Writer may not be used after the agent execution has completed. type Writer = http.ResponseWriter // Agent the agent interface type Agent interface { // Stream stream the agent Stream(ctx *Context, messages []Message, handler message.StreamFunc) error // Run run the agent Run(ctx *Context, messages []Message) (*Response, error) } ================================================ FILE: agent/context/interrupt.go ================================================ package context import ( "context" "fmt" "time" "github.com/yaoapp/kun/log" ) // NewInterruptController creates a new interrupt controller func NewInterruptController() *InterruptController { ctrl := &InterruptController{ queue: make(chan *InterruptSignal, 10), // Buffer for 10 interrupts pending: make([]*InterruptSignal, 0), } ctrl.ctx, ctrl.cancel = context.WithCancel(context.Background()) return ctrl } // Start starts the interrupt listener goroutine func (ic *InterruptController) Start(contextID string) { if ic.listenerStarted { return } ic.mutex.Lock() ic.listenerStarted = true ic.contextID = contextID ic.mutex.Unlock() go ic.listen() } // SetHandler sets the handler for interrupt signals func (ic *InterruptController) SetHandler(handler InterruptHandler) { if ic == nil { return } ic.handler = handler } // listen is the main listener goroutine that processes interrupt signals func (ic *InterruptController) listen() { for { select { case signal := <-ic.queue: // Handle user interrupt signal (stop button, for appending messages) ic.handleSignal(signal) case <-ic.ctx.Done(): // Internal context cancelled, stop listening return } } } // handleSignal processes an interrupt signal func (ic *InterruptController) handleSignal(signal *InterruptSignal) { if signal == nil { return } log.Trace("[INTERRUPT] Signal received: type=%s, messages=%d, timestamp=%d", signal.Type, len(signal.Messages), signal.Timestamp) ic.mutex.Lock() // If no current interrupt, set it as current if ic.current == nil { ic.current = signal } else { // If there's already a current interrupt, add to pending queue ic.pending = append(ic.pending, signal) } // For force interrupt with no messages (pure cancellation), cancel the interrupt context // This allows LLM streaming and other operations to check and stop if signal.Type == InterruptForce && len(signal.Messages) == 0 { if ic.cancel != nil { ic.cancel() // Create a new context for potential future operations ic.ctx, ic.cancel = context.WithCancel(context.Background()) } } ic.mutex.Unlock() // Call the registered handler if available (outside lock to avoid deadlock) if ic.handler != nil && ic.contextID != "" { go func() { // Retrieve the parent context from global registry ctx, err := Get(ic.contextID) if err != nil { fmt.Printf("Failed to get context for interrupt handler: %v\n", err) return } // Call the handler if err := ic.handler(ctx, signal); err != nil { fmt.Printf("Interrupt handler error: %v\n", err) } }() } } // Check checks for current interrupt signal (non-blocking) // Returns the current interrupt and moves to next one if available func (ic *InterruptController) Check() *InterruptSignal { if ic == nil { return nil } ic.mutex.Lock() defer ic.mutex.Unlock() if ic.current == nil { return nil } // Get current interrupt signal := ic.current // Move to next interrupt in queue if len(ic.pending) > 0 { ic.current = ic.pending[0] ic.pending = ic.pending[1:] } else { ic.current = nil } return signal } // CheckWithMerge checks for interrupts and merges all pending messages // This is useful when multiple interrupts should be handled together func (ic *InterruptController) CheckWithMerge() *InterruptSignal { if ic == nil { return nil } ic.mutex.Lock() defer ic.mutex.Unlock() if ic.current == nil { return nil } // If there are pending interrupts, merge all messages if len(ic.pending) > 0 { // Collect all messages allMessages := append([]Message{}, ic.current.Messages...) for _, pending := range ic.pending { allMessages = append(allMessages, pending.Messages...) } // Create merged signal mergedSignal := &InterruptSignal{ Type: ic.current.Type, // Use first signal's type Messages: allMessages, Timestamp: time.Now().UnixMilli(), Metadata: map[string]interface{}{ "merged": true, "merged_count": len(ic.pending) + 1, "original_time": ic.current.Timestamp, }, } // Clear all interrupts ic.current = nil ic.pending = make([]*InterruptSignal, 0) return mergedSignal } // No pending interrupts, return current signal := ic.current ic.current = nil return signal } // Peek returns the current interrupt without removing it func (ic *InterruptController) Peek() *InterruptSignal { if ic == nil { return nil } ic.mutex.RLock() defer ic.mutex.RUnlock() return ic.current } // IsInterrupted checks if interrupt context is cancelled (force interrupt) func (ic *InterruptController) IsInterrupted() bool { if ic == nil || ic.ctx == nil { return false } select { case <-ic.ctx.Done(): return true default: return false } } // Context returns the interrupt control context // This can be used in select statements to check for force interrupts func (ic *InterruptController) Context() context.Context { if ic == nil { return context.Background() } return ic.ctx } // GetPendingCount returns the number of pending interrupts func (ic *InterruptController) GetPendingCount() int { if ic == nil { return 0 } ic.mutex.RLock() defer ic.mutex.RUnlock() count := len(ic.pending) if ic.current != nil { count++ } return count } // Clear clears all interrupts (current and pending) func (ic *InterruptController) Clear() { if ic == nil { return } ic.mutex.Lock() defer ic.mutex.Unlock() ic.current = nil ic.pending = make([]*InterruptSignal, 0) } // Stop stops the interrupt controller and cleans up resources func (ic *InterruptController) Stop() { if ic == nil { return } // Cancel context to stop listener if ic.cancel != nil { ic.cancel() } // Close channel if ic.queue != nil { close(ic.queue) } // Clear interrupts ic.Clear() } // SendSignal sends an interrupt signal to the controller // This is called from external sources (e.g., another HTTP request) func (ic *InterruptController) SendSignal(signal *InterruptSignal) error { if ic == nil { return fmt.Errorf("interrupt controller is nil") } if ic.queue == nil { return fmt.Errorf("interrupt queue is not initialized") } // Non-blocking send select { case ic.queue <- signal: return nil case <-time.After(500 * time.Millisecond): return fmt.Errorf("failed to send interrupt: timeout") } } ================================================ FILE: agent/context/interrupt_test.go ================================================ package context_test import ( stdContext "context" "fmt" "testing" "time" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/openapi/oauth/types" ) // newTestContextWithInterrupt creates a Context with interrupt controller for testing func newTestContextWithInterrupt(chatID, assistantID string) *context.Context { ctx := context.New(stdContext.Background(), &types.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client-id", UserID: "test-user-123", TeamID: "test-team-456", TenantID: "test-tenant-789", SessionID: "test-session-id", }, chatID) ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "TestAgent/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.Route = "/test/route" ctx.Metadata = map[string]interface{}{ "test": "context_metadata", } // Initialize interrupt controller ctx.Interrupt = context.NewInterruptController() // Register context globally if err := context.Register(ctx); err != nil { panic(fmt.Sprintf("Failed to register context: %v", err)) } // Start interrupt listener ctx.Interrupt.Start(ctx.ID) return ctx } // TestInterruptBasic tests basic interrupt signal sending and receiving func TestInterruptBasic(t *testing.T) { // Create context with interrupt support ctx := newTestContextWithInterrupt("chat-test-interrupt", "test-assistant") defer ctx.Release() t.Run("SendGracefulInterrupt", func(t *testing.T) { // Create a graceful interrupt signal signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{ {Role: context.RoleUser, Content: "This is a graceful interrupt"}, }, Timestamp: time.Now().UnixMilli(), } // Send interrupt signal err := context.SendInterrupt(ctx.ID, signal) if err != nil { t.Fatalf("Failed to send interrupt signal: %v", err) } // Wait a bit for the signal to be processed time.Sleep(100 * time.Millisecond) // Check if signal was received receivedSignal := ctx.Interrupt.Peek() if receivedSignal == nil { t.Fatal("Expected interrupt signal to be received, got nil") } if receivedSignal.Type != context.InterruptGraceful { t.Errorf("Expected interrupt type 'graceful', got: %s", receivedSignal.Type) } if len(receivedSignal.Messages) != 1 { t.Errorf("Expected 1 message, got: %d", len(receivedSignal.Messages)) } if receivedSignal.Messages[0].Content != "This is a graceful interrupt" { t.Errorf("Expected message content 'This is a graceful interrupt', got: %s", receivedSignal.Messages[0].Content) } t.Log("✓ Graceful interrupt signal sent and received successfully") }) t.Run("SendForceInterrupt", func(t *testing.T) { // Clear previous signals ctx.Interrupt.Clear() // Create a force interrupt signal signal := &context.InterruptSignal{ Type: context.InterruptForce, Messages: []context.Message{ {Role: context.RoleUser, Content: "This is a force interrupt"}, }, Timestamp: time.Now().UnixMilli(), } // Send interrupt signal err := context.SendInterrupt(ctx.ID, signal) if err != nil { t.Fatalf("Failed to send interrupt signal: %v", err) } // Wait a bit for the signal to be processed time.Sleep(100 * time.Millisecond) // Check if signal was received receivedSignal := ctx.Interrupt.Peek() if receivedSignal == nil { t.Fatal("Expected interrupt signal to be received, got nil") } if receivedSignal.Type != context.InterruptForce { t.Errorf("Expected interrupt type 'force', got: %s", receivedSignal.Type) } t.Log("✓ Force interrupt signal sent and received successfully") }) t.Run("MultipleInterrupts", func(t *testing.T) { // Clear previous signals ctx.Interrupt.Clear() // Send multiple interrupt signals for i := 0; i < 3; i++ { signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{ {Role: context.RoleUser, Content: fmt.Sprintf("Message %d", i+1)}, }, Timestamp: time.Now().UnixMilli(), } err := context.SendInterrupt(ctx.ID, signal) if err != nil { t.Fatalf("Failed to send interrupt signal %d: %v", i+1, err) } } // Wait a bit for signals to be processed time.Sleep(100 * time.Millisecond) // Check pending count pendingCount := ctx.Interrupt.GetPendingCount() if pendingCount != 3 { t.Errorf("Expected 3 pending interrupts, got: %d", pendingCount) } // Check merged signal mergedSignal := ctx.Interrupt.CheckWithMerge() if mergedSignal == nil { t.Fatal("Expected merged signal, got nil") } if len(mergedSignal.Messages) != 3 { t.Errorf("Expected 3 merged messages, got: %d", len(mergedSignal.Messages)) } // Verify all messages are present for i := 0; i < 3; i++ { expectedContent := fmt.Sprintf("Message %d", i+1) if mergedSignal.Messages[i].Content != expectedContent { t.Errorf("Expected message %d content '%s', got: %s", i+1, expectedContent, mergedSignal.Messages[i].Content) } } t.Log("✓ Multiple interrupt signals merged successfully") }) } // TestInterruptHandler tests interrupt handler invocation func TestInterruptHandler(t *testing.T) { // Create context with interrupt support ctx := newTestContextWithInterrupt("chat-test-interrupt-handler", "test-assistant") defer ctx.Release() t.Run("HandlerInvocation", func(t *testing.T) { // Track if handler was called handlerCalled := false var receivedSignal *context.InterruptSignal // Set up handler ctx.Interrupt.SetHandler(func(c *context.Context, signal *context.InterruptSignal) error { handlerCalled = true receivedSignal = signal t.Logf("Handler called with signal type: %s, messages: %d", signal.Type, len(signal.Messages)) return nil }) // Send interrupt signal signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{ {Role: context.RoleUser, Content: "Test handler invocation"}, }, Timestamp: time.Now().UnixMilli(), } err := context.SendInterrupt(ctx.ID, signal) if err != nil { t.Fatalf("Failed to send interrupt signal: %v", err) } // Wait for handler to be called time.Sleep(200 * time.Millisecond) // Verify handler was called if !handlerCalled { t.Error("Expected handler to be called, but it wasn't") } if receivedSignal == nil { t.Fatal("Expected signal in handler, got nil") } if receivedSignal.Type != context.InterruptGraceful { t.Errorf("Expected graceful interrupt in handler, got: %s", receivedSignal.Type) } if len(receivedSignal.Messages) != 1 { t.Errorf("Expected 1 message in handler, got: %d", len(receivedSignal.Messages)) } t.Log("✓ Interrupt handler invoked successfully") }) t.Run("HandlerWithError", func(t *testing.T) { // Create new context ctx2 := newTestContextWithInterrupt("chat-test-handler-error", "test-assistant") defer ctx2.Release() // Set up handler that returns error handlerCalled := false ctx2.Interrupt.SetHandler(func(c *context.Context, signal *context.InterruptSignal) error { handlerCalled = true return fmt.Errorf("test error from handler") }) // Send interrupt signal signal := &context.InterruptSignal{ Type: context.InterruptForce, Messages: []context.Message{ {Role: context.RoleUser, Content: "Test error handling"}, }, Timestamp: time.Now().UnixMilli(), } err := context.SendInterrupt(ctx2.ID, signal) if err != nil { t.Fatalf("Failed to send interrupt signal: %v", err) } // Wait for handler to be called time.Sleep(200 * time.Millisecond) // Handler should still be called even if it returns error if !handlerCalled { t.Error("Expected handler to be called even with error") } t.Log("✓ Handler error handling works correctly") }) } // TestInterruptContextLifecycle tests context registration and cleanup func TestInterruptContextLifecycle(t *testing.T) { t.Run("RegisterAndRetrieve", func(t *testing.T) { ctx := newTestContextWithInterrupt("chat-test-lifecycle", "test-assistant") // Verify context can be retrieved retrievedCtx, err := context.Get(ctx.ID) if err != nil { t.Fatalf("Failed to retrieve context: %v", err) } if retrievedCtx.ID != ctx.ID { t.Errorf("Expected context ID %s, got: %s", ctx.ID, retrievedCtx.ID) } ctx.Release() // After release, context should be removed _, err = context.Get(ctx.ID) if err == nil { t.Error("Expected error when retrieving released context") } t.Log("✓ Context registration and cleanup works correctly") }) t.Run("SendToNonExistentContext", func(t *testing.T) { signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{{Role: context.RoleUser, Content: "test"}}, Timestamp: time.Now().UnixMilli(), } err := context.SendInterrupt("non-existent-id", signal) if err == nil { t.Error("Expected error when sending to non-existent context") } t.Log("✓ Sending to non-existent context returns error") }) } // TestInterruptCheckMethods tests different check methods func TestInterruptCheckMethods(t *testing.T) { ctx := newTestContextWithInterrupt("chat-test-check-methods", "test-assistant") defer ctx.Release() t.Run("PeekDoesNotRemove", func(t *testing.T) { // Send signal signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{{Role: context.RoleUser, Content: "peek test"}}, Timestamp: time.Now().UnixMilli(), } context.SendInterrupt(ctx.ID, signal) time.Sleep(100 * time.Millisecond) // Peek should return signal but not remove it peeked1 := ctx.Interrupt.Peek() if peeked1 == nil { t.Fatal("Expected signal from first peek") } peeked2 := ctx.Interrupt.Peek() if peeked2 == nil { t.Fatal("Expected signal from second peek") } if peeked1.Messages[0].Content != peeked2.Messages[0].Content { t.Error("Peek should return the same signal") } t.Log("✓ Peek does not remove signal") }) t.Run("CheckRemovesSignal", func(t *testing.T) { ctx.Interrupt.Clear() // Send signal signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{{Role: context.RoleUser, Content: "check test"}}, Timestamp: time.Now().UnixMilli(), } context.SendInterrupt(ctx.ID, signal) time.Sleep(100 * time.Millisecond) // Check should return and remove signal checked := ctx.Interrupt.Check() if checked == nil { t.Fatal("Expected signal from check") } // Second check should return nil checked2 := ctx.Interrupt.Check() if checked2 != nil { t.Error("Expected nil from second check after removal") } t.Log("✓ Check removes signal after retrieval") }) t.Run("CheckWithMergeMultipleSignals", func(t *testing.T) { ctx.Interrupt.Clear() // Send 5 signals with different messages messages := []string{ "First message", "Second message", "Third message", "Fourth message", "Fifth message", } for i, msg := range messages { signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{ {Role: context.RoleUser, Content: msg}, }, Timestamp: time.Now().UnixMilli(), Metadata: map[string]interface{}{ "sequence": i + 1, }, } err := context.SendInterrupt(ctx.ID, signal) if err != nil { t.Fatalf("Failed to send signal %d: %v", i+1, err) } time.Sleep(10 * time.Millisecond) // Small delay between signals } time.Sleep(100 * time.Millisecond) // Verify all signals are queued pendingCount := ctx.Interrupt.GetPendingCount() if pendingCount != 5 { t.Errorf("Expected 5 pending signals, got: %d", pendingCount) } // CheckWithMerge should merge all messages into one signal merged := ctx.Interrupt.CheckWithMerge() if merged == nil { t.Fatal("Expected merged signal, got nil") } // Verify all messages are merged if len(merged.Messages) != 5 { t.Errorf("Expected 5 merged messages, got: %d", len(merged.Messages)) } // Verify message order for i, msg := range messages { if merged.Messages[i].Content != msg { t.Errorf("Message %d mismatch: expected '%s', got '%s'", i+1, msg, merged.Messages[i].Content) } } // Verify metadata indicates merge if merged.Metadata["merged"] != true { t.Error("Expected merged metadata to be true") } if merged.Metadata["merged_count"] != 5 { t.Errorf("Expected merged_count 5, got: %v", merged.Metadata["merged_count"]) } // After merge, queue should be empty if ctx.Interrupt.GetPendingCount() != 0 { t.Errorf("Expected empty queue after merge, got: %d", ctx.Interrupt.GetPendingCount()) } t.Log("✓ CheckWithMerge correctly merged 5 signals into one") }) t.Run("CheckWithMergeSingleSignal", func(t *testing.T) { ctx.Interrupt.Clear() // Send single signal signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{{Role: context.RoleUser, Content: "single signal"}}, Timestamp: time.Now().UnixMilli(), } context.SendInterrupt(ctx.ID, signal) time.Sleep(100 * time.Millisecond) // CheckWithMerge with single signal should return it without merge metadata merged := ctx.Interrupt.CheckWithMerge() if merged == nil { t.Fatal("Expected signal, got nil") } if len(merged.Messages) != 1 { t.Errorf("Expected 1 message, got: %d", len(merged.Messages)) } // Single signal should not have merge metadata if merged.Metadata != nil && merged.Metadata["merged"] == true { t.Error("Single signal should not have merge metadata") } t.Log("✓ CheckWithMerge handles single signal correctly") }) } // TestInterruptContext tests interrupt context methods func TestInterruptContext(t *testing.T) { ctx := newTestContextWithInterrupt("chat-test-interrupt-context", "test-assistant") defer ctx.Release() t.Run("InterruptContextMethod", func(t *testing.T) { // Get interrupt context interruptCtx := ctx.Interrupt.Context() if interruptCtx == nil { t.Fatal("Expected interrupt context, got nil") } // Context should not be done initially select { case <-interruptCtx.Done(): t.Error("Interrupt context should not be done initially") default: t.Log("✓ Interrupt context is not done initially") } }) t.Run("IsInterruptedFalseInitially", func(t *testing.T) { // Should not be interrupted initially if ctx.Interrupt.IsInterrupted() { t.Error("Should not be interrupted initially") } t.Log("✓ IsInterrupted returns false initially") }) t.Run("ForceInterruptCancelsContext", func(t *testing.T) { // Get context before interrupt interruptCtx := ctx.Interrupt.Context() // Send force interrupt with empty messages (pure cancellation) // This is the pattern for stopping streaming without appending messages signal := &context.InterruptSignal{ Type: context.InterruptForce, Messages: []context.Message{}, // Empty messages = pure cancellation Timestamp: time.Now().UnixMilli(), } err := context.SendInterrupt(ctx.ID, signal) if err != nil { t.Fatalf("Failed to send interrupt: %v", err) } time.Sleep(100 * time.Millisecond) // The OLD context should be cancelled select { case <-interruptCtx.Done(): t.Log("✓ Force interrupt with empty messages cancelled the old context") case <-time.After(200 * time.Millisecond): t.Error("Old context was not cancelled after force interrupt with empty messages") } // Note: IsInterrupted() checks the NEW context (which was recreated) // So it will return false. This is expected behavior. // The key is that the old context was cancelled (checked above) t.Log("✓ Context was recreated after force interrupt (expected behavior)") }) t.Run("GracefulInterruptDoesNotCancelContext", func(t *testing.T) { // Create new context for this test ctx2 := newTestContextWithInterrupt("chat-test-graceful-no-cancel", "test-assistant") defer ctx2.Release() interruptCtx := ctx2.Interrupt.Context() // Send graceful interrupt signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{{Role: context.RoleUser, Content: "graceful"}}, Timestamp: time.Now().UnixMilli(), } context.SendInterrupt(ctx2.ID, signal) time.Sleep(100 * time.Millisecond) // Context should NOT be cancelled for graceful interrupt select { case <-interruptCtx.Done(): t.Error("Graceful interrupt should not cancel context") default: t.Log("✓ Graceful interrupt does not cancel context") } // IsInterrupted should still return false for graceful if ctx2.Interrupt.IsInterrupted() { t.Error("IsInterrupted should return false for graceful interrupt") } else { t.Log("✓ IsInterrupted returns false for graceful interrupt") } }) } // TestInterruptSendSignalDirectly tests SendSignal method directly func TestInterruptSendSignalDirectly(t *testing.T) { ctx := newTestContextWithInterrupt("chat-test-send-signal", "test-assistant") defer ctx.Release() t.Run("SendSignalSuccess", func(t *testing.T) { signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{{Role: context.RoleUser, Content: "direct send"}}, Timestamp: time.Now().UnixMilli(), } err := ctx.Interrupt.SendSignal(signal) if err != nil { t.Fatalf("SendSignal failed: %v", err) } time.Sleep(100 * time.Millisecond) // Verify signal was received received := ctx.Interrupt.Peek() if received == nil { t.Fatal("Signal not received") } if received.Messages[0].Content != "direct send" { t.Errorf("Expected 'direct send', got: %s", received.Messages[0].Content) } t.Log("✓ SendSignal directly works") }) t.Run("SendSignalToNilController", func(t *testing.T) { var nilController *context.InterruptController signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{{Role: context.RoleUser, Content: "test"}}, Timestamp: time.Now().UnixMilli(), } err := nilController.SendSignal(signal) if err == nil { t.Error("Expected error when sending to nil controller") } else { t.Logf("✓ Correctly returned error for nil controller: %v", err) } }) t.Run("SendSignalTimeout", func(t *testing.T) { // Create controller but don't start listener testCtrl := context.NewInterruptController() // Don't call Start(), so channel won't be read // Fill the buffer (capacity is 10) for i := 0; i < 10; i++ { signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{{Role: context.RoleUser, Content: fmt.Sprintf("msg %d", i)}}, Timestamp: time.Now().UnixMilli(), } testCtrl.SendSignal(signal) } // This should timeout since buffer is full and no listener signal := &context.InterruptSignal{ Type: context.InterruptGraceful, Messages: []context.Message{{Role: context.RoleUser, Content: "overflow"}}, Timestamp: time.Now().UnixMilli(), } err := testCtrl.SendSignal(signal) if err == nil { t.Error("Expected timeout error when buffer is full") } else { t.Logf("✓ SendSignal correctly times out when buffer full: %v", err) } }) } ================================================ FILE: agent/context/jsapi.go ================================================ package context import ( "time" "github.com/yaoapp/gou/runtime/v8/bridge" "github.com/yaoapp/yao/agent/memory" "github.com/yaoapp/yao/agent/output" "github.com/yaoapp/yao/agent/output/message" traceJsapi "github.com/yaoapp/yao/trace/jsapi" "rogchap.com/v8go" ) // JsValue return the JavaScript value of the context func (ctx *Context) JsValue(v8ctx *v8go.Context) (*v8go.Value, error) { return ctx.NewObject(v8ctx) } // NewObject Create a new JavaScript object from the context func (ctx *Context) NewObject(v8ctx *v8go.Context) (*v8go.Value, error) { jsObject := v8go.NewObjectTemplate(v8ctx.Isolate()) // Set internal field count to 1 to store the goValueID // Internal fields are not accessible from JavaScript, providing better security jsObject.SetInternalFieldCount(1) // Register context in global bridge registry for efficient Go object retrieval // The goValueID will be stored in internal field (index 0) after instance creation goValueID := bridge.RegisterGoObject(ctx) // Set release function (both __release and Release do the same thing) // __release: Internal cleanup (called by GC or Use()) // Release: Public method for manual cleanup (try-finally pattern) releaseFunc := ctx.objectRelease(v8ctx.Isolate(), goValueID) jsObject.Set("__release", releaseFunc) jsObject.Set("Release", releaseFunc) // Set primitive fields in template jsObject.Set("chat_id", ctx.ChatID) jsObject.Set("assistant_id", ctx.AssistantID) jsObject.Set("locale", ctx.Locale) jsObject.Set("theme", ctx.Theme) jsObject.Set("referer", ctx.Referer) jsObject.Set("accept", string(ctx.Accept)) jsObject.Set("route", ctx.Route) // Set methods jsObject.Set("Send", ctx.sendMethod(v8ctx.Isolate())) jsObject.Set("SendStream", ctx.sendStreamMethod(v8ctx.Isolate())) jsObject.Set("Replace", ctx.replaceMethod(v8ctx.Isolate())) jsObject.Set("Append", ctx.appendMethod(v8ctx.Isolate())) jsObject.Set("Merge", ctx.mergeMethod(v8ctx.Isolate())) jsObject.Set("Set", ctx.setMethod(v8ctx.Isolate())) jsObject.Set("End", ctx.endMethod(v8ctx.Isolate())) // Set ID generator methods jsObject.Set("MessageID", ctx.messageIDMethod(v8ctx.Isolate())) jsObject.Set("BlockID", ctx.blockIDMethod(v8ctx.Isolate())) jsObject.Set("ThreadID", ctx.threadIDMethod(v8ctx.Isolate())) // Lifecycle methods jsObject.Set("EndBlock", ctx.endBlockMethod(v8ctx.Isolate())) // Set mcp object jsObject.Set("mcp", ctx.newMCPObject(v8ctx.Isolate())) // Set search object jsObject.Set("search", ctx.newSearchObject(v8ctx.Isolate())) // Set agent object for calling other agents jsObject.Set("agent", ctx.newAgentObject(v8ctx.Isolate())) // Set llm object for direct LLM calls jsObject.Set("llm", ctx.newLlmObject(v8ctx.Isolate())) // Note: Space object will be set after instance creation (requires v8ctx) // Create instance instance, err := jsObject.NewInstance(v8ctx) if err != nil { // Clean up: release from global registry if instance creation failed bridge.ReleaseGoObject(goValueID) return nil, err } // Store the goValueID in internal field (index 0) // This is not accessible from JavaScript, providing better security obj, err := instance.Value.AsObject() if err != nil { bridge.ReleaseGoObject(goValueID) return nil, err } err = obj.SetInternalField(0, goValueID) if err != nil { bridge.ReleaseGoObject(goValueID) return nil, err } // Set trace object (property, not method) // If trace is not initialized, use no-op object traceObj := ctx.createTraceObject(v8ctx) if traceObj != nil { obj.Set("trace", traceObj) } // Set complex objects (maps, arrays) after instance creation using bridge // Client object clientData := map[string]interface{}{ "type": ctx.Client.Type, "user_agent": ctx.Client.UserAgent, "ip": ctx.Client.IP, } clientVal, err := bridge.JsValue(v8ctx, clientData) if err == nil { obj.Set("client", clientVal) clientVal.Release() // Release Go-side Persistent handle, V8 internal reference remains } // Metadata object - always set to empty map if nil metadataData := ctx.Metadata if metadataData == nil { metadataData = map[string]interface{}{} } metadataVal, err := bridge.JsValue(v8ctx, metadataData) if err == nil { obj.Set("metadata", metadataVal) metadataVal.Release() // Release Go-side Persistent handle, V8 internal reference remains } // Authorized object - pass the complete structure if ctx.Authorized != nil { authorizedVal, err := bridge.JsValue(v8ctx, ctx.Authorized) if err == nil { obj.Set("authorized", authorizedVal) authorizedVal.Release() // Release Go-side Persistent handle, V8 internal reference remains } } else { // Set to empty object when nil emptyObj, err := bridge.JsValue(v8ctx, map[string]interface{}{}) if err == nil { obj.Set("authorized", emptyObj) emptyObj.Release() } } // Memory object - create a JavaScript object with User/Team/Chat/Context namespaces if ctx.Memory != nil { memoryObj := ctx.createMemoryObject(v8ctx) obj.Set("memory", memoryObj) memoryObj.Release() } // Sandbox object - only set if sandbox executor is available (V1) if ctx.sandboxExecutor != nil { sandboxObj := ctx.createSandboxInstance(v8ctx) if sandboxObj != nil { obj.Set("sandbox", sandboxObj) sandboxObj.Release() } } // Computer object - only set if V2 computer is available if ctx.computer != nil { computerObj := ctx.createComputerInstance(v8ctx) if computerObj != nil { obj.Set("computer", computerObj) computerObj.Release() } } // Workspace object - only set if V2 workspace is available if ctx.workspace != nil { wsObj := ctx.createWorkspaceInstance(v8ctx) if wsObj != nil { obj.Set("workspace", wsObj) wsObj.Release() } } return instance.Value, nil } // objectRelease releases the Go object from the global bridge registry // It retrieves the goValueID from internal field (index 0) and releases the Go object // Also releases associated Trace object if present func (ctx *Context) objectRelease(iso *v8go.Isolate, goValueID string) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { // Get the context object (this) thisObj, err := info.This().AsObject() if err == nil { // NOTE: We do NOT automatically release Trace object here // // Rationale: // 1. Each Hook execution creates a new V8 script context (scriptCtx) // 2. The agent Context (ctx) is passed to the Hook as a parameter // 3. When scriptCtx.Close() is called (via defer), V8 cleanup triggers ctx.__release() // 4. If we release Trace here, it gets released after EVERY Hook execution // 5. This causes "context canceled" errors in subsequent operations // // Trace lifecycle: // - Trace is created when agent.Stream() starts (in Context.Trace()) // - Trace should persist across ALL Hook executions (Create, Next, Done) // - Trace is released when agent Context.Release() is called (after agent.Stream() completes) // // Memory management: // - If JS code explicitly calls trace.Release(), it will work (trace/jsapi/trace.go:traceGoRelease) // - If not explicitly called, Context.Release() will clean it up (context/context.go:Release) // - This is the correct lifecycle: one Context -> one Trace -> multiple Hook executions // Release Context Go object from bridge registry if thisObj.InternalFieldCount() > 0 { // Get goValueID from internal field (index 0) goValueID := thisObj.GetInternalField(0) if goValueID != nil && goValueID.IsString() { // Release from global bridge registry bridge.ReleaseGoObject(goValueID.String()) } } } return v8go.Undefined(info.Context().Isolate()) }) } // createTraceObject creates a Trace object instance // Returns a no-op Trace object if trace is not initialized func (ctx *Context) createTraceObject(v8ctx *v8go.Context) *v8go.Value { // Try to get trace manager manager, err := ctx.Trace() if err != nil || manager == nil { // Return no-op trace object if initialization fails noOpTrace, _ := traceJsapi.NewNoOpTraceObject(v8ctx) return noOpTrace } // Get trace ID traceID := "" if ctx.Stack != nil { traceID = ctx.Stack.TraceID } // Create JavaScript Trace object traceObj, err := traceJsapi.NewTraceObject(v8ctx, traceID, manager) if err != nil { // Return no-op trace object if creation fails noOpTrace, _ := traceJsapi.NewNoOpTraceObject(v8ctx) return noOpTrace } return traceObj } // sendMethod implements ctx.Send(message, blockId?) // Usage: const messageId = ctx.Send({ type: "text", props: { content: "Hello" } }) // Usage: const messageId = ctx.Send("Hello") // shorthand for text message // Usage: const messageId = ctx.Send("Hello", "B1") // specify block ID // Automatically generates MessageID and BlockID (if not specified), flushes output // Returns: message_id (string) func (ctx *Context) sendMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 1 { return bridge.JsException(v8ctx, "Send requires a message argument") } // Parse message argument msg, err := parseMessage(v8ctx, args[0]) if err != nil { return bridge.JsException(v8ctx, "invalid message: "+err.Error()) } // Get optional blockId argument (second argument) // Note: message object's block_id has higher priority if len(args) >= 2 && args[1].IsString() && msg.BlockID == "" { msg.BlockID = args[1].String() } // Generate unique MessageID if not provided if msg.MessageID == "" { if ctx.IDGenerator != nil { msg.MessageID = ctx.IDGenerator.GenerateMessageID() } else { msg.MessageID = output.GenerateID() } } // Call ctx.Send (will auto-generate BlockID if still empty) if err := ctx.Send(msg); err != nil { return bridge.JsException(v8ctx, "Send failed: "+err.Error()) } // Automatically flush after sending if err := ctx.Flush(); err != nil { return bridge.JsException(v8ctx, "Flush failed: "+err.Error()) } // Return the message ID messageID, err := v8go.NewValue(iso, msg.MessageID) if err != nil { return bridge.JsException(v8ctx, "Failed to create return value: "+err.Error()) } return messageID }) } // sendStreamMethod implements ctx.SendStream(message) // Usage: const msgId = ctx.SendStream({ type: "text", props: { content: "Initial content" } }) // Starts a streaming message that can be appended to with ctx.Append() // Must be finalized with ctx.End(msgId) or ctx.End(msgId, "final content") // Unlike Send(), this does NOT automatically send message_end event // Returns: message_id (string) func (ctx *Context) sendStreamMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 1 { return bridge.JsException(v8ctx, "SendStream requires a message argument") } // Parse message argument msg, err := parseMessage(v8ctx, args[0]) if err != nil { return bridge.JsException(v8ctx, "invalid message: "+err.Error()) } // Get optional blockId argument (second argument) if len(args) >= 2 && args[1].IsString() && msg.BlockID == "" { msg.BlockID = args[1].String() } // Call ctx.SendStream messageID, err := ctx.SendStream(msg) if err != nil { return bridge.JsException(v8ctx, "SendStream failed: "+err.Error()) } // Automatically flush after sending if err := ctx.Flush(); err != nil { return bridge.JsException(v8ctx, "Flush failed: "+err.Error()) } // Return the message ID returnID, err := v8go.NewValue(iso, messageID) if err != nil { return bridge.JsException(v8ctx, "Failed to create return value: "+err.Error()) } return returnID }) } // endMethod implements ctx.End(messageId, finalContent?) // Usage: ctx.End(msgId) or ctx.End(msgId, "final content to append") // Finalizes a streaming message started with SendStream() // Sends message_end event with the complete accumulated content // Returns: message_id (string) func (ctx *Context) endMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 1 { return bridge.JsException(v8ctx, "End requires a messageId argument") } // Get message ID (first argument) if !args[0].IsString() { return bridge.JsException(v8ctx, "messageId must be a string") } messageID := args[0].String() // Get optional final content (second argument) var finalContent string if len(args) >= 2 && args[1].IsString() { finalContent = args[1].String() } // Call ctx.End var err error if finalContent != "" { err = ctx.End(messageID, finalContent) } else { err = ctx.End(messageID) } if err != nil { return bridge.JsException(v8ctx, "End failed: "+err.Error()) } // Automatically flush after sending if err := ctx.Flush(); err != nil { return bridge.JsException(v8ctx, "Flush failed: "+err.Error()) } // Return the message ID returnID, err := v8go.NewValue(iso, messageID) if err != nil { return bridge.JsException(v8ctx, "Failed to create return value: "+err.Error()) } return returnID }) } // replaceMethod implements ctx.Replace(messageId, message) // Usage: ctx.Replace(messageId, { type: "text", props: { content: "Updated content" } }) // Replaces the entire message content with the specified message_id // Automatically flushes output // Returns: message_id (string) func (ctx *Context) replaceMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 2 { return bridge.JsException(v8ctx, "Replace requires messageId and message arguments") } // Get message ID (first argument) if !args[0].IsString() { return bridge.JsException(v8ctx, "messageId must be a string") } messageID := args[0].String() // Parse message argument (second argument) msg, err := parseMessage(v8ctx, args[1]) if err != nil { return bridge.JsException(v8ctx, "invalid message: "+err.Error()) } // Set message ID to the provided ID msg.MessageID = messageID // Set delta mode for replacement msg.Delta = true msg.DeltaAction = message.DeltaReplace msg.DeltaPath = "" // Empty path means replace entire message // Call ctx.Send if err := ctx.Send(msg); err != nil { return bridge.JsException(v8ctx, "Replace failed: "+err.Error()) } // Automatically flush after sending if err := ctx.Flush(); err != nil { return bridge.JsException(v8ctx, "Flush failed: "+err.Error()) } // Return the message ID returnID, err := v8go.NewValue(iso, messageID) if err != nil { return bridge.JsException(v8ctx, "Failed to create return value: "+err.Error()) } return returnID }) } // appendMethod implements ctx.Append(messageId, content, path?) // Usage: ctx.Append(messageId, "more text") // append to default content path // Usage: ctx.Append(messageId, "more text", "props.content") // append to specific path // Usage: ctx.Append(messageId, { type: "text", props: { content: "more text" } }) // Usage: ctx.Append(messageId, { props: { content: "more text" } }, "props.data") // append to custom path // Appends content to an existing message (delta append operation) // Automatically flushes output // Returns: message_id (string) func (ctx *Context) appendMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 2 { return bridge.JsException(v8ctx, "Append requires messageId and content arguments") } // Get message ID (first argument) if !args[0].IsString() { return bridge.JsException(v8ctx, "messageId must be a string") } messageID := args[0].String() // Parse content argument (second argument) msg, err := parseMessage(v8ctx, args[1]) if err != nil { return bridge.JsException(v8ctx, "invalid content: "+err.Error()) } // Get optional path argument (third argument) deltaPath := "" if len(args) >= 3 && args[2].IsString() { deltaPath = args[2].String() } // Set message ID to the provided ID msg.MessageID = messageID // Set delta mode for append msg.Delta = true msg.DeltaAction = message.DeltaAppend msg.DeltaPath = deltaPath // Empty path means append to default content, or specify custom path // Call ctx.Send if err := ctx.Send(msg); err != nil { return bridge.JsException(v8ctx, "Append failed: "+err.Error()) } // Automatically flush after sending if err := ctx.Flush(); err != nil { return bridge.JsException(v8ctx, "Flush failed: "+err.Error()) } // Return the message ID returnID, err := v8go.NewValue(iso, messageID) if err != nil { return bridge.JsException(v8ctx, "Failed to create return value: "+err.Error()) } return returnID }) } // mergeMethod implements ctx.Merge(messageId, data, path?) // Usage: ctx.Merge(messageId, { key: "value" }) // merge to default object path // Usage: ctx.Merge(messageId, { status: "done" }, "props") // merge to specific path // Usage: ctx.Merge(messageId, { props: { status: "done", progress: 100 } }) // Merges data into an existing message object (delta merge operation) // Automatically flushes output // Returns: message_id (string) func (ctx *Context) mergeMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 2 { return bridge.JsException(v8ctx, "Merge requires messageId and data arguments") } // Get message ID (first argument) if !args[0].IsString() { return bridge.JsException(v8ctx, "messageId must be a string") } messageID := args[0].String() // Parse data argument (second argument) msg, err := parseMessage(v8ctx, args[1]) if err != nil { return bridge.JsException(v8ctx, "invalid data: "+err.Error()) } // Get optional path argument (third argument) deltaPath := "" if len(args) >= 3 && args[2].IsString() { deltaPath = args[2].String() } // Set message ID to the provided ID msg.MessageID = messageID // Set delta mode for merge msg.Delta = true msg.DeltaAction = message.DeltaMerge msg.DeltaPath = deltaPath // Empty path means merge to default object, or specify custom path // Call ctx.Send if err := ctx.Send(msg); err != nil { return bridge.JsException(v8ctx, "Merge failed: "+err.Error()) } // Automatically flush after sending if err := ctx.Flush(); err != nil { return bridge.JsException(v8ctx, "Flush failed: "+err.Error()) } // Return the message ID returnID, err := v8go.NewValue(iso, messageID) if err != nil { return bridge.JsException(v8ctx, "Failed to create return value: "+err.Error()) } return returnID }) } // setMethod implements ctx.Set(messageId, data, path) // Usage: ctx.Set(messageId, "value", "props.newField") // set new field at specific path // Usage: ctx.Set(messageId, { newKey: "value" }, "props") // set new fields in props // Sets a new field or value in an existing message (delta set operation) // Automatically flushes output // Returns: message_id (string) func (ctx *Context) setMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments (path is required for Set operation) if len(args) < 3 { return bridge.JsException(v8ctx, "Set requires messageId, data, and path arguments") } // Get message ID (first argument) if !args[0].IsString() { return bridge.JsException(v8ctx, "messageId must be a string") } messageID := args[0].String() // Parse data argument (second argument) msg, err := parseMessage(v8ctx, args[1]) if err != nil { return bridge.JsException(v8ctx, "invalid data: "+err.Error()) } // Get path argument (third argument - required) if !args[2].IsString() { return bridge.JsException(v8ctx, "path must be a string") } deltaPath := args[2].String() if deltaPath == "" { return bridge.JsException(v8ctx, "path cannot be empty for Set operation") } // Set message ID to the provided ID msg.MessageID = messageID // Set delta mode for set msg.Delta = true msg.DeltaAction = message.DeltaSet msg.DeltaPath = deltaPath // Path is required for Set operation // Call ctx.Send if err := ctx.Send(msg); err != nil { return bridge.JsException(v8ctx, "Set failed: "+err.Error()) } // Automatically flush after sending if err := ctx.Flush(); err != nil { return bridge.JsException(v8ctx, "Flush failed: "+err.Error()) } // Return the message ID returnID, err := v8go.NewValue(iso, messageID) if err != nil { return bridge.JsException(v8ctx, "Failed to create return value: "+err.Error()) } return returnID }) } // messageIDMethod implements ctx.MessageID() // Usage: const msgId = ctx.MessageID() // Returns: "M1", "M2", "M3"... // Generates a unique message ID for manual message management // Returns: message_id (string) func (ctx *Context) messageIDMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() var messageID string if ctx.IDGenerator != nil { messageID = ctx.IDGenerator.GenerateMessageID() } else { messageID = output.GenerateID() } // Return the generated ID id, err := v8go.NewValue(iso, messageID) if err != nil { return bridge.JsException(v8ctx, "Failed to generate message ID: "+err.Error()) } return id }) } // blockIDMethod implements ctx.BlockID() // Usage: const blockId = ctx.BlockID() // Returns: "B1", "B2", "B3"... // Generates a unique block ID for grouping messages // Returns: block_id (string) func (ctx *Context) blockIDMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() var blockID string if ctx.IDGenerator != nil { blockID = ctx.IDGenerator.GenerateBlockID() } else { blockID = output.GenerateID() } // Return the generated ID id, err := v8go.NewValue(iso, blockID) if err != nil { return bridge.JsException(v8ctx, "Failed to generate block ID: "+err.Error()) } return id }) } // threadIDMethod implements ctx.ThreadID() // Usage: const threadId = ctx.ThreadID() // Returns: "T1", "T2", "T3"... // Generates a unique thread ID for concurrent operations // Returns: thread_id (string) func (ctx *Context) threadIDMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() var threadID string if ctx.IDGenerator != nil { threadID = ctx.IDGenerator.GenerateThreadID() } else { threadID = output.GenerateID() } // Return the generated ID id, err := v8go.NewValue(iso, threadID) if err != nil { return bridge.JsException(v8ctx, "Failed to generate thread ID: "+err.Error()) } return id }) } // endBlockMethod implements ctx.EndBlock(block_id) // Usage: ctx.EndBlock("B1") // Sends a block_end event for the specified block // Returns: undefined func (ctx *Context) endBlockMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 1 { return bridge.JsException(v8ctx, "EndBlock requires block_id argument") } if !args[0].IsString() { return bridge.JsException(v8ctx, "block_id must be a string") } blockID := args[0].String() // Call ctx.EndBlock if err := ctx.EndBlock(blockID); err != nil { return bridge.JsException(v8ctx, "EndBlock failed: "+err.Error()) } // Automatically flush after ending block if err := ctx.Flush(); err != nil { return bridge.JsException(v8ctx, "Flush failed: "+err.Error()) } return v8go.Undefined(iso) }) } // createMemoryObject creates a Memory object for JavaScript access // Memory provides four namespaces: User, Team, Chat, Context // Each namespace supports: Get, Set, Del, Has, Keys, Len, Clear, Incr, Decr func (ctx *Context) createMemoryObject(v8ctx *v8go.Context) *v8go.Value { iso := v8ctx.Isolate() objTpl := v8go.NewObjectTemplate(iso) obj, _ := objTpl.NewInstance(v8ctx) // Create namespace accessors if ctx.Memory.User != nil { userObj := ctx.createNamespaceObject(v8ctx, ctx.Memory.User) obj.Set("user", userObj) userObj.Release() } if ctx.Memory.Team != nil { teamObj := ctx.createNamespaceObject(v8ctx, ctx.Memory.Team) obj.Set("team", teamObj) teamObj.Release() } if ctx.Memory.Chat != nil { chatObj := ctx.createNamespaceObject(v8ctx, ctx.Memory.Chat) obj.Set("chat", chatObj) chatObj.Release() } if ctx.Memory.Context != nil { contextObj := ctx.createNamespaceObject(v8ctx, ctx.Memory.Context) obj.Set("context", contextObj) contextObj.Release() } return obj.Value } // createNamespaceObject creates a namespace object with KV store methods func (ctx *Context) createNamespaceObject(v8ctx *v8go.Context, ns *memory.Namespace) *v8go.Value { iso := v8ctx.Isolate() objTpl := v8go.NewObjectTemplate(iso) obj, _ := objTpl.NewInstance(v8ctx) // Get method: ns.Get(key) getFunc := v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { if len(info.Args()) < 1 { return bridge.JsException(info.Context(), "Get requires a key argument") } key := info.Args()[0].String() value, ok := ns.Get(key) if !ok { return v8go.Null(iso) } jsValue, err := bridge.JsValue(info.Context(), value) if err != nil { return v8go.Null(iso) } return jsValue }) getFuncVal := getFunc.GetFunction(v8ctx) obj.Set("Get", getFuncVal.Value) // Set method: ns.Set(key, value, ttl?) setFunc := v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { if len(info.Args()) < 2 { return bridge.JsException(info.Context(), "Set requires key and value arguments") } key := info.Args()[0].String() value, err := bridge.GoValue(info.Args()[1], info.Context()) if err != nil { return bridge.JsException(info.Context(), "Failed to convert value: "+err.Error()) } // Optional TTL in milliseconds (third argument) var ttl time.Duration if len(info.Args()) >= 3 && info.Args()[2].IsNumber() { ttlMs := info.Args()[2].Integer() ttl = time.Duration(ttlMs) * time.Millisecond } if err := ns.Set(key, value, ttl); err != nil { return bridge.JsException(info.Context(), "Failed to set value: "+err.Error()) } return v8go.Undefined(iso) }) setFuncVal := setFunc.GetFunction(v8ctx) obj.Set("Set", setFuncVal.Value) // Del method: ns.Del(key) delFunc := v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { if len(info.Args()) < 1 { return bridge.JsException(info.Context(), "Del requires a key argument") } key := info.Args()[0].String() if err := ns.Del(key); err != nil { return bridge.JsException(info.Context(), "Failed to delete key: "+err.Error()) } return v8go.Undefined(iso) }) delFuncVal := delFunc.GetFunction(v8ctx) obj.Set("Del", delFuncVal.Value) // Has method: ns.Has(key) hasFunc := v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { if len(info.Args()) < 1 { return bridge.JsException(info.Context(), "Has requires a key argument") } key := info.Args()[0].String() exists := ns.Has(key) jsValue, _ := v8go.NewValue(iso, exists) return jsValue }) hasFuncVal := hasFunc.GetFunction(v8ctx) obj.Set("Has", hasFuncVal.Value) // Keys method: ns.Keys() keysFunc := v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { keys := ns.Keys() jsValue, err := bridge.JsValue(info.Context(), keys) if err != nil { return bridge.JsException(info.Context(), "Failed to get keys: "+err.Error()) } return jsValue }) keysFuncVal := keysFunc.GetFunction(v8ctx) obj.Set("Keys", keysFuncVal.Value) // Len method: ns.Len() lenFunc := v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { length := ns.Len() // Use int32 for JavaScript Number (int64 becomes BigInt which is incompatible) jsValue, _ := v8go.NewValue(iso, int32(length)) return jsValue }) lenFuncVal := lenFunc.GetFunction(v8ctx) obj.Set("Len", lenFuncVal.Value) // Clear method: ns.Clear() clearFunc := v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { ns.Clear() return v8go.Undefined(iso) }) clearFuncVal := clearFunc.GetFunction(v8ctx) obj.Set("Clear", clearFuncVal.Value) // Incr method: ns.Incr(key, delta?) incrFunc := v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { if len(info.Args()) < 1 { return bridge.JsException(info.Context(), "Incr requires a key argument") } key := info.Args()[0].String() delta := int64(1) if len(info.Args()) >= 2 && info.Args()[1].IsNumber() { delta = info.Args()[1].Integer() } newValue, err := ns.Incr(key, delta) if err != nil { return bridge.JsException(info.Context(), "Failed to increment: "+err.Error()) } // Use int32 for JavaScript Number (int64 becomes BigInt which is incompatible with ===) // For counters, int32 range (-2^31 to 2^31-1) is sufficient jsValue, _ := v8go.NewValue(iso, int32(newValue)) return jsValue }) incrFuncVal := incrFunc.GetFunction(v8ctx) obj.Set("Incr", incrFuncVal.Value) // Decr method: ns.Decr(key, delta?) decrFunc := v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { if len(info.Args()) < 1 { return bridge.JsException(info.Context(), "Decr requires a key argument") } key := info.Args()[0].String() delta := int64(1) if len(info.Args()) >= 2 && info.Args()[1].IsNumber() { delta = info.Args()[1].Integer() } newValue, err := ns.Decr(key, delta) if err != nil { return bridge.JsException(info.Context(), "Failed to decrement: "+err.Error()) } // Use int32 for JavaScript Number (int64 becomes BigInt which is incompatible with ===) jsValue, _ := v8go.NewValue(iso, int32(newValue)) return jsValue }) decrFuncVal := decrFunc.GetFunction(v8ctx) obj.Set("Decr", decrFuncVal.Value) // GetDel method: ns.GetDel(key) - Get value and delete immediately getDelFunc := v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { if len(info.Args()) < 1 { return bridge.JsException(info.Context(), "GetDel requires a key argument") } key := info.Args()[0].String() value, ok := ns.GetDel(key) if !ok { return v8go.Null(iso) } jsValue, err := bridge.JsValue(info.Context(), value) if err != nil { return v8go.Null(iso) } return jsValue }) getDelFuncVal := getDelFunc.GetFunction(v8ctx) obj.Set("GetDel", getDelFuncVal.Value) return obj.Value } // sendGroupMethod implements ctx.SendGroup(group) // Usage: ctx.SendGroup({ id: "group1", messages: [...] }) // Automatically generates IDs, sends group_start/group_end events, and flushes output ================================================ FILE: agent/context/jsapi_agent.go ================================================ package context import ( "github.com/yaoapp/gou/runtime/v8/bridge" "github.com/yaoapp/yao/agent/output/message" "rogchap.com/v8go" ) // AgentAPI defines the agent JSAPI interface for ctx.agent.* // This interface is defined here to avoid circular dependency between context and caller packages. // The actual implementation is in agent/caller/jsapi.go type AgentAPI interface { // Call executes a single agent call // Returns *caller.Result or error information Call(agentID string, messages []interface{}, opts map[string]interface{}) interface{} // Parallel agent call methods - inspired by JavaScript Promise // All waits for all agent calls to complete (like Promise.all) All(requests []interface{}) []interface{} // Any returns when any agent call succeeds (like Promise.any) Any(requests []interface{}) []interface{} // Race returns when any agent call completes (like Promise.race) Race(requests []interface{}) []interface{} } // AgentAPIWithCallback extends AgentAPI with callback support // This interface provides methods that accept OnMessage handlers for real-time message processing type AgentAPIWithCallback interface { AgentAPI // CallWithHandler executes a single agent call with an OnMessage handler // handler receives SSE messages: func(msg *message.Message) int CallWithHandler(agentID string, messages []interface{}, opts map[string]interface{}, handler OnMessageFunc) interface{} // AllWithHandler executes all agent calls with handlers // globalHandler receives messages with agentID and index: func(agentID, index, msg) int // Individual request handlers (if set) take precedence over globalHandler AllWithHandler(requests []interface{}, globalHandler BatchOnMessageFunc) []interface{} // AnyWithHandler executes agent calls and returns on first success, with handlers AnyWithHandler(requests []interface{}, globalHandler BatchOnMessageFunc) []interface{} // RaceWithHandler executes agent calls and returns on first completion, with handlers RaceWithHandler(requests []interface{}, globalHandler BatchOnMessageFunc) []interface{} } // BatchOnMessageFunc is the OnMessage function for batch calls // It includes agentID and index to identify the source of each message type BatchOnMessageFunc func(agentID string, index int, msg *message.Message) int // AgentAPIFactory is a function type that creates an AgentAPI for a context // This is set by the caller package during initialization var AgentAPIFactory func(ctx *Context) AgentAPI // Agent returns the agent API for this context // Returns nil if AgentAPIFactory is not set func (ctx *Context) Agent() AgentAPI { if AgentAPIFactory == nil { return nil } return AgentAPIFactory(ctx) } // newAgentObject creates a new agent object with all agent methods // This is called from jsapi.go NewObject() to mount ctx.agent func (ctx *Context) newAgentObject(iso *v8go.Isolate) *v8go.ObjectTemplate { agentObj := v8go.NewObjectTemplate(iso) // Single agent call method agentObj.Set("Call", ctx.agentCallMethod(iso)) // Parallel agent call methods - inspired by JavaScript Promise agentObj.Set("All", ctx.agentAllMethod(iso)) agentObj.Set("Any", ctx.agentAnyMethod(iso)) agentObj.Set("Race", ctx.agentRaceMethod(iso)) return agentObj } // agentCallMethod implements ctx.agent.Call(agentID, messages, options?) // Usage: const result = ctx.agent.Call("assistant-id", [{ role: "user", content: "Hello" }], { connector: "gpt4", onChunk: (type, data) => 0 }) // Returns: { agent_id, response, content, error } func (ctx *Context) agentCallMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 2 { return bridge.JsException(v8ctx, "Call requires agentID and messages parameters") } // Get agent ID (first argument) if !args[0].IsString() { return bridge.JsException(v8ctx, "agentID must be a string") } agentID := args[0].String() // Parse messages (second argument) messagesVal, err := bridge.GoValue(args[1], v8ctx) if err != nil { return bridge.JsException(v8ctx, "invalid messages: "+err.Error()) } messages, ok := messagesVal.([]interface{}) if !ok { return bridge.JsException(v8ctx, "messages must be an array") } // Parse options (optional third argument) - extract onChunk separately var opts map[string]interface{} var onChunkFn *v8go.Function if len(args) >= 3 && !args[2].IsUndefined() && !args[2].IsNull() { optsObj, err := args[2].AsObject() if err == nil && optsObj != nil { // Extract onChunk callback before converting to Go value onChunkVal, _ := optsObj.Get("onChunk") if onChunkVal != nil && onChunkVal.IsFunction() { onChunkFn, _ = onChunkVal.AsFunction() } // Convert the rest of options to Go map goVal, err := bridge.GoValue(args[2], v8ctx) if err == nil { if optsMap, ok := goVal.(map[string]interface{}); ok { // Remove onChunk from the map (it's handled separately) delete(optsMap, "onChunk") opts = optsMap } } } } // Get agent API agentAPI := ctx.Agent() if agentAPI == nil { return bridge.JsException(v8ctx, "agent API not available") } var result interface{} // If onChunk callback is provided and API supports it, use CallWithHandler if onChunkFn != nil { if apiWithCb, ok := agentAPI.(AgentAPIWithCallback); ok { // Create Go StreamFunc that calls JS callback handler := createJSStreamHandler(v8ctx, onChunkFn) result = apiWithCb.CallWithHandler(agentID, messages, opts, handler) } else { // Fallback: ignore callback if API doesn't support it result = agentAPI.Call(agentID, messages, opts) } } else { // No callback, use regular Call result = agentAPI.Call(agentID, messages, opts) } // Convert result to JS value jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, "failed to convert result: "+err.Error()) } return jsVal }) } // createJSOnMessageHandler creates a Go OnMessageFunc that calls a JS callback // JS callback signature: (msg: object) => number // msg contains: type, props, delta, message_id, chunk_id, etc. func createJSStreamHandler(v8ctx *v8go.Context, callback *v8go.Function) OnMessageFunc { return func(msg *message.Message) int { if callback == nil || v8ctx == nil || msg == nil { return 0 // Continue if no callback } // Convert message to JS value jsMsg, err := bridge.JsValue(v8ctx, msg) if err != nil { return 1 // Stop on error } // Call the JS callback with the message object result, err := callback.Call(v8ctx.Global(), jsMsg) if err != nil { return 1 // Stop on error } // Check return value (0 = continue, non-zero = stop) if result != nil && result.IsNumber() { ret := result.Integer() if ret != 0 { return int(ret) } } return 0 // Continue } } // agentAllMethod implements ctx.agent.All(requests, options?) // Waits for all agent calls to complete (like Promise.all) // Each request should have: // - agent: string - target agent ID // - messages: array - messages to send // - options?: object - call options // // Global options (second argument): // - onChunk?: (agentID, index, msg) => number - callback for all messages (uses channel for V8 safety) func (ctx *Context) agentAllMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 1 { return bridge.JsException(v8ctx, "All requires requests parameter") } // Parse requests and extract global callback requests, globalCallback := ctx.parseRequestsForBatch(args, v8ctx) // Get agent API agentAPI := ctx.Agent() if agentAPI == nil { return bridge.JsException(v8ctx, "agent API not available") } // Execute with channel-based callback handling results := ctx.executeBatchWithCallback(BatchMethodAll, requests, globalCallback, v8ctx) // Convert results to JS value jsVal, err := bridge.JsValue(v8ctx, results) if err != nil { return bridge.JsException(v8ctx, "failed to convert results: "+err.Error()) } return jsVal }) } // agentAnyMethod implements ctx.agent.Any(requests, options?) // Returns when any agent call succeeds (like Promise.any) // Each request should have: // - agent: string - target agent ID // - messages: array - messages to send // - options?: object - call options // // Global options (second argument): // - onChunk?: (agentID, index, msg) => number - callback for all messages (uses channel for V8 safety) func (ctx *Context) agentAnyMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 1 { return bridge.JsException(v8ctx, "Any requires requests parameter") } // Parse requests and extract global callback requests, globalCallback := ctx.parseRequestsForBatch(args, v8ctx) // Get agent API agentAPI := ctx.Agent() if agentAPI == nil { return bridge.JsException(v8ctx, "agent API not available") } // Execute with channel-based callback handling results := ctx.executeBatchWithCallback(BatchMethodAny, requests, globalCallback, v8ctx) // Convert results to JS value jsVal, err := bridge.JsValue(v8ctx, results) if err != nil { return bridge.JsException(v8ctx, "failed to convert results: "+err.Error()) } return jsVal }) } // agentRaceMethod implements ctx.agent.Race(requests, options?) // Returns when any agent call completes (like Promise.race) // Each request should have: // - agent: string - target agent ID // - messages: array - messages to send // - options?: object - call options // // Global options (second argument): // - onChunk?: (agentID, index, msg) => number - callback for all messages (uses channel for V8 safety) func (ctx *Context) agentRaceMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 1 { return bridge.JsException(v8ctx, "Race requires requests parameter") } // Parse requests and extract global callback requests, globalCallback := ctx.parseRequestsForBatch(args, v8ctx) // Get agent API agentAPI := ctx.Agent() if agentAPI == nil { return bridge.JsException(v8ctx, "agent API not available") } // Execute with channel-based callback handling results := ctx.executeBatchWithCallback(BatchMethodRace, requests, globalCallback, v8ctx) // Convert results to JS value jsVal, err := bridge.JsValue(v8ctx, results) if err != nil { return bridge.JsException(v8ctx, "failed to convert results: "+err.Error()) } return jsVal }) } // batchMessage represents a message from a batch call for channel-based callback handling type batchMessage struct { AgentID string // Agent ID that generated this message Index int // Index of the request in the batch Message *message.Message // The message object } // parseRequestsForBatch parses the requests array and extracts global callback for batch calls // Returns the requests array and the global JS callback function (if any) func (ctx *Context) parseRequestsForBatch(args []*v8go.Value, v8ctx *v8go.Context) ([]interface{}, *v8go.Function) { var globalCallback *v8go.Function // Parse global options (second argument) for global onChunk if len(args) >= 2 && !args[1].IsUndefined() && !args[1].IsNull() { globalOptsObj, err := args[1].AsObject() if err == nil && globalOptsObj != nil { onChunkVal, _ := globalOptsObj.Get("onChunk") if onChunkVal != nil && onChunkVal.IsFunction() { globalCallback, _ = onChunkVal.AsFunction() } } } // Parse requests array if len(args) < 1 || args[0].IsUndefined() || args[0].IsNull() { return []interface{}{}, globalCallback } requestsObj, err := args[0].AsObject() if err != nil { return []interface{}{}, globalCallback } // Get array length lengthVal, err := requestsObj.Get("length") if err != nil { return []interface{}{}, globalCallback } length := int(lengthVal.Integer()) requests := make([]interface{}, 0, length) for i := 0; i < length; i++ { itemVal, err := requestsObj.GetIdx(uint32(i)) if err != nil || itemVal.IsUndefined() || itemVal.IsNull() { continue } // Convert to Go map goVal, err := bridge.GoValue(itemVal, v8ctx) if err != nil { continue } reqMap, ok := goVal.(map[string]interface{}) if !ok { continue } // Remove onChunk from per-request options (only global callback is supported) if opts, ok := reqMap["options"].(map[string]interface{}); ok { delete(opts, "onChunk") } requests = append(requests, reqMap) } return requests, globalCallback } // BatchMethod represents the type of batch operation type BatchMethod int const ( BatchMethodAll BatchMethod = iota BatchMethodAny BatchMethodRace ) // executeBatchWithCallback executes a batch operation with channel-based callback handling // This ensures V8 thread safety by processing all callbacks in the main goroutine func (ctx *Context) executeBatchWithCallback( method BatchMethod, requests []interface{}, callback *v8go.Function, v8ctx *v8go.Context, ) []interface{} { // Get agent API agentAPI := ctx.Agent() if agentAPI == nil { return []interface{}{} } // If no callback, just execute directly if callback == nil { switch method { case BatchMethodAll: return agentAPI.All(requests) case BatchMethodAny: return agentAPI.Any(requests) case BatchMethodRace: return agentAPI.Race(requests) } return []interface{}{} } // Check if API supports callbacks apiWithCb, ok := agentAPI.(AgentAPIWithCallback) if !ok { switch method { case BatchMethodAll: return agentAPI.All(requests) case BatchMethodAny: return agentAPI.Any(requests) case BatchMethodRace: return agentAPI.Race(requests) } return []interface{}{} } // Create message channel for callback handling // Use a large buffer (1000) to reduce blocking, with blocking send to guarantee no message loss msgChan := make(chan batchMessage, 1000) doneChan := make(chan []interface{}, 1) // Create Go handler that sends messages to channel // Blocking send ensures no message is lost (natural backpressure) goHandler := func(agentID string, index int, msg *message.Message) int { msgChan <- batchMessage{AgentID: agentID, Index: index, Message: msg} return 0 } // Start batch execution in background goroutine go func() { defer close(msgChan) var results []interface{} switch method { case BatchMethodAll: results = apiWithCb.AllWithHandler(requests, goHandler) case BatchMethodAny: results = apiWithCb.AnyWithHandler(requests, goHandler) case BatchMethodRace: results = apiWithCb.RaceWithHandler(requests, goHandler) } doneChan <- results }() // Process messages in main goroutine (V8 thread-safe) for msg := range msgChan { callJSBatchCallback(v8ctx, callback, msg.AgentID, msg.Index, msg.Message) } // Wait for results return <-doneChan } // callJSBatchCallback calls the JS callback with batch message parameters // Must be called from the main V8 goroutine func callJSBatchCallback(v8ctx *v8go.Context, callback *v8go.Function, agentID string, index int, msg *message.Message) { if callback == nil || v8ctx == nil || msg == nil { return } iso := v8ctx.Isolate() agentIDVal, err := v8go.NewValue(iso, agentID) if err != nil { return } indexVal, err := v8go.NewValue(iso, int32(index)) if err != nil { return } // Convert message to JS value jsMsg, err := bridge.JsValue(v8ctx, msg) if err != nil { return } callback.Call(v8ctx.Global(), agentIDVal, indexVal, jsMsg) } ================================================ FILE: agent/context/jsapi_agent_test.go ================================================ package context_test import ( stdContext "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/context" ) func TestContext_Agent_NilFactory(t *testing.T) { // Reset factory context.AgentAPIFactory = nil ctx := context.New(stdContext.Background(), nil, "test-chat") agentAPI := ctx.Agent() assert.Nil(t, agentAPI) } func TestContext_Agent_WithFactory(t *testing.T) { // Set up a mock factory var capturedCtx *context.Context context.AgentAPIFactory = func(ctx *context.Context) context.AgentAPI { capturedCtx = ctx return &mockAgentAPI{} } defer func() { context.AgentAPIFactory = nil }() ctx := context.New(stdContext.Background(), nil, "test-chat") agentAPI := ctx.Agent() require.NotNil(t, agentAPI) assert.Equal(t, ctx, capturedCtx) } // mockAgentAPI implements context.AgentAPI for testing type mockAgentAPI struct{} func (m *mockAgentAPI) Call(agentID string, messages []interface{}, opts map[string]interface{}) interface{} { return map[string]interface{}{ "agent_id": agentID, "content": "mock response", } } func (m *mockAgentAPI) All(requests []interface{}) []interface{} { return []interface{}{} } func (m *mockAgentAPI) Any(requests []interface{}) []interface{} { return []interface{}{} } func (m *mockAgentAPI) Race(requests []interface{}) []interface{} { return []interface{}{} } ================================================ FILE: agent/context/jsapi_agent_v8_test.go ================================================ package context_test import ( stdContext "context" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" // Import assistant package to register AgentAPIFactory _ "github.com/yaoapp/yao/agent/assistant" ) // TestAgent_Call_V8 tests basic ctx.agent.Call() functionality with real V8 execution func TestAgent_Call_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Create authorized info for the context authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-v8-call") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const result = ctx.agent.Call( "tests.simple-greeting", [{ role: "user", content: "Hello" }] ); return { success: true, agent_id: result.agent_id, has_content: result.content && result.content.length > 0, has_response: result.response !== undefined, error: result.error || "" }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Result should be a map") success, _ := result["success"].(bool) if !success { t.Fatalf("Test failed: %v", result["error"]) } assert.Equal(t, "tests.simple-greeting", result["agent_id"]) hasContent, _ := result["has_content"].(bool) assert.True(t, hasContent, "Should have content in response") hasResponse, _ := result["has_response"].(bool) assert.True(t, hasResponse, "Should have response object") errorStr, _ := result["error"].(string) assert.Empty(t, errorStr, "Should not have error") } // TestAgent_Call_WithOptions_V8 tests ctx.agent.Call() with options func TestAgent_Call_WithOptions_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-v8-options") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const result = ctx.agent.Call( "tests.simple-greeting", [{ role: "user", content: "Hi there!" }], { skip: { history: true, trace: true } } ); return { success: true, agent_id: result.agent_id, content: result.content || "", error: result.error || "" }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) if !result["success"].(bool) { t.Fatalf("Test failed: %v", result["error"]) } assert.Equal(t, "tests.simple-greeting", result["agent_id"]) assert.NotEmpty(t, result["content"], "Should have content") } // TestAgent_All_V8 tests ctx.agent.All() for parallel execution func TestAgent_All_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-v8-all") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const results = ctx.agent.All([ { agent: "tests.simple-greeting", messages: [{ role: "user", content: "Hello from request 1" }] }, { agent: "tests.simple-greeting", messages: [{ role: "user", content: "Hello from request 2" }] } ]); return { success: true, count: results.length, first_agent: results[0] ? results[0].agent_id : "", second_agent: results[1] ? results[1].agent_id : "", first_has_content: results[0] && results[0].content && results[0].content.length > 0, second_has_content: results[1] && results[1].content && results[1].content.length > 0, first_error: results[0] ? (results[0].error || "") : "no result", second_error: results[1] ? (results[1].error || "") : "no result" }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) if !result["success"].(bool) { t.Fatalf("Test failed: %v", result["error"]) } assert.Equal(t, float64(2), result["count"]) assert.Equal(t, "tests.simple-greeting", result["first_agent"]) assert.Equal(t, "tests.simple-greeting", result["second_agent"]) assert.True(t, result["first_has_content"].(bool), "First result should have content") assert.True(t, result["second_has_content"].(bool), "Second result should have content") assert.Empty(t, result["first_error"], "First result should not have error") assert.Empty(t, result["second_error"], "Second result should not have error") } // TestAgent_Any_V8 tests ctx.agent.Any() returns on first success func TestAgent_Any_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-v8-any") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const results = ctx.agent.Any([ { agent: "tests.simple-greeting", messages: [{ role: "user", content: "Hello" }] }, { agent: "tests.simple-greeting", messages: [{ role: "user", content: "Hi" }] } ]); // At least one result should be successful let hasSuccess = false; for (const r of results) { if (r && r.content && !r.error) { hasSuccess = true; break; } } return { success: true, count: results.length, has_successful_result: hasSuccess }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) if !result["success"].(bool) { t.Fatalf("Test failed: %v", result["error"]) } assert.Equal(t, float64(2), result["count"]) assert.True(t, result["has_successful_result"].(bool), "Should have at least one successful result") } // TestAgent_Race_V8 tests ctx.agent.Race() returns on first completion func TestAgent_Race_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-v8-race") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const results = ctx.agent.Race([ { agent: "tests.simple-greeting", messages: [{ role: "user", content: "Hello" }] }, { agent: "tests.simple-greeting", messages: [{ role: "user", content: "Hi" }] } ]); // At least one result should exist (first to complete) let hasResult = false; for (const r of results) { if (r && (r.content || r.error)) { hasResult = true; break; } } return { success: true, count: results.length, has_result: hasResult }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) if !result["success"].(bool) { t.Fatalf("Test failed: %v", result["error"]) } assert.Equal(t, float64(2), result["count"]) assert.True(t, result["has_result"].(bool), "Should have at least one result") } // TestAgent_ErrorHandling_V8 tests error handling when calling non-existent agent func TestAgent_ErrorHandling_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-v8-error") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const result = ctx.agent.Call( "non-existent-agent", [{ role: "user", content: "Hello" }] ); return { success: true, has_error: result.error && result.error.length > 0, error_message: result.error || "" }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) // The call should succeed (no JS exception), but result should contain error assert.True(t, result["success"].(bool), "JS execution should succeed") assert.True(t, result["has_error"].(bool), "Result should have error for non-existent agent") assert.True(t, strings.Contains(result["error_message"].(string), "failed to get agent"), "Error should mention failed to get agent") } // TestAgent_EmptyRequests_V8 tests handling of empty requests array func TestAgent_EmptyRequests_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-v8-empty") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const results = ctx.agent.All([]); return { success: true, count: results.length }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) assert.True(t, result["success"].(bool)) assert.Equal(t, float64(0), result["count"]) } // TestAgent_InvalidArguments_V8 tests error handling for invalid arguments func TestAgent_InvalidArguments_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-v8-invalid") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() // Test missing arguments res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Call with no arguments should throw ctx.agent.Call(); return { success: false, error: "Should have thrown" }; } catch (error) { return { success: true, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) assert.True(t, result["success"].(bool), "Should catch the error") assert.Contains(t, result["error"].(string), "requires") } // ============================================================================ // Callback Tests // ============================================================================ // TestAgent_Call_WithCallback_V8 tests ctx.agent.Call() with onChunk callback func TestAgent_Call_WithCallback_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-v8-callback") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const messages = []; let messageCount = 0; const result = ctx.agent.Call( "tests.simple-greeting", [{ role: "user", content: "Hello" }], { onChunk: (msg) => { // msg is the SSE message object messageCount++; messages.push({ type: msg.type, has_props: msg.props !== undefined }); return 0; // Continue } } ); return { success: true, agent_id: result.agent_id, has_content: result.content && result.content.length > 0, message_count: messageCount, received_messages: messages.slice(0, 5), // First 5 messages error: result.error || "" }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Result should be a map") success, _ := result["success"].(bool) if !success { t.Fatalf("Test failed: %v", result["error"]) } assert.Equal(t, "tests.simple-greeting", result["agent_id"]) // Should have received some messages via callback messageCount, _ := result["message_count"].(float64) t.Logf("Received %v messages via callback", messageCount) assert.Greater(t, messageCount, float64(0), "Should have received messages via callback") // Check that we received message objects with type and props receivedMsgs, _ := result["received_messages"].([]interface{}) if len(receivedMsgs) > 0 { firstMsg := receivedMsgs[0].(map[string]interface{}) t.Logf("First message type: %v", firstMsg["type"]) assert.NotEmpty(t, firstMsg["type"], "Message should have type") } } // TestAgent_Call_WithCallback_Stop_V8 tests that callback can stop streaming func TestAgent_Call_WithCallback_Stop_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-v8-callback-stop") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { let messageCount = 0; const result = ctx.agent.Call( "tests.simple-greeting", [{ role: "user", content: "Hello" }], { onChunk: (msg) => { messageCount++; // Stop after receiving 3 messages if (messageCount >= 3) { return 1; // Stop } return 0; // Continue } } ); return { success: true, message_count: messageCount, stopped_early: messageCount <= 5 // Should have stopped early }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Result should be a map") success, _ := result["success"].(bool) if !success { t.Fatalf("Test failed: %v", result["error"]) } messageCount, _ := result["message_count"].(float64) t.Logf("Received %v messages before stopping", messageCount) // Note: The exact count may vary based on when the stop is processed } // TestAgent_All_WithGlobalCallback_V8 tests ctx.agent.All() with global onChunk callback // Uses channel-based callback handling for V8 thread safety func TestAgent_All_WithGlobalCallback_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-v8-all-callback") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const messagesByAgent = {}; const results = ctx.agent.All( [ { agent: "tests.simple-greeting", messages: [{ role: "user", content: "Hello from 1" }] }, { agent: "tests.simple-greeting", messages: [{ role: "user", content: "Hello from 2" }] } ], { // Global callback receives agentID, index, and message onChunk: (agentID, index, msg) => { const key = agentID + "_" + index; if (!messagesByAgent[key]) { messagesByAgent[key] = 0; } messagesByAgent[key]++; return 0; } } ); return { success: true, result_count: results.length, messages_by_agent: messagesByAgent }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Result should be a map") success, _ := result["success"].(bool) if !success { t.Fatalf("Test failed: %v", result["error"]) } assert.Equal(t, float64(2), result["result_count"]) // Should have received messages from both agents messagesByAgent, _ := result["messages_by_agent"].(map[string]interface{}) t.Logf("Messages by agent: %v", messagesByAgent) // At least one agent should have sent messages assert.Greater(t, len(messagesByAgent), 0, "Should have received messages from agents") } ================================================ FILE: agent/context/jsapi_computer.go ================================================ package context import ( "context" "fmt" "strings" "github.com/yaoapp/gou/runtime/v8/bridge" infraV2 "github.com/yaoapp/yao/sandbox/v2" "github.com/yaoapp/yao/tai/workspace" "rogchap.com/v8go" ) // SetComputer sets the V2 computer and its workspace for this context. // Should be called after Runner.Prepare succeeds in initSandboxV2. func (ctx *Context) SetComputer(computer infraV2.Computer) { ctx.computer = computer if computer != nil { ctx.workspace = computer.Workplace() } } // SetWorkspace sets the workspace FS directly without requiring a Computer. // Use this when the user selected a workspace but no sandbox is configured. func (ctx *Context) SetWorkspace(ws workspace.FS) { ctx.workspace = ws } // GetComputer returns the V2 computer if available. func (ctx *Context) GetComputer() infraV2.Computer { return ctx.computer } // GetWorkspace returns the V2 workspace FS if available. func (ctx *Context) GetWorkspace() workspace.FS { return ctx.workspace } // HasComputer returns true if V2 computer is available. func (ctx *Context) HasComputer() bool { return ctx.computer != nil } // HasWorkspace returns true if workspace FS is available. func (ctx *Context) HasWorkspace() bool { return ctx.workspace != nil } // createComputerInstance creates the ctx.computer JavaScript object. func (ctx *Context) createComputerInstance(v8ctx *v8go.Context) *v8go.Value { if ctx.computer == nil { return nil } iso := v8ctx.Isolate() objTpl := v8go.NewObjectTemplate(iso) info := ctx.computer.ComputerInfo() id := info.BoxID if id == "" { id = info.NodeID } objTpl.Set("id", id) objTpl.Set("Exec", ctx.computerExecMethod(iso)) objTpl.Set("VNC", ctx.computerVNCMethod(iso)) objTpl.Set("Proxy", ctx.computerProxyMethod(iso)) objTpl.Set("Info", ctx.computerInfoMethod(iso)) instance, err := objTpl.NewInstance(v8ctx) if err != nil { return nil } return instance.Value } // computerExecMethod implements ctx.computer.Exec(cmd) // cmd can be a string or an array of strings. // Returns: { stdout, stderr, exit_code } func (ctx *Context) computerExecMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.computer == nil { return bridge.JsException(v8ctx, "computer not available") } if len(args) < 1 { return bridge.JsException(v8ctx, "Exec requires a command argument") } cmd, err := parseCommandArg(v8ctx, args[0]) if err != nil { return bridge.JsException(v8ctx, err.Error()) } result, err := ctx.computer.Exec(context.Background(), cmd) if err != nil { return bridge.JsException(v8ctx, "Exec failed: "+err.Error()) } res := map[string]interface{}{ "stdout": result.Stdout, "stderr": result.Stderr, "exit_code": int32(result.ExitCode), } jsVal, err := bridge.JsValue(v8ctx, res) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // computerVNCMethod implements ctx.computer.VNC() // Returns the VNC URL string. func (ctx *Context) computerVNCMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() if ctx.computer == nil { return bridge.JsException(v8ctx, "computer not available") } url, err := ctx.computer.VNC(context.Background()) if err != nil { return bridge.JsException(v8ctx, "VNC failed: "+err.Error()) } jsVal, err := v8go.NewValue(iso, url) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // computerProxyMethod implements ctx.computer.Proxy(port, path?) // Returns the proxy URL string. func (ctx *Context) computerProxyMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.computer == nil { return bridge.JsException(v8ctx, "computer not available") } if len(args) < 1 || !args[0].IsNumber() { return bridge.JsException(v8ctx, "Proxy requires a port number") } port := int(args[0].Integer()) path := "" if len(args) >= 2 && args[1].IsString() { path = args[1].String() } url, err := ctx.computer.Proxy(context.Background(), port, path) if err != nil { return bridge.JsException(v8ctx, "Proxy failed: "+err.Error()) } jsVal, err := v8go.NewValue(iso, url) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // computerInfoMethod implements ctx.computer.Info() // Returns a JS object with computer identity and system information. func (ctx *Context) computerInfoMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() if ctx.computer == nil { return bridge.JsException(v8ctx, "computer not available") } ci := ctx.computer.ComputerInfo() result := map[string]interface{}{ "kind": ci.Kind, "node_id": ci.NodeID, "tai_id": ci.TaiID, "status": ci.Status, "system": map[string]interface{}{ "os": ci.System.OS, "arch": ci.System.Arch, "hostname": ci.System.Hostname, "num_cpu": int32(ci.System.NumCPU), "shell": ci.System.Shell, }, } if ci.BoxID != "" { result["box_id"] = ci.BoxID result["container_id"] = ci.ContainerID result["image"] = ci.Image result["policy"] = string(ci.Policy) } jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // parseCommandArg converts a JS value (string or string array) to []string. func parseCommandArg(v8ctx *v8go.Context, val *v8go.Value) ([]string, error) { if val.IsString() { raw := val.String() return strings.Fields(raw), nil } if val.IsArray() { obj, err := val.AsObject() if err != nil { return nil, err } lengthVal, err := obj.Get("length") if err != nil { return nil, err } length := int(lengthVal.Integer()) cmd := make([]string, length) for i := 0; i < length; i++ { item, err := obj.GetIdx(uint32(i)) if err != nil { return nil, err } cmd[i] = item.String() } return cmd, nil } return nil, fmt.Errorf("command must be a string or array of strings") } ================================================ FILE: agent/context/jsapi_helpers.go ================================================ package context import ( "fmt" "github.com/yaoapp/gou/runtime/v8/bridge" "github.com/yaoapp/yao/agent/output/message" "rogchap.com/v8go" ) // parseMessage parses a JavaScript value into a message.Message func parseMessage(v8ctx *v8go.Context, jsValue *v8go.Value) (*message.Message, error) { // Handle string shorthand: convert to text message if jsValue.IsString() { return &message.Message{ Type: message.TypeText, Props: map[string]interface{}{ "content": jsValue.String(), }, }, nil } // Handle object if !jsValue.IsObject() { return nil, fmt.Errorf("message must be a string or object") } // Convert to Go map goValue, err := bridge.GoValue(jsValue, v8ctx) if err != nil { return nil, fmt.Errorf("failed to convert message: %w", err) } msgMap, ok := goValue.(map[string]interface{}) if !ok { return nil, fmt.Errorf("message must be an object") } // Build message msg := &message.Message{} // Type field (required) if msgType, ok := msgMap["type"].(string); ok { msg.Type = msgType } else { return nil, fmt.Errorf("message.type is required and must be a string") } // Props field (optional) if props, ok := msgMap["props"].(map[string]interface{}); ok { msg.Props = props } // Optional fields - Streaming control if chunkID, ok := msgMap["chunk_id"].(string); ok { msg.ChunkID = chunkID } if messageID, ok := msgMap["message_id"].(string); ok { msg.MessageID = messageID } if blockID, ok := msgMap["block_id"].(string); ok { msg.BlockID = blockID } if threadID, ok := msgMap["thread_id"].(string); ok { msg.ThreadID = threadID } // Delta control if delta, ok := msgMap["delta"].(bool); ok { msg.Delta = delta } if deltaPath, ok := msgMap["delta_path"].(string); ok { msg.DeltaPath = deltaPath } if deltaAction, ok := msgMap["delta_action"].(string); ok { msg.DeltaAction = deltaAction } if typeChange, ok := msgMap["type_change"].(bool); ok { msg.TypeChange = typeChange } // Metadata (optional) if metadataMap, ok := msgMap["metadata"].(map[string]interface{}); ok { metadata := &message.Metadata{} if timestamp, ok := metadataMap["timestamp"].(float64); ok { metadata.Timestamp = int64(timestamp) } if sequence, ok := metadataMap["sequence"].(float64); ok { metadata.Sequence = int(sequence) } if traceID, ok := metadataMap["trace_id"].(string); ok { metadata.TraceID = traceID } msg.Metadata = metadata } return msg, nil } // parseGroup parses a JavaScript value into a message.Group func parseGroup(v8ctx *v8go.Context, jsValue *v8go.Value) (*message.Group, error) { // Must be an object if !jsValue.IsObject() { return nil, fmt.Errorf("group must be an object") } // Convert to Go map goValue, err := bridge.GoValue(jsValue, v8ctx) if err != nil { return nil, fmt.Errorf("failed to convert group: %w", err) } groupMap, ok := goValue.(map[string]interface{}) if !ok { return nil, fmt.Errorf("group must be an object") } // Build group group := &message.Group{} // ID field (required) if id, ok := groupMap["id"].(string); ok { group.ID = id } else { return nil, fmt.Errorf("group.id is required and must be a string") } // Messages field (required) if messagesArray, ok := groupMap["messages"].([]interface{}); ok { group.Messages = make([]*message.Message, 0, len(messagesArray)) for i, msgInterface := range messagesArray { // Convert to map msgMap, ok := msgInterface.(map[string]interface{}) if !ok { return nil, fmt.Errorf("group.messages[%d] must be an object", i) } // Convert map to Message msg := &message.Message{} // Type field (required) if msgType, ok := msgMap["type"].(string); ok { msg.Type = msgType } else { return nil, fmt.Errorf("group.messages[%d].type is required", i) } // Props field (optional) if props, ok := msgMap["props"].(map[string]interface{}); ok { msg.Props = props } // Optional fields - Streaming control if chunkID, ok := msgMap["chunk_id"].(string); ok { msg.ChunkID = chunkID } if messageID, ok := msgMap["message_id"].(string); ok { msg.MessageID = messageID } if blockID, ok := msgMap["block_id"].(string); ok { msg.BlockID = blockID } if threadID, ok := msgMap["thread_id"].(string); ok { msg.ThreadID = threadID } // Delta control if delta, ok := msgMap["delta"].(bool); ok { msg.Delta = delta } if deltaPath, ok := msgMap["delta_path"].(string); ok { msg.DeltaPath = deltaPath } if deltaAction, ok := msgMap["delta_action"].(string); ok { msg.DeltaAction = deltaAction } if typeChange, ok := msgMap["type_change"].(bool); ok { msg.TypeChange = typeChange } // Metadata (optional) if metadataMap, ok := msgMap["metadata"].(map[string]interface{}); ok { metadata := &message.Metadata{} if timestamp, ok := metadataMap["timestamp"].(float64); ok { metadata.Timestamp = int64(timestamp) } if sequence, ok := metadataMap["sequence"].(float64); ok { metadata.Sequence = int(sequence) } if traceID, ok := metadataMap["trace_id"].(string); ok { metadata.TraceID = traceID } msg.Metadata = metadata } group.Messages = append(group.Messages, msg) } } else { return nil, fmt.Errorf("group.messages is required and must be an array") } // Metadata (optional) if metadataMap, ok := groupMap["metadata"].(map[string]interface{}); ok { metadata := &message.Metadata{} if timestamp, ok := metadataMap["timestamp"].(float64); ok { metadata.Timestamp = int64(timestamp) } if sequence, ok := metadataMap["sequence"].(float64); ok { metadata.Sequence = int(sequence) } if traceID, ok := metadataMap["trace_id"].(string); ok { metadata.TraceID = traceID } group.Metadata = metadata } return group, nil } ================================================ FILE: agent/context/jsapi_llm.go ================================================ package context import ( "github.com/yaoapp/gou/runtime/v8/bridge" "github.com/yaoapp/yao/agent/output/message" "rogchap.com/v8go" ) // LlmAPI defines the LLM JSAPI interface for ctx.llm.* // This interface is defined here to avoid circular dependency between context and llm packages. // The actual implementation is in agent/llm/jsapi.go type LlmAPI interface { // Stream calls LLM with streaming output to ctx.Writer // Returns *llm.Result or error information Stream(connector string, messages []interface{}, opts map[string]interface{}) interface{} // Parallel LLM call methods - inspired by JavaScript Promise // All waits for all LLM calls to complete (like Promise.all) All(requests []interface{}) []interface{} // Any returns when any LLM call succeeds (like Promise.any) Any(requests []interface{}) []interface{} // Race returns when any LLM call completes (like Promise.race) Race(requests []interface{}) []interface{} } // LlmAPIWithCallback extends LlmAPI with callback support // This interface provides methods that accept OnMessage handlers for real-time message processing type LlmAPIWithCallback interface { LlmAPI // StreamWithHandler calls LLM with an OnMessage handler // handler receives SSE messages: func(msg *message.Message) int StreamWithHandler(connector string, messages []interface{}, opts map[string]interface{}, handler OnMessageFunc) interface{} // AllWithHandler executes all LLM calls with handlers // globalHandler receives messages with connectorID and index: func(connectorID, index, msg) int AllWithHandler(requests []interface{}, globalHandler LlmBatchOnMessageFunc) []interface{} // AnyWithHandler executes LLM calls and returns on first success, with handlers AnyWithHandler(requests []interface{}, globalHandler LlmBatchOnMessageFunc) []interface{} // RaceWithHandler executes LLM calls and returns on first completion, with handlers RaceWithHandler(requests []interface{}, globalHandler LlmBatchOnMessageFunc) []interface{} } // LlmBatchOnMessageFunc is the OnMessage function for batch LLM calls // It includes connectorID and index to identify the source of each message type LlmBatchOnMessageFunc func(connectorID string, index int, msg *message.Message) int // LlmAPIFactory is a function type that creates a LlmAPI for a context // This is set by the llm package during initialization var LlmAPIFactory func(ctx *Context) LlmAPI // Llm returns the LLM API for this context // Returns nil if LlmAPIFactory is not set func (ctx *Context) Llm() LlmAPI { if LlmAPIFactory == nil { return nil } return LlmAPIFactory(ctx) } // newLlmObject creates a new llm object with all llm methods // This is called from jsapi.go NewObject() to mount ctx.llm func (ctx *Context) newLlmObject(iso *v8go.Isolate) *v8go.ObjectTemplate { llmObj := v8go.NewObjectTemplate(iso) // Single LLM call method llmObj.Set("Stream", ctx.llmStreamMethod(iso)) // Parallel LLM call methods - inspired by JavaScript Promise llmObj.Set("All", ctx.llmAllMethod(iso)) llmObj.Set("Any", ctx.llmAnyMethod(iso)) llmObj.Set("Race", ctx.llmRaceMethod(iso)) return llmObj } // llmStreamMethod implements ctx.llm.Stream(connector, messages, options?) // Usage: const result = ctx.llm.Stream("gpt-4o", [{ role: "user", content: "Hello" }], { temperature: 0.7, onChunk: (msg) => 0 }) // Returns: { connector, response, content, error } func (ctx *Context) llmStreamMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 2 { return bridge.JsException(v8ctx, "Stream requires connector and messages parameters") } // Get connector ID (first argument) if !args[0].IsString() { return bridge.JsException(v8ctx, "connector must be a string") } connector := args[0].String() // Parse messages (second argument) messagesVal, err := bridge.GoValue(args[1], v8ctx) if err != nil { return bridge.JsException(v8ctx, "invalid messages: "+err.Error()) } messages, ok := messagesVal.([]interface{}) if !ok { return bridge.JsException(v8ctx, "messages must be an array") } // Parse options (optional third argument) - extract onChunk separately var opts map[string]interface{} var onChunkFn *v8go.Function if len(args) >= 3 && !args[2].IsUndefined() && !args[2].IsNull() { optsObj, err := args[2].AsObject() if err == nil && optsObj != nil { // Extract onChunk callback before converting to Go value onChunkVal, _ := optsObj.Get("onChunk") if onChunkVal != nil && onChunkVal.IsFunction() { onChunkFn, _ = onChunkVal.AsFunction() } // Convert the rest of options to Go map goVal, err := bridge.GoValue(args[2], v8ctx) if err == nil { if optsMap, ok := goVal.(map[string]interface{}); ok { // Remove onChunk from the map (it's handled separately) delete(optsMap, "onChunk") opts = optsMap } } } } // Get LLM API llmAPI := ctx.Llm() if llmAPI == nil { return bridge.JsException(v8ctx, "LLM API not available") } var result interface{} // If onChunk callback is provided and API supports it, use StreamWithHandler if onChunkFn != nil { if apiWithCb, ok := llmAPI.(LlmAPIWithCallback); ok { // Create Go OnMessageFunc that calls JS callback handler := createJSStreamHandler(v8ctx, onChunkFn) result = apiWithCb.StreamWithHandler(connector, messages, opts, handler) } else { // Fallback: ignore callback if API doesn't support it result = llmAPI.Stream(connector, messages, opts) } } else { // No callback, use regular Stream result = llmAPI.Stream(connector, messages, opts) } // Convert result to JS value jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, "failed to convert result: "+err.Error()) } return jsVal }) } // llmAllMethod implements ctx.llm.All(requests, options?) // Usage: const results = ctx.llm.All([ // // { connector: "gpt-4o", messages: [...], options: {...} }, // { connector: "claude-3", messages: [...] } // // ], { onChunk: (connectorID, index, msg) => 0 }) // Returns: [{ connector, response, content, error }, ...] func (ctx *Context) llmAllMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { return ctx.executeLlmBatchMethod(info, LlmBatchMethodAll) }) } // llmAnyMethod implements ctx.llm.Any(requests, options?) // Returns first successful result func (ctx *Context) llmAnyMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { return ctx.executeLlmBatchMethod(info, LlmBatchMethodAny) }) } // llmRaceMethod implements ctx.llm.Race(requests, options?) // Returns first completed result (success or failure) func (ctx *Context) llmRaceMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { return ctx.executeLlmBatchMethod(info, LlmBatchMethodRace) }) } // LlmBatchMethod represents the type of batch LLM operation type LlmBatchMethod int const ( LlmBatchMethodAll LlmBatchMethod = iota LlmBatchMethodAny LlmBatchMethodRace ) // executeLlmBatchMethod handles All/Any/Race batch LLM calls func (ctx *Context) executeLlmBatchMethod(info *v8go.FunctionCallbackInfo, method LlmBatchMethod) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 1 { return bridge.JsException(v8ctx, "requires requests array parameter") } // Parse requests array (first argument) requestsVal, err := bridge.GoValue(args[0], v8ctx) if err != nil { return bridge.JsException(v8ctx, "invalid requests: "+err.Error()) } requests, ok := requestsVal.([]interface{}) if !ok { return bridge.JsException(v8ctx, "requests must be an array") } // Get LLM API llmAPI := ctx.Llm() if llmAPI == nil { return bridge.JsException(v8ctx, "LLM API not available") } // Parse optional global callback from second argument (options object) var globalCallback *v8go.Function if len(args) >= 2 && !args[1].IsUndefined() && !args[1].IsNull() { optsObj, err := args[1].AsObject() if err == nil && optsObj != nil { onChunkVal, _ := optsObj.Get("onChunk") if onChunkVal != nil && onChunkVal.IsFunction() { globalCallback, _ = onChunkVal.AsFunction() } } } var results []interface{} // If callback is provided and API supports it, use channel-based execution if globalCallback != nil { if apiWithCb, ok := llmAPI.(LlmAPIWithCallback); ok { results = ctx.executeLlmBatchWithCallback(method, requests, globalCallback, v8ctx, apiWithCb) } else { // Fallback: execute without callback results = ctx.executeLlmBatchWithoutCallback(method, requests, llmAPI) } } else { // No callback, use regular batch methods results = ctx.executeLlmBatchWithoutCallback(method, requests, llmAPI) } // Convert results to JS value jsVal, err := bridge.JsValue(v8ctx, results) if err != nil { return bridge.JsException(v8ctx, "failed to convert results: "+err.Error()) } return jsVal } // executeLlmBatchWithoutCallback executes batch LLM calls without callback func (ctx *Context) executeLlmBatchWithoutCallback(method LlmBatchMethod, requests []interface{}, llmAPI LlmAPI) []interface{} { switch method { case LlmBatchMethodAll: return llmAPI.All(requests) case LlmBatchMethodAny: return llmAPI.Any(requests) case LlmBatchMethodRace: return llmAPI.Race(requests) default: return llmAPI.All(requests) } } // llmBatchMessage is used for channel communication in batch LLM calls type llmBatchMessage struct { ConnectorID string Index int Message *message.Message } // executeLlmBatchWithCallback executes batch LLM calls with callback using channel // This ensures V8 thread safety by serializing callback invocations func (ctx *Context) executeLlmBatchWithCallback(method LlmBatchMethod, requests []interface{}, callback *v8go.Function, v8ctx *v8go.Context, apiWithCb LlmAPIWithCallback) []interface{} { // Create a buffered channel for messages // Using blocking send to ensure all messages are delivered msgChan := make(chan llmBatchMessage, 1000) doneChan := make(chan []interface{}, 1) // Create Go handler that sends to channel goHandler := func(connectorID string, index int, msg *message.Message) int { msgChan <- llmBatchMessage{ ConnectorID: connectorID, Index: index, Message: msg, } return 0 } // Execute batch calls in background goroutine go func() { defer close(msgChan) var results []interface{} switch method { case LlmBatchMethodAll: results = apiWithCb.AllWithHandler(requests, goHandler) case LlmBatchMethodAny: results = apiWithCb.AnyWithHandler(requests, goHandler) case LlmBatchMethodRace: results = apiWithCb.RaceWithHandler(requests, goHandler) default: results = apiWithCb.AllWithHandler(requests, goHandler) } doneChan <- results }() // Process messages in main goroutine (V8 thread) for msg := range msgChan { callJSLlmBatchCallback(v8ctx, callback, msg.ConnectorID, msg.Index, msg.Message) } // Wait for results return <-doneChan } // callJSLlmBatchCallback calls the JS callback function for batch LLM calls func callJSLlmBatchCallback(v8ctx *v8go.Context, callback *v8go.Function, connectorID string, index int, msg *message.Message) { if callback == nil || v8ctx == nil { return } iso := v8ctx.Isolate() // Create arguments: connectorID, index, message connectorVal, err := v8go.NewValue(iso, connectorID) if err != nil { return } indexVal, err := v8go.NewValue(iso, int32(index)) if err != nil { return } // Convert message to JS object msgVal, err := bridge.JsValue(v8ctx, msg) if err != nil { return } // Call the callback _, _ = callback.Call(v8go.Undefined(iso), connectorVal, indexVal, msgVal) } ================================================ FILE: agent/context/jsapi_llm_v8_test.go ================================================ package context_test import ( stdContext "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" // Import assistant package to register LlmAPIFactory _ "github.com/yaoapp/yao/agent/assistant" ) // TestLlm_Stream_V8 tests basic ctx.llm.Stream functionality with real V8 execution func TestLlm_Stream_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) // Create authorized info for the context authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } // Create a context ctx := context.New(stdContext.Background(), authorized, "test-chat-llm-stream") ctx.AssistantID = "tests.simple-greeting" defer ctx.Release() // Test basic Stream call with real connector res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const result = ctx.llm.Stream("gpt-4o-mini", [ { role: "user", content: "Say hello in one word" } ], { temperature: 0.1, max_tokens: 10 }); return { success: true, connector: result.connector, has_content: result.content && result.content.length > 0, has_response: result.response !== undefined, error: result.error || "" }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) require.NotNil(t, res) result, ok := res.(map[string]interface{}) require.True(t, ok, "result should be a map") success, _ := result["success"].(bool) if !success { t.Logf("Test result: %v", result) } require.True(t, success, "Test should succeed, error: %v", result["error"]) assert.Equal(t, "gpt-4o-mini", result["connector"]) hasContent, _ := result["has_content"].(bool) assert.True(t, hasContent, "Should have content in response") hasResponse, _ := result["has_response"].(bool) assert.True(t, hasResponse, "Should have response object") } // TestLlm_Stream_WithCallback_V8 tests ctx.llm.Stream with onChunk callback func TestLlm_Stream_WithCallback_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-llm-callback") ctx.AssistantID = "tests.simple-greeting" defer ctx.Release() // Test Stream call with callback res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { let callbackCount = 0; let receivedTypes = []; const result = ctx.llm.Stream("gpt-4o-mini", [ { role: "user", content: "Say hi" } ], { temperature: 0.1, max_tokens: 10, onChunk: function(msg) { callbackCount++; if (msg && msg.type) { receivedTypes.push(msg.type); } return 0; // Continue } }); return { success: true, connector: result.connector, callbackCount: callbackCount, receivedTypes: receivedTypes, has_content: result.content && result.content.length > 0, error: result.error || "" }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) require.NotNil(t, res) result, ok := res.(map[string]interface{}) require.True(t, ok, "result should be a map") success, _ := result["success"].(bool) if !success { t.Logf("Test result: %v", result) } require.True(t, success, "Test should succeed, error: %v", result["error"]) // Callback should have been called at least once callbackCount, _ := result["callbackCount"].(float64) assert.Greater(t, callbackCount, float64(0), "Callback should be called at least once") } // TestLlm_All_V8 tests ctx.llm.All with multiple connectors func TestLlm_All_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-llm-all") ctx.AssistantID = "tests.simple-greeting" defer ctx.Release() // Test All with multiple requests to same connector (different prompts) res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const results = ctx.llm.All([ { connector: "gpt-4o-mini", messages: [{ role: "user", content: "Say 'one'" }], options: { temperature: 0.1, max_tokens: 5 } }, { connector: "gpt-4o-mini", messages: [{ role: "user", content: "Say 'two'" }], options: { temperature: 0.1, max_tokens: 5 } } ]); return { success: true, count: results.length, results: results.map(r => ({ connector: r.connector, has_content: r.content && r.content.length > 0, error: r.error || "" })) }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) require.NotNil(t, res) result, ok := res.(map[string]interface{}) require.True(t, ok, "result should be a map") success, _ := result["success"].(bool) if !success { t.Logf("Test result: %v", result) } require.True(t, success, "Test should succeed, error: %v", result["error"]) // Should have 2 results count, _ := result["count"].(float64) assert.Equal(t, float64(2), count, "Should have 2 results") // Check individual results results, _ := result["results"].([]interface{}) require.Len(t, results, 2) for i, r := range results { rMap, _ := r.(map[string]interface{}) hasContent, _ := rMap["has_content"].(bool) assert.True(t, hasContent, "Result %d should have content", i) errorStr, _ := rMap["error"].(string) assert.Empty(t, errorStr, "Result %d should not have error", i) } } // TestLlm_All_WithCallback_V8 tests ctx.llm.All with global callback func TestLlm_All_WithCallback_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-llm-all-callback") ctx.AssistantID = "tests.simple-greeting" defer ctx.Release() // Test All with global callback res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { let callbackCount = 0; let indexesSeen = new Set(); const results = ctx.llm.All([ { connector: "gpt-4o-mini", messages: [{ role: "user", content: "Say 'A'" }], options: { temperature: 0.1, max_tokens: 5 } }, { connector: "gpt-4o-mini", messages: [{ role: "user", content: "Say 'B'" }], options: { temperature: 0.1, max_tokens: 5 } } ], { onChunk: function(connectorID, index, msg) { callbackCount++; indexesSeen.add(index); return 0; } }); return { success: true, count: results.length, callbackCount: callbackCount, indexesSeen: indexesSeen.size }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) require.NotNil(t, res) result, ok := res.(map[string]interface{}) require.True(t, ok, "result should be a map") success, _ := result["success"].(bool) if !success { t.Logf("Test result: %v", result) } require.True(t, success, "Test should succeed, error: %v", result["error"]) // Callback should have been called callbackCount, _ := result["callbackCount"].(float64) assert.Greater(t, callbackCount, float64(0), "Callback should be called") // Should have seen at least one index (both requests may complete so fast that only one is tracked) // Note: Due to V8 thread safety with channel-based approach, callbacks are serialized indexesSeen, _ := result["indexesSeen"].(float64) assert.GreaterOrEqual(t, indexesSeen, float64(1), "Should have seen callbacks from at least one request") } // TestLlm_Any_V8 tests ctx.llm.Any - returns first success func TestLlm_Any_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-llm-any") ctx.AssistantID = "tests.simple-greeting" defer ctx.Release() // Test Any - should return first successful result res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const results = ctx.llm.Any([ { connector: "gpt-4o-mini", messages: [{ role: "user", content: "Say 'hello'" }], options: { temperature: 0.1, max_tokens: 5 } }, { connector: "gpt-4o-mini", messages: [{ role: "user", content: "Say 'world'" }], options: { temperature: 0.1, max_tokens: 5 } } ]); // Any returns array with single successful result return { success: true, count: results.length, first_has_content: results[0] && results[0].content && results[0].content.length > 0, first_error: results[0] ? (results[0].error || "") : "no result" }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) require.NotNil(t, res) result, ok := res.(map[string]interface{}) require.True(t, ok, "result should be a map") success, _ := result["success"].(bool) if !success { t.Logf("Test result: %v", result) } require.True(t, success, "Test should succeed, error: %v", result["error"]) // Any returns single result on success count, _ := result["count"].(float64) assert.Equal(t, float64(1), count, "Should have 1 result (first success)") firstHasContent, _ := result["first_has_content"].(bool) assert.True(t, firstHasContent, "First result should have content") firstError, _ := result["first_error"].(string) assert.Empty(t, firstError, "First result should not have error") } // TestLlm_Race_V8 tests ctx.llm.Race - returns first completion func TestLlm_Race_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-llm-race") ctx.AssistantID = "tests.simple-greeting" defer ctx.Release() // Test Race - should return first completed result res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const results = ctx.llm.Race([ { connector: "gpt-4o-mini", messages: [{ role: "user", content: "Say 'fast'" }], options: { temperature: 0.1, max_tokens: 5 } }, { connector: "gpt-4o-mini", messages: [{ role: "user", content: "Say 'slow'" }], options: { temperature: 0.1, max_tokens: 5 } } ]); // Race returns array with single result (first to complete) return { success: true, count: results.length, has_result: results[0] !== undefined, first_connector: results[0] ? results[0].connector : "", first_has_content: results[0] && results[0].content && results[0].content.length > 0 }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) require.NotNil(t, res) result, ok := res.(map[string]interface{}) require.True(t, ok, "result should be a map") success, _ := result["success"].(bool) if !success { t.Logf("Test result: %v", result) } require.True(t, success, "Test should succeed, error: %v", result["error"]) // Race returns single result count, _ := result["count"].(float64) assert.Equal(t, float64(1), count, "Should have 1 result (first to complete)") hasResult, _ := result["has_result"].(bool) assert.True(t, hasResult, "Should have a result") firstConnector, _ := result["first_connector"].(string) assert.Equal(t, "gpt-4o-mini", firstConnector, "First result should have connector") } // TestLlm_Stream_InvalidConnector_V8 tests error handling for invalid connector func TestLlm_Stream_InvalidConnector_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-llm-invalid") ctx.AssistantID = "tests.simple-greeting" defer ctx.Release() // Test Stream with invalid connector res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const result = ctx.llm.Stream("invalid-connector-that-does-not-exist", [ { role: "user", content: "Hello" } ]); return { has_error: result.error && result.error.length > 0, error: result.error || "" }; } catch (error) { return { has_error: true, error: error.message }; } }`, ctx) require.NoError(t, err) require.NotNil(t, res) result, ok := res.(map[string]interface{}) require.True(t, ok, "result should be a map") hasError, _ := result["has_error"].(bool) assert.True(t, hasError, "Should have error for invalid connector") } ================================================ FILE: agent/context/jsapi_mcp.go ================================================ package context import ( "fmt" "github.com/yaoapp/gou/mcp/types" "github.com/yaoapp/gou/runtime/v8/bridge" "rogchap.com/v8go" ) // Suppress unused import warning - types is used in other functions var _ = types.ToolCall{} // MCP JavaScript API methods // These methods expose MCP functionality to JavaScript runtime // mcpListResourcesMethod implements ctx.MCP.ListResources(mcp, cursor) // Lists all available resources from an MCP client func (ctx *Context) mcpListResourcesMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 1 { return bridge.JsException(v8ctx, "ListResources requires mcp parameter") } mcpID := args[0].String() cursor := "" if len(args) >= 2 && !args[1].IsUndefined() { cursor = args[1].String() } result, err := ctx.ListResources(mcpID, cursor) if err != nil { return bridge.JsException(v8ctx, err.Error()) } jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // mcpReadResourceMethod implements ctx.MCP.ReadResource(mcp, uri) // Reads a specific resource from an MCP client func (ctx *Context) mcpReadResourceMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 2 { return bridge.JsException(v8ctx, "ReadResource requires mcp and uri parameters") } mcpID := args[0].String() uri := args[1].String() result, err := ctx.ReadResource(mcpID, uri) if err != nil { return bridge.JsException(v8ctx, err.Error()) } jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // mcpListToolsMethod implements ctx.MCP.ListTools(mcp, cursor) // Lists all available tools from an MCP client func (ctx *Context) mcpListToolsMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 1 { return bridge.JsException(v8ctx, "ListTools requires mcp parameter") } mcpID := args[0].String() cursor := "" if len(args) >= 2 && !args[1].IsUndefined() { cursor = args[1].String() } result, err := ctx.ListTools(mcpID, cursor) if err != nil { return bridge.JsException(v8ctx, err.Error()) } jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // mcpCallToolMethod implements ctx.MCP.CallTool(mcp, name, args) // Calls a specific tool from an MCP client // Returns CallToolResult with parsed 'result' field for convenience func (ctx *Context) mcpCallToolMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 2 { return bridge.JsException(v8ctx, "CallTool requires mcp and name parameters") } mcpID := args[0].String() toolName := args[1].String() // Parse arguments (optional) var toolArgs map[string]interface{} if len(args) >= 3 && !args[2].IsUndefined() { goVal, err := bridge.GoValue(args[2], v8ctx) if err != nil { return bridge.JsException(v8ctx, "invalid tool arguments: "+err.Error()) } if argsMap, ok := goVal.(map[string]interface{}); ok { toolArgs = argsMap } } response, err := ctx.CallTool(mcpID, toolName, toolArgs) if err != nil { return bridge.JsException(v8ctx, err.Error()) } // Return parsed result directly result := parseCallToolResponse(response) jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // mcpCallToolsMethod implements ctx.MCP.CallTools(mcp, tools) // Calls multiple tools sequentially from an MCP client // Returns CallToolsResult with parsed 'result' field in each item func (ctx *Context) mcpCallToolsMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 2 { return bridge.JsException(v8ctx, "CallTools requires mcp and tools parameters") } mcpID := args[0].String() // Parse tools array goVal, err := bridge.GoValue(args[1], v8ctx) if err != nil { return bridge.JsException(v8ctx, "invalid tools parameter: "+err.Error()) } toolsArray, ok := goVal.([]interface{}) if !ok { return bridge.JsException(v8ctx, "tools parameter must be an array") } // Convert to ToolCall array tools := make([]types.ToolCall, 0, len(toolsArray)) for i, item := range toolsArray { toolMap, ok := item.(map[string]interface{}) if !ok { return bridge.JsException(v8ctx, "each tool must be an object") } name, ok := toolMap["name"].(string) if !ok { return bridge.JsException(v8ctx, "tool name is required") } toolCall := types.ToolCall{ Name: name, } if argsVal, exists := toolMap["arguments"]; exists && argsVal != nil { if argsMap, ok := argsVal.(map[string]interface{}); ok { toolCall.Arguments = argsMap } } tools = append(tools, toolCall) // Suppress unused variable warning _ = i } response, err := ctx.CallTools(mcpID, tools) if err != nil { return bridge.JsException(v8ctx, err.Error()) } // Return parsed results directly as array result := parseCallToolsResponse(response) jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // mcpCallToolsParallelMethod implements ctx.MCP.CallToolsParallel(mcp, tools) // Calls multiple tools in parallel from an MCP client // Returns CallToolsResult with parsed 'result' field in each item func (ctx *Context) mcpCallToolsParallelMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 2 { return bridge.JsException(v8ctx, "CallToolsParallel requires mcp and tools parameters") } mcpID := args[0].String() // Parse tools array goVal, err := bridge.GoValue(args[1], v8ctx) if err != nil { return bridge.JsException(v8ctx, "invalid tools parameter: "+err.Error()) } toolsArray, ok := goVal.([]interface{}) if !ok { return bridge.JsException(v8ctx, "tools parameter must be an array") } // Convert to ToolCall array tools := make([]types.ToolCall, 0, len(toolsArray)) for i, item := range toolsArray { toolMap, ok := item.(map[string]interface{}) if !ok { return bridge.JsException(v8ctx, "each tool must be an object") } name, ok := toolMap["name"].(string) if !ok { return bridge.JsException(v8ctx, "tool name is required") } toolCall := types.ToolCall{ Name: name, } if argsVal, exists := toolMap["arguments"]; exists && argsVal != nil { if argsMap, ok := argsVal.(map[string]interface{}); ok { toolCall.Arguments = argsMap } } tools = append(tools, toolCall) // Suppress unused variable warning _ = i } response, err := ctx.CallToolsParallel(mcpID, tools) if err != nil { return bridge.JsException(v8ctx, err.Error()) } // Return parsed results directly as array result := parseCallToolsResponse(response) jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // mcpListPromptsMethod implements ctx.MCP.ListPrompts(mcp, cursor) // Lists all available prompts from an MCP client func (ctx *Context) mcpListPromptsMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 1 { return bridge.JsException(v8ctx, "ListPrompts requires mcp parameter") } mcpID := args[0].String() cursor := "" if len(args) >= 2 && !args[1].IsUndefined() { cursor = args[1].String() } result, err := ctx.ListPrompts(mcpID, cursor) if err != nil { return bridge.JsException(v8ctx, err.Error()) } jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // mcpGetPromptMethod implements ctx.MCP.GetPrompt(mcp, name, args) // Gets a specific prompt from an MCP client func (ctx *Context) mcpGetPromptMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 2 { return bridge.JsException(v8ctx, "GetPrompt requires mcp and name parameters") } mcpID := args[0].String() promptName := args[1].String() // Parse arguments (optional) var promptArgs map[string]interface{} if len(args) >= 3 && !args[2].IsUndefined() { goVal, err := bridge.GoValue(args[2], v8ctx) if err != nil { return bridge.JsException(v8ctx, "invalid prompt arguments: "+err.Error()) } if argsMap, ok := goVal.(map[string]interface{}); ok { promptArgs = argsMap } } result, err := ctx.GetPrompt(mcpID, promptName, promptArgs) if err != nil { return bridge.JsException(v8ctx, err.Error()) } jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // mcpListSamplesMethod implements ctx.MCP.ListSamples(mcp, type, name) // Lists all available samples from an MCP client func (ctx *Context) mcpListSamplesMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 3 { return bridge.JsException(v8ctx, "ListSamples requires mcp, type, and name parameters") } mcpID := args[0].String() sampleType := types.SampleItemType(args[1].String()) name := args[2].String() result, err := ctx.ListSamples(mcpID, sampleType, name) if err != nil { return bridge.JsException(v8ctx, err.Error()) } jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // mcpGetSampleMethod implements ctx.MCP.GetSample(mcp, type, name, index) // Gets a specific sample from an MCP client func (ctx *Context) mcpGetSampleMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 4 { return bridge.JsException(v8ctx, "GetSample requires mcp, type, name, and index parameters") } mcpID := args[0].String() sampleType := types.SampleItemType(args[1].String()) name := args[2].String() // Parse index indexVal, err := bridge.GoValue(args[3], v8ctx) if err != nil { return bridge.JsException(v8ctx, "invalid index parameter: "+err.Error()) } var index int switch v := indexVal.(type) { case int: index = v case int32: index = int(v) case int64: index = int(v) case float64: index = int(v) default: return bridge.JsException(v8ctx, "index must be a number") } result, err := ctx.GetSample(mcpID, sampleType, name, index) if err != nil { return bridge.JsException(v8ctx, err.Error()) } jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // mcpAllMethod implements ctx.mcp.All(requests) // Calls tools on multiple MCP servers concurrently and waits for all to complete (like Promise.all) // Each request should have: { mcp: string, tool: string, arguments?: object } func (ctx *Context) mcpAllMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 1 { return bridge.JsException(v8ctx, "All requires requests parameter") } // Parse requests array requests, err := ctx.parseMCPToolRequests(args[0], v8ctx) if err != nil { return bridge.JsException(v8ctx, err.Error()) } // Execute all requests results := ctx.CallToolAll(requests) // Convert results to JS value jsVal, err := bridge.JsValue(v8ctx, results) if err != nil { return bridge.JsException(v8ctx, "failed to convert results: "+err.Error()) } return jsVal }) } // mcpAnyMethod implements ctx.mcp.Any(requests) // Calls tools on multiple MCP servers concurrently and returns when any succeeds (like Promise.any) // Each request should have: { mcp: string, tool: string, arguments?: object } func (ctx *Context) mcpAnyMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 1 { return bridge.JsException(v8ctx, "Any requires requests parameter") } // Parse requests array requests, err := ctx.parseMCPToolRequests(args[0], v8ctx) if err != nil { return bridge.JsException(v8ctx, err.Error()) } // Execute requests until any succeeds results := ctx.CallToolAny(requests) // Convert results to JS value jsVal, err := bridge.JsValue(v8ctx, results) if err != nil { return bridge.JsException(v8ctx, "failed to convert results: "+err.Error()) } return jsVal }) } // mcpRaceMethod implements ctx.mcp.Race(requests) // Calls tools on multiple MCP servers concurrently and returns when any completes (like Promise.race) // Each request should have: { mcp: string, tool: string, arguments?: object } func (ctx *Context) mcpRaceMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 1 { return bridge.JsException(v8ctx, "Race requires requests parameter") } // Parse requests array requests, err := ctx.parseMCPToolRequests(args[0], v8ctx) if err != nil { return bridge.JsException(v8ctx, err.Error()) } // Execute requests and return first completion results := ctx.CallToolRace(requests) // Convert results to JS value jsVal, err := bridge.JsValue(v8ctx, results) if err != nil { return bridge.JsException(v8ctx, "failed to convert results: "+err.Error()) } return jsVal }) } // parseMCPToolRequests parses JS requests array into MCPToolRequest slice func (ctx *Context) parseMCPToolRequests(arg *v8go.Value, v8ctx *v8go.Context) ([]*MCPToolRequest, error) { goVal, err := bridge.GoValue(arg, v8ctx) if err != nil { return nil, fmt.Errorf("invalid requests: %w", err) } requestsArray, ok := goVal.([]interface{}) if !ok { return nil, fmt.Errorf("requests must be an array") } requests := make([]*MCPToolRequest, 0, len(requestsArray)) for _, item := range requestsArray { reqMap, ok := item.(map[string]interface{}) if !ok { return nil, fmt.Errorf("each request must be an object") } // Required: mcp mcpID, ok := reqMap["mcp"].(string) if !ok || mcpID == "" { return nil, fmt.Errorf("request.mcp is required and must be a string") } // Required: tool tool, ok := reqMap["tool"].(string) if !ok || tool == "" { return nil, fmt.Errorf("request.tool is required and must be a string") } req := &MCPToolRequest{ MCP: mcpID, Tool: tool, } // Optional: arguments if args, exists := reqMap["arguments"]; exists && args != nil { req.Arguments = args } requests = append(requests, req) } return requests, nil } // newMCPObject creates a new MCP object with all MCP methods func (ctx *Context) newMCPObject(iso *v8go.Isolate) *v8go.ObjectTemplate { mcpObj := v8go.NewObjectTemplate(iso) // Resource operations mcpObj.Set("ListResources", ctx.mcpListResourcesMethod(iso)) mcpObj.Set("ReadResource", ctx.mcpReadResourceMethod(iso)) // Tool operations mcpObj.Set("ListTools", ctx.mcpListToolsMethod(iso)) mcpObj.Set("CallTool", ctx.mcpCallToolMethod(iso)) mcpObj.Set("CallTools", ctx.mcpCallToolsMethod(iso)) mcpObj.Set("CallToolsParallel", ctx.mcpCallToolsParallelMethod(iso)) // Cross-server parallel tool operations (Promise-like patterns) mcpObj.Set("All", ctx.mcpAllMethod(iso)) mcpObj.Set("Any", ctx.mcpAnyMethod(iso)) mcpObj.Set("Race", ctx.mcpRaceMethod(iso)) // Prompt operations mcpObj.Set("ListPrompts", ctx.mcpListPromptsMethod(iso)) mcpObj.Set("GetPrompt", ctx.mcpGetPromptMethod(iso)) // Sample operations mcpObj.Set("ListSamples", ctx.mcpListSamplesMethod(iso)) mcpObj.Set("GetSample", ctx.mcpGetSampleMethod(iso)) return mcpObj } ================================================ FILE: agent/context/jsapi_mcp_test.go ================================================ package context_test import ( stdContext "context" "testing" "github.com/stretchr/testify/assert" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // newMCPTestContext creates a test context for MCP testing func newMCPTestContext() *context.Context { ctx := context.New(stdContext.Background(), nil, "test-chat-id") ctx.AssistantID = "test-assistant-id" ctx.Locale = "en" ctx.Referer = context.RefererAPI stack, _, _ := context.EnterStack(ctx, "test-assistant", &context.Options{}) ctx.Stack = stack return ctx } // TestMCPListResources tests MCP.ListResources from JavaScript func TestMCPListResources(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newMCPTestContext() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // List resources from echo MCP const result = ctx.mcp.ListResources("echo", "") if (!result || !result.resources) { throw new Error("Expected resources") } return { count: result.resources.length, has_info: result.resources.some(r => r.name === "info"), has_health: result.resources.some(r => r.name === "health") } }`, ctx) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, float64(2), result["count"], "should have 2 resources") assert.Equal(t, true, result["has_info"], "should have info resource") assert.Equal(t, true, result["has_health"], "should have health resource") } // TestMCPReadResource tests MCP.ReadResource from JavaScript func TestMCPReadResource(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newMCPTestContext() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // Read info resource const result = ctx.mcp.ReadResource("echo", "echo://info") if (!result || !result.contents) { throw new Error("Expected contents") } return { count: result.contents.length, has_content: result.contents.length > 0 } }`, ctx) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, float64(1), result["count"], "should have 1 content") assert.Equal(t, true, result["has_content"], "should have content") } // TestMCPListTools tests MCP.ListTools from JavaScript func TestMCPListTools(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newMCPTestContext() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // List tools from echo MCP const result = ctx.mcp.ListTools("echo", "") if (!result || !result.tools) { throw new Error("Expected tools") } return { count: result.tools.length, has_ping: result.tools.some(t => t.name === "ping"), has_status: result.tools.some(t => t.name === "status"), has_echo: result.tools.some(t => t.name === "echo") } }`, ctx) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, float64(3), result["count"], "should have 3 tools") assert.Equal(t, true, result["has_ping"], "should have ping tool") assert.Equal(t, true, result["has_status"], "should have status tool") assert.Equal(t, true, result["has_echo"], "should have echo tool") } // TestMCPCallTool tests MCP.CallTool from JavaScript func TestMCPCallTool(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newMCPTestContext() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // Call ping tool - returns parsed result directly const result = ctx.mcp.CallTool("echo", "ping", { count: 3, message: "test" }) if (result === undefined || result === null) { throw new Error("Expected result") } return { has_result: true, message: result.message } }`, ctx) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["has_result"], "should have result") assert.Equal(t, "test", result["message"], "should have message") } // TestMCPCallTools tests MCP.CallTools from JavaScript func TestMCPCallTools(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newMCPTestContext() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // Call multiple tools sequentially - returns array of parsed results const tools = [ { name: "ping", arguments: { count: 1 } }, { name: "status", arguments: { verbose: false } } ] const results = ctx.mcp.CallTools("echo", tools) if (!Array.isArray(results)) { throw new Error("Expected array of results") } return { count: results.length, ping_message: results[0]?.message, status_online: results[1]?.status === "online" } }`, ctx) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, float64(2), result["count"], "should have 2 results") assert.Equal(t, "pong", result["ping_message"], "ping should return pong") assert.Equal(t, true, result["status_online"], "status should be online") } // TestMCPCallToolsParallel tests MCP.CallToolsParallel from JavaScript func TestMCPCallToolsParallel(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newMCPTestContext() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // Call multiple tools in parallel - returns array of parsed results const tools = [ { name: "ping", arguments: { count: 1 } }, { name: "status", arguments: { verbose: true } } ] const results = ctx.mcp.CallToolsParallel("echo", tools) if (!Array.isArray(results)) { throw new Error("Expected array of results") } return { count: results.length, ping_message: results[0]?.message, status_online: results[1]?.status === "online" } }`, ctx) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, float64(2), result["count"], "should have 2 results") assert.Equal(t, "pong", result["ping_message"], "ping should return pong") assert.Equal(t, true, result["status_online"], "status should be online") } // TestMCPListPrompts tests MCP.ListPrompts from JavaScript func TestMCPListPrompts(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newMCPTestContext() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // List prompts from echo MCP const result = ctx.mcp.ListPrompts("echo", "") if (!result || !result.prompts) { throw new Error("Expected prompts") } return { count: result.prompts.length, has_test_connection: result.prompts.some(p => p.name === "test_connection"), has_test_echo: result.prompts.some(p => p.name === "test_echo") } }`, ctx) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, float64(2), result["count"], "should have 2 prompts") assert.Equal(t, true, result["has_test_connection"], "should have test_connection prompt") assert.Equal(t, true, result["has_test_echo"], "should have test_echo prompt") } // TestMCPGetPrompt tests MCP.GetPrompt from JavaScript func TestMCPGetPrompt(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newMCPTestContext() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // Get test_connection prompt const result = ctx.mcp.GetPrompt("echo", "test_connection", { detailed: "true" }) if (!result || !result.messages) { throw new Error("Expected messages") } return { count: result.messages.length, has_messages: result.messages.length > 0 } }`, ctx) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, float64(1), result["count"], "should have 1 message") assert.Equal(t, true, result["has_messages"], "should have messages") } // TestMCPListSamples tests MCP.ListSamples from JavaScript func TestMCPListSamples(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newMCPTestContext() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // List samples for ping tool const result = ctx.mcp.ListSamples("echo", "tool", "ping") if (!result || !result.samples) { throw new Error("Expected samples") } return { count: result.samples.length, has_samples: result.samples.length > 0 } }`, ctx) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, float64(3), result["count"], "should have 3 samples") assert.Equal(t, true, result["has_samples"], "should have samples") } // TestMCPGetSample tests MCP.GetSample from JavaScript func TestMCPGetSample(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newMCPTestContext() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // Get first sample for ping tool const result = ctx.mcp.GetSample("echo", "tool", "ping", 0) if (!result) { throw new Error("Expected sample") } return { has_name: !!result.name, has_input: !!result.input, name: result.name } }`, ctx) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["has_name"], "should have name") assert.Equal(t, true, result["has_input"], "should have input") assert.Equal(t, "single_ping", result["name"], "name should be single_ping") } // TestMCPJsApiWithTrace tests MCP operations with trace from JavaScript func TestMCPJsApiWithTrace(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newMCPTestContext() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // Get trace (property, not method call) const trace = ctx.trace // Call MCP tool - returns parsed result directly const result = ctx.mcp.CallTool("echo", "ping", { count: 5 }) // Verify trace and result exist return { has_trace: !!trace, has_result: result !== undefined && result !== null, ping_message: result?.message } }`, ctx) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["has_trace"], "should have trace") assert.Equal(t, true, result["has_result"], "should have result") assert.Equal(t, "pong", result["ping_message"], "should have ping response") } ================================================ FILE: agent/context/jsapi_mcp_v8_test.go ================================================ package context_test import ( stdContext "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/openapi/oauth/types" ) // TestMCP_All_V8 tests ctx.mcp.All() with real V8 execution func TestMCP_All_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-mcp-all") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const results = ctx.mcp.All([ { mcp: "echo", tool: "ping", arguments: { count: 1 } }, { mcp: "echo", tool: "status", arguments: { verbose: false } }, { mcp: "echo", tool: "echo", arguments: { message: "hello" } } ]); return { success: true, count: results.length, // Each result has mcp, tool, result (parsed), error results: results.map(r => ({ mcp: r.mcp, tool: r.tool, has_result: r.result !== undefined && r.result !== null, error: r.error || "" })) }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Result should be a map") success, _ := result["success"].(bool) if !success { t.Fatalf("Test failed: %v", result["error"]) } // Should have 3 results - handle different integer types var count int switch v := result["count"].(type) { case int: count = v case int32: count = int(v) case int64: count = int(v) case float64: count = int(v) default: t.Logf("Unexpected count type: %T, value: %v", result["count"], result["count"]) } assert.Equal(t, 3, count, "Should have 3 results") // Check each result results, ok := result["results"].([]interface{}) require.True(t, ok, "Results should be an array") require.Len(t, results, 3) for i, r := range results { resMap, ok := r.(map[string]interface{}) require.True(t, ok, "Result %d should be a map", i) hasResult, _ := resMap["has_result"].(bool) assert.True(t, hasResult, "Result %d should have parsed result", i) errorStr, _ := resMap["error"].(string) assert.Empty(t, errorStr, "Result %d should not have error", i) } } // TestMCP_All_WithError_V8 tests ctx.mcp.All() with some failing requests func TestMCP_All_WithError_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-mcp-all-error") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const results = ctx.mcp.All([ { mcp: "echo", tool: "ping", arguments: { count: 1 } }, { mcp: "nonexistent-mcp", tool: "some-tool", arguments: {} }, { mcp: "echo", tool: "status", arguments: {} } ]); return { success: true, count: results.length, results: results.map(r => ({ mcp: r.mcp, tool: r.tool, has_result: r.result !== undefined && r.result !== null, has_error: r.error !== undefined && r.error !== "" && r.error !== null })) }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Result should be a map") success, _ := result["success"].(bool) if !success { t.Fatalf("Test failed: %v", result["error"]) } // Should have 3 results - handle different integer types var count int switch v := result["count"].(type) { case int: count = v case int32: count = int(v) case int64: count = int(v) case float64: count = int(v) } assert.Equal(t, 3, count, "Should have 3 results") // Check results results, ok := result["results"].([]interface{}) require.True(t, ok, "Results should be an array") // First result (ping) should succeed r0, _ := results[0].(map[string]interface{}) assert.True(t, r0["has_result"].(bool), "Ping should have result") assert.False(t, r0["has_error"].(bool), "Ping should not have error") // Second result (nonexistent) should fail r1, _ := results[1].(map[string]interface{}) assert.False(t, r1["has_result"].(bool), "Nonexistent should not have result") assert.True(t, r1["has_error"].(bool), "Nonexistent should have error") // Third result (status) should succeed r2, _ := results[2].(map[string]interface{}) assert.True(t, r2["has_result"].(bool), "Status should have result") assert.False(t, r2["has_error"].(bool), "Status should not have error") } // TestMCP_Any_V8 tests ctx.mcp.Any() with real V8 execution func TestMCP_Any_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-mcp-any") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const results = ctx.mcp.Any([ { mcp: "echo", tool: "ping", arguments: { count: 1 } }, { mcp: "echo", tool: "status", arguments: {} } ]); // Find success results (has result field, no error) const successResults = results.filter(r => r && r.result && !r.error); return { success: true, total_count: results.length, success_count: successResults.length, has_at_least_one_success: successResults.length >= 1 }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Result should be a map") success, _ := result["success"].(bool) if !success { t.Fatalf("Test failed: %v", result["error"]) } hasAtLeastOne, _ := result["has_at_least_one_success"].(bool) assert.True(t, hasAtLeastOne, "Should have at least one successful result") } // TestMCP_Any_AllFail_V8 tests ctx.mcp.Any() when all requests fail func TestMCP_Any_AllFail_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-mcp-any-fail") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const results = ctx.mcp.Any([ { mcp: "nonexistent-1", tool: "tool1", arguments: {} }, { mcp: "nonexistent-2", tool: "tool2", arguments: {} } ]); // All should fail const failedResults = results.filter(r => r && r.error); return { success: true, total_count: results.length, failed_count: failedResults.length, all_failed: failedResults.length === results.length }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Result should be a map") success, _ := result["success"].(bool) if !success { t.Fatalf("Test failed: %v", result["error"]) } allFailed, _ := result["all_failed"].(bool) assert.True(t, allFailed, "All requests should fail") } // TestMCP_Race_V8 tests ctx.mcp.Race() with real V8 execution func TestMCP_Race_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-mcp-race") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const results = ctx.mcp.Race([ { mcp: "echo", tool: "ping", arguments: { count: 1 } }, { mcp: "echo", tool: "status", arguments: {} } ]); // Find completed results (could be success or error) const completedResults = results.filter(r => r !== undefined && r !== null); return { success: true, total_count: results.length, completed_count: completedResults.length, has_at_least_one_completed: completedResults.length >= 1 }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Result should be a map") success, _ := result["success"].(bool) if !success { t.Fatalf("Test failed: %v", result["error"]) } hasAtLeastOne, _ := result["has_at_least_one_completed"].(bool) assert.True(t, hasAtLeastOne, "Should have at least one completed result") } // TestMCP_All_ResultContent_V8 tests that the result contains parsed content directly func TestMCP_All_ResultContent_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-mcp-content") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const results = ctx.mcp.All([ { mcp: "echo", tool: "echo", arguments: { message: "hello world", uppercase: true } } ]); if (results.length !== 1) { return { success: false, error: "Expected 1 result" }; } const r = results[0]; if (r.error) { return { success: false, error: "Tool call failed: " + r.error }; } // Result should contain parsed data directly const data = r.result; if (!data) { return { success: false, error: "Result should have parsed data" }; } return { success: true, echo_message: data.echo, uppercase_flag: data.uppercase, original_length: data.length }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Result should be a map") success, _ := result["success"].(bool) if !success { t.Fatalf("Test failed: %v", result["error"]) } echoMessage, _ := result["echo_message"].(string) assert.Equal(t, "HELLO WORLD", echoMessage, "Echo message should be uppercase") uppercaseFlag, _ := result["uppercase_flag"].(bool) assert.True(t, uppercaseFlag, "Uppercase flag should be true") // Handle different integer types from V8 var originalLength int switch v := result["original_length"].(type) { case int: originalLength = v case int32: originalLength = int(v) case int64: originalLength = int(v) case float64: originalLength = int(v) } assert.Equal(t, 11, originalLength, "Original message length should be 11") } // TestMCP_All_MultipleTools_V8 tests All with multiple tools and verifies parsed results func TestMCP_All_MultipleTools_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-mcp-multi") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const results = ctx.mcp.All([ { mcp: "echo", tool: "ping", arguments: { count: 5 } }, { mcp: "echo", tool: "echo", arguments: { message: "test", uppercase: false } } ]); if (results.length !== 2) { return { success: false, error: "Expected 2 results" }; } // Access parsed results directly const ping = results[0]; const echo = results[1]; return { success: true, ping_message: ping.result?.message, echo_message: echo.result?.echo }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Result should be a map") success, _ := result["success"].(bool) if !success { t.Fatalf("Test failed: %v", result["error"]) } pingMessage, _ := result["ping_message"].(string) assert.Equal(t, "pong", pingMessage, "Ping should return pong") echoMessage, _ := result["echo_message"].(string) assert.Equal(t, "test", echoMessage, "Echo should return the message") } // TestMCP_CallTool_ParsedResult_V8 tests that ctx.mcp.CallTool() returns parsed result directly func TestMCP_CallTool_ParsedResult_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-mcp-calltool-parsed") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Test CallTool returns parsed result directly const result = ctx.mcp.CallTool("echo", "echo", { message: "test message", uppercase: true }); // Result should be the parsed data directly if (result === undefined || result === null) { return { success: false, error: "Result should not be null" }; } return { success: true, echo_message: result.echo, original_length: result.length }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Result should be a map") success, _ := result["success"].(bool) if !success { t.Fatalf("Test failed: %v", result["error"]) } echoMessage, _ := result["echo_message"].(string) assert.Equal(t, "TEST MESSAGE", echoMessage, "Echo message should be uppercase") // Handle different integer types from V8 var length int switch v := result["original_length"].(type) { case int: length = v case int32: length = int(v) case int64: length = int(v) case float64: length = int(v) } assert.Equal(t, 12, length, "Original message length should be 12") } // TestMCP_CallToolsParallel_ParsedResult_V8 tests that ctx.mcp.CallToolsParallel() returns parsed results directly func TestMCP_CallToolsParallel_ParsedResult_V8(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } testutils.Prepare(t) defer testutils.Clean(t) authorized := &types.AuthorizedInfo{ Subject: "test-user", UserID: "test-123", TenantID: "test-tenant", } ctx := context.New(stdContext.Background(), authorized, "test-chat-mcp-calltools-parsed") ctx.AssistantID = "tests.agent-caller" defer ctx.Release() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Test CallToolsParallel returns parsed results directly as array const results = ctx.mcp.CallToolsParallel("echo", [ { name: "ping", arguments: { count: 2 } }, { name: "echo", arguments: { message: "hello", uppercase: false } } ]); if (!Array.isArray(results)) { return { success: false, error: "Results should be an array" }; } if (results.length !== 2) { return { success: false, error: "Expected 2 results, got " + results.length }; } // Each result is the parsed data directly const pingResult = results[0]; const echoResult = results[1]; return { success: true, ping_message: pingResult?.message, echo_message: echoResult?.echo }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Result should be a map") success, _ := result["success"].(bool) if !success { t.Fatalf("Test failed: %v", result["error"]) } pingMessage, _ := result["ping_message"].(string) assert.Equal(t, "pong", pingMessage, "Ping result should have message='pong'") echoMessage, _ := result["echo_message"].(string) assert.Equal(t, "hello", echoMessage, "Echo message should be preserved (no uppercase)") } ================================================ FILE: agent/context/jsapi_memory_test.go ================================================ package context_test import ( stdContext "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/memory" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/openapi/oauth/types" "github.com/yaoapp/yao/test" ) func TestMemoryUserNamespace(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "team1", "chat1", "ctx1") require.NoError(t, err) ctx := &context.Context{ ChatID: "test-chat-id", AssistantID: "test-assistant-id", Locale: "en", Context: stdContext.Background(), Memory: mem, } res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Set values in user namespace ctx.memory.user.Set("name", "John"); ctx.memory.user.Set("age", 30); ctx.memory.user.Set("active", true); // Get values back const name = ctx.memory.user.Get("name"); const age = ctx.memory.user.Get("age"); const active = ctx.memory.user.Get("active"); // Verify if (name !== "John") throw new Error("Name mismatch"); if (age !== 30) throw new Error("Age mismatch"); if (active !== true) throw new Error("Active mismatch"); return { success: true, name: name, age: age, active: active }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) assert.True(t, result["success"].(bool)) assert.Equal(t, "John", result["name"]) assert.Equal(t, float64(30), result["age"]) assert.Equal(t, true, result["active"]) } func TestMemoryTeamNamespace(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "team1", "chat1", "ctx1") require.NoError(t, err) ctx := &context.Context{ ChatID: "test-chat-id", AssistantID: "test-assistant-id", Locale: "en", Context: stdContext.Background(), Memory: mem, } res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Set team-wide settings ctx.memory.team.Set("settings", { theme: "dark", language: "en" }); // Get back const settings = ctx.memory.team.Get("settings"); if (settings.theme !== "dark") throw new Error("Theme mismatch"); if (settings.language !== "en") throw new Error("Language mismatch"); return { success: true, settings: settings }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) assert.True(t, result["success"].(bool)) } func TestMemoryChatNamespace(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "team1", "chat1", "ctx1") require.NoError(t, err) ctx := &context.Context{ ChatID: "test-chat-id", AssistantID: "test-assistant-id", Locale: "en", Context: stdContext.Background(), Memory: mem, } res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Set chat context ctx.memory.chat.Set("topic", "AI Discussion"); ctx.memory.chat.Set("participants", ["Alice", "Bob"]); // Get back const topic = ctx.memory.chat.Get("topic"); const participants = ctx.memory.chat.Get("participants"); if (topic !== "AI Discussion") throw new Error("Topic mismatch"); if (participants.length !== 2) throw new Error("Participants mismatch"); return { success: true, topic: topic, participants: participants }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) assert.True(t, result["success"].(bool)) assert.Equal(t, "AI Discussion", result["topic"]) } func TestMemoryContextNamespace(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "team1", "chat1", "ctx1") require.NoError(t, err) ctx := &context.Context{ ChatID: "test-chat-id", AssistantID: "test-assistant-id", Locale: "en", Context: stdContext.Background(), Memory: mem, } res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Set temporary context data ctx.memory.context.Set("temp_result", { step: 1, data: "processing" }); // Get back const result = ctx.memory.context.Get("temp_result"); if (result.step !== 1) throw new Error("Step mismatch"); if (result.data !== "processing") throw new Error("Data mismatch"); return { success: true, result: result }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) assert.True(t, result["success"].(bool)) } func TestMemoryHasAndDel(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "team1", "chat1", "ctx1") require.NoError(t, err) ctx := &context.Context{ ChatID: "test-chat-id", AssistantID: "test-assistant-id", Locale: "en", Context: stdContext.Background(), Memory: mem, } res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Set a value ctx.memory.user.Set("key", "value"); // Check Has const hasBefore = ctx.memory.user.Has("key"); if (!hasBefore) throw new Error("Should have key before delete"); // Delete ctx.memory.user.Del("key"); // Check Has again const hasAfter = ctx.memory.user.Has("key"); if (hasAfter) throw new Error("Should not have key after delete"); return { success: true, hasBefore: hasBefore, hasAfter: hasAfter }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) assert.True(t, result["success"].(bool)) assert.True(t, result["hasBefore"].(bool)) assert.False(t, result["hasAfter"].(bool)) } func TestMemoryIncrDecr(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "team1", "chat1", "ctx1") require.NoError(t, err) ctx := &context.Context{ ChatID: "test-chat-id", AssistantID: "test-assistant-id", Locale: "en", Context: stdContext.Background(), Memory: mem, } res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Incr on non-existent key const v1 = ctx.memory.user.Incr("counter"); if (v1 !== 1) throw new Error("First incr should be 1, got " + v1); // Incr with delta const v2 = ctx.memory.user.Incr("counter", 5); if (v2 !== 6) throw new Error("Second incr should be 6, got " + v2); // Decr const v3 = ctx.memory.user.Decr("counter", 2); if (v3 !== 4) throw new Error("Decr should be 4, got " + v3); return { success: true, v1: v1, v2: v2, v3: v3 }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) assert.True(t, result["success"].(bool)) assert.Equal(t, float64(1), result["v1"]) assert.Equal(t, float64(6), result["v2"]) assert.Equal(t, float64(4), result["v3"]) } func TestMemoryKeysAndLen(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Use unique IDs to avoid data pollution from other tests mem, err := memory.New(nil, "user-keys-len", "team-keys-len", "chat-keys-len", "ctx-keys-len") require.NoError(t, err) ctx := &context.Context{ ChatID: "test-chat-id", AssistantID: "test-assistant-id", Locale: "en", Context: stdContext.Background(), Memory: mem, } res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Set multiple values ctx.memory.user.Set("a", 1); ctx.memory.user.Set("b", 2); ctx.memory.user.Set("c", 3); // Get keys const keys = ctx.memory.user.Keys(); if (keys.length !== 3) throw new Error("Should have 3 keys, got " + keys.length); // Get len const len = ctx.memory.user.Len(); if (len !== 3) throw new Error("Len should be 3, got " + len); return { success: true, keys: keys, len: len }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) if !result["success"].(bool) { t.Fatalf("Test failed: %v", result["error"]) } assert.Equal(t, float64(3), result["len"]) } func TestMemoryClear(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "team1", "chat1", "ctx1") require.NoError(t, err) ctx := &context.Context{ ChatID: "test-chat-id", AssistantID: "test-assistant-id", Locale: "en", Context: stdContext.Background(), Memory: mem, } res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Set values ctx.memory.user.Set("a", 1); ctx.memory.user.Set("b", 2); // Clear ctx.memory.user.Clear(); // Check len const len = ctx.memory.user.Len(); if (len !== 0) throw new Error("Len should be 0 after clear, got " + len); return { success: true, len: len }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) assert.True(t, result["success"].(bool)) assert.Equal(t, float64(0), result["len"]) } func TestMemoryGetDel(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "team1", "chat1", "ctx1") require.NoError(t, err) ctx := &context.Context{ ChatID: "test-chat-id", AssistantID: "test-assistant-id", Locale: "en", Context: stdContext.Background(), Memory: mem, } res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Set a one-time value ctx.memory.user.Set("token", "secret123"); // GetDel const value = ctx.memory.user.GetDel("token"); if (value !== "secret123") throw new Error("Value mismatch"); // Should be deleted const after = ctx.memory.user.Get("token"); if (after !== null) throw new Error("Should be null after GetDel"); return { success: true, value: value, after: after }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) assert.True(t, result["success"].(bool)) assert.Equal(t, "secret123", result["value"]) assert.Nil(t, result["after"]) } func TestMemoryIsolation(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create two different memory instances mem1, err := memory.New(nil, "user1", "", "", "") require.NoError(t, err) mem2, err := memory.New(nil, "user2", "", "", "") require.NoError(t, err) ctx1 := &context.Context{ ChatID: "chat1", Context: stdContext.Background(), Memory: mem1, } ctx2 := &context.Context{ ChatID: "chat2", Context: stdContext.Background(), Memory: mem2, } // Set value in user1 res1, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { ctx.memory.user.Set("key", "user1_value"); return ctx.memory.user.Get("key"); }`, ctx1) require.NoError(t, err) assert.Equal(t, "user1_value", res1) // Set value in user2 res2, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { ctx.memory.user.Set("key", "user2_value"); return ctx.memory.user.Get("key"); }`, ctx2) require.NoError(t, err) assert.Equal(t, "user2_value", res2) // Verify user1 still has its own value res3, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { return ctx.memory.user.Get("key"); }`, ctx1) require.NoError(t, err) assert.Equal(t, "user1_value", res3) } func TestMemoryNoMemory(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := &context.Context{ ChatID: "test-chat-id", Context: stdContext.Background(), Memory: nil, // No memory } res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { const hasMemory = ctx.memory !== undefined && ctx.memory !== null; return { success: true, hasMemory: hasMemory }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) assert.True(t, result["success"].(bool)) assert.False(t, result["hasMemory"].(bool)) } func TestMemoryWithAuthorizedInfo(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Use context.New to create context with authorized info authorized := &types.AuthorizedInfo{ UserID: "user123", TeamID: "team456", } ctx := context.New(stdContext.Background(), authorized, "chat789") defer ctx.Release() // Verify memory was created with correct IDs require.NotNil(t, ctx.Memory) require.NotNil(t, ctx.Memory.User) require.NotNil(t, ctx.Memory.Team) require.NotNil(t, ctx.Memory.Chat) require.NotNil(t, ctx.Memory.Context) res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Set values in different namespaces ctx.memory.user.Set("pref", "dark"); ctx.memory.team.Set("setting", "shared"); ctx.memory.chat.Set("topic", "test"); ctx.memory.context.Set("temp", "data"); return { success: true, user: ctx.memory.user.Get("pref"), team: ctx.memory.team.Get("setting"), chat: ctx.memory.chat.Get("topic"), context: ctx.memory.context.Get("temp") }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result := res.(map[string]interface{}) assert.True(t, result["success"].(bool)) assert.Equal(t, "dark", result["user"]) assert.Equal(t, "shared", result["team"]) assert.Equal(t, "test", result["chat"]) assert.Equal(t, "data", result["context"]) } ================================================ FILE: agent/context/jsapi_output_test.go ================================================ package context_test import ( "bytes" stdContext "context" "net/http" "testing" "github.com/stretchr/testify/assert" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // testMockResponseWriter is a mock implementation of http.ResponseWriter for testing type testMockResponseWriter struct { headers http.Header buffer *bytes.Buffer status int } func newTestMockResponseWriter() *testMockResponseWriter { return &testMockResponseWriter{ headers: make(http.Header), buffer: &bytes.Buffer{}, status: http.StatusOK, } } func (m *testMockResponseWriter) Header() http.Header { return m.headers } func (m *testMockResponseWriter) Write(b []byte) (int, error) { return m.buffer.Write(b) } func (m *testMockResponseWriter) WriteHeader(statusCode int) { m.status = statusCode } // TestJsValueSend test the Send method on Context func TestJsValueSend(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptStandard cxt.Locale = "en" cxt.Writer = newTestMockResponseWriter() // Test sending string shorthand res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Send simple string ctx.Send("Hello World"); return { success: true }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["success"], "Send string should succeed") // Test sending message object res, err = v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Send message object ctx.Send({ type: "text", props: { content: "Hello from JavaScript" }, id: "msg_123", metadata: { timestamp: Date.now(), sequence: 1 } }); return { success: true }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok = res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["success"], "Send message object should succeed") } // TestJsValueSendGroup test the SendGroup method on Context // TestJsValueSendDeltaUpdates test delta updates in Send func TestJsValueSendDeltaUpdates(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptStandard cxt.Locale = "en" cxt.Writer = newTestMockResponseWriter() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Send initial message ctx.Send({ type: "text", props: { content: "Hello" }, id: "msg_1", delta: false }); // Send delta update (append) ctx.Send({ type: "text", props: { content: " World" }, id: "msg_1", delta: true, delta_path: "content", delta_action: "append" }); // Send completion (no done field needed) return { success: true }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["success"], "Delta updates should succeed") } // TestJsValueSendMultipleTypes test sending different message types func TestJsValueSendMultipleTypes(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptStandard cxt.Locale = "en" cxt.Writer = newTestMockResponseWriter() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Text message ctx.Send({ type: "text", props: { content: "Hello" } }); // Thinking message ctx.Send({ type: "thinking", props: { content: "Let me think..." } }); // Loading message ctx.Send({ type: "loading", props: { message: "Processing..." } }); // Tool call message ctx.Send({ type: "tool_call", props: { id: "call_123", name: "get_weather", arguments: '{"location": "San Francisco"}' } }); // Error message ctx.Send({ type: "error", props: { message: "Something went wrong", code: "ERR_500" } }); // Image message ctx.Send({ type: "image", props: { url: "https://example.com/image.jpg", alt: "Example image", width: 800, height: 600 } }); return { success: true }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["success"], "Multiple message types should succeed") } // TestJsValueSendErrorHandling test error handling in Send func TestJsValueSendErrorHandling(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptStandard cxt.Locale = "en" cxt.Writer = newTestMockResponseWriter() // Test invalid argument - no arguments res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { ctx.Send(); return { success: true }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, false, result["success"], "Send without arguments should fail") assert.Contains(t, result["error"], "Send requires a message argument", "Error should mention missing message") } // TestJsValueSendGroupErrorHandling test error handling in SendGroup // TestJsValueSendWithCUIAccept test Send with CUI accept types func TestJsValueSendWithCUIAccept(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() acceptTypes := []context.Accept{context.AcceptWebCUI, context.AccepNativeCUI, context.AcceptDesktopCUI} for _, acceptType := range acceptTypes { t.Run(string(acceptType), func(t *testing.T) { cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = acceptType cxt.Locale = "en" cxt.Writer = newTestMockResponseWriter() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { ctx.Send({ type: "text", props: { content: "Hello CUI" } }); return { success: true }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["success"], "Send with "+string(acceptType)+" should succeed") }) } } // TestJsValueSendGroupWithMetadata test SendGroup with various metadata // TestJsValueSendChainedCalls test chained Send calls func TestJsValueSendChainedCalls(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptStandard cxt.Locale = "en" cxt.Writer = newTestMockResponseWriter() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Multiple sequential sends (each auto-flushes) ctx.Send("Step 1"); ctx.Send("Step 2"); ctx.Send("Step 3"); ctx.Send("Step 4"); return { success: true }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["success"], "Chained Send calls should succeed") } // TestJsValueIDGenerators test ID generator methods func TestJsValueIDGenerators(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptStandard cxt.Locale = "en" cxt.Writer = newTestMockResponseWriter() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Test MessageID generator const msgId1 = ctx.MessageID(); const msgId2 = ctx.MessageID(); // Test BlockID generator const blockId1 = ctx.BlockID(); const blockId2 = ctx.BlockID(); // Test ThreadID generator const threadId1 = ctx.ThreadID(); const threadId2 = ctx.ThreadID(); // Verify IDs are strings and sequential if (typeof msgId1 !== 'string' || typeof msgId2 !== 'string') { throw new Error('MessageID should return string'); } if (typeof blockId1 !== 'string' || typeof blockId2 !== 'string') { throw new Error('BlockID should return string'); } if (typeof threadId1 !== 'string' || typeof threadId2 !== 'string') { throw new Error('ThreadID should return string'); } // Verify they follow the pattern (M1, M2, B1, B2, T1, T2) if (!msgId1.startsWith('M') || !msgId2.startsWith('M')) { throw new Error('MessageID should start with M'); } if (!blockId1.startsWith('B') || !blockId2.startsWith('B')) { throw new Error('BlockID should start with B'); } if (!threadId1.startsWith('T') || !threadId2.startsWith('T')) { throw new Error('ThreadID should start with T'); } return { success: true, msgId1: msgId1, msgId2: msgId2, blockId1: blockId1, blockId2: blockId2, threadId1: threadId1, threadId2: threadId2 }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["success"], "ID generators should succeed") } // TestJsValueSendWithBlockID test Send with block_id parameter func TestJsValueSendWithBlockID(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptStandard cxt.Locale = "en" cxt.Writer = newTestMockResponseWriter() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Generate block ID manually const blockId = ctx.BlockID(); // Send multiple messages with same block ID const msg1 = ctx.Send("Message 1", blockId); const msg2 = ctx.Send("Message 2", blockId); const msg3 = ctx.Send("Message 3", blockId); // Send message with block_id in object (higher priority) const msg4 = ctx.Send({ type: "text", props: { content: "Message 4" }, block_id: "B_custom" }, blockId); // blockId parameter should be ignored return { success: true, msg1: msg1, msg2: msg2, msg3: msg3, msg4: msg4, blockId: blockId }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["success"], "Send with blockId should succeed") } // TestJsValueReplace test ctx.Replace method func TestJsValueReplace(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptStandard cxt.Locale = "en" cxt.Writer = newTestMockResponseWriter() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Send initial message const msgId = ctx.Send("Initial content"); // Replace with new content ctx.Replace(msgId, "Updated content"); // Replace with object ctx.Replace(msgId, { type: "text", props: { content: "Final content" } }); return { success: true, msgId: msgId }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["success"], "Replace should succeed") } // TestJsValueAppend test ctx.Append method func TestJsValueAppend(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptStandard cxt.Locale = "en" cxt.Writer = newTestMockResponseWriter() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Send initial message const msgId = ctx.Send("Hello"); // Append to default path ctx.Append(msgId, " World"); ctx.Append(msgId, "!"); // Append to specific path const msgId2 = ctx.Send({ type: "data", props: { content: "Line 1\n" } }); ctx.Append(msgId2, "Line 2\n", "props.content"); ctx.Append(msgId2, "Line 3\n", "props.content"); return { success: true, msgId: msgId, msgId2: msgId2 }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["success"], "Append should succeed") } // TestJsValueMerge test ctx.Merge method func TestJsValueMerge(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptStandard cxt.Locale = "en" cxt.Writer = newTestMockResponseWriter() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Send initial message with object const msgId = ctx.Send({ type: "status", props: { status: "running", progress: 0, started: true } }); // Merge updates (keeps other fields) ctx.Merge(msgId, { type: "status", props: { progress: 50 } }, "props"); ctx.Merge(msgId, { type: "status", props: { progress: 100, status: "completed" } }, "props"); return { success: true, msgId: msgId }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } if !result["success"].(bool) { t.Logf("Error: %v", result["error"]) } assert.Equal(t, true, result["success"], "Merge should succeed") } // TestJsValueSet test ctx.Set method func TestJsValueSet(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptStandard cxt.Locale = "en" cxt.Writer = newTestMockResponseWriter() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Send initial message const msgId = ctx.Send({ type: "result", props: { content: "Initial" } }); // Set new fields ctx.Set(msgId, { type: "result", props: { status: "success" } }, "props.status"); ctx.Set(msgId, { type: "result", props: { timestamp: Date.now() } }, "props.timestamp"); ctx.Set(msgId, { type: "result", props: { metadata: { duration: 1500 } } }, "props.metadata"); return { success: true, msgId: msgId }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } if !result["success"].(bool) { t.Logf("Error: %v", result["error"]) } assert.Equal(t, true, result["success"], "Set should succeed") } // TestJsValueBlockIDInheritance test that delta operations inherit block_id func TestJsValueBlockIDInheritance(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptStandard cxt.Locale = "en" cxt.Writer = newTestMockResponseWriter() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Send message with block_id const blockId = ctx.BlockID(); const msgId = ctx.Send("Initial message", blockId); // Delta operations should inherit block_id automatically ctx.Append(msgId, " appended"); ctx.Replace(msgId, "Replaced message"); ctx.Merge(msgId, { type: "text", props: { status: "done" } }, "props"); ctx.Set(msgId, { type: "text", props: { state: "final" } }, "props.state"); return { success: true, msgId: msgId, blockId: blockId }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } if !result["success"].(bool) { t.Logf("Error: %v", result["error"]) } assert.Equal(t, true, result["success"], "Delta operations should inherit block_id") } // TestJsValueEndBlock tests the EndBlock method on Context func TestJsValueEndBlock(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Setup mock writer mockWriter := newTestMockResponseWriter() // Use New() to properly initialize messageMetadata cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptWebCUI cxt.Locale = "en" cxt.Writer = mockWriter // Test EndBlock method res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Create a block and send messages const block_id = ctx.BlockID(); // "B1" ctx.Send("Message 1", block_id); ctx.Send("Message 2", block_id); ctx.Send("Message 3", block_id); // End the block manually ctx.EndBlock(block_id); return { success: true, block_id: block_id }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } if !result["success"].(bool) { t.Logf("Error: %v", result["error"]) } assert.Equal(t, true, result["success"], "EndBlock should work correctly") // Close SafeWriter to wait for all async writes to complete cxt.CloseSafeWriter() // Verify that block_end event was sent output := mockWriter.buffer.String() assert.Contains(t, output, "block_end", "Output should contain block_end event") } // TestJsValueSendStream tests the SendStream method on Context func TestJsValueSendStream(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Setup mock writer mockWriter := newTestMockResponseWriter() // Use New() to properly initialize messageMetadata cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptWebCUI cxt.Locale = "en" cxt.Writer = mockWriter // Test SendStream method res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Start a streaming message const msgId = ctx.SendStream({ type: "text", props: { content: "Initial content" } }); // Verify msgId is returned if (typeof msgId !== 'string' || msgId === '') { throw new Error('SendStream should return a message ID'); } return { success: true, msgId: msgId }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } if !result["success"].(bool) { t.Logf("Error: %v", result["error"]) } assert.Equal(t, true, result["success"], "SendStream should work correctly") // Close SafeWriter to wait for all async writes to complete cxt.CloseSafeWriter() // Verify message_start was sent but NOT message_end output := mockWriter.buffer.String() assert.Contains(t, output, "message_start", "Output should contain message_start event") assert.NotContains(t, output, "message_end", "Output should NOT contain message_end event (streaming)") } // TestJsValueSendStreamWithBlockID tests SendStream with block_id parameter func TestJsValueSendStreamWithBlockID(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mockWriter := newTestMockResponseWriter() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptWebCUI cxt.Locale = "en" cxt.Writer = mockWriter res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Generate block ID const blockId = ctx.BlockID(); // Start streaming with block_id const msgId = ctx.SendStream({ type: "text", props: { content: "Streaming with block" }, block_id: blockId }); return { success: true, msgId: msgId, blockId: blockId }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["success"], "SendStream with blockId should succeed") // Close SafeWriter to wait for all async writes to complete cxt.CloseSafeWriter() // Verify block_start was also sent output := mockWriter.buffer.String() assert.Contains(t, output, "block_start", "Output should contain block_start event") } // TestJsValueEnd tests the End method on Context func TestJsValueEnd(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mockWriter := newTestMockResponseWriter() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptWebCUI cxt.Locale = "en" cxt.Writer = mockWriter res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Start a streaming message const msgId = ctx.SendStream({ type: "text", props: { content: "Hello" } }); // End the message ctx.End(msgId); return { success: true, msgId: msgId }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } if !result["success"].(bool) { t.Logf("Error: %v", result["error"]) } assert.Equal(t, true, result["success"], "End should work correctly") // Close SafeWriter to wait for all async writes to complete cxt.CloseSafeWriter() // Verify message_end was sent output := mockWriter.buffer.String() assert.Contains(t, output, "message_end", "Output should contain message_end event after End()") } // TestJsValueEndWithFinalContent tests End with final content parameter func TestJsValueEndWithFinalContent(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mockWriter := newTestMockResponseWriter() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptWebCUI cxt.Locale = "en" cxt.Writer = mockWriter res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Start a streaming message const msgId = ctx.SendStream({ type: "text", props: { content: "Start" } }); // End with final content ctx.End(msgId, " End"); return { success: true, msgId: msgId }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } if !result["success"].(bool) { t.Logf("Error: %v", result["error"]) } assert.Equal(t, true, result["success"], "End with final content should work correctly") // Close SafeWriter to wait for all async writes to complete cxt.CloseSafeWriter() // Verify message_end was sent output := mockWriter.buffer.String() assert.Contains(t, output, "message_end", "Output should contain message_end event") } // TestJsValueStreamingWorkflow tests the complete streaming workflow: SendStream -> Append -> End func TestJsValueStreamingWorkflow(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mockWriter := newTestMockResponseWriter() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptWebCUI cxt.Locale = "en" cxt.Writer = mockWriter res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Start a streaming message const msgId = ctx.SendStream({ type: "text", props: { content: "# Title\n\n" } }); // Append content in chunks (simulating streaming) ctx.Append(msgId, "First paragraph. "); ctx.Append(msgId, "Second sentence. "); ctx.Append(msgId, "Third sentence.\n\n"); ctx.Append(msgId, "Second paragraph."); // Finalize the message ctx.End(msgId); return { success: true, msgId: msgId }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } if !result["success"].(bool) { t.Logf("Error: %v", result["error"]) } assert.Equal(t, true, result["success"], "Streaming workflow should work correctly") // Close SafeWriter to flush all pending async writes before reading buffer. // SafeWriter processes writes in a background goroutine via channel; // without this, the buffer may still be empty on slow CI runners. cxt.CloseSafeWriter() // Verify the complete workflow events output := mockWriter.buffer.String() assert.Contains(t, output, "message_start", "Output should contain message_start") assert.Contains(t, output, "message_end", "Output should contain message_end") assert.Contains(t, output, "# Title", "Output should contain initial content") assert.Contains(t, output, "First paragraph", "Output should contain appended content") } // TestJsValueSendStreamStringShorthand tests SendStream with string shorthand func TestJsValueSendStreamStringShorthand(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mockWriter := newTestMockResponseWriter() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptWebCUI cxt.Locale = "en" cxt.Writer = mockWriter res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // SendStream with string shorthand const msgId = ctx.SendStream("Hello streaming"); if (typeof msgId !== 'string' || msgId === '') { throw new Error('SendStream should return a message ID'); } ctx.End(msgId); return { success: true, msgId: msgId }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["success"], "SendStream with string shorthand should succeed") } // TestJsValueEndErrorHandling tests error handling in End method func TestJsValueEndErrorHandling(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mockWriter := newTestMockResponseWriter() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptWebCUI cxt.Locale = "en" cxt.Writer = mockWriter // Test End without arguments res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { ctx.End(); return { success: true }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, false, result["success"], "End without arguments should fail") assert.Contains(t, result["error"], "messageId", "Error should mention missing messageId") } // TestJsValueEndWithInvalidMessageID tests End with invalid messageId type func TestJsValueEndWithInvalidMessageID(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mockWriter := newTestMockResponseWriter() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptWebCUI cxt.Locale = "en" cxt.Writer = mockWriter // Test End with non-string messageId res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { ctx.End(123); return { success: true }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, false, result["success"], "End with non-string messageId should fail") assert.Contains(t, result["error"], "string", "Error should mention messageId must be string") } // TestJsValueSendStreamErrorHandling tests error handling in SendStream method func TestJsValueSendStreamErrorHandling(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mockWriter := newTestMockResponseWriter() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptWebCUI cxt.Locale = "en" cxt.Writer = mockWriter // Test SendStream without arguments res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { ctx.SendStream(); return { success: true }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, false, result["success"], "SendStream without arguments should fail") assert.Contains(t, result["error"], "SendStream requires a message argument", "Error should mention missing message") } // TestJsValueMultipleStreams tests handling multiple concurrent streaming messages func TestJsValueMultipleStreams(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mockWriter := newTestMockResponseWriter() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptWebCUI cxt.Locale = "en" cxt.Writer = mockWriter res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { // Start multiple streaming messages const msg1 = ctx.SendStream({ type: "text", props: { content: "Stream 1: " } }); const msg2 = ctx.SendStream({ type: "text", props: { content: "Stream 2: " } }); // Interleave appends ctx.Append(msg1, "A"); ctx.Append(msg2, "X"); ctx.Append(msg1, "B"); ctx.Append(msg2, "Y"); ctx.Append(msg1, "C"); ctx.Append(msg2, "Z"); // End both streams ctx.End(msg1); ctx.End(msg2); return { success: true, msg1: msg1, msg2: msg2 }; } catch (error) { return { success: false, error: error.message }; } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } if !result["success"].(bool) { t.Logf("Error: %v", result["error"]) } assert.Equal(t, true, result["success"], "Multiple streams should work correctly") assert.NotEqual(t, result["msg1"], result["msg2"], "Message IDs should be different") } // TestJsValueSendVsSendStream tests the difference between Send and SendStream func TestJsValueSendVsSendStream(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Test Send - should auto-send message_end t.Run("Send auto-ends", func(t *testing.T) { mockWriter := newTestMockResponseWriter() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptWebCUI cxt.Locale = "en" cxt.Writer = mockWriter _, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { ctx.Send("Complete message"); return true; }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } // Close SafeWriter to flush all pending async writes before reading buffer. // SafeWriter processes writes in a background goroutine via channel; // without this, the buffer may still be empty on slow CI runners. cxt.CloseSafeWriter() output := mockWriter.buffer.String() assert.Contains(t, output, "message_start", "Send should emit message_start") assert.Contains(t, output, "message_end", "Send should auto-emit message_end") }) // Test SendStream - should NOT auto-send message_end t.Run("SendStream requires explicit End", func(t *testing.T) { mockWriter := newTestMockResponseWriter() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Accept = context.AcceptWebCUI cxt.Locale = "en" cxt.Writer = mockWriter _, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { const msgId = ctx.SendStream("Streaming message"); // Intentionally NOT calling ctx.End(msgId) return msgId; }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } // Close SafeWriter to flush all pending async writes before reading buffer. cxt.CloseSafeWriter() output := mockWriter.buffer.String() assert.Contains(t, output, "message_start", "SendStream should emit message_start") assert.NotContains(t, output, "message_end", "SendStream should NOT auto-emit message_end") }) } ================================================ FILE: agent/context/jsapi_release_test.go ================================================ package context_test import ( stdContext "context" "testing" "github.com/stretchr/testify/assert" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // newReleaseTestContext creates a test context for release testing func newReleaseTestContext() *context.Context { ctx := context.New(stdContext.Background(), nil, "test-chat-id") ctx.AssistantID = "test-assistant-id" ctx.Referer = context.RefererAPI stack, _, _ := context.EnterStack(ctx, "test-assistant", &context.Options{}) ctx.Stack = stack return ctx } // TestContextRelease tests explicit Release() method on Context func TestContextRelease(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := newReleaseTestContext() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // Verify context has Release method if (typeof ctx.Release !== 'function') { throw new Error("ctx.Release is not a function") } // Verify context has __release method if (typeof ctx.__release !== 'function') { throw new Error("ctx.__release is not a function") } // Call Release explicitly ctx.Release() // Can call Release multiple times safely (idempotent) ctx.Release() return { has_release: true, success: true } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["has_release"], "should have Release method") assert.Equal(t, true, result["success"], "release should succeed") } // TestTraceRelease tests explicit Release() method on Trace func TestTraceRelease(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := newReleaseTestContext() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // Get trace const trace = ctx.trace // Verify trace has Release method if (typeof trace.Release !== 'function') { throw new Error("trace.Release is not a function") } // Verify trace has __release method if (typeof trace.__release !== 'function') { throw new Error("trace.__release is not a function") } // Use trace const node = trace.Add({ type: "test" }, { label: "Test Node" }) trace.Info("Test message") // Release trace explicitly trace.Release() // Can call Release multiple times safely (idempotent) trace.Release() return { has_release: true, has_node: !!node, success: true } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["has_release"], "should have Release method") assert.Equal(t, true, result["has_node"], "should create node") assert.Equal(t, true, result["success"], "release should succeed") } // TestContextReleaseWithTrace tests that releasing Context also releases Trace func TestContextReleaseWithTrace(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := newReleaseTestContext() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // Get trace const trace = ctx.trace // Use trace const node = trace.Add({ type: "test" }, { label: "Test Node" }) trace.Info("Test message") node.Complete({ result: "done" }) // Release context (should also release trace) ctx.Release() return { trace_released_via_context: true, success: true } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["trace_released_via_context"], "trace should be released via context") assert.Equal(t, true, result["success"], "release should succeed") } // TestTryFinallyPattern tests the try-finally pattern with Release() func TestTryFinallyPattern(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := newReleaseTestContext() res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { const trace = ctx.trace // Try-finally pattern for explicit resource management try { const node = trace.Add({ type: "step" }, { label: "Processing" }) // Simulate some work trace.Info("Step 1: Initialize") trace.Info("Step 2: Process") node.Complete({ result: "success" }) return { completed: true } } finally { // Explicit cleanup trace.Release() ctx.Release() } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["completed"], "should complete successfully") } // TestNoOpTraceRelease tests that no-op Trace also has Release method func TestNoOpTraceRelease(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Context without trace initialization cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // Get trace (should be no-op) const trace = ctx.trace // Verify trace has Release method even when it's no-op if (typeof trace.Release !== 'function') { throw new Error("no-op trace.Release is not a function") } // Call methods on no-op trace (should not error) trace.Info("This is a no-op") const node = trace.Add({ type: "test" }, { label: "No-op" }) node.Complete({ result: "done" }) // Release no-op trace (should not error) trace.Release() return { noop_trace_works: true, success: true } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } assert.Equal(t, true, result["noop_trace_works"], "no-op trace should work") assert.Equal(t, true, result["success"], "release should succeed") } // TestTryFinallyPatternWithError tests try-finally with error handling func TestTryFinallyPatternWithError(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := newReleaseTestContext() _, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { const trace = ctx.trace // Try-finally pattern ensures cleanup even when error occurs try { const node = trace.Add({ type: "step" }, { label: "Processing" }) trace.Info("Starting work") // Simulate an error throw new Error("Simulated error") } finally { // Cleanup happens even after error trace.Release() ctx.Release() } }`, cxt) // Error should be propagated if err == nil { t.Fatal("Expected error to be propagated") } // But cleanup should have happened (no way to verify directly, but test should not crash) assert.Contains(t, err.Error(), "Simulated error", "error should be propagated") } ================================================ FILE: agent/context/jsapi_sandbox.go ================================================ package context import ( "context" "github.com/yaoapp/gou/runtime/v8/bridge" openapiSandbox "github.com/yaoapp/yao/openapi/sandbox" infraSandbox "github.com/yaoapp/yao/sandbox" "rogchap.com/v8go" ) // SandboxExecutor defines the interface for sandbox operations // This interface is implemented by agent/sandbox.Executor // It's defined here to avoid import cycles type SandboxExecutor interface { // Filesystem operations ReadFile(ctx context.Context, path string) ([]byte, error) WriteFile(ctx context.Context, path string, content []byte) error ListDir(ctx context.Context, path string) ([]infraSandbox.FileInfo, error) // Command execution Exec(ctx context.Context, cmd []string) (string, error) // Workspace info GetWorkDir() string // Sandbox identification GetSandboxID() string // VNC access (returns empty string if not available) GetVNCUrl() string } // SetSandboxExecutor sets the sandbox executor for this context // This should be called before hooks are executed func (ctx *Context) SetSandboxExecutor(executor SandboxExecutor) { ctx.sandboxExecutor = executor } // GetSandboxExecutor returns the sandbox executor if available func (ctx *Context) GetSandboxExecutor() SandboxExecutor { return ctx.sandboxExecutor } // HasSandbox returns true if sandbox executor is available func (ctx *Context) HasSandbox() bool { return ctx.sandboxExecutor != nil } // newSandboxObject creates the ctx.sandbox JavaScript object // Returns nil if sandbox executor is not available func (ctx *Context) newSandboxObject(iso *v8go.Isolate) *v8go.ObjectTemplate { if ctx.sandboxExecutor == nil { return nil } sandboxObj := v8go.NewObjectTemplate(iso) // Set methods sandboxObj.Set("ReadFile", ctx.sandboxReadFileMethod(iso)) sandboxObj.Set("WriteFile", ctx.sandboxWriteFileMethod(iso)) sandboxObj.Set("ListDir", ctx.sandboxListDirMethod(iso)) sandboxObj.Set("Exec", ctx.sandboxExecMethod(iso)) sandboxObj.Set("GetVNCUrl", ctx.sandboxGetVNCUrlMethod(iso)) sandboxObj.Set("GetSandboxID", ctx.sandboxGetSandboxIDMethod(iso)) return sandboxObj } // createSandboxInstance creates the sandbox object instance with workdir property func (ctx *Context) createSandboxInstance(v8ctx *v8go.Context) *v8go.Value { if ctx.sandboxExecutor == nil { return nil } sandboxTemplate := ctx.newSandboxObject(v8ctx.Isolate()) if sandboxTemplate == nil { return nil } // Set workdir as a property sandboxTemplate.Set("workdir", ctx.sandboxExecutor.GetWorkDir()) // Set sandbox_id as a property sandboxID := ctx.sandboxExecutor.GetSandboxID() sandboxTemplate.Set("sandbox_id", sandboxID) // Set vnc_url as a property (empty string if not available) // GetVNCUrl returns sandbox ID if VNC is supported, empty otherwise vncSandboxID := ctx.sandboxExecutor.GetVNCUrl() if vncSandboxID != "" { sandboxTemplate.Set("vnc_url", openapiSandbox.GetVNCClientURL(vncSandboxID)) } else { sandboxTemplate.Set("vnc_url", "") } instance, err := sandboxTemplate.NewInstance(v8ctx) if err != nil { return nil } return instance.Value } // sandboxReadFileMethod implements ctx.sandbox.ReadFile(path) func (ctx *Context) sandboxReadFileMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.sandboxExecutor == nil { return bridge.JsException(v8ctx, "sandbox executor not available") } if len(args) < 1 { return bridge.JsException(v8ctx, "ReadFile requires path parameter") } path := args[0].String() content, err := ctx.sandboxExecutor.ReadFile(context.Background(), path) if err != nil { return bridge.JsException(v8ctx, err.Error()) } // Return as string jsVal, err := v8go.NewValue(iso, string(content)) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // sandboxWriteFileMethod implements ctx.sandbox.WriteFile(path, content) func (ctx *Context) sandboxWriteFileMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.sandboxExecutor == nil { return bridge.JsException(v8ctx, "sandbox executor not available") } if len(args) < 2 { return bridge.JsException(v8ctx, "WriteFile requires path and content parameters") } path := args[0].String() content := args[1].String() err := ctx.sandboxExecutor.WriteFile(context.Background(), path, []byte(content)) if err != nil { return bridge.JsException(v8ctx, err.Error()) } // Return undefined on success return v8go.Undefined(iso) }) } // sandboxListDirMethod implements ctx.sandbox.ListDir(path) func (ctx *Context) sandboxListDirMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.sandboxExecutor == nil { return bridge.JsException(v8ctx, "sandbox executor not available") } if len(args) < 1 { return bridge.JsException(v8ctx, "ListDir requires path parameter") } path := args[0].String() files, err := ctx.sandboxExecutor.ListDir(context.Background(), path) if err != nil { return bridge.JsException(v8ctx, err.Error()) } // Convert to JavaScript array of objects result := make([]map[string]interface{}, len(files)) for i, f := range files { result[i] = map[string]interface{}{ "name": f.Name, "size": f.Size, "is_dir": f.IsDir, } } jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // sandboxExecMethod implements ctx.sandbox.Exec(cmd) func (ctx *Context) sandboxExecMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.sandboxExecutor == nil { return bridge.JsException(v8ctx, "sandbox executor not available") } if len(args) < 1 { return bridge.JsException(v8ctx, "Exec requires cmd parameter (array of strings)") } // Parse command array cmdArg := args[0] if !cmdArg.IsArray() { return bridge.JsException(v8ctx, "Exec requires cmd to be an array of strings") } cmdObj, err := cmdArg.AsObject() if err != nil { return bridge.JsException(v8ctx, "failed to parse cmd array: "+err.Error()) } // Get array length lengthVal, err := cmdObj.Get("length") if err != nil { return bridge.JsException(v8ctx, "failed to get cmd array length: "+err.Error()) } length := int(lengthVal.Integer()) // Build command slice cmd := make([]string, length) for i := 0; i < length; i++ { itemVal, err := cmdObj.GetIdx(uint32(i)) if err != nil { return bridge.JsException(v8ctx, "failed to get cmd array element: "+err.Error()) } cmd[i] = itemVal.String() } output, err := ctx.sandboxExecutor.Exec(context.Background(), cmd) if err != nil { return bridge.JsException(v8ctx, err.Error()) } jsVal, err := v8go.NewValue(v8ctx.Isolate(), output) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // sandboxGetVNCUrlMethod implements ctx.sandbox.GetVNCUrl() func (ctx *Context) sandboxGetVNCUrlMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() if ctx.sandboxExecutor == nil { return bridge.JsException(v8ctx, "sandbox executor not available") } // GetVNCUrl returns sandbox ID if VNC is supported, empty otherwise vncSandboxID := ctx.sandboxExecutor.GetVNCUrl() vncUrl := "" if vncSandboxID != "" { vncUrl = openapiSandbox.GetVNCClientURL(vncSandboxID) } jsVal, err := v8go.NewValue(iso, vncUrl) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // sandboxGetSandboxIDMethod implements ctx.sandbox.GetSandboxID() func (ctx *Context) sandboxGetSandboxIDMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() if ctx.sandboxExecutor == nil { return bridge.JsException(v8ctx, "sandbox executor not available") } sandboxID := ctx.sandboxExecutor.GetSandboxID() jsVal, err := v8go.NewValue(iso, sandboxID) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } ================================================ FILE: agent/context/jsapi_sandbox_test.go ================================================ package context_test import ( stdContext "context" "os" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" infraSandbox "github.com/yaoapp/yao/sandbox" "github.com/yaoapp/yao/test" ) // createTestSandboxManager creates a real sandbox manager for testing func createTestSandboxManager(t *testing.T) *infraSandbox.Manager { // Get data root from environment or use temp directory dataRoot := os.Getenv("YAO_ROOT") if dataRoot == "" { dataRoot = t.TempDir() } // Create config with proper paths cfg := infraSandbox.DefaultConfig() cfg.Init(dataRoot) manager, err := infraSandbox.NewManager(cfg) if err != nil { t.Skipf("Skipping test: Docker not available: %v", err) return nil } return manager } // createTestContainer creates a container and returns a cleanup function func createTestContainer(t *testing.T, manager *infraSandbox.Manager, userID, chatID string) (*infraSandbox.Container, func()) { container, err := manager.GetOrCreate(stdContext.Background(), userID, chatID) require.NoError(t, err) require.NotNil(t, container) // Return cleanup function that removes the container cleanup := func() { err := manager.Remove(stdContext.Background(), container.Name) if err != nil { t.Logf("Warning: failed to cleanup container %s: %v", container.Name, err) } } return container, cleanup } // realSandboxExecutor wraps infraSandbox.Manager to implement context.SandboxExecutor type realSandboxExecutor struct { manager *infraSandbox.Manager containerName string workDir string } func (e *realSandboxExecutor) ReadFile(ctx stdContext.Context, path string) ([]byte, error) { fullPath := e.workDir + "/" + path return e.manager.ReadFile(ctx, e.containerName, fullPath) } func (e *realSandboxExecutor) WriteFile(ctx stdContext.Context, path string, content []byte) error { fullPath := e.workDir + "/" + path return e.manager.WriteFile(ctx, e.containerName, fullPath, content) } func (e *realSandboxExecutor) ListDir(ctx stdContext.Context, path string) ([]infraSandbox.FileInfo, error) { fullPath := e.workDir + "/" + path return e.manager.ListDir(ctx, e.containerName, fullPath) } func (e *realSandboxExecutor) Exec(ctx stdContext.Context, cmd []string) (string, error) { result, err := e.manager.Exec(ctx, e.containerName, cmd, &infraSandbox.ExecOptions{ WorkDir: e.workDir, }) if err != nil { return "", err } return result.Stdout, nil } func (e *realSandboxExecutor) GetWorkDir() string { return e.workDir } func (e *realSandboxExecutor) GetSandboxID() string { // Extract sandbox ID from container name (format: yao-sandbox-{userID}-{chatID}) // For tests, just return a mock ID return "test-user-test-chat" } func (e *realSandboxExecutor) GetVNCUrl() string { // Tests don't use VNC, return empty return "" } // TestJsSandboxNotAvailable tests ctx.sandbox when not configured func TestJsSandboxNotAvailable(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := context.New(stdContext.Background(), nil, "test-chat-no-sandbox") ctx.AssistantID = "test-assistant" // Test that ctx.sandbox is undefined when not configured res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { if (ctx.sandbox === undefined || ctx.sandbox === null) { return { success: true, hasSandbox: false }; } return { success: true, hasSandbox: true }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Expected map result") assert.Equal(t, true, result["success"]) assert.Equal(t, false, result["hasSandbox"], "ctx.sandbox should not be available when not configured") } // TestJsSandboxWriteFile tests ctx.sandbox.WriteFile via JavaScript func TestJsSandboxWriteFile(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestSandboxManager(t) if manager == nil { return } defer manager.Close() // Create container with auto-cleanup container, cleanup := createTestContainer(t, manager, "test-user", "test-js-writefile") defer cleanup() executor := &realSandboxExecutor{ manager: manager, containerName: container.Name, workDir: "/workspace", } // Create context with sandbox ctx := context.New(stdContext.Background(), nil, "test-chat-writefile") ctx.AssistantID = "test-assistant" ctx.SetSandboxExecutor(executor) // Test WriteFile via JavaScript res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { if (!ctx.sandbox) { return { success: false, error: "sandbox not available" }; } // Write a file ctx.sandbox.WriteFile("js-test.txt", "Hello from JavaScript!"); return { success: true }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Expected map result") assert.Equal(t, true, result["success"], "WriteFile should succeed: %v", result["error"]) // Verify file was written by reading it back directly content, err := executor.ReadFile(stdContext.Background(), "js-test.txt") require.NoError(t, err) assert.Equal(t, "Hello from JavaScript!", string(content)) } // TestJsSandboxReadFile tests ctx.sandbox.ReadFile via JavaScript func TestJsSandboxReadFile(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestSandboxManager(t) if manager == nil { return } defer manager.Close() // Create container with auto-cleanup container, cleanup := createTestContainer(t, manager, "test-user", "test-js-readfile") defer cleanup() executor := &realSandboxExecutor{ manager: manager, containerName: container.Name, workDir: "/workspace", } // Write a file first err := executor.WriteFile(stdContext.Background(), "read-test.txt", []byte("Content to read")) require.NoError(t, err) // Create context with sandbox ctx := context.New(stdContext.Background(), nil, "test-chat-readfile") ctx.AssistantID = "test-assistant" ctx.SetSandboxExecutor(executor) // Test ReadFile via JavaScript res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { if (!ctx.sandbox) { return { success: false, error: "sandbox not available" }; } // Read the file const content = ctx.sandbox.ReadFile("read-test.txt"); return { success: true, content: content }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Expected map result") assert.Equal(t, true, result["success"], "ReadFile should succeed: %v", result["error"]) assert.Equal(t, "Content to read", result["content"]) } // TestJsSandboxListDir tests ctx.sandbox.ListDir via JavaScript func TestJsSandboxListDir(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestSandboxManager(t) if manager == nil { return } defer manager.Close() // Create container with auto-cleanup container, cleanup := createTestContainer(t, manager, "test-user", "test-js-listdir") defer cleanup() executor := &realSandboxExecutor{ manager: manager, containerName: container.Name, workDir: "/workspace", } // Write some files first err := executor.WriteFile(stdContext.Background(), "file1.txt", []byte("content1")) require.NoError(t, err) err = executor.WriteFile(stdContext.Background(), "file2.txt", []byte("content2")) require.NoError(t, err) // Create context with sandbox ctx := context.New(stdContext.Background(), nil, "test-chat-listdir") ctx.AssistantID = "test-assistant" ctx.SetSandboxExecutor(executor) // Test ListDir via JavaScript res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { if (!ctx.sandbox) { return { success: false, error: "sandbox not available" }; } // List directory const files = ctx.sandbox.ListDir("."); // Find our test files const fileNames = files.map(f => f.name); const hasFile1 = fileNames.includes("file1.txt"); const hasFile2 = fileNames.includes("file2.txt"); return { success: true, fileCount: files.length, hasFile1: hasFile1, hasFile2: hasFile2, files: fileNames }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Expected map result") assert.Equal(t, true, result["success"], "ListDir should succeed: %v", result["error"]) assert.Equal(t, true, result["hasFile1"], "Should find file1.txt") assert.Equal(t, true, result["hasFile2"], "Should find file2.txt") } // TestJsSandboxExec tests ctx.sandbox.Exec via JavaScript func TestJsSandboxExec(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestSandboxManager(t) if manager == nil { return } defer manager.Close() // Create container with auto-cleanup container, cleanup := createTestContainer(t, manager, "test-user", "test-js-exec") defer cleanup() executor := &realSandboxExecutor{ manager: manager, containerName: container.Name, workDir: "/workspace", } // Create context with sandbox ctx := context.New(stdContext.Background(), nil, "test-chat-exec") ctx.AssistantID = "test-assistant" ctx.SetSandboxExecutor(executor) // Test Exec via JavaScript res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { if (!ctx.sandbox) { return { success: false, error: "sandbox not available" }; } // Execute echo command const output = ctx.sandbox.Exec(["echo", "hello-from-js"]); return { success: true, output: output.trim() }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Expected map result") assert.Equal(t, true, result["success"], "Exec should succeed: %v", result["error"]) // Output may contain Docker stream header bytes, so use Contains output, _ := result["output"].(string) assert.Contains(t, output, "hello-from-js", "Exec output should contain expected text") } // TestJsSandboxWorkdir tests ctx.sandbox.workdir property via JavaScript func TestJsSandboxWorkdir(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestSandboxManager(t) if manager == nil { return } defer manager.Close() // Create container with auto-cleanup container, cleanup := createTestContainer(t, manager, "test-user", "test-js-workdir") defer cleanup() executor := &realSandboxExecutor{ manager: manager, containerName: container.Name, workDir: "/workspace", } // Create context with sandbox ctx := context.New(stdContext.Background(), nil, "test-chat-workdir") ctx.AssistantID = "test-assistant" ctx.SetSandboxExecutor(executor) // Test workdir property via JavaScript res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { if (!ctx.sandbox) { return { success: false, error: "sandbox not available" }; } // Get workdir property const workdir = ctx.sandbox.workdir; return { success: true, workdir: workdir }; } catch (error) { return { success: false, error: error.message }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Expected map result") assert.Equal(t, true, result["success"], "workdir access should succeed: %v", result["error"]) assert.Equal(t, "/workspace", result["workdir"]) } // TestJsSandboxCompleteWorkflow tests a complete workflow via JavaScript func TestJsSandboxCompleteWorkflow(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestSandboxManager(t) if manager == nil { return } defer manager.Close() // Create container with auto-cleanup container, cleanup := createTestContainer(t, manager, "test-user", "test-js-workflow") defer cleanup() executor := &realSandboxExecutor{ manager: manager, containerName: container.Name, workDir: "/workspace", } // Create context with sandbox ctx := context.New(stdContext.Background(), nil, "test-chat-workflow") ctx.AssistantID = "test-assistant" ctx.SetSandboxExecutor(executor) // Test complete workflow: write file, exec cat, verify content res, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { if (!ctx.sandbox) { return { success: false, error: "sandbox not available" }; } // 1. Check workdir const workdir = ctx.sandbox.workdir; if (workdir !== "/workspace") { return { success: false, error: "unexpected workdir: " + workdir }; } // 2. Write a file const testContent = "Test workflow content: " + Date.now(); ctx.sandbox.WriteFile("workflow-test.txt", testContent); // 3. Read it back const readContent = ctx.sandbox.ReadFile("workflow-test.txt"); if (readContent !== testContent) { return { success: false, error: "content mismatch after read" }; } // 4. List directory and verify file exists const files = ctx.sandbox.ListDir("."); const fileNames = files.map(f => f.name); if (!fileNames.includes("workflow-test.txt")) { return { success: false, error: "file not found in listing" }; } // 5. Execute cat command const catOutput = ctx.sandbox.Exec(["cat", workdir + "/workflow-test.txt"]); if (!catOutput.includes("Test workflow content")) { return { success: false, error: "cat output mismatch" }; } // 6. Execute pwd command const pwdOutput = ctx.sandbox.Exec(["pwd"]); if (!pwdOutput.includes("/workspace")) { return { success: false, error: "pwd output mismatch: " + pwdOutput }; } return { success: true, workdir: workdir, content: readContent, fileCount: files.length }; } catch (error) { return { success: false, error: error.message, stack: error.stack }; } }`, ctx) require.NoError(t, err) result, ok := res.(map[string]interface{}) require.True(t, ok, "Expected map result") assert.Equal(t, true, result["success"], "Complete workflow should succeed: %v", result["error"]) assert.Equal(t, "/workspace", result["workdir"]) } ================================================ FILE: agent/context/jsapi_search.go ================================================ package context import ( "github.com/yaoapp/gou/runtime/v8/bridge" "rogchap.com/v8go" ) // SearchAPI defines the search JSAPI interface for ctx.search.* // This interface is defined here to avoid circular dependency between context and search packages. // The actual implementation is in agent/search/jsapi.go type SearchAPI interface { // Web executes web search // Returns *types.Result or error information Web(query string, opts map[string]interface{}) interface{} // KB executes knowledge base search // Returns *types.Result or error information KB(query string, opts map[string]interface{}) interface{} // DB executes database search // Returns *types.Result or error information DB(query string, opts map[string]interface{}) interface{} // Parallel search methods - inspired by JavaScript Promise // All waits for all searches to complete (like Promise.all) All(requests []interface{}) []interface{} // Any returns when any search succeeds with results (like Promise.any) Any(requests []interface{}) []interface{} // Race returns when any search completes (like Promise.race) Race(requests []interface{}) []interface{} } // SearchAPIFactory is a function type that creates a SearchAPI for a context // This is set by the search package during initialization var SearchAPIFactory func(ctx *Context) SearchAPI // Search returns the search API for this context // Returns nil if SearchAPIFactory is not set func (ctx *Context) Search() SearchAPI { if SearchAPIFactory == nil { return nil } return SearchAPIFactory(ctx) } // newSearchObject creates a new search object with all search methods // This is called from jsapi.go NewObject() to mount ctx.search func (ctx *Context) newSearchObject(iso *v8go.Isolate) *v8go.ObjectTemplate { searchObj := v8go.NewObjectTemplate(iso) // Single search methods searchObj.Set("Web", ctx.searchWebMethod(iso)) searchObj.Set("KB", ctx.searchKBMethod(iso)) searchObj.Set("DB", ctx.searchDBMethod(iso)) // Parallel search methods - inspired by JavaScript Promise searchObj.Set("All", ctx.searchAllMethod(iso)) searchObj.Set("Any", ctx.searchAnyMethod(iso)) searchObj.Set("Race", ctx.searchRaceMethod(iso)) return searchObj } // searchWebMethod implements ctx.search.Web(query, options?) // Options: // - limit: number - max results (default: 10) // - sites: string[] - restrict to specific sites // - time_range: string - "day", "week", "month", "year" // - rerank: { top_n: number } - rerank options func (ctx *Context) searchWebMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 1 { return bridge.JsException(v8ctx, "Web requires query parameter") } // Get query string if !args[0].IsString() { return bridge.JsException(v8ctx, "query must be a string") } query := args[0].String() // Parse options (optional) var opts map[string]interface{} if len(args) >= 2 && !args[1].IsUndefined() && !args[1].IsNull() { goVal, err := bridge.GoValue(args[1], v8ctx) if err != nil { return bridge.JsException(v8ctx, "invalid options: "+err.Error()) } if optsMap, ok := goVal.(map[string]interface{}); ok { opts = optsMap } } // Get search API searchAPI := ctx.Search() if searchAPI == nil { return bridge.JsException(v8ctx, "search API not available") } // Execute search result := searchAPI.Web(query, opts) // Convert result to JS value jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, "failed to convert result: "+err.Error()) } return jsVal }) } // searchKBMethod implements ctx.search.KB(query, options?) // Options: // - collections: string[] - collection IDs // - threshold: number - similarity threshold (0-1) // - limit: number - max results // - graph: boolean - enable graph association // - rerank: { top_n: number } - rerank options func (ctx *Context) searchKBMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 1 { return bridge.JsException(v8ctx, "KB requires query parameter") } // Get query string if !args[0].IsString() { return bridge.JsException(v8ctx, "query must be a string") } query := args[0].String() // Parse options (optional) var opts map[string]interface{} if len(args) >= 2 && !args[1].IsUndefined() && !args[1].IsNull() { goVal, err := bridge.GoValue(args[1], v8ctx) if err != nil { return bridge.JsException(v8ctx, "invalid options: "+err.Error()) } if optsMap, ok := goVal.(map[string]interface{}); ok { opts = optsMap } } // Get search API searchAPI := ctx.Search() if searchAPI == nil { return bridge.JsException(v8ctx, "search API not available") } // Execute search result := searchAPI.KB(query, opts) // Convert result to JS value jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, "failed to convert result: "+err.Error()) } return jsVal }) } // searchDBMethod implements ctx.search.DB(query, options?) // Options: // - models: string[] - model IDs // - wheres: Where[] - pre-defined filters (GOU QueryDSL Where format) // - orders: Order[] - sort orders (GOU QueryDSL Order format) // - select: string[] - fields to return // - limit: number - max results // - rerank: { top_n: number } - rerank options func (ctx *Context) searchDBMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 1 { return bridge.JsException(v8ctx, "DB requires query parameter") } // Get query string if !args[0].IsString() { return bridge.JsException(v8ctx, "query must be a string") } query := args[0].String() // Parse options (optional) var opts map[string]interface{} if len(args) >= 2 && !args[1].IsUndefined() && !args[1].IsNull() { goVal, err := bridge.GoValue(args[1], v8ctx) if err != nil { return bridge.JsException(v8ctx, "invalid options: "+err.Error()) } if optsMap, ok := goVal.(map[string]interface{}); ok { opts = optsMap } } // Get search API searchAPI := ctx.Search() if searchAPI == nil { return bridge.JsException(v8ctx, "search API not available") } // Execute search result := searchAPI.DB(query, opts) // Convert result to JS value jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, "failed to convert result: "+err.Error()) } return jsVal }) } // searchAllMethod implements ctx.search.All(requests) // Waits for all searches to complete (like Promise.all) // Each request should have: // - type: string - "web", "kb", or "db" // - query: string - search query // - ... other type-specific options func (ctx *Context) searchAllMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 1 { return bridge.JsException(v8ctx, "All requires requests parameter") } // Parse requests array goVal, err := bridge.GoValue(args[0], v8ctx) if err != nil { return bridge.JsException(v8ctx, "invalid requests: "+err.Error()) } requestsArray, ok := goVal.([]interface{}) if !ok { return bridge.JsException(v8ctx, "requests must be an array") } // Get search API searchAPI := ctx.Search() if searchAPI == nil { return bridge.JsException(v8ctx, "search API not available") } // Execute parallel search results := searchAPI.All(requestsArray) // Convert results to JS value jsVal, err := bridge.JsValue(v8ctx, results) if err != nil { return bridge.JsException(v8ctx, "failed to convert results: "+err.Error()) } return jsVal }) } // searchAnyMethod implements ctx.search.Any(requests) // Returns when any search succeeds with results (like Promise.any) // Each request should have: // - type: string - "web", "kb", or "db" // - query: string - search query // - ... other type-specific options func (ctx *Context) searchAnyMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 1 { return bridge.JsException(v8ctx, "Any requires requests parameter") } // Parse requests array goVal, err := bridge.GoValue(args[0], v8ctx) if err != nil { return bridge.JsException(v8ctx, "invalid requests: "+err.Error()) } requestsArray, ok := goVal.([]interface{}) if !ok { return bridge.JsException(v8ctx, "requests must be an array") } // Get search API searchAPI := ctx.Search() if searchAPI == nil { return bridge.JsException(v8ctx, "search API not available") } // Execute parallel search results := searchAPI.Any(requestsArray) // Convert results to JS value jsVal, err := bridge.JsValue(v8ctx, results) if err != nil { return bridge.JsException(v8ctx, "failed to convert results: "+err.Error()) } return jsVal }) } // searchRaceMethod implements ctx.search.Race(requests) // Returns when any search completes (like Promise.race) // Each request should have: // - type: string - "web", "kb", or "db" // - query: string - search query // - ... other type-specific options func (ctx *Context) searchRaceMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() // Validate arguments if len(args) < 1 { return bridge.JsException(v8ctx, "Race requires requests parameter") } // Parse requests array goVal, err := bridge.GoValue(args[0], v8ctx) if err != nil { return bridge.JsException(v8ctx, "invalid requests: "+err.Error()) } requestsArray, ok := goVal.([]interface{}) if !ok { return bridge.JsException(v8ctx, "requests must be an array") } // Get search API searchAPI := ctx.Search() if searchAPI == nil { return bridge.JsException(v8ctx, "search API not available") } // Execute parallel search results := searchAPI.Race(requestsArray) // Convert results to JS value jsVal, err := bridge.JsValue(v8ctx, results) if err != nil { return bridge.JsException(v8ctx, "failed to convert results: "+err.Error()) } return jsVal }) } ================================================ FILE: agent/context/jsapi_search_test.go ================================================ package context_test import ( stdContext "context" "encoding/json" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/testutils" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" ) // Note: SearchAPIFactory is set by assistant.init() with proper config getter // We import assistant package to ensure init() runs before tests // newSearchTestContext creates a Context for search JSAPI testing func newSearchTestContext(chatID, assistantID string) *context.Context { authorized := &oauthTypes.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client-id", Scope: "openid profile email", SessionID: "test-session-id", UserID: "test-user-123", } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.AssistantID = assistantID ctx.Locale = "en-us" ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptWebCUI ctx.Metadata = make(map[string]interface{}) return ctx } // getResponseContent extracts the content from the first assistant message func getResponseContent(res *context.HookCreateResponse) string { if res == nil || len(res.Messages) == 0 { return "" } for _, msg := range res.Messages { if msg.Role == "assistant" { if content, ok := msg.Content.(string); ok { return content } } } return "" } // TestSearchJSAPI_Web tests ctx.search.Web() via Create Hook // Skip: requires external API key (Tavily/Serper) func TestSearchJSAPI_Web(t *testing.T) { t.Skip("Skipping: requires external API key (Tavily/Serper)") testutils.Prepare(t) defer testutils.Clean(t) // Load the search-jsapi test assistant agent, err := assistant.Get("tests.search-jsapi") require.NoError(t, err, "Failed to get tests.search-jsapi assistant") require.NotNil(t, agent.HookScript, "The tests.search-jsapi assistant has no script") ctx := newSearchTestContext("chat-search-web", "tests.search-jsapi") // Call Create hook with test:web command res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "test:web Yao App Engine"}}) require.NoError(t, err, "Create hook failed") require.NotNil(t, res, "Expected non-nil response") // Get response content from messages content := getResponseContent(res) require.NotEmpty(t, content, "Expected response content") // Parse the JSON response var result types.Result err = json.Unmarshal([]byte(content), &result) require.NoError(t, err, "Response should be valid JSON: %s", content) // Verify result assert.Equal(t, types.SearchTypeWeb, result.Type, "type should be web") assert.Equal(t, "Yao App Engine", result.Query, "query should match") assert.Empty(t, result.Error, "should not have error: %s", result.Error) assert.Greater(t, len(result.Items), 0, "should have items") t.Logf("Web search returned %d items", len(result.Items)) for i, item := range result.Items { if i < 3 { t.Logf(" [%s] %s - %s", item.CitationID, item.Title, item.URL) } } } // TestSearchJSAPI_WebWithSites tests ctx.search.Web() with site restriction // Skip: requires external API key (Tavily/Serper) func TestSearchJSAPI_WebWithSites(t *testing.T) { t.Skip("Skipping: requires external API key (Tavily/Serper)") testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.search-jsapi") require.NoError(t, err) require.NotNil(t, agent.HookScript) ctx := newSearchTestContext("chat-search-web-sites", "tests.search-jsapi") res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "test:web_sites Yao App Engine"}}) require.NoError(t, err) require.NotNil(t, res) content := getResponseContent(res) require.NotEmpty(t, content, "Expected response content") var result types.Result err = json.Unmarshal([]byte(content), &result) require.NoError(t, err, "Response should be valid JSON: %s", content) assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Empty(t, result.Error, "should not have error: %s", result.Error) assert.Greater(t, len(result.Items), 0, "should have items") // Verify all results are from allowed sites allowedSites := []string{"github.com", "yaoapps.com"} for _, item := range result.Items { isAllowed := false for _, site := range allowedSites { if strings.Contains(item.URL, site) { isAllowed = true break } } assert.True(t, isAllowed, "URL %s should be from allowed sites", item.URL) } t.Logf("Site-restricted search returned %d items", len(result.Items)) } // TestSearchJSAPI_KB tests ctx.search.KB() via Create Hook (skeleton) func TestSearchJSAPI_KB(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.search-jsapi") require.NoError(t, err) require.NotNil(t, agent.HookScript) ctx := newSearchTestContext("chat-search-kb", "tests.search-jsapi") res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "test:kb test query"}}) require.NoError(t, err) require.NotNil(t, res) content := getResponseContent(res) require.NotEmpty(t, content, "Expected response content") var result types.Result err = json.Unmarshal([]byte(content), &result) require.NoError(t, err, "Response should be valid JSON: %s", content) assert.Equal(t, types.SearchTypeKB, result.Type, "type should be kb") assert.Equal(t, "test query", result.Query, "query should match") assert.Equal(t, types.SourceHook, result.Source, "source should be hook") } // TestSearchJSAPI_DB tests ctx.search.DB() via Create Hook (skeleton) func TestSearchJSAPI_DB(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.search-jsapi") require.NoError(t, err) require.NotNil(t, agent.HookScript) ctx := newSearchTestContext("chat-search-db", "tests.search-jsapi") res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "test:db test query"}}) require.NoError(t, err) require.NotNil(t, res) content := getResponseContent(res) require.NotEmpty(t, content, "Expected response content") var result types.Result err = json.Unmarshal([]byte(content), &result) require.NoError(t, err, "Response should be valid JSON: %s", content) assert.Equal(t, types.SearchTypeDB, result.Type, "type should be db") assert.Equal(t, "test query", result.Query, "query should match") assert.Equal(t, types.SourceHook, result.Source, "source should be hook") } // TestSearchJSAPI_All tests ctx.search.All() via Create Hook // Skip: requires external API key (Tavily/Serper) func TestSearchJSAPI_All(t *testing.T) { t.Skip("Skipping: requires external API key (Tavily/Serper)") testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.search-jsapi") require.NoError(t, err) require.NotNil(t, agent.HookScript) ctx := newSearchTestContext("chat-search-all", "tests.search-jsapi") res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "test:all"}}) require.NoError(t, err) require.NotNil(t, res) content := getResponseContent(res) require.NotEmpty(t, content, "Expected response content") // Parse as array of results var results []*types.Result err = json.Unmarshal([]byte(content), &results) require.NoError(t, err, "Response should be valid JSON array: %s", content) assert.Len(t, results, 2, "should have 2 results") // Both should succeed successCount := 0 totalItems := 0 for _, r := range results { if r != nil && r.Error == "" { successCount++ totalItems += len(r.Items) } } assert.Equal(t, 2, successCount, "both searches should succeed") assert.Greater(t, totalItems, 0, "should have items") t.Logf("All search: %d results, %d total items", len(results), totalItems) } // TestSearchJSAPI_Any tests ctx.search.Any() via Create Hook // Skip: requires external API key (Tavily/Serper) func TestSearchJSAPI_Any(t *testing.T) { t.Skip("Skipping: requires external API key (Tavily/Serper)") testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.search-jsapi") require.NoError(t, err) require.NotNil(t, agent.HookScript) ctx := newSearchTestContext("chat-search-any", "tests.search-jsapi") res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "test:any"}}) require.NoError(t, err) require.NotNil(t, res) content := getResponseContent(res) require.NotEmpty(t, content, "Expected response content") var results []*types.Result err = json.Unmarshal([]byte(content), &results) require.NoError(t, err, "Response should be valid JSON array: %s", content) assert.Len(t, results, 2, "should have 2 result slots") // At least one should have results hasSuccess := false for _, r := range results { if r != nil && len(r.Items) > 0 && r.Error == "" { hasSuccess = true break } } assert.True(t, hasSuccess, "at least one search should succeed") t.Logf("Any search completed") } // TestSearchJSAPI_Race tests ctx.search.Race() via Create Hook // Skip: requires external API key (Tavily/Serper) func TestSearchJSAPI_Race(t *testing.T) { t.Skip("Skipping: requires external API key (Tavily/Serper)") testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.search-jsapi") require.NoError(t, err) require.NotNil(t, agent.HookScript) ctx := newSearchTestContext("chat-search-race", "tests.search-jsapi") res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "test:race"}}) require.NoError(t, err) require.NotNil(t, res) content := getResponseContent(res) require.NotEmpty(t, content, "Expected response content") var results []*types.Result err = json.Unmarshal([]byte(content), &results) require.NoError(t, err, "Response should be valid JSON array: %s", content) assert.Len(t, results, 2, "should have 2 result slots") // At least one should have completed hasResult := false for _, r := range results { if r != nil { hasResult = true break } } assert.True(t, hasResult, "at least one search should complete") t.Logf("Race search completed") } // TestSearchJSAPI_InvalidCommand tests invalid test command func TestSearchJSAPI_InvalidCommand(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.search-jsapi") require.NoError(t, err) require.NotNil(t, agent.HookScript) ctx := newSearchTestContext("chat-search-invalid", "tests.search-jsapi") res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "invalid command"}}) require.NoError(t, err) require.NotNil(t, res) content := getResponseContent(res) assert.Contains(t, content, "Invalid test command", "should return error message") } // TestSearchJSAPI_UnknownMethod tests unknown test method func TestSearchJSAPI_UnknownMethod(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) agent, err := assistant.Get("tests.search-jsapi") require.NoError(t, err) require.NotNil(t, agent.HookScript) ctx := newSearchTestContext("chat-search-unknown", "tests.search-jsapi") res, _, err := agent.HookScript.Create(ctx, []context.Message{{Role: "user", Content: "test:unknown"}}) require.NoError(t, err) require.NotNil(t, res) content := getResponseContent(res) assert.Contains(t, content, "Unknown test method", "should return error message") } ================================================ FILE: agent/context/jsapi_stress_test.go ================================================ package context_test import ( stdContext "context" "fmt" "os" "runtime" "sync" "testing" "time" "github.com/stretchr/testify/assert" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // newStressTestContext creates a test context for stress testing func newStressTestContext(chatID string) *context.Context { ctx := context.New(stdContext.Background(), nil, chatID) ctx.AssistantID = "test-assistant" ctx.Referer = context.RefererAPI stack, _, _ := context.EnterStack(ctx, "test-assistant", &context.Options{}) ctx.Stack = stack return ctx } // TestStressContextCreationAndRelease tests massive context creation and cleanup func TestStressContextCreationAndRelease(t *testing.T) { if testing.Short() { t.Skip("Skipping stress test in short mode") } test.Prepare(t, config.Conf) defer test.Clean() iterations := 1000 startMemory := getMemStats() for i := 0; i < iterations; i++ { cxt := newStressTestContext(fmt.Sprintf("chat-%d", i)) _, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // Use trace ctx.trace.Add({ type: "test" }, { label: "Test" }) ctx.trace.Info("Processing") // Explicit release ctx.Release() return { iteration: true } }`, cxt) if err != nil { t.Fatalf("Iteration %d failed: %v", i, err) } // Force GC every 100 iterations to check for leaks if i%100 == 0 { runtime.GC() currentMemory := getMemStats() t.Logf("Iteration %d: Memory usage: %d MB", i, currentMemory/1024/1024) } } // Final GC and memory check runtime.GC() time.Sleep(100 * time.Millisecond) endMemory := getMemStats() t.Logf("Start memory: %d MB", startMemory/1024/1024) t.Logf("End memory: %d MB", endMemory/1024/1024) // Calculate memory growth (handle case where end < start) var memoryGrowth int64 if endMemory > startMemory { memoryGrowth = int64(endMemory - startMemory) t.Logf("Memory growth: %d MB", memoryGrowth/1024/1024) } else { memoryGrowth = -int64(startMemory - endMemory) t.Logf("Memory decreased: %d MB", -memoryGrowth/1024/1024) } // Allow reasonable memory growth (not more than 50MB for 1000 iterations) // Memory can decrease due to GC, which is fine if memoryGrowth > 0 { assert.Less(t, memoryGrowth, int64(50*1024*1024), "Memory leak detected") } } // TestStressTraceOperations tests intensive trace operations func TestStressTraceOperations(t *testing.T) { if testing.Short() { t.Skip("Skipping stress test in short mode") } test.Prepare(t, config.Conf) defer test.Clean() iterations := 500 nodesPerIteration := 10 startMemory := getMemStats() for i := 0; i < iterations; i++ { // Create new context for each iteration to avoid context cancellation issues cxt := newStressTestContext(fmt.Sprintf("stress-test-chat-%d", i)) _, err := v8.Call(v8.CallOptions{}, fmt.Sprintf(` function test(ctx) { const trace = ctx.trace const nodes = [] // Create multiple nodes for (let j = 0; j < %d; j++) { const node = trace.Add( { type: "step", data: "data-" + j }, { label: "Step " + j } ) nodes.push(node) // Add logs node.Info("Processing step " + j) node.Debug("Debug info " + j) } // Complete all nodes for (const node of nodes) { node.Complete({ result: "success" }) } // Release resources ctx.Release() return { nodes: nodes.length } }`, nodesPerIteration), cxt) if err != nil { t.Fatalf("Iteration %d failed: %v", i, err) } if i%50 == 0 { runtime.GC() currentMemory := getMemStats() t.Logf("Iteration %d: Created %d nodes, Memory: %d MB", i, i*nodesPerIteration, currentMemory/1024/1024) } } runtime.GC() time.Sleep(100 * time.Millisecond) endMemory := getMemStats() t.Logf("Total nodes created: %d", iterations*nodesPerIteration) t.Logf("Start memory: %d MB", startMemory/1024/1024) t.Logf("End memory: %d MB", endMemory/1024/1024) if endMemory > startMemory { t.Logf("Memory growth: %d MB", (endMemory-startMemory)/1024/1024) } else { t.Logf("Memory decreased: %d MB", (startMemory-endMemory)/1024/1024) } } // TestStressMCPOperations tests intensive MCP operations func TestStressMCPOperations(t *testing.T) { if testing.Short() { t.Skip("Skipping stress test in short mode") } test.Prepare(t, config.Conf) defer test.Clean() iterations := 500 cxt := newStressTestContext("mcp-stress-test") startMemory := getMemStats() for i := 0; i < iterations; i++ { _, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // List operations const tools = ctx.mcp.ListTools("echo", "") const resources = ctx.mcp.ListResources("echo", "") const prompts = ctx.mcp.ListPrompts("echo", "") // Call operations const result1 = ctx.mcp.CallTool("echo", "ping", { count: 1 }) const result2 = ctx.mcp.CallTool("echo", "status", { verbose: false }) // Read operations const info = ctx.mcp.ReadResource("echo", "echo://info") return { tools: tools.tools.length, resources: resources.resources.length, prompts: prompts.prompts.length } }`, cxt) if err != nil { t.Fatalf("Iteration %d failed: %v", i, err) } if i%50 == 0 { runtime.GC() currentMemory := getMemStats() t.Logf("Iteration %d: Memory: %d MB", i, currentMemory/1024/1024) } } runtime.GC() time.Sleep(100 * time.Millisecond) endMemory := getMemStats() t.Logf("MCP operations: %d iterations", iterations) t.Logf("Start memory: %d MB", startMemory/1024/1024) t.Logf("End memory: %d MB", endMemory/1024/1024) if endMemory > startMemory { t.Logf("Memory growth: %d MB", (endMemory-startMemory)/1024/1024) } else { t.Logf("Memory decreased: %d MB", (startMemory-endMemory)/1024/1024) } } // TestStressConcurrentContexts tests concurrent context creation and usage func TestStressConcurrentContexts(t *testing.T) { if testing.Short() { t.Skip("Skipping stress test in short mode") } test.Prepare(t, config.Conf) defer test.Clean() goroutines := 50 iterationsPerGoroutine := 20 startMemory := getMemStats() var wg sync.WaitGroup errors := make(chan error, goroutines*iterationsPerGoroutine) for g := 0; g < goroutines; g++ { wg.Add(1) go func(goroutineID int) { defer wg.Done() for i := 0; i < iterationsPerGoroutine; i++ { cxt := newStressTestContext(fmt.Sprintf("chat-%d-%d", goroutineID, i)) _, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { // Use trace const node = ctx.trace.Add({ type: "test" }, { label: "Concurrent Test" }) ctx.trace.Info("Processing concurrent request") node.Complete({ result: "success" }) // Use MCP const tools = ctx.mcp.ListTools("echo", "") // Release resources ctx.Release() return { success: true } }`, cxt) if err != nil { errors <- fmt.Errorf("goroutine %d iteration %d: %v", goroutineID, i, err) return } } }(g) } wg.Wait() close(errors) // Check for errors errorCount := 0 for err := range errors { t.Error(err) errorCount++ } assert.Equal(t, 0, errorCount, "No errors should occur in concurrent operations") runtime.GC() time.Sleep(100 * time.Millisecond) endMemory := getMemStats() totalOperations := goroutines * iterationsPerGoroutine t.Logf("Total operations: %d (goroutines: %d, iterations: %d)", totalOperations, goroutines, iterationsPerGoroutine) t.Logf("Start memory: %d MB", startMemory/1024/1024) t.Logf("End memory: %d MB", endMemory/1024/1024) if endMemory > startMemory { t.Logf("Memory growth: %d MB", (endMemory-startMemory)/1024/1024) } else { t.Logf("Memory decreased: %d MB", (startMemory-endMemory)/1024/1024) } } // TestStressNoOpTracePerformance tests no-op trace performance func TestStressNoOpTracePerformance(t *testing.T) { if testing.Short() { t.Skip("Skipping stress test in short mode") } test.Prepare(t, config.Conf) defer test.Clean() iterations := 1000 // Context without trace initialization (no-op trace) cxt := context.New(stdContext.Background(), nil, "noop-stress-test") cxt.AssistantID = "test-assistant" startMemory := getMemStats() startTime := time.Now() for i := 0; i < iterations; i++ { _, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { const trace = ctx.trace // no-op trace // All operations should be no-ops and fast trace.Info("No-op info") const node = trace.Add({ type: "test" }, { label: "No-op" }) node.Info("No-op node info") node.Complete({ result: "done" }) trace.Release() return { noop: true } }`, cxt) if err != nil { t.Fatalf("Iteration %d failed: %v", i, err) } } duration := time.Since(startTime) runtime.GC() endMemory := getMemStats() avgTimePerOp := duration / time.Duration(iterations) t.Logf("No-op trace operations: %d iterations", iterations) t.Logf("Total time: %v", duration) t.Logf("Average time per operation: %v", avgTimePerOp) t.Logf("Start memory: %d MB", startMemory/1024/1024) t.Logf("End memory: %d MB", endMemory/1024/1024) // No-op operations should be reasonably fast // Note: CI environments may be slower due to resource limits // Local: ~2ms, CI: ~10ms maxTimePerOp := 5 * time.Millisecond if os.Getenv("CI") != "" || os.Getenv("GITHUB_ACTIONS") != "" { maxTimePerOp = 15 * time.Millisecond // More lenient for CI } assert.Less(t, avgTimePerOp, maxTimePerOp, "No-op operations should be fast") // No-op operations should not leak memory (< 5MB growth) if endMemory > startMemory { memoryGrowth := int64(endMemory - startMemory) assert.Less(t, memoryGrowth, int64(5*1024*1024), "No-op operations should not leak memory") t.Logf("Memory growth: %d MB", memoryGrowth/1024/1024) } else { t.Logf("Memory decreased: %d MB", (startMemory-endMemory)/1024/1024) } } // TestStressReleasePatterns tests different release patterns func TestStressReleasePatterns(t *testing.T) { if testing.Short() { t.Skip("Skipping stress test in short mode") } test.Prepare(t, config.Conf) defer test.Clean() iterations := 200 t.Run("ManualRelease", func(t *testing.T) { startMemory := getMemStats() for i := 0; i < iterations; i++ { cxt := newStressTestContext(fmt.Sprintf("manual-%d", i)) _, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { ctx.trace.Add({ type: "test" }, { label: "Manual Release" }) return { success: true } } finally { ctx.Release() // Manual release } }`, cxt) if err != nil { t.Fatalf("Manual release iteration %d failed: %v", i, err) } } runtime.GC() endMemory := getMemStats() if endMemory > startMemory { t.Logf("Manual release: Memory growth: %d MB", (endMemory-startMemory)/1024/1024) } else { t.Logf("Manual release: Memory decreased: %d MB", (startMemory-endMemory)/1024/1024) } }) t.Run("NoRelease_RelyOnGC", func(t *testing.T) { startMemory := getMemStats() for i := 0; i < iterations; i++ { cxt := newStressTestContext(fmt.Sprintf("gc-%d", i)) _, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { ctx.trace.Add({ type: "test" }, { label: "GC Release" }) return { success: true } // No manual release - rely on GC }`, cxt) if err != nil { t.Fatalf("GC release iteration %d failed: %v", i, err) } } // Force GC multiple times for i := 0; i < 3; i++ { runtime.GC() time.Sleep(50 * time.Millisecond) } endMemory := getMemStats() if endMemory > startMemory { t.Logf("GC release: Memory growth: %d MB", (endMemory-startMemory)/1024/1024) } else { t.Logf("GC release: Memory decreased: %d MB", (startMemory-endMemory)/1024/1024) } }) t.Run("SeparateTraceRelease", func(t *testing.T) { startMemory := getMemStats() for i := 0; i < iterations; i++ { cxt := newStressTestContext(fmt.Sprintf("separate-%d", i)) _, err := v8.Call(v8.CallOptions{}, ` function test(ctx) { try { ctx.trace.Add({ type: "test" }, { label: "Separate Release" }) ctx.trace.Release() // Release trace separately return { success: true } } finally { ctx.Release() // Release context } }`, cxt) if err != nil { t.Fatalf("Separate release iteration %d failed: %v", i, err) } } runtime.GC() endMemory := getMemStats() if endMemory > startMemory { t.Logf("Separate release: Memory growth: %d MB", (endMemory-startMemory)/1024/1024) } else { t.Logf("Separate release: Memory decreased: %d MB", (startMemory-endMemory)/1024/1024) } }) } // TestStressLongRunningTrace tests long-running trace with many operations func TestStressLongRunningTrace(t *testing.T) { if testing.Short() { t.Skip("Skipping stress test in short mode") } test.Prepare(t, config.Conf) defer test.Clean() cxt := newStressTestContext("long-running-test") startMemory := getMemStats() operations := 100 _, err := v8.Call(v8.CallOptions{}, fmt.Sprintf(` function test(ctx) { const trace = ctx.trace const allNodes = [] // Create many nested nodes for (let i = 0; i < %d; i++) { const parentNode = trace.Add( { type: "parent", index: i }, { label: "Parent " + i } ) allNodes.push(parentNode) // Create child nodes for (let j = 0; j < 5; j++) { const childNode = parentNode.Add( { type: "child", parent: i, index: j }, { label: "Child " + i + "-" + j } ) allNodes.push(childNode) // Add logs childNode.Info("Processing child " + i + "-" + j) childNode.Complete({ result: "success" }) } parentNode.Complete({ result: "all children completed" }) } // Release at the end trace.Release() ctx.Release() return { totalNodes: allNodes.length, operations: %d } }`, operations, operations), cxt) if err != nil { t.Fatalf("Long running trace failed: %v", err) } runtime.GC() endMemory := getMemStats() expectedNodes := operations * 6 // parent + 5 children t.Logf("Long-running trace: %d operations, %d nodes", operations, expectedNodes) t.Logf("Start memory: %d MB", startMemory/1024/1024) t.Logf("End memory: %d MB", endMemory/1024/1024) if endMemory > startMemory { t.Logf("Memory growth: %d MB", (endMemory-startMemory)/1024/1024) } else { t.Logf("Memory decreased: %d MB", (startMemory-endMemory)/1024/1024) } } // Helper function to get current memory usage func getMemStats() uint64 { runtime.GC() var m runtime.MemStats runtime.ReadMemStats(&m) return m.Alloc } ================================================ FILE: agent/context/jsapi_test.go ================================================ package context_test import ( stdContext "context" "fmt" "sync" "testing" "github.com/stretchr/testify/assert" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/gou/runtime/v8/bridge" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/openapi/oauth/types" "github.com/yaoapp/yao/test" "rogchap.com/v8go" ) // TestJsValue test the JsValue function func TestJsValue(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := context.New(stdContext.Background(), nil, "ChatID-123456") cxt.AssistantID = "AssistantID-1234" v8.RegisterFunction("testContextJsvalue", testContextJsvalueEmbed) res, err := v8.Call(v8.CallOptions{}, ` function test(cxt) { return testContextJsvalue(cxt) }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } assert.Equal(t, "ChatID-123456", res) // Note: We can't directly check goMaps cleanup as it's in the bridge package } func testContextJsvalueEmbed(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, testContextJsvalueFunction) } func testContextJsvalueFunction(info *v8go.FunctionCallbackInfo) *v8go.Value { var args = info.Args() if len(args) < 1 { return bridge.JsException(info.Context(), "Missing parameters") } ctx, err := args[0].AsObject() if err != nil { return bridge.JsException(info.Context(), err) } chatID, err := ctx.Get("chat_id") if err != nil { return bridge.JsException(info.Context(), err) } return chatID } // TestJsValueConcurrent test the JsValue function with concurrent requests func TestJsValueConcurrent(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() v8.RegisterFunction("testContextJsvalue", testContextJsvalueEmbed) // Number of concurrent goroutines concurrency := 10 iterationsPerGoroutine := 5 var wg sync.WaitGroup errors := make(chan error, concurrency*iterationsPerGoroutine) results := make(chan string, concurrency*iterationsPerGoroutine) // Launch concurrent goroutines for i := 0; i < concurrency; i++ { wg.Add(1) go func(routineID int) { defer wg.Done() for j := 0; j < iterationsPerGoroutine; j++ { chatID := fmt.Sprintf("ChatID-%d-%d", routineID, j) assistantID := fmt.Sprintf("AssistantID-%d-%d", routineID, j) cxt := context.New(stdContext.Background(), nil, chatID) cxt.AssistantID = assistantID res, err := v8.Call(v8.CallOptions{}, ` function test(cxt) { return testContextJsvalue(cxt) }`, cxt) if err != nil { errors <- fmt.Errorf("routine %d iteration %d failed: %v", routineID, j, err) return } results <- res.(string) } }(i) } // Wait for all goroutines to complete wg.Wait() close(errors) close(results) // Check for errors for err := range errors { t.Error(err) } // Verify all results resultCount := 0 for res := range results { assert.Contains(t, res, "ChatID-") resultCount++ } // Verify the correct number of results expectedResults := concurrency * iterationsPerGoroutine assert.Equal(t, expectedResults, resultCount, "Should have %d results", expectedResults) // Verify all objects are cleaned up after GC // Note: objects should be released when v8 values are garbage collected // Note: We can't directly check goMaps cleanup as it's in the bridge package } // TestJsValueRegistrationAndCleanup test the object registration and cleanup mechanism func TestJsValueRegistrationAndCleanup(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() v8.RegisterFunction("testContextRegistration", testContextRegistrationEmbed) // Create multiple contexts and verify registration contextCount := 5 for i := 0; i < contextCount; i++ { cxt := context.New(stdContext.Background(), nil, fmt.Sprintf("ChatID-%d", i)) cxt.AssistantID = fmt.Sprintf("AssistantID-%d", i) _, err := v8.Call(v8.CallOptions{}, ` function test(cxt) { return testContextRegistration(cxt) }`, cxt) if err != nil { t.Fatalf("Call %d failed: %v", i, err) } } // All objects should be cleaned up after v8.Call completes // Note: We can't directly check goMaps cleanup as it's in the bridge package } func testContextRegistrationEmbed(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, testContextRegistrationFunction) } func testContextRegistrationFunction(info *v8go.FunctionCallbackInfo) *v8go.Value { var args = info.Args() if len(args) < 1 { return bridge.JsException(info.Context(), "Missing parameters") } ctx, err := args[0].AsObject() if err != nil { return bridge.JsException(info.Context(), err) } // Verify the object has __release function release, err := ctx.Get("__release") if err != nil { return bridge.JsException(info.Context(), err) } if !release.IsFunction() { return bridge.JsException(info.Context(), fmt.Errorf("__release should be a function")) } // Verify the object has internal field (goValueID is stored in internal field, not accessible from JS) if ctx.InternalFieldCount() == 0 { return bridge.JsException(info.Context(), fmt.Errorf("object should have internal field")) } goValueID := ctx.GetInternalField(0) if goValueID == nil || !goValueID.IsString() { return bridge.JsException(info.Context(), fmt.Errorf("internal field should contain goValueID string")) } val, err := v8go.NewValue(info.Context().Isolate(), true) if err != nil { return bridge.JsException(info.Context(), err) } return val } // TestJsValueAllFields test that all Context fields are properly exported to JavaScript func TestJsValueAllFields(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() authInfo := &types.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client", UserID: "user-123", TeamID: "team-456", TenantID: "tenant-789", Constraints: types.DataConstraints{ OwnerOnly: true, CreatorOnly: false, TeamOnly: true, Extra: map[string]interface{}{ "department": "engineering", "region": "us-west", }, }, } cxt := context.New(stdContext.Background(), authInfo, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Locale = "zh-cn" cxt.Theme = "dark" cxt.Client = context.Client{ Type: "web", UserAgent: "Mozilla/5.0", IP: "127.0.0.1", } cxt.Referer = "api" cxt.Accept = "cui-web" cxt.Route = "/dashboard/home" cxt.Metadata = map[string]interface{}{ "key1": "value1", "key2": 123, "key3": true, } v8.RegisterFunction("testAllFields", testAllFieldsEmbed) res, err := v8.Call(v8.CallOptions{}, ` function test(cxt) { return testAllFields(cxt) }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } // Verify all fields assert.Equal(t, "test-chat-id", result["chat_id"], "chat_id mismatch") assert.Equal(t, "test-assistant-id", result["assistant_id"], "assistant_id mismatch") assert.Equal(t, "zh-cn", result["locale"], "locale mismatch") assert.Equal(t, "dark", result["theme"], "theme mismatch") assert.Equal(t, "api", result["referer"], "referer mismatch") assert.Equal(t, "cui-web", result["accept"], "accept mismatch") assert.Equal(t, "/dashboard/home", result["route"], "route mismatch") // Verify client object client, ok := result["client"].(map[string]interface{}) assert.True(t, ok, "client should be an object") assert.Equal(t, "web", client["type"], "client.type mismatch") assert.Equal(t, "Mozilla/5.0", client["user_agent"], "client.user_agent mismatch") assert.Equal(t, "127.0.0.1", client["ip"], "client.ip mismatch") // Verify metadata object metadata, ok := result["metadata"].(map[string]interface{}) assert.True(t, ok, "metadata should be an object") assert.Equal(t, "value1", metadata["key1"], "metadata.key1 mismatch") assert.Equal(t, float64(123), metadata["key2"], "metadata.key2 mismatch") assert.Equal(t, true, metadata["key3"], "metadata.key3 mismatch") // Verify authorized object authorized, ok := result["authorized"].(map[string]interface{}) assert.True(t, ok, "authorized should be an object") assert.Equal(t, "test-user", authorized["sub"], "authorized.sub mismatch") assert.Equal(t, "test-client", authorized["client_id"], "authorized.client_id mismatch") assert.Equal(t, "user-123", authorized["user_id"], "authorized.user_id mismatch") assert.Equal(t, "team-456", authorized["team_id"], "authorized.team_id mismatch") assert.Equal(t, "tenant-789", authorized["tenant_id"], "authorized.tenant_id mismatch") // Verify authorized.constraints object constraints, ok := authorized["constraints"].(map[string]interface{}) assert.True(t, ok, "authorized.constraints should be an object") assert.Equal(t, true, constraints["owner_only"], "constraints.owner_only mismatch") // creator_only is false, and with omitempty it may not be present if creatorOnly, exists := constraints["creator_only"]; exists { assert.Equal(t, false, creatorOnly, "constraints.creator_only mismatch") } assert.Equal(t, true, constraints["team_only"], "constraints.team_only mismatch") // Verify constraints.extra object extra, ok := constraints["extra"].(map[string]interface{}) assert.True(t, ok, "constraints.extra should be an object") assert.Equal(t, "engineering", extra["department"], "constraints.extra.department mismatch") assert.Equal(t, "us-west", extra["region"], "constraints.extra.region mismatch") // Note: We can't directly check goMaps cleanup as it's in the bridge package } func testAllFieldsEmbed(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, testAllFieldsFunction) } func testAllFieldsFunction(info *v8go.FunctionCallbackInfo) *v8go.Value { var args = info.Args() if len(args) < 1 { return bridge.JsException(info.Context(), "Missing parameters") } ctx, err := args[0].AsObject() if err != nil { return bridge.JsException(info.Context(), err) } // Extract all fields and return as a map result := map[string]interface{}{} // Helper function to get field value getField := func(name string) (interface{}, bool) { val, err := ctx.Get(name) if err != nil || val.IsUndefined() { return nil, false } goVal, err := bridge.GoValue(val, info.Context()) if err != nil { return nil, false } return goVal, true } if val, ok := getField("chat_id"); ok { result["chat_id"] = val } if val, ok := getField("assistant_id"); ok { result["assistant_id"] = val } if val, ok := getField("locale"); ok { result["locale"] = val } if val, ok := getField("theme"); ok { result["theme"] = val } if val, ok := getField("client"); ok { result["client"] = val } if val, ok := getField("referer"); ok { result["referer"] = val } if val, ok := getField("accept"); ok { result["accept"] = val } if val, ok := getField("route"); ok { result["route"] = val } if val, ok := getField("metadata"); ok { result["metadata"] = val } if val, ok := getField("authorized"); ok { result["authorized"] = val } // Check for deprecated fields - they should NOT exist if val, ok := getField("sid"); ok { result["sid"] = val } if val, ok := getField("silent"); ok { result["silent"] = val } jsVal, err := bridge.JsValue(info.Context(), result) if err != nil { return bridge.JsException(info.Context(), err) } return jsVal } // TestJsValueTrace test the Trace method on Context func TestJsValueTrace(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Stack = &context.Stack{ TraceID: "test-trace-id", } res, err := v8.Call(v8.CallOptions{}, ` function test(cxt) { // Get trace from context (property, not method call) const trace = cxt.trace // Verify trace object exists if (!trace) { throw new Error("Trace returned null or undefined") } // Verify trace has expected methods if (typeof trace.Add !== 'function') { throw new Error("trace.Add is not a function") } if (typeof trace.Info !== 'function') { throw new Error("trace.Info is not a function") } // Actually use the trace - add a node const node = trace.Add({ type: "test", content: "Test from context" }, { label: "Test Node" }) // Log some info trace.Info("Testing trace from context") node.Info("Node info message") // Complete the node node.Complete({ result: "success" }) // Return verification info return { trace_id: trace.id, node_id: node.id, success: true } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } // Verify trace was accessible and operations succeeded assert.Equal(t, "test-trace-id", result["trace_id"], "trace_id should match") assert.NotEmpty(t, result["node_id"], "node_id should not be empty") assert.Equal(t, true, result["success"], "operation should succeed") } // TestJsValueAuthorizedAndMetadata test the authorized and metadata fields func TestJsValueAuthorizedAndMetadata(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() authInfo := &types.AuthorizedInfo{ UserID: "user-123", TenantID: "tenant-456", ClientID: "client-789", } cxt := context.New(stdContext.Background(), authInfo, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Metadata = map[string]interface{}{ "request_id": "req-001", "source": "api", "version": "1.0.0", } v8.RegisterFunction("testAuthorizedMetadata", testAuthorizedMetadataEmbed) res, err := v8.Call(v8.CallOptions{}, ` function test(cxt) { return testAuthorizedMetadata(cxt) }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } // Verify authorized object authorized, ok := result["authorized"].(map[string]interface{}) assert.True(t, ok, "authorized should be an object") assert.Equal(t, "user-123", authorized["user_id"], "authorized.user_id mismatch") assert.Equal(t, "tenant-456", authorized["tenant_id"], "authorized.tenant_id mismatch") assert.Equal(t, "client-789", authorized["client_id"], "authorized.client_id mismatch") // Verify metadata object metadata, ok := result["metadata"].(map[string]interface{}) assert.True(t, ok, "metadata should be an object") assert.Equal(t, "req-001", metadata["request_id"], "metadata.request_id mismatch") assert.Equal(t, "api", metadata["source"], "metadata.source mismatch") assert.Equal(t, "1.0.0", metadata["version"], "metadata.version mismatch") } func testAuthorizedMetadataEmbed(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, testAuthorizedMetadataFunction) } func testAuthorizedMetadataFunction(info *v8go.FunctionCallbackInfo) *v8go.Value { var args = info.Args() if len(args) < 1 { return bridge.JsException(info.Context(), "Missing parameters") } ctx, err := args[0].AsObject() if err != nil { return bridge.JsException(info.Context(), err) } // Extract authorized and metadata fields result := map[string]interface{}{} // Get authorized authorizedVal, err := ctx.Get("authorized") if err != nil { return bridge.JsException(info.Context(), err) } if !authorizedVal.IsUndefined() && !authorizedVal.IsNull() { authorized, err := bridge.GoValue(authorizedVal, info.Context()) if err != nil { return bridge.JsException(info.Context(), err) } result["authorized"] = authorized } // Get metadata metadataVal, err := ctx.Get("metadata") if err != nil { return bridge.JsException(info.Context(), err) } if !metadataVal.IsUndefined() && !metadataVal.IsNull() { metadata, err := bridge.GoValue(metadataVal, info.Context()) if err != nil { return bridge.JsException(info.Context(), err) } result["metadata"] = metadata } jsVal, err := bridge.JsValue(info.Context(), result) if err != nil { return bridge.JsException(info.Context(), err) } return jsVal } // TestJsValueAuthorizedNil test when authorized is nil func TestJsValueAuthorizedNil(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() cxt := context.New(stdContext.Background(), nil, "test-chat-id") cxt.AssistantID = "test-assistant-id" cxt.Metadata = nil // Explicitly nil (should be empty object) res, err := v8.Call(v8.CallOptions{}, ` function test(cxt) { // Debug: check the actual values const authorized = cxt.authorized; const metadata = cxt.metadata; return { authorized_type: typeof authorized, authorized_is_null: authorized === null, authorized_is_undefined: authorized === undefined, metadata_type: typeof metadata, metadata_is_object: typeof metadata === 'object' && metadata !== null, metadata_is_empty: metadata && Object.keys(metadata).length === 0, has_authorized: 'authorized' in cxt, has_metadata: 'metadata' in cxt } }`, cxt) if err != nil { t.Fatalf("Call failed: %v", err) } result, ok := res.(map[string]interface{}) if !ok { t.Fatalf("Expected map result, got %T", res) } // Verify authorized exists and is an empty object when nil assert.Equal(t, true, result["has_authorized"], "authorized property should exist") assert.Equal(t, "object", result["authorized_type"], "authorized should be an object") assert.Equal(t, true, result["metadata_is_object"], "authorized should be an object (not null)") // Verify metadata is an empty object when not set assert.Equal(t, true, result["has_metadata"], "metadata property should exist") assert.Equal(t, "object", result["metadata_type"], "metadata should be an object") assert.Equal(t, true, result["metadata_is_object"], "metadata should be an object") assert.Equal(t, true, result["metadata_is_empty"], "metadata should be empty object when not set") } ================================================ FILE: agent/context/jsapi_workspace.go ================================================ package context import ( "io/fs" "os" "github.com/yaoapp/gou/runtime/v8/bridge" "rogchap.com/v8go" ) // createWorkspaceInstance creates the ctx.workspace JavaScript object. func (ctx *Context) createWorkspaceInstance(v8ctx *v8go.Context) *v8go.Value { if ctx.workspace == nil { return nil } iso := v8ctx.Isolate() objTpl := v8go.NewObjectTemplate(iso) objTpl.Set("ReadFile", ctx.wsReadFileMethod(iso)) objTpl.Set("WriteFile", ctx.wsWriteFileMethod(iso)) objTpl.Set("ReadDir", ctx.wsReadDirMethod(iso)) objTpl.Set("MkdirAll", ctx.wsMkdirAllMethod(iso)) objTpl.Set("Remove", ctx.wsRemoveMethod(iso)) objTpl.Set("RemoveAll", ctx.wsRemoveAllMethod(iso)) objTpl.Set("Rename", ctx.wsRenameMethod(iso)) objTpl.Set("Copy", ctx.wsCopyMethod(iso)) objTpl.Set("Stat", ctx.wsStatMethod(iso)) objTpl.Set("Exists", ctx.wsExistsMethod(iso)) instance, err := objTpl.NewInstance(v8ctx) if err != nil { return nil } return instance.Value } // wsReadFileMethod implements ctx.workspace.ReadFile(path) func (ctx *Context) wsReadFileMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.workspace == nil { return bridge.JsException(v8ctx, "workspace not available") } if len(args) < 1 { return bridge.JsException(v8ctx, "ReadFile requires a path argument") } data, err := ctx.workspace.ReadFile(args[0].String()) if err != nil { return bridge.JsException(v8ctx, "ReadFile failed: "+err.Error()) } jsVal, err := v8go.NewValue(iso, string(data)) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // wsWriteFileMethod implements ctx.workspace.WriteFile(path, content) func (ctx *Context) wsWriteFileMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.workspace == nil { return bridge.JsException(v8ctx, "workspace not available") } if len(args) < 2 { return bridge.JsException(v8ctx, "WriteFile requires path and content arguments") } path := args[0].String() content := args[1].String() if err := ctx.workspace.WriteFile(path, []byte(content), 0o644); err != nil { return bridge.JsException(v8ctx, "WriteFile failed: "+err.Error()) } return v8go.Undefined(iso) }) } // wsReadDirMethod implements ctx.workspace.ReadDir(path) func (ctx *Context) wsReadDirMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.workspace == nil { return bridge.JsException(v8ctx, "workspace not available") } path := "." if len(args) >= 1 && args[0].IsString() { path = args[0].String() } entries, err := ctx.workspace.ReadDir(path) if err != nil { return bridge.JsException(v8ctx, "ReadDir failed: "+err.Error()) } result := make([]map[string]interface{}, 0, len(entries)) for _, e := range entries { fi, _ := e.Info() item := map[string]interface{}{ "name": e.Name(), "is_dir": e.IsDir(), } if fi != nil { item["size"] = int32(fi.Size()) } result = append(result, item) } jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // wsMkdirAllMethod implements ctx.workspace.MkdirAll(path) func (ctx *Context) wsMkdirAllMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.workspace == nil { return bridge.JsException(v8ctx, "workspace not available") } if len(args) < 1 { return bridge.JsException(v8ctx, "MkdirAll requires a path argument") } if err := ctx.workspace.MkdirAll(args[0].String(), 0o755); err != nil { return bridge.JsException(v8ctx, "MkdirAll failed: "+err.Error()) } return v8go.Undefined(iso) }) } // wsRemoveMethod implements ctx.workspace.Remove(path) func (ctx *Context) wsRemoveMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.workspace == nil { return bridge.JsException(v8ctx, "workspace not available") } if len(args) < 1 { return bridge.JsException(v8ctx, "Remove requires a path argument") } if err := ctx.workspace.Remove(args[0].String()); err != nil { return bridge.JsException(v8ctx, "Remove failed: "+err.Error()) } return v8go.Undefined(iso) }) } // wsRemoveAllMethod implements ctx.workspace.RemoveAll(path) func (ctx *Context) wsRemoveAllMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.workspace == nil { return bridge.JsException(v8ctx, "workspace not available") } if len(args) < 1 { return bridge.JsException(v8ctx, "RemoveAll requires a path argument") } if err := ctx.workspace.RemoveAll(args[0].String()); err != nil { return bridge.JsException(v8ctx, "RemoveAll failed: "+err.Error()) } return v8go.Undefined(iso) }) } // wsRenameMethod implements ctx.workspace.Rename(oldName, newName) func (ctx *Context) wsRenameMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.workspace == nil { return bridge.JsException(v8ctx, "workspace not available") } if len(args) < 2 { return bridge.JsException(v8ctx, "Rename requires oldName and newName arguments") } if err := ctx.workspace.Rename(args[0].String(), args[1].String()); err != nil { return bridge.JsException(v8ctx, "Rename failed: "+err.Error()) } return v8go.Undefined(iso) }) } // wsCopyMethod implements ctx.workspace.Copy(src, dst) func (ctx *Context) wsCopyMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.workspace == nil { return bridge.JsException(v8ctx, "workspace not available") } if len(args) < 2 { return bridge.JsException(v8ctx, "Copy requires src and dst arguments") } if _, err := ctx.workspace.Copy(args[0].String(), args[1].String()); err != nil { return bridge.JsException(v8ctx, "Copy failed: "+err.Error()) } return v8go.Undefined(iso) }) } // wsStatMethod implements ctx.workspace.Stat(path) func (ctx *Context) wsStatMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.workspace == nil { return bridge.JsException(v8ctx, "workspace not available") } if len(args) < 1 { return bridge.JsException(v8ctx, "Stat requires a path argument") } fi, err := ctx.workspace.Stat(args[0].String()) if err != nil { return bridge.JsException(v8ctx, "Stat failed: "+err.Error()) } result := map[string]interface{}{ "name": fi.Name(), "size": int32(fi.Size()), "is_dir": fi.IsDir(), "mode": int32(fi.Mode()), "mtime": fi.ModTime().UnixMilli(), } jsVal, err := bridge.JsValue(v8ctx, result) if err != nil { return bridge.JsException(v8ctx, err.Error()) } return jsVal }) } // wsExistsMethod implements ctx.workspace.Exists(path) func (ctx *Context) wsExistsMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if ctx.workspace == nil { return bridge.JsException(v8ctx, "workspace not available") } if len(args) < 1 { return bridge.JsException(v8ctx, "Exists requires a path argument") } _, err := ctx.workspace.Stat(args[0].String()) exists := err == nil || !isNotExist(err) jsVal, _ := v8go.NewValue(iso, exists) return jsVal }) } func isNotExist(err error) bool { if os.IsNotExist(err) { return true } pathErr, ok := err.(*fs.PathError) if ok && os.IsNotExist(pathErr.Err) { return true } return false } ================================================ FILE: agent/context/log.go ================================================ package context import ( "fmt" "strings" "sync" "time" kunlog "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/config" ) // ============================================================================= // ANSI Color Codes // ============================================================================= const ( colorReset = "\033[0m" colorRed = "\033[31m" colorGreen = "\033[32m" colorYellow = "\033[33m" colorBlue = "\033[34m" colorMagenta = "\033[35m" colorCyan = "\033[36m" colorWhite = "\033[37m" colorGray = "\033[90m" colorBoldRed = "\033[1;31m" colorBoldGreen = "\033[1;32m" colorBoldYellow = "\033[1;33m" colorBoldBlue = "\033[1;34m" colorBoldMagenta = "\033[1;35m" colorBoldCyan = "\033[1;36m" ) // ============================================================================= // Log Level // ============================================================================= // LogLevel represents log severity type LogLevel int const ( // LogLevelTrace represents the most verbose logging level for detailed tracing LogLevelTrace LogLevel = iota // LogLevelDebug represents debug level logging for development diagnostics LogLevelDebug // LogLevelInfo represents informational messages for normal operation LogLevelInfo // LogLevelWarn represents warning messages for potentially harmful situations LogLevelWarn // LogLevelError represents error messages for serious problems LogLevelError ) // ============================================================================= // Log Entry // ============================================================================= // LogEntry represents a single log entry type LogEntry struct { Level LogLevel Message string Timestamp time.Time Phase string // For phase logging Elapsed time.Duration } // ============================================================================= // Request Logger // ============================================================================= // RequestLogger provides request-scoped async logging type RequestLogger struct { assistantIDStack []string // Stack-based: delegate calls push, pop on exit; top = current chatID string requestID string shortID string // Short version of requestID for display parentID string // Parent request ID for A2A tree structure startTime time.Time ch chan LogEntry done chan struct{} once sync.Once closed bool noop bool // noop logger does nothing (for nil safety) mu sync.RWMutex } // LoggerOption configures a RequestLogger type LoggerOption func(*RequestLogger) // WithParentID sets the parent request ID for A2A tree structure func WithParentID(parentID string) LoggerOption { return func(l *RequestLogger) { l.parentID = parentID } } // noopLogger is a shared no-op logger instance var noopLogger = &RequestLogger{noop: true} // NewRequestLogger creates a new request-scoped logger with async processing func NewRequestLogger(assistantID, chatID, requestID string, opts ...LoggerOption) *RequestLogger { l := &RequestLogger{ assistantIDStack: []string{assistantID}, chatID: chatID, requestID: requestID, shortID: shortID(requestID), startTime: time.Now(), ch: make(chan LogEntry, 100), // Buffered channel done: make(chan struct{}), } for _, opt := range opts { opt(l) } // Start consumer goroutine go l.consume() return l } // Noop returns a no-op logger that does nothing (nil-safe) func Noop() *RequestLogger { return noopLogger } // SetAssistantID pushes a new assistant ID onto the stack (called when entering Stream). // Each SetAssistantID must be paired with a RestoreAssistantID on exit. func (l *RequestLogger) SetAssistantID(id string) { if l.noop { return } l.mu.Lock() l.assistantIDStack = append(l.assistantIDStack, id) l.mu.Unlock() } // RestoreAssistantID pops the current assistant ID, reverting to the previous one. // Safe to call even if the stack has only one entry (the initial ID is never removed). func (l *RequestLogger) RestoreAssistantID() { if l.noop { return } l.mu.Lock() if len(l.assistantIDStack) > 1 { l.assistantIDStack = l.assistantIDStack[:len(l.assistantIDStack)-1] } l.mu.Unlock() } func (l *RequestLogger) currentAssistantID() string { if len(l.assistantIDStack) == 0 { return "" } return l.assistantIDStack[len(l.assistantIDStack)-1] } // Close closes the logger and waits for all entries to be processed func (l *RequestLogger) Close() { if l.noop { return } l.once.Do(func() { l.mu.Lock() l.closed = true l.mu.Unlock() close(l.ch) <-l.done // Wait for consumer to finish }) } // consume processes log entries from the channel func (l *RequestLogger) consume() { defer close(l.done) for entry := range l.ch { l.processEntry(entry) } } // processEntry handles a single log entry based on mode func (l *RequestLogger) processEntry(entry LogEntry) { if config.IsDevelopment() { l.printDev(entry) l.writeLog(entry, true) } else { l.writeLog(entry, false) } } // printDev prints colored output to stdout in development mode func (l *RequestLogger) printDev(entry LogEntry) { switch entry.Level { case LogLevelTrace: fmt.Printf("%s → %s%s\n", colorGray, entry.Message, colorReset) case LogLevelDebug: fmt.Printf("%s • %s%s\n", colorGray, entry.Message, colorReset) case LogLevelInfo: fmt.Printf("%s ℹ %s%s\n", colorCyan, entry.Message, colorReset) case LogLevelWarn: fmt.Printf("%s ⚠ %s%s\n", colorYellow, entry.Message, colorReset) case LogLevelError: fmt.Printf("%s ✗ %s%s\n", colorRed, entry.Message, colorReset) } } // writeLog writes structured events to kun/log func (l *RequestLogger) writeLog(entry LogEntry, devMode bool) { prefix := fmt.Sprintf("[AGENT] %s ", l.shortID) if devMode { kunlog.Trace("%s%s", prefix, entry.Message) return } switch entry.Level { case LogLevelTrace: kunlog.Trace("%s%s", prefix, entry.Message) case LogLevelDebug: // Skip debug in production case LogLevelInfo: kunlog.Info("%s%s", prefix, entry.Message) case LogLevelWarn: kunlog.Warn("%s%s", prefix, entry.Message) case LogLevelError: kunlog.Error("%s%s", prefix, entry.Message) } } // send sends an entry to the channel (non-blocking if closed) func (l *RequestLogger) send(entry LogEntry) { if l.noop { return } l.mu.RLock() closed := l.closed l.mu.RUnlock() if closed { return } entry.Timestamp = time.Now() select { case l.ch <- entry: default: // Channel full, drop the log (shouldn't happen with buffered channel) } } // ============================================================================= // Standard Log Interface // ============================================================================= // Trace logs a trace level message func (l *RequestLogger) Trace(format string, args ...interface{}) { l.send(LogEntry{ Level: LogLevelTrace, Message: fmt.Sprintf(format, args...), }) } // Debug logs a debug level message func (l *RequestLogger) Debug(format string, args ...interface{}) { l.send(LogEntry{ Level: LogLevelDebug, Message: fmt.Sprintf(format, args...), }) } // Info logs an info level message func (l *RequestLogger) Info(format string, args ...interface{}) { l.send(LogEntry{ Level: LogLevelInfo, Message: fmt.Sprintf(format, args...), }) } // Warn logs a warning level message func (l *RequestLogger) Warn(format string, args ...interface{}) { l.send(LogEntry{ Level: LogLevelWarn, Message: fmt.Sprintf(format, args...), }) } // Error logs an error level message func (l *RequestLogger) Error(format string, args ...interface{}) { l.send(LogEntry{ Level: LogLevelError, Message: fmt.Sprintf(format, args...), }) } // ============================================================================= // Business Quick Functions // ============================================================================= // Start logs the start of a request with visual separator func (l *RequestLogger) Start() { if l.noop { return } kunlog.Trace("[AGENT] Request %s started: assistant=%s, chat=%s, request=%s", l.shortID, l.currentAssistantID(), shortID(l.chatID), shortID(l.requestID)) if !config.IsDevelopment() { return } fmt.Println() fmt.Printf("%s%s%s\n", colorBoldCyan, strings.Repeat("═", 60), colorReset) fmt.Printf("%s AGENT REQUEST %s%s\n", colorBoldCyan, l.shortID, colorReset) fmt.Printf("%s%s%s\n", colorBoldCyan, strings.Repeat("─", 60), colorReset) fmt.Printf("%s Assistant: %s%s%s\n", colorGray, colorWhite, l.currentAssistantID(), colorReset) fmt.Printf("%s Chat ID: %s%s%s\n", colorGray, colorWhite, l.chatID, colorReset) fmt.Printf("%s Request: %s%s%s\n", colorGray, colorWhite, l.requestID, colorReset) fmt.Printf("%s Time: %s%s%s\n", colorGray, colorWhite, l.startTime.Format("15:04:05.000"), colorReset) fmt.Printf("%s%s%s\n", colorCyan, strings.Repeat("─", 60), colorReset) } // End logs the end of a request with summary func (l *RequestLogger) End(success bool, err error) { if l.noop { return } duration := time.Since(l.startTime) if success { kunlog.Trace("[AGENT] Request %s completed: assistant=%s, duration=%v", l.shortID, l.currentAssistantID(), duration.Round(time.Millisecond)) } else { kunlog.Error("[AGENT] Request %s failed: assistant=%s, duration=%v, error=%v", l.shortID, l.currentAssistantID(), duration.Round(time.Millisecond), err) } if !config.IsDevelopment() { return } fmt.Printf("%s%s%s\n", colorCyan, strings.Repeat("─", 60), colorReset) if success { fmt.Printf("%s REQUEST %s COMPLETED%s\n", colorBoldGreen, l.shortID, colorReset) } else { fmt.Printf("%s REQUEST %s FAILED%s\n", colorBoldRed, l.shortID, colorReset) if err != nil { fmt.Printf("%s Error: %s%v%s\n", colorGray, colorRed, err, colorReset) } } fmt.Printf("%s Assistant: %s%s%s\n", colorGray, colorWhite, l.currentAssistantID(), colorReset) fmt.Printf("%s Duration: %s%v%s\n", colorGray, colorWhite, duration.Round(time.Millisecond), colorReset) fmt.Printf("%s%s%s\n", colorCyan, strings.Repeat("─", 60), colorReset) fmt.Println() } // Phase logs a major phase in the request lifecycle func (l *RequestLogger) Phase(name string) { if l.noop { return } elapsed := time.Since(l.startTime).Round(time.Millisecond) kunlog.Trace("[AGENT] %s Phase: %s (+%v)", l.shortID, name, elapsed) if !config.IsDevelopment() { return } fmt.Printf("%s > %s%s %s[+%v]%s\n", colorBoldBlue, name, colorReset, colorGray, elapsed, colorReset) } // PhaseComplete logs the completion of a phase func (l *RequestLogger) PhaseComplete(name string) { if l.noop { return } elapsed := time.Since(l.startTime).Round(time.Millisecond) kunlog.Trace("[AGENT] %s Phase completed: %s (+%v)", l.shortID, name, elapsed) if !config.IsDevelopment() { return } fmt.Printf("%s + %s%s %s[+%v]%s\n", colorGreen, name, colorReset, colorGray, elapsed, colorReset) } // PhaseSkip logs a skipped phase (development only) func (l *RequestLogger) PhaseSkip(name, reason string) { if l.noop { return } if !config.IsDevelopment() { return } fmt.Printf("%s - %s (%s)%s\n", colorGray, name, reason, colorReset) } // LLMStart logs the start of an LLM call func (l *RequestLogger) LLMStart(connector, model string, messageCount int) { if l.noop { return } elapsed := time.Since(l.startTime).Round(time.Millisecond) kunlog.Trace("[AGENT] %s LLM call: connector=%s, model=%s, messages=%d (+%v)", l.shortID, connector, model, messageCount, elapsed) if !config.IsDevelopment() { return } fmt.Printf("%s LLM Call%s %s[+%v]%s\n", colorBoldMagenta, colorReset, colorGray, elapsed, colorReset) fmt.Printf("%s Connector: %s%s%s\n", colorGray, colorWhite, connector, colorReset) if model != "" { fmt.Printf("%s Model: %s%s%s\n", colorGray, colorWhite, model, colorReset) } fmt.Printf("%s Messages: %s%d%s\n", colorGray, colorWhite, messageCount, colorReset) } // LLMComplete logs the completion of an LLM call func (l *RequestLogger) LLMComplete(tokens int, hasToolCalls bool) { if l.noop { return } elapsed := time.Since(l.startTime).Round(time.Millisecond) status := "streaming" if hasToolCalls { status = "tool_calls" } kunlog.Trace("[AGENT] %s LLM response: status=%s, tokens=%d (+%v)", l.shortID, status, tokens, elapsed) if !config.IsDevelopment() { return } fmt.Printf("%s + LLM Response (%s)%s", colorGreen, status, colorReset) if tokens > 0 { fmt.Printf(" %s[tokens: %d]%s", colorGray, tokens, colorReset) } fmt.Printf(" %s[+%v]%s\n", colorGray, elapsed, colorReset) } // ToolStart logs the start of tool execution func (l *RequestLogger) ToolStart(toolName string) { if l.noop { return } kunlog.Trace("[AGENT] %s Tool call: %s", l.shortID, toolName) if !config.IsDevelopment() { return } fmt.Printf("%s Tool: %s%s\n", colorYellow, toolName, colorReset) } // ToolComplete logs the completion of tool execution func (l *RequestLogger) ToolComplete(toolName string, success bool) { if l.noop { return } if success { kunlog.Trace("[AGENT] %s Tool completed: %s", l.shortID, toolName) } else { kunlog.Error("[AGENT] %s Tool failed: %s", l.shortID, toolName) } if !config.IsDevelopment() { return } if success { fmt.Printf("%s + %s completed%s\n", colorGreen, toolName, colorReset) } else { fmt.Printf("%s x %s failed%s\n", colorRed, toolName, colorReset) } } // HookStart logs the start of a hook execution func (l *RequestLogger) HookStart(hookName string) { if l.noop { return } elapsed := time.Since(l.startTime).Round(time.Millisecond) kunlog.Trace("[AGENT] %s Hook: %s (+%v)", l.shortID, hookName, elapsed) if !config.IsDevelopment() { return } fmt.Printf("%s Hook: %s%s %s[+%v]%s\n", colorMagenta, hookName, colorReset, colorGray, elapsed, colorReset) } // HookComplete logs the completion of a hook func (l *RequestLogger) HookComplete(hookName string) { if l.noop { return } kunlog.Trace("[AGENT] %s Hook completed: %s", l.shortID, hookName) if !config.IsDevelopment() { return } fmt.Printf("%s + %s done%s\n", colorGreen, hookName, colorReset) } // Cleanup logs resource cleanup func (l *RequestLogger) Cleanup(resource string) { if l.noop { return } kunlog.Trace("[AGENT] %s Cleanup: %s", l.shortID, resource) if !config.IsDevelopment() { return } fmt.Printf("%s + %s%s\n", colorGray, resource, colorReset) } // HistoryLoad logs history loading func (l *RequestLogger) HistoryLoad(count, maxSize int) { if l.noop { return } kunlog.Trace("[AGENT] %s History loaded: %d/%d messages", l.shortID, count, maxSize) if !config.IsDevelopment() { return } fmt.Printf("%s Loaded %d/%d history messages%s\n", colorGray, count, maxSize, colorReset) } // HistoryOverlap logs overlap detection func (l *RequestLogger) HistoryOverlap(overlapCount int) { if l.noop { return } if overlapCount > 0 { kunlog.Trace("[AGENT] %s History overlap removed: %d messages", l.shortID, overlapCount) if !config.IsDevelopment() { return } fmt.Printf("%s Removed %d overlapping messages%s\n", colorYellow, overlapCount, colorReset) } } // Release logs the start of resource release phase func (l *RequestLogger) Release() { if l.noop { return } kunlog.Trace("[AGENT] %s Release started", l.shortID) if !config.IsDevelopment() { return } fmt.Printf("%s RELEASE %s%s %s(%s)%s\n", colorBoldYellow, l.shortID, colorReset, colorGray, l.currentAssistantID(), colorReset) } // ============================================================================= // Helper // ============================================================================= // shortID returns first 8 characters of an ID func shortID(id string) string { if len(id) > 8 { return id[:8] } return id } ================================================ FILE: agent/context/mcp.go ================================================ package context import ( "fmt" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/mcp" "github.com/yaoapp/gou/mcp/types" "github.com/yaoapp/yao/agent/i18n" traceTypes "github.com/yaoapp/yao/trace/types" ) // MCP Client Operations with automatic trace logging and resource management // Resource Operations // ================== // ListResources lists all available resources from an MCP client // Automatically creates trace node and handles client lifecycle func (ctx *Context) ListResources(mcpID string, cursor string) (*types.ListResourcesResponse, error) { // Get MCP client client, err := mcp.Select(mcpID) if err != nil { return nil, fmt.Errorf("failed to select MCP client '%s': %w", mcpID, err) } // Get client label for display clientLabel := client.GetMetaInfo().Label if clientLabel == "" { clientLabel = mcpID } // Get trace manager trace, _ := ctx.Trace() // Create trace node var node traceTypes.Node if trace != nil { node, _ = trace.Add( map[string]any{ "mcp": mcpID, "cursor": cursor, }, traceTypes.TraceNodeOption{ Label: i18n.T(ctx.Locale, "mcp.list_resources.label"), // "MCP: List Resources" Type: "mcp", Icon: "list", Description: fmt.Sprintf(i18n.T(ctx.Locale, "mcp.list_resources.description"), clientLabel), // "List resources from MCP client '%s'" }, ) } // Call ListResources result, err := client.ListResources(ctx.Context, cursor) if err != nil { if node != nil { node.Fail(err) } return nil, err } // Complete trace node with result if node != nil { node.Complete(map[string]any{ "resources": len(result.Resources), "nextCursor": result.NextCursor, }) } return result, nil } // ReadResource reads a specific resource from an MCP client // Automatically creates trace node and handles client lifecycle func (ctx *Context) ReadResource(mcpID string, uri string) (*types.ReadResourceResponse, error) { // Get MCP client client, err := mcp.Select(mcpID) if err != nil { return nil, fmt.Errorf("failed to select MCP client '%s': %w", mcpID, err) } // Get client label for display clientLabel := client.GetMetaInfo().Label if clientLabel == "" { clientLabel = mcpID } // Get trace manager trace, _ := ctx.Trace() // Create trace node var node traceTypes.Node if trace != nil { node, _ = trace.Add( map[string]any{ "mcp": mcpID, "uri": uri, }, traceTypes.TraceNodeOption{ Label: i18n.T(ctx.Locale, "mcp.read_resource.label"), // "MCP: Read Resource" Type: "mcp", Icon: "description", Description: fmt.Sprintf(i18n.T(ctx.Locale, "mcp.read_resource.description"), uri, clientLabel), // "Read resource '%s' from MCP client '%s'" }, ) } // Call ReadResource result, err := client.ReadResource(ctx.Context, uri) if err != nil { if node != nil { node.Fail(err) } return nil, err } // Complete trace node with result if node != nil { node.Complete(map[string]any{ "contents": len(result.Contents), }) } return result, nil } // Tool Operations // =============== // ListTools lists all available tools from an MCP client // Automatically creates trace node and handles client lifecycle func (ctx *Context) ListTools(mcpID string, cursor string) (*types.ListToolsResponse, error) { // Get MCP client client, err := mcp.Select(mcpID) if err != nil { return nil, fmt.Errorf("failed to select MCP client '%s': %w", mcpID, err) } // Get client label for display clientLabel := client.GetMetaInfo().Label if clientLabel == "" { clientLabel = mcpID } // Get trace manager trace, _ := ctx.Trace() // Create trace node var node traceTypes.Node if trace != nil { node, _ = trace.Add( map[string]any{ "mcp": mcpID, "cursor": cursor, }, traceTypes.TraceNodeOption{ Label: i18n.T(ctx.Locale, "mcp.list_tools.label"), // "MCP: List Tools" Type: "mcp", Icon: "build", Description: fmt.Sprintf(i18n.T(ctx.Locale, "mcp.list_tools.description"), clientLabel), // "List tools from MCP client '%s'" }, ) } // Call ListTools result, err := client.ListTools(ctx.Context, cursor) if err != nil { if node != nil { node.Fail(err) } return nil, err } // Complete trace node with result if node != nil { node.Complete(map[string]any{ "tools": len(result.Tools), "nextCursor": result.NextCursor, }) } return result, nil } // CallTool calls a single tool from an MCP client // Automatically creates trace node and handles client lifecycle func (ctx *Context) CallTool(mcpID string, name string, arguments interface{}) (*types.CallToolResponse, error) { // Get MCP client client, err := mcp.Select(mcpID) if err != nil { return nil, fmt.Errorf("failed to select MCP client '%s': %w", mcpID, err) } // Get client label for display clientLabel := client.GetMetaInfo().Label if clientLabel == "" { clientLabel = mcpID } // Get trace manager trace, _ := ctx.Trace() // Create trace node var node traceTypes.Node if trace != nil { node, _ = trace.Add( map[string]any{ "mcp": mcpID, "tool": name, "arguments": arguments, }, traceTypes.TraceNodeOption{ Label: i18n.T(ctx.Locale, "mcp.call_tool.label"), // "MCP: Call Tool" Type: "mcp", Icon: "settings", Description: fmt.Sprintf(i18n.T(ctx.Locale, "mcp.call_tool.description"), name, clientLabel), // "Call tool '%s' from MCP client '%s'" }, ) } // Call tool (pass ctx as extraArgs for Process transport to propagate Authorized()) result, err := client.CallTool(ctx.Context, name, arguments, ctx) if err != nil { if node != nil { node.Fail(err) } return nil, err } // Complete trace node with result if node != nil { node.Complete(map[string]any{ "contents": len(result.Content), }) } return result, nil } // CallTools calls multiple tools sequentially from an MCP client // Automatically creates trace node and handles client lifecycle func (ctx *Context) CallTools(mcpID string, tools []types.ToolCall) (*types.CallToolsResponse, error) { // Get MCP client client, err := mcp.Select(mcpID) if err != nil { return nil, fmt.Errorf("failed to select MCP client '%s': %w", mcpID, err) } // Get client label for display clientLabel := client.GetMetaInfo().Label if clientLabel == "" { clientLabel = mcpID } // Get trace manager trace, _ := ctx.Trace() // Create trace node var node traceTypes.Node if trace != nil { node, _ = trace.Add( map[string]any{ "mcp": mcpID, "tools": tools, "count": len(tools), }, traceTypes.TraceNodeOption{ Label: i18n.T(ctx.Locale, "mcp.call_tools.label"), // "MCP: Call Tools" Type: "mcp", Icon: "settings", Description: fmt.Sprintf(i18n.T(ctx.Locale, "mcp.call_tools.description"), len(tools), clientLabel), // "Call %d tools sequentially from MCP client '%s'" }, ) } // Call tools sequentially (pass ctx as extraArgs for Process transport to propagate Authorized()) result, err := client.CallTools(ctx.Context, tools, ctx) if err != nil { if node != nil { node.Fail(err) } return nil, err } // Complete trace node with result if node != nil { node.Complete(map[string]any{ "results": len(result.Results), }) } return result, nil } // CallToolsParallel calls multiple tools in parallel from an MCP client // Automatically creates trace node and handles client lifecycle func (ctx *Context) CallToolsParallel(mcpID string, tools []types.ToolCall) (*types.CallToolsResponse, error) { // Get MCP client client, err := mcp.Select(mcpID) if err != nil { return nil, fmt.Errorf("failed to select MCP client '%s': %w", mcpID, err) } // Get client label for display clientLabel := client.GetMetaInfo().Label if clientLabel == "" { clientLabel = mcpID } // Get trace manager trace, _ := ctx.Trace() // Create trace node var node traceTypes.Node if trace != nil { node, _ = trace.Add( map[string]any{ "mcp": mcpID, "tools": tools, "count": len(tools), }, traceTypes.TraceNodeOption{ Label: i18n.T(ctx.Locale, "mcp.call_tools_parallel.label"), // "MCP: Call Tools (Parallel)" Type: "mcp", Icon: "settings", Description: fmt.Sprintf(i18n.T(ctx.Locale, "mcp.call_tools_parallel.description"), len(tools), clientLabel), // "Call %d tools in parallel from MCP client '%s'" }, ) } // Call tools in parallel (pass ctx as extraArgs for Process transport to propagate Authorized()) result, err := client.CallToolsParallel(ctx.Context, tools, ctx) if err != nil { if node != nil { node.Fail(err) } return nil, err } // Complete trace node with result if node != nil { node.Complete(map[string]any{ "results": len(result.Results), }) } return result, nil } // Prompt Operations // ================= // ListPrompts lists all available prompts from an MCP client // Automatically creates trace node and handles client lifecycle func (ctx *Context) ListPrompts(mcpID string, cursor string) (*types.ListPromptsResponse, error) { // Get MCP client client, err := mcp.Select(mcpID) if err != nil { return nil, fmt.Errorf("failed to select MCP client '%s': %w", mcpID, err) } // Get client label for display clientLabel := client.GetMetaInfo().Label if clientLabel == "" { clientLabel = mcpID } // Get trace manager trace, _ := ctx.Trace() // Create trace node var node traceTypes.Node if trace != nil { node, _ = trace.Add( map[string]any{ "mcp": mcpID, "cursor": cursor, }, traceTypes.TraceNodeOption{ Label: i18n.T(ctx.Locale, "mcp.list_prompts.label"), // "MCP: List Prompts" Type: "mcp", Icon: "chat", Description: fmt.Sprintf(i18n.T(ctx.Locale, "mcp.list_prompts.description"), clientLabel), // "List prompts from MCP client '%s'" }, ) } // Call ListPrompts result, err := client.ListPrompts(ctx.Context, cursor) if err != nil { if node != nil { node.Fail(err) } return nil, err } // Complete trace node with result if node != nil { node.Complete(map[string]any{ "prompts": len(result.Prompts), "nextCursor": result.NextCursor, }) } return result, nil } // GetPrompt gets a prompt with arguments from an MCP client // Automatically creates trace node and handles client lifecycle func (ctx *Context) GetPrompt(mcpID string, name string, arguments map[string]interface{}) (*types.GetPromptResponse, error) { // Get MCP client client, err := mcp.Select(mcpID) if err != nil { return nil, fmt.Errorf("failed to select MCP client '%s': %w", mcpID, err) } // Get client label for display clientLabel := client.GetMetaInfo().Label if clientLabel == "" { clientLabel = mcpID } // Get trace manager trace, _ := ctx.Trace() // Create trace node var node traceTypes.Node if trace != nil { node, _ = trace.Add( map[string]any{ "mcp": mcpID, "prompt": name, "arguments": arguments, }, traceTypes.TraceNodeOption{ Label: i18n.T(ctx.Locale, "mcp.get_prompt.label"), // "MCP: Get Prompt" Type: "mcp", Icon: "chat", Description: fmt.Sprintf(i18n.T(ctx.Locale, "mcp.get_prompt.description"), name, clientLabel), // "Get prompt '%s' from MCP client '%s'" }, ) } // Get prompt result, err := client.GetPrompt(ctx.Context, name, arguments) if err != nil { if node != nil { node.Fail(err) } return nil, err } // Complete trace node with result if node != nil { node.Complete(map[string]any{ "messages": len(result.Messages), }) } return result, nil } // Sample Operations // ================= // ListSamples lists samples for a tool or resource from an MCP client // Automatically creates trace node and handles client lifecycle func (ctx *Context) ListSamples(mcpID string, itemType types.SampleItemType, itemName string) (*types.ListSamplesResponse, error) { // Get MCP client client, err := mcp.Select(mcpID) if err != nil { return nil, fmt.Errorf("failed to select MCP client '%s': %w", mcpID, err) } // Get client label for display clientLabel := client.GetMetaInfo().Label if clientLabel == "" { clientLabel = mcpID } // Get trace manager trace, _ := ctx.Trace() // Create trace node var node traceTypes.Node if trace != nil { node, _ = trace.Add( map[string]any{ "mcp": mcpID, "itemType": itemType, "itemName": itemName, }, traceTypes.TraceNodeOption{ Label: i18n.T(ctx.Locale, "mcp.list_samples.label"), // "MCP: List Samples" Type: "mcp", Icon: "library_books", Description: fmt.Sprintf(i18n.T(ctx.Locale, "mcp.list_samples.description"), itemName, clientLabel), // "List samples for '%s' from MCP client '%s'" }, ) } // Call ListSamples result, err := client.ListSamples(ctx.Context, itemType, itemName) if err != nil { if node != nil { node.Fail(err) } return nil, err } // Complete trace node with result if node != nil { node.Complete(map[string]any{ "samples": len(result.Samples), }) } return result, nil } // GetSample gets a specific sample by index from an MCP client // Automatically creates trace node and handles client lifecycle func (ctx *Context) GetSample(mcpID string, itemType types.SampleItemType, itemName string, index int) (*types.SampleData, error) { // Get MCP client client, err := mcp.Select(mcpID) if err != nil { return nil, fmt.Errorf("failed to select MCP client '%s': %w", mcpID, err) } // Get client label for display clientLabel := client.GetMetaInfo().Label if clientLabel == "" { clientLabel = mcpID } // Get trace manager trace, _ := ctx.Trace() // Create trace node var node traceTypes.Node if trace != nil { node, _ = trace.Add( map[string]any{ "mcp": mcpID, "itemType": itemType, "itemName": itemName, "index": index, }, traceTypes.TraceNodeOption{ Label: i18n.T(ctx.Locale, "mcp.get_sample.label"), // "MCP: Get Sample" Type: "mcp", Icon: "library_books", Description: fmt.Sprintf(i18n.T(ctx.Locale, "mcp.get_sample.description"), index, itemName, clientLabel), // "Get sample #%d for '%s' from MCP client '%s'" }, ) } // Get sample result, err := client.GetSample(ctx.Context, itemType, itemName, index) if err != nil { if node != nil { node.Fail(err) } return nil, err } // Complete trace node with result if node != nil { node.Complete(result) } return result, nil } // Single-Server Tool Response Helpers // ==================================== // parseCallToolResponse parses a CallToolResponse and returns the parsed content directly func parseCallToolResponse(response *types.CallToolResponse) interface{} { if response == nil { return nil } return parseToolResponseContent(response) } // parseCallToolsResponse parses a CallToolsResponse and returns an array of parsed results func parseCallToolsResponse(response *types.CallToolsResponse) []interface{} { if response == nil { return nil } results := make([]interface{}, len(response.Results)) for i, r := range response.Results { results[i] = parseToolResponseContent(&r) } return results } // Cross-Server Tool Operations // ============================ // MCPToolRequest represents a request to call a tool on a specific MCP server type MCPToolRequest struct { MCP string `json:"mcp"` // MCP server ID Tool string `json:"tool"` // Tool name Arguments interface{} `json:"arguments"` // Tool arguments } // MCPToolResult represents the result of a cross-server tool call // Returns parsed result directly, with error field for failures type MCPToolResult struct { MCP string `json:"mcp"` // MCP server ID Tool string `json:"tool"` // Tool name Result interface{} `json:"result,omitempty"` // Parsed result content (directly usable) Error string `json:"error,omitempty"` // Error message (on failure) } // callToolResult is used internally to pass results through channels type callToolResult struct { idx int result *MCPToolResult } // CallToolAll calls tools on multiple MCP servers concurrently and waits for all to complete // Returns results in the same order as requests, regardless of completion order (like Promise.all) func (ctx *Context) CallToolAll(requests []*MCPToolRequest) []*MCPToolResult { if len(requests) == 0 { return []*MCPToolResult{} } results := make([]*MCPToolResult, len(requests)) done := make(chan struct{}) remaining := len(requests) for i, req := range requests { go func(idx int, r *MCPToolRequest) { defer func() { if err := recover(); err != nil { results[idx] = &MCPToolResult{ MCP: r.MCP, Tool: r.Tool, Error: fmt.Sprintf("panic: %v", err), } } done <- struct{}{} }() results[idx] = ctx.callToolSingle(r) }(i, req) } // Wait for all to complete for remaining > 0 { <-done remaining-- } return results } // CallToolAny calls tools on multiple MCP servers concurrently and returns when any succeeds // Returns all results received so far when first success is found (like Promise.any) func (ctx *Context) CallToolAny(requests []*MCPToolRequest) []*MCPToolResult { if len(requests) == 0 { return []*MCPToolResult{} } resultChan := make(chan callToolResult, len(requests)) remaining := len(requests) for i, req := range requests { go func(idx int, r *MCPToolRequest) { defer func() { if err := recover(); err != nil { resultChan <- callToolResult{ idx: idx, result: &MCPToolResult{ MCP: r.MCP, Tool: r.Tool, Error: fmt.Sprintf("panic: %v", err), }, } } }() resultChan <- callToolResult{idx: idx, result: ctx.callToolSingle(r)} }(i, req) } // Collect results until we find a success or all fail results := make([]*MCPToolResult, len(requests)) for remaining > 0 { cr := <-resultChan remaining-- results[cr.idx] = cr.result // Check if this is a success (no error) if cr.result.Error == "" { break // Stop waiting, we have a success } } // Drain remaining results in background (don't block) if remaining > 0 { go func(count int) { for i := 0; i < count; i++ { <-resultChan } }(remaining) } return results } // CallToolRace calls tools on multiple MCP servers concurrently and returns when any completes // Returns all results received so far when first completion (like Promise.race) func (ctx *Context) CallToolRace(requests []*MCPToolRequest) []*MCPToolResult { if len(requests) == 0 { return []*MCPToolResult{} } resultChan := make(chan callToolResult, len(requests)) remaining := len(requests) for i, req := range requests { go func(idx int, r *MCPToolRequest) { defer func() { if err := recover(); err != nil { resultChan <- callToolResult{ idx: idx, result: &MCPToolResult{ MCP: r.MCP, Tool: r.Tool, Error: fmt.Sprintf("panic: %v", err), }, } } }() resultChan <- callToolResult{idx: idx, result: ctx.callToolSingle(r)} }(i, req) } // Get first result (success or failure) results := make([]*MCPToolResult, len(requests)) cr := <-resultChan remaining-- results[cr.idx] = cr.result // Drain remaining results in background (don't block) if remaining > 0 { go func(count int) { for i := 0; i < count; i++ { <-resultChan } }(remaining) } return results } // callToolSingle executes a single tool call on an MCP server // This is a helper method for the parallel call methods func (ctx *Context) callToolSingle(req *MCPToolRequest) *MCPToolResult { result := &MCPToolResult{ MCP: req.MCP, Tool: req.Tool, } // Call the tool using existing CallTool method response, err := ctx.CallTool(req.MCP, req.Tool, req.Arguments) if err != nil { result.Error = err.Error() return result } // Parse and return result directly result.Result = parseToolResponseContent(response) return result } // parseToolResponseContent extracts and parses the actual content from a CallToolResponse // Similar to ToolCallResult.ParsedContent() in assistant/types.go // - For "text" type, parses the Text field as JSON (or returns as string if not JSON) // - For "image" type, returns the Data and MimeType // - For "resource" type, returns the Resource object // - If only one content item, returns it directly (not as array) func parseToolResponseContent(response *types.CallToolResponse) interface{} { if response == nil || len(response.Content) == 0 { return nil } var results []interface{} for _, tc := range response.Content { switch tc.Type { case types.ToolContentTypeText: // For text type, try to parse as JSON if tc.Text != "" { var parsed interface{} if err := jsoniter.UnmarshalFromString(tc.Text, &parsed); err == nil { results = append(results, parsed) } else { // If not JSON, return as plain string results = append(results, tc.Text) } } case types.ToolContentTypeImage: // For image type, return data and mimeType results = append(results, map[string]interface{}{ "type": "image", "data": tc.Data, "mimeType": tc.MimeType, }) case types.ToolContentTypeResource: // For resource type, return the resource object if tc.Resource != nil { results = append(results, tc.Resource) } default: // Unknown type, include as-is with type info results = append(results, map[string]interface{}{ "type": tc.Type, "text": tc.Text, }) } } // If only one result, return it directly (not as array) if len(results) == 1 { return results[0] } return results } ================================================ FILE: agent/context/mcp_test.go ================================================ package context_test import ( stdContext "context" "testing" "github.com/yaoapp/gou/mcp/types" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // newTestMCPContext creates a test context func newTestMCPContext() *context.Context { ctx := context.New(stdContext.Background(), nil, "test-chat") ctx.AssistantID = "test-assistant" ctx.Locale = "en" ctx.Referer = context.RefererAPI // Initialize stack and trace stack, traceID, _ := context.EnterStack(ctx, "test-assistant", &context.Options{}) ctx.Stack = stack _ = traceID // traceID is set in stack return ctx } // TestListResources tests the ListResources function func TestListResources(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newTestMCPContext() result, err := ctx.ListResources("echo", "") if err != nil { t.Fatalf("ListResources failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if len(result.Resources) == 0 { t.Error("Expected resources, got empty list") } t.Logf("✓ ListResources returned %d resources", len(result.Resources)) // Check if specific resources exist resourceNames := make(map[string]bool) for _, resource := range result.Resources { resourceNames[resource.Name] = true t.Logf(" - Resource: %s (URI: %s)", resource.Name, resource.URI) } if !resourceNames["info"] { t.Error("Expected 'info' resource not found") } if !resourceNames["health"] { t.Error("Expected 'health' resource not found") } } // TestReadResource tests the ReadResource function func TestReadResource(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newTestMCPContext() t.Run("ReadServerInfo", func(t *testing.T) { result, err := ctx.ReadResource("echo", "echo://info") if err != nil { t.Fatalf("ReadResource failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if len(result.Contents) == 0 { t.Error("Expected contents, got empty list") } t.Logf("✓ ReadResource returned %d contents", len(result.Contents)) }) t.Run("ReadHealthCheck", func(t *testing.T) { result, err := ctx.ReadResource("echo", "echo://health?check=all") if err != nil { t.Fatalf("ReadResource failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if len(result.Contents) == 0 { t.Error("Expected contents, got empty list") } t.Logf("✓ ReadResource for health check returned %d contents", len(result.Contents)) }) } // TestListTools tests the ListTools function func TestListTools(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newTestMCPContext() result, err := ctx.ListTools("echo", "") if err != nil { t.Fatalf("ListTools failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if len(result.Tools) == 0 { t.Error("Expected tools, got empty list") } t.Logf("✓ ListTools returned %d tools", len(result.Tools)) // Check if specific tools exist toolNames := make(map[string]bool) for _, tool := range result.Tools { toolNames[tool.Name] = true } if !toolNames["ping"] { t.Error("Expected 'ping' tool not found") } if !toolNames["status"] { t.Error("Expected 'status' tool not found") } if !toolNames["echo"] { t.Error("Expected 'echo' tool not found") } } // TestCallTool tests the CallTool function func TestCallTool(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newTestMCPContext() t.Run("CallPing", func(t *testing.T) { result, err := ctx.CallTool("echo", "ping", map[string]interface{}{ "count": 3, "message": "test", }) if err != nil { t.Fatalf("CallTool failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if len(result.Content) == 0 { t.Error("Expected content, got empty list") } t.Logf("✓ CallTool (ping) returned %d contents", len(result.Content)) }) t.Run("CallStatus", func(t *testing.T) { result, err := ctx.CallTool("echo", "status", map[string]interface{}{ "verbose": true, }) if err != nil { t.Fatalf("CallTool failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if len(result.Content) == 0 { t.Error("Expected content, got empty list") } t.Logf("✓ CallTool (status) returned %d contents", len(result.Content)) }) t.Run("CallEcho", func(t *testing.T) { result, err := ctx.CallTool("echo", "echo", map[string]interface{}{ "message": "Hello World", "uppercase": true, }) if err != nil { t.Fatalf("CallTool failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if len(result.Content) == 0 { t.Error("Expected content, got empty list") } t.Logf("✓ CallTool (echo) returned %d contents", len(result.Content)) }) } // TestCallTools tests the CallTools function (sequential) func TestCallTools(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newTestMCPContext() tools := []types.ToolCall{ { Name: "ping", Arguments: map[string]interface{}{ "count": 1, }, }, { Name: "status", Arguments: map[string]interface{}{ "verbose": false, }, }, { Name: "echo", Arguments: map[string]interface{}{ "message": "test", }, }, } result, err := ctx.CallTools("echo", tools) if err != nil { t.Fatalf("CallTools failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if len(result.Results) != 3 { t.Errorf("Expected 3 results, got %d", len(result.Results)) } t.Logf("✓ CallTools returned %d results", len(result.Results)) } // TestCallToolsParallel tests the CallToolsParallel function func TestCallToolsParallel(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newTestMCPContext() tools := []types.ToolCall{ { Name: "ping", Arguments: map[string]interface{}{ "count": 1, }, }, { Name: "status", Arguments: map[string]interface{}{ "verbose": true, }, }, } result, err := ctx.CallToolsParallel("echo", tools) if err != nil { t.Fatalf("CallToolsParallel failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if len(result.Results) != 2 { t.Errorf("Expected 2 results, got %d", len(result.Results)) } t.Logf("✓ CallToolsParallel returned %d results", len(result.Results)) } // TestListPrompts tests the ListPrompts function func TestListPrompts(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newTestMCPContext() result, err := ctx.ListPrompts("echo", "") if err != nil { t.Fatalf("ListPrompts failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if len(result.Prompts) == 0 { t.Error("Expected prompts, got empty list") } t.Logf("✓ ListPrompts returned %d prompts", len(result.Prompts)) // Check if specific prompts exist promptNames := make(map[string]bool) for _, prompt := range result.Prompts { promptNames[prompt.Name] = true } if !promptNames["test_connection"] { t.Error("Expected 'test_connection' prompt not found") } if !promptNames["test_echo"] { t.Error("Expected 'test_echo' prompt not found") } } // TestGetPrompt tests the GetPrompt function func TestGetPrompt(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newTestMCPContext() t.Run("GetTestConnectionPrompt", func(t *testing.T) { result, err := ctx.GetPrompt("echo", "test_connection", map[string]interface{}{ "detailed": "true", }) if err != nil { t.Fatalf("GetPrompt failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if len(result.Messages) == 0 { t.Error("Expected messages, got empty list") } t.Logf("✓ GetPrompt returned %d messages", len(result.Messages)) }) t.Run("GetTestEchoPrompt", func(t *testing.T) { result, err := ctx.GetPrompt("echo", "test_echo", map[string]interface{}{ "message": "Hello", "format": "uppercase", }) if err != nil { t.Fatalf("GetPrompt failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if len(result.Messages) == 0 { t.Error("Expected messages, got empty list") } t.Logf("✓ GetPrompt returned %d messages", len(result.Messages)) }) } // TestListSamples tests the ListSamples function func TestListSamples(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newTestMCPContext() t.Run("ListToolSamples", func(t *testing.T) { result, err := ctx.ListSamples("echo", types.SampleTool, "ping") if err != nil { t.Fatalf("ListSamples failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if len(result.Samples) == 0 { t.Error("Expected samples, got empty list") } t.Logf("✓ ListSamples for tool 'ping' returned %d samples", len(result.Samples)) }) t.Run("ListResourceSamples", func(t *testing.T) { result, err := ctx.ListSamples("echo", types.SampleResource, "info") if err != nil { t.Fatalf("ListSamples failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if len(result.Samples) == 0 { t.Error("Expected samples, got empty list") } t.Logf("✓ ListSamples for resource 'info' returned %d samples", len(result.Samples)) }) } // TestGetSample tests the GetSample function func TestGetSample(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newTestMCPContext() t.Run("GetToolSample", func(t *testing.T) { result, err := ctx.GetSample("echo", types.SampleTool, "ping", 0) if err != nil { t.Fatalf("GetSample failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if result.Name == "" { t.Error("Expected sample name, got empty string") } t.Logf("✓ GetSample for tool 'ping' returned sample '%s'", result.Name) }) t.Run("GetResourceSample", func(t *testing.T) { result, err := ctx.GetSample("echo", types.SampleResource, "info", 0) if err != nil { t.Fatalf("GetSample failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } if result.Name == "" { t.Error("Expected sample name, got empty string") } t.Logf("✓ GetSample for resource 'info' returned sample '%s'", result.Name) }) } // TestMCPWithTrace tests MCP operations with trace func TestMCPWithTrace(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newTestMCPContext() // Initialize trace trace, err := ctx.Trace() if err != nil { t.Fatalf("Failed to initialize trace: %v", err) } if trace == nil { t.Fatal("Expected trace, got nil") } // Call tool with trace result, err := ctx.CallTool("echo", "ping", map[string]interface{}{ "count": 5, }) if err != nil { t.Fatalf("CallTool with trace failed: %v", err) } if result == nil { t.Fatal("Expected result, got nil") } // Get trace nodes to verify trace was created nodes, err := trace.GetAllNodes() if err != nil { t.Fatalf("Failed to get trace nodes: %v", err) } if len(nodes) == 0 { t.Error("Expected trace nodes, got empty list") } t.Logf("✓ MCP operation created %d trace nodes", len(nodes)) } ================================================ FILE: agent/context/message.go ================================================ package context import ( "encoding/json" "fmt" "sync" ) // messageMetadataStore provides thread-safe storage for message and block metadata type messageMetadataStore struct { messages map[string]*MessageMetadata // Message metadata by MessageID blocks map[string]*BlockMetadata // Block metadata by BlockID mu sync.RWMutex } // newMessageMetadataStore creates a new message metadata store func newMessageMetadataStore() *messageMetadataStore { return &messageMetadataStore{ messages: make(map[string]*MessageMetadata), blocks: make(map[string]*BlockMetadata), } } // setMessage stores metadata for a message (thread-safe) func (s *messageMetadataStore) setMessage(messageID string, metadata *MessageMetadata) { s.mu.Lock() defer s.mu.Unlock() s.messages[messageID] = metadata } // getMessage retrieves metadata for a message (thread-safe) func (s *messageMetadataStore) getMessage(messageID string) *MessageMetadata { s.mu.RLock() defer s.mu.RUnlock() return s.messages[messageID] } // setBlock stores metadata for a block (thread-safe) func (s *messageMetadataStore) setBlock(blockID string, metadata *BlockMetadata) { s.mu.Lock() defer s.mu.Unlock() s.blocks[blockID] = metadata } // getBlock retrieves metadata for a block (thread-safe) func (s *messageMetadataStore) getBlock(blockID string) *BlockMetadata { s.mu.RLock() defer s.mu.RUnlock() return s.blocks[blockID] } // updateBlock updates block metadata (thread-safe) func (s *messageMetadataStore) updateBlock(blockID string, update func(*BlockMetadata)) { s.mu.Lock() defer s.mu.Unlock() if block, exists := s.blocks[blockID]; exists { update(block) } } // UnmarshalJSON custom unmarshaler for Message to handle Content field func (m *Message) UnmarshalJSON(data []byte) error { // Define a temporary struct to avoid infinite recursion type Alias Message aux := &struct { Content json.RawMessage `json:"content,omitempty"` *Alias }{ Alias: (*Alias)(m), } if err := json.Unmarshal(data, &aux); err != nil { return err } // If content is empty, return early if len(aux.Content) == 0 || string(aux.Content) == "null" { m.Content = nil return nil } // Try to unmarshal as string first var contentStr string if err := json.Unmarshal(aux.Content, &contentStr); err == nil { m.Content = contentStr return nil } // Try to unmarshal as array of ContentPart var contentParts []ContentPart if err := json.Unmarshal(aux.Content, &contentParts); err == nil { m.Content = contentParts return nil } return fmt.Errorf("content must be either a string or an array of ContentPart") } // MarshalJSON custom marshaler for Message func (m *Message) MarshalJSON() ([]byte, error) { type Alias Message return json.Marshal(&struct { *Alias }{ Alias: (*Alias)(m), }) } // NewTextMessage creates a new message with text content func NewTextMessage(role MessageRole, text string) *Message { return &Message{ Role: role, Content: text, } } // NewMultipartMessage creates a new message with multipart content func NewMultipartMessage(role MessageRole, parts []ContentPart) *Message { return &Message{ Role: role, Content: parts, } } // GetContentAsString returns content as string if possible func (m *Message) GetContentAsString() (string, bool) { if str, ok := m.Content.(string); ok { return str, true } return "", false } // GetContentAsParts returns content as ContentPart array if possible func (m *Message) GetContentAsParts() ([]ContentPart, bool) { if parts, ok := m.Content.([]ContentPart); ok { return parts, true } return nil, false } // HasToolCalls checks if the message has tool calls func (m *Message) HasToolCalls() bool { return len(m.ToolCalls) > 0 } // IsRefusal checks if the message is a refusal func (m *Message) IsRefusal() bool { return m.Refusal != nil && *m.Refusal != "" } ================================================ FILE: agent/context/message_events_test.go ================================================ package context_test import ( "bytes" stdContext "context" "encoding/json" "net/http" "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" ) func TestMessageLifecycleEvents(t *testing.T) { // Create a mock response writer var buf bytes.Buffer mockWriter := &mockResponseWriter{ buffer: &buf, headers: make(http.Header), } // Create context using New() to ensure proper initialization ctx := context.New(stdContext.Background(), nil, "test-chat") ctx.Accept = context.AcceptWebCUI ctx.Writer = mockWriter ctx.AssistantID = "test-assistant" ctx.Locale = "en" // Send a simple text message err := ctx.Send(&message.Message{ Type: message.TypeText, Props: map[string]interface{}{ "content": "Hello World", }, }) assert.NoError(t, err) // Flush to ensure all messages are written ctx.Flush() // Close SafeWriter to wait for all async writes to complete // SafeWriter uses a channel-based queue, so we must close it before reading buffer ctx.CloseSafeWriter() // Parse output to find events output := buf.String() t.Logf("Output:\n%s", output) lines := bytes.Split([]byte(output), []byte("\n")) var messages []map[string]interface{} for _, line := range lines { if len(line) == 0 { continue } // CUI format: data: {...} if bytes.HasPrefix(line, []byte("data: ")) { line = bytes.TrimPrefix(line, []byte("data: ")) } var msg map[string]interface{} if err := json.Unmarshal(line, &msg); err == nil { messages = append(messages, msg) t.Logf("Message: type=%s", msg["type"]) } } // Check for events hasMessageStart := false hasMessageEnd := false hasTextMessage := false for _, msg := range messages { msgType, _ := msg["type"].(string) if msgType == "event" { if props, ok := msg["props"].(map[string]interface{}); ok { if eventType, ok := props["event"].(string); ok { t.Logf("Event type: %s", eventType) if eventType == "message_start" { hasMessageStart = true t.Logf("✓ Found message_start event") } if eventType == "message_end" { hasMessageEnd = true t.Logf("✓ Found message_end event: %+v", props) } } } } else if msgType == "text" { hasTextMessage = true t.Logf("✓ Found text message") } } t.Logf("Summary: start=%v, text=%v, end=%v", hasMessageStart, hasTextMessage, hasMessageEnd) assert.True(t, hasMessageStart, "Should have message_start event") assert.True(t, hasTextMessage, "Should have text message") assert.True(t, hasMessageEnd, "Should have message_end event") } // mockResponseWriter implements http.ResponseWriter for testing type mockResponseWriter struct { buffer *bytes.Buffer statusCode int headers http.Header } func (m *mockResponseWriter) Header() http.Header { return m.headers } func (m *mockResponseWriter) Write(data []byte) (int, error) { return m.buffer.Write(data) } func (m *mockResponseWriter) WriteHeader(statusCode int) { m.statusCode = statusCode } ================================================ FILE: agent/context/message_test.go ================================================ package context_test import ( "encoding/json" "testing" "github.com/yaoapp/yao/agent/context" ) func TestMessage_UnmarshalJSON_StringContent(t *testing.T) { jsonData := `{ "role": "user", "content": "Hello, world!" }` var msg context.Message err := json.Unmarshal([]byte(jsonData), &msg) if err != nil { t.Fatalf("Failed to unmarshal: %v", err) } if msg.Role != context.RoleUser { t.Errorf("Expected role %s, got %s", context.RoleUser, msg.Role) } content, ok := msg.GetContentAsString() if !ok { t.Fatal("Expected content to be string") } if content != "Hello, world!" { t.Errorf("Expected content 'Hello, world!', got '%s'", content) } } func TestMessage_UnmarshalJSON_ArrayContent(t *testing.T) { jsonData := `{ "role": "user", "content": [ { "type": "text", "text": "What's in this image?" }, { "type": "image_url", "image_url": { "url": "https://example.com/image.jpg", "detail": "high" } } ] }` var msg context.Message err := json.Unmarshal([]byte(jsonData), &msg) if err != nil { t.Fatalf("Failed to unmarshal: %v", err) } if msg.Role != context.RoleUser { t.Errorf("Expected role %s, got %s", context.RoleUser, msg.Role) } parts, ok := msg.GetContentAsParts() if !ok { t.Fatal("Expected content to be array of ContentPart") } if len(parts) != 2 { t.Fatalf("Expected 2 content parts, got %d", len(parts)) } // Check first part (text) if parts[0].Type != context.ContentText { t.Errorf("Expected type %s, got %s", context.ContentText, parts[0].Type) } if parts[0].Text != "What's in this image?" { t.Errorf("Expected text 'What's in this image?', got '%s'", parts[0].Text) } // Check second part (image) if parts[1].Type != context.ContentImageURL { t.Errorf("Expected type %s, got %s", context.ContentImageURL, parts[1].Type) } if parts[1].ImageURL == nil { t.Fatal("Expected ImageURL to be non-nil") } if parts[1].ImageURL.URL != "https://example.com/image.jpg" { t.Errorf("Expected URL 'https://example.com/image.jpg', got '%s'", parts[1].ImageURL.URL) } if parts[1].ImageURL.Detail != context.DetailHigh { t.Errorf("Expected detail %s, got %s", context.DetailHigh, parts[1].ImageURL.Detail) } } func TestMessage_UnmarshalJSON_NullContent(t *testing.T) { jsonData := `{ "role": "assistant", "content": null, "tool_calls": [ { "id": "call_123", "type": "function", "function": { "name": "get_weather", "arguments": "{\"location\":\"Tokyo\"}" } } ] }` var msg context.Message err := json.Unmarshal([]byte(jsonData), &msg) if err != nil { t.Fatalf("Failed to unmarshal: %v", err) } if msg.Role != context.RoleAssistant { t.Errorf("Expected role %s, got %s", context.RoleAssistant, msg.Role) } if msg.Content != nil { t.Errorf("Expected content to be nil, got %v", msg.Content) } if !msg.HasToolCalls() { t.Fatal("Expected message to have tool calls") } if len(msg.ToolCalls) != 1 { t.Fatalf("Expected 1 tool call, got %d", len(msg.ToolCalls)) } if msg.ToolCalls[0].ID != "call_123" { t.Errorf("Expected tool call ID 'call_123', got '%s'", msg.ToolCalls[0].ID) } } func TestMessage_UnmarshalJSON_WithRefusal(t *testing.T) { refusalText := "I cannot help with that request." jsonData := `{ "role": "assistant", "content": "I'm sorry, but I can't assist with that.", "refusal": "I cannot help with that request." }` var msg context.Message err := json.Unmarshal([]byte(jsonData), &msg) if err != nil { t.Fatalf("Failed to unmarshal: %v", err) } if !msg.IsRefusal() { t.Error("Expected message to be a refusal") } if msg.Refusal == nil { t.Fatal("Expected refusal to be non-nil") } if *msg.Refusal != refusalText { t.Errorf("Expected refusal '%s', got '%s'", refusalText, *msg.Refusal) } } func TestMessage_UnmarshalJSON_AudioContent(t *testing.T) { jsonData := `{ "role": "user", "content": [ { "type": "text", "text": "Transcribe this audio" }, { "type": "input_audio", "input_audio": { "data": "base64encodedaudiodata", "format": "wav" } } ] }` var msg context.Message err := json.Unmarshal([]byte(jsonData), &msg) if err != nil { t.Fatalf("Failed to unmarshal: %v", err) } parts, ok := msg.GetContentAsParts() if !ok { t.Fatal("Expected content to be array of ContentPart") } if len(parts) != 2 { t.Fatalf("Expected 2 content parts, got %d", len(parts)) } // Check audio part if parts[1].Type != context.ContentInputAudio { t.Errorf("Expected type %s, got %s", context.ContentInputAudio, parts[1].Type) } if parts[1].InputAudio == nil { t.Fatal("Expected InputAudio to be non-nil") } if parts[1].InputAudio.Data != "base64encodedaudiodata" { t.Errorf("Expected audio data 'base64encodedaudiodata', got '%s'", parts[1].InputAudio.Data) } if parts[1].InputAudio.Format != "wav" { t.Errorf("Expected format 'wav', got '%s'", parts[1].InputAudio.Format) } } func TestMessage_MarshalJSON_StringContent(t *testing.T) { msg := context.NewTextMessage(context.RoleUser, "Hello, AI!") data, err := json.Marshal(msg) if err != nil { t.Fatalf("Failed to marshal: %v", err) } var result map[string]interface{} err = json.Unmarshal(data, &result) if err != nil { t.Fatalf("Failed to unmarshal result: %v", err) } if result["role"] != string(context.RoleUser) { t.Errorf("Expected role %s, got %v", context.RoleUser, result["role"]) } if result["content"] != "Hello, AI!" { t.Errorf("Expected content 'Hello, AI!', got %v", result["content"]) } } func TestMessage_MarshalJSON_ArrayContent(t *testing.T) { parts := []context.ContentPart{ { Type: context.ContentText, Text: "Describe this image", }, { Type: context.ContentImageURL, ImageURL: &context.ImageURL{ URL: "https://example.com/test.jpg", Detail: context.DetailLow, }, }, } msg := context.NewMultipartMessage(context.RoleUser, parts) data, err := json.Marshal(msg) if err != nil { t.Fatalf("Failed to marshal: %v", err) } // Unmarshal back to verify var result context.Message err = json.Unmarshal(data, &result) if err != nil { t.Fatalf("Failed to unmarshal result: %v", err) } resultParts, ok := result.GetContentAsParts() if !ok { t.Fatal("Expected content to be array of ContentPart") } if len(resultParts) != 2 { t.Fatalf("Expected 2 content parts, got %d", len(resultParts)) } } func TestMessage_MarshalJSON_WithToolCalls(t *testing.T) { msg := &context.Message{ Role: context.RoleAssistant, Content: nil, ToolCalls: []context.ToolCall{ { ID: "call_abc123", Type: context.ToolTypeFunction, Function: context.Function{ Name: "get_weather", Arguments: `{"location":"San Francisco"}`, }, }, }, } data, err := json.Marshal(msg) if err != nil { t.Fatalf("Failed to marshal: %v", err) } // Unmarshal back to verify var result context.Message err = json.Unmarshal(data, &result) if err != nil { t.Fatalf("Failed to unmarshal result: %v", err) } if !result.HasToolCalls() { t.Error("Expected message to have tool calls") } if len(result.ToolCalls) != 1 { t.Fatalf("Expected 1 tool call, got %d", len(result.ToolCalls)) } if result.ToolCalls[0].Function.Name != "get_weather" { t.Errorf("Expected function name 'get_weather', got '%s'", result.ToolCalls[0].Function.Name) } } func TestMessage_ToolMessage(t *testing.T) { toolCallID := "call_abc123" jsonData := `{ "role": "tool", "tool_call_id": "call_abc123", "content": "The weather in San Francisco is sunny, 72°F" }` var msg context.Message err := json.Unmarshal([]byte(jsonData), &msg) if err != nil { t.Fatalf("Failed to unmarshal: %v", err) } if msg.Role != context.RoleTool { t.Errorf("Expected role %s, got %s", context.RoleTool, msg.Role) } if msg.ToolCallID == nil { t.Fatal("Expected tool_call_id to be non-nil") } if *msg.ToolCallID != toolCallID { t.Errorf("Expected tool_call_id '%s', got '%s'", toolCallID, *msg.ToolCallID) } content, ok := msg.GetContentAsString() if !ok { t.Fatal("Expected content to be string") } if content != "The weather in San Francisco is sunny, 72°F" { t.Errorf("Unexpected content: %s", content) } } func TestNewTextMessage(t *testing.T) { msg := context.NewTextMessage(context.RoleSystem, "You are a helpful assistant.") if msg.Role != context.RoleSystem { t.Errorf("Expected role %s, got %s", context.RoleSystem, msg.Role) } content, ok := msg.GetContentAsString() if !ok { t.Fatal("Expected content to be string") } if content != "You are a helpful assistant." { t.Errorf("Expected content 'You are a helpful assistant.', got '%s'", content) } } func TestNewMultipartMessage(t *testing.T) { parts := []context.ContentPart{ {Type: context.ContentText, Text: "Hello"}, } msg := context.NewMultipartMessage(context.RoleUser, parts) if msg.Role != context.RoleUser { t.Errorf("Expected role %s, got %s", context.RoleUser, msg.Role) } resultParts, ok := msg.GetContentAsParts() if !ok { t.Fatal("Expected content to be array of ContentPart") } if len(resultParts) != 1 { t.Fatalf("Expected 1 content part, got %d", len(resultParts)) } } ================================================ FILE: agent/context/openapi.go ================================================ package context import ( "bytes" "encoding/base64" "encoding/json" "fmt" "io" "strings" "github.com/gin-gonic/gin" "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/store" "github.com/yaoapp/yao/openapi/oauth/authorized" ) // GetCompletionRequest parse completion request and create context from openapi request // Returns: *CompletionRequest, *Context, *Options, error func GetCompletionRequest(c *gin.Context, cache store.Store) (*CompletionRequest, *Context, *Options, error) { // Get authorized information authInfo := authorized.GetInfo(c) // Parse completion request from payload or query first completionReq, err := parseCompletionRequestData(c) if err != nil { return nil, nil, nil, fmt.Errorf("failed to parse completion request: %w", err) } // Extract assistant ID using completionReq (can extract from model field) assistantID, err := GetAssistantID(c, completionReq) if err != nil { return nil, nil, nil, fmt.Errorf("failed to get assistant ID: %w", err) } // Extract chat ID (may generate from messages if not provided) chatID, err := GetChatID(c, cache, completionReq) if err != nil { // Fallback: Generate a new chat ID if extraction fails chatID = GenChatID() } // Parse client information from User-Agent header userAgent := c.GetHeader("User-Agent") clientType := getClientType(userAgent) clientIP := c.ClientIP() // Create context with unique ID using New() to ensure proper initialization ctx := New(c.Request.Context(), authInfo, chatID) // Set context fields (session-level state) ctx.Cache = cache ctx.Writer = c.Writer ctx.AssistantID = assistantID ctx.Locale = GetLocale(c, completionReq) ctx.Theme = GetTheme(c, completionReq) ctx.Referer = GetReferer(c, completionReq) ctx.Accept = GetAccept(c, completionReq) ctx.Client = Client{ Type: clientType, UserAgent: userAgent, IP: clientIP, } ctx.Route = GetRoute(c, completionReq) ctx.Metadata = GetMetadata(c, completionReq) // Create Options (call-level parameters) opts := &Options{ Context: c.Request.Context(), Skip: GetSkip(c, completionReq), Mode: GetMode(c, completionReq), } // Try to extract custom connector from model field // If model is a valid connector ID, set it to opts.Connector // Otherwise, keep the standard OpenAI-compatible behavior (model as assistant ID) if completionReq != nil && completionReq.Model != "" { // Check if model is a valid connector (not containing "-yao_" which indicates assistant ID format) if !strings.Contains(completionReq.Model, "-yao_") { // Try to validate if it's a real connector if _, err := connector.Select(completionReq.Model); err == nil { // It's a valid connector, use it opts.Connector = completionReq.Model } // If not a valid connector, ignore it (keep opts.Connector empty to use assistant's default) } } // Initialize interrupt controller ctx.Interrupt = NewInterruptController() // Register context to global registry first (required for interrupt handler callback) if err := Register(ctx); err != nil { return nil, nil, nil, fmt.Errorf("failed to register context: %w", err) } // Start interrupt listener after registration // Only monitors interrupt signals (user stop button for appending messages) // HTTP context cancellation is handled by LLM/Agent layers naturally ctx.Interrupt.Start(ctx.ID) return completionReq, ctx, opts, nil } // getClientType parses the client type from User-Agent header func getClientType(userAgent string) string { if userAgent == "" { return "web" // Default to web } ua := strings.ToLower(userAgent) // Check for specific client types switch { case strings.Contains(ua, "yao-agent") || strings.Contains(ua, "agent"): return "agent" case strings.Contains(ua, "yao-jssdk") || strings.Contains(ua, "jssdk"): return "jssdk" case strings.Contains(ua, "android"): return "android" case strings.Contains(ua, "iphone") || strings.Contains(ua, "ipad") || strings.Contains(ua, "ipod"): return "ios" case strings.Contains(ua, "windows"): return "windows" case strings.Contains(ua, "mac os x") || strings.Contains(ua, "macintosh"): return "macos" case strings.Contains(ua, "linux"): return "linux" default: return "web" } } // GetAssistantID extracts assistant ID from request with priority: // 1. Query parameter "assistant_id" // 2. Header "X-Yao-Assistant" // 3. Extract from model field (from CompletionRequest or Query) - splits by "-" takes last field, extracts ID from "yao_xxx" prefix func GetAssistantID(c *gin.Context, req *CompletionRequest) (string, error) { // Priority 1: Query parameter assistant_id if assistantID := c.Query("assistant_id"); assistantID != "" { return assistantID, nil } // Priority 2: Header X-Yao-Assistant if assistantID := c.GetHeader("X-Yao-Assistant"); assistantID != "" { return assistantID, nil } // Priority 3: Extract from model field (from CompletionRequest or Query) model := "" if req != nil && req.Model != "" { model = req.Model } else { model = c.Query("model") } if model != "" { // Parse model ID using the same logic as ParseModelID // Expected format: [prefix-]assistantName-model-yao_assistantID // Find the last occurrence of "-yao_" parts := strings.Split(model, "-yao_") if len(parts) >= 2 { assistantID := parts[len(parts)-1] if assistantID != "" { return assistantID, nil } } } // If no assistant ID found, return error return "", fmt.Errorf("assistant_id is required") } // GetMessages extracts messages from the request // Priority: // 1. Query parameter "messages" (JSON string) // 2. CompletionRequest.Messages (from payload) func GetMessages(c *gin.Context, req *CompletionRequest) ([]Message, error) { // Priority 1: Query parameter messages if messagesJSON := c.Query("messages"); messagesJSON != "" { var messages []Message if err := json.Unmarshal([]byte(messagesJSON), &messages); err == nil && len(messages) > 0 { return messages, nil } } // Priority 2: From CompletionRequest (payload) if req != nil && len(req.Messages) > 0 { return req.Messages, nil } return nil, fmt.Errorf("messages field is required") } // GetLocale extracts locale from request with priority: // 1. Query parameter "locale" // 2. Header "Accept-Language" // 3. CompletionRequest metadata "locale" (from payload) func GetLocale(c *gin.Context, req *CompletionRequest) string { // Priority 1: Query parameter if locale := c.Query("locale"); locale != "" { return strings.ToLower(locale) } // Priority 2: Header Accept-Language if acceptLang := c.GetHeader("Accept-Language"); acceptLang != "" { // Parse Accept-Language header (e.g., "en-US,en;q=0.9,zh;q=0.8") // Take the first language parts := strings.Split(acceptLang, ",") if len(parts) > 0 { // Remove quality value if present lang := strings.Split(parts[0], ";")[0] return strings.ToLower(strings.TrimSpace(lang)) } } // Priority 3: From CompletionRequest metadata if req != nil && req.Metadata != nil { if locale, ok := req.Metadata["locale"]; ok { if localeStr, ok := locale.(string); ok && localeStr != "" { return strings.ToLower(localeStr) } } } return "" } // GetTheme extracts theme from request with priority: // 1. Query parameter "theme" // 2. Header "X-Yao-Theme" // 3. CompletionRequest metadata "theme" (from payload) func GetTheme(c *gin.Context, req *CompletionRequest) string { // Priority 1: Query parameter if theme := c.Query("theme"); theme != "" { return strings.ToLower(theme) } // Priority 2: Header if theme := c.GetHeader("X-Yao-Theme"); theme != "" { return strings.ToLower(theme) } // Priority 3: From CompletionRequest metadata if req != nil && req.Metadata != nil { if theme, ok := req.Metadata["theme"]; ok { if themeStr, ok := theme.(string); ok && themeStr != "" { return strings.ToLower(themeStr) } } } return "" } // GetReferer extracts referer from request with priority: // 1. Query parameter "referer" // 2. Header "X-Yao-Referer" // 3. CompletionRequest metadata "referer" (from payload) // 4. Default to "api" func GetReferer(c *gin.Context, req *CompletionRequest) string { // Priority 1: Query parameter if referer := c.Query("referer"); referer != "" { return validateReferer(referer) } // Priority 2: Header if referer := c.GetHeader("X-Yao-Referer"); referer != "" { return validateReferer(referer) } // Priority 3: From CompletionRequest metadata if req != nil && req.Metadata != nil { if referer, ok := req.Metadata["referer"]; ok { if refererStr, ok := referer.(string); ok && refererStr != "" { return validateReferer(refererStr) } } } // Priority 4: Default return RefererAPI } // GetAccept extracts accept type from request with priority: // 1. Query parameter "accept" // 2. Header "X-Yao-Accept" // 3. CompletionRequest metadata "accept" (from payload) // 4. Default to "standard" (OpenAI-compatible format) func GetAccept(c *gin.Context, req *CompletionRequest) Accept { // Priority 1: Query parameter if accept := c.Query("accept"); accept != "" { return validateAccept(accept) } // Priority 2: Header if accept := c.GetHeader("X-Yao-Accept"); accept != "" { return validateAccept(accept) } // Priority 3: From CompletionRequest metadata if req != nil && req.Metadata != nil { if accept, ok := req.Metadata["accept"]; ok { if acceptStr, ok := accept.(string); ok && acceptStr != "" { return validateAccept(acceptStr) } } } // Priority 4: Default to "standard" (OpenAI-compatible format) return AcceptStandard // // Future: Parse from User-Agent if needed // userAgent := c.GetHeader("User-Agent") // clientType := getClientType(userAgent) // return parseAccept(clientType) } // GetChatID get the chat ID from the request // Priority: // 1. Query parameter "chat_id" // 2. Header "X-Yao-Chat" // 3. CompletionRequest metadata "chat_id" (from payload) // 4. Generate from messages using GetChatIDByMessages func GetChatID(c *gin.Context, cache store.Store, req *CompletionRequest) (string, error) { // Priority 1: Query parameter chat_id if chatID := c.Query("chat_id"); chatID != "" { return chatID, nil } // Priority 2: Header X-Yao-Chat if chatID := c.GetHeader("X-Yao-Chat"); chatID != "" { return chatID, nil } // Priority 3: From CompletionRequest metadata if req != nil && req.Metadata != nil { if chatID, ok := req.Metadata["chat_id"]; ok { if chatIDStr, ok := chatID.(string); ok && chatIDStr != "" { return chatIDStr, nil } } } // Priority 4: Generate from messages messages, err := GetMessages(c, req) if err != nil { return "", fmt.Errorf("failed to get messages for chat ID generation: %w", err) } chatID, err := GetChatIDByMessages(cache, messages) if err != nil { return "", fmt.Errorf("failed to generate chat ID from messages: %w", err) } return chatID, nil } // GetRoute extracts route from request with priority: // 1. Query parameter "route" // 2. Header "X-Yao-Route" // 3. CompletionRequest.Route (from payload) func GetRoute(c *gin.Context, req *CompletionRequest) string { // Priority 1: Query parameter if route := c.Query("route"); route != "" { return route } // Priority 2: Header if route := c.GetHeader("X-Yao-Route"); route != "" { return route } // Priority 3: From CompletionRequest if req != nil && req.Route != "" { return req.Route } return "" } // GetMode extracts mode from request with priority: // 1. Query parameter "mode" // 2. Header "X-Yao-Mode" // 3. CompletionRequest metadata "mode" (from payload) func GetMode(c *gin.Context, req *CompletionRequest) string { // Priority 1: Query parameter if mode := c.Query("mode"); mode != "" { return mode } // Priority 2: Header if mode := c.GetHeader("X-Yao-Mode"); mode != "" { return mode } // Priority 3: From CompletionRequest metadata if req != nil && req.Metadata != nil { if mode, ok := req.Metadata["mode"]; ok { if modeStr, ok := mode.(string); ok && modeStr != "" { return modeStr } } } return "" } // GetSkip extracts skip configuration from request with priority: // 1. CompletionRequest.Skip (from payload body) - Priority // 2. Individual query parameters: "skip_history", "skip_trace" func GetSkip(c *gin.Context, req *CompletionRequest) *Skip { // Priority 1: From CompletionRequest body (most direct) if req != nil && req.Skip != nil { return req.Skip } // Priority 2: Individual query parameters (recommended for query usage) skipHistory := c.Query("skip_history") == "true" || c.Query("skip_history") == "1" skipTrace := c.Query("skip_trace") == "true" || c.Query("skip_trace") == "1" // Check if any skip parameter is set if c.Query("skip_history") != "" || c.Query("skip_trace") != "" { return &Skip{ History: skipHistory, Trace: skipTrace, } } return nil } // GetMetadata extracts metadata from request with priority: // 1. Query parameter "metadata" (JSON string) // 2. Header "X-Yao-Metadata" (Base64 encoded JSON string) // 3. CompletionRequest.Metadata (from payload) func GetMetadata(c *gin.Context, req *CompletionRequest) map[string]interface{} { // Priority 1: Query parameter (JSON string) if metadataJSON := c.Query("metadata"); metadataJSON != "" { var metadata map[string]interface{} if err := json.Unmarshal([]byte(metadataJSON), &metadata); err == nil { return metadata } } // Priority 2: Header (Base64 encoded JSON string) if metadataBase64 := c.GetHeader("X-Yao-Metadata"); metadataBase64 != "" { // Try to decode Base64 if decoded, err := base64.StdEncoding.DecodeString(metadataBase64); err == nil { var metadata map[string]interface{} if err := json.Unmarshal(decoded, &metadata); err == nil { return metadata } } // Fallback: try to parse as plain JSON var metadata map[string]interface{} if err := json.Unmarshal([]byte(metadataBase64), &metadata); err == nil { return metadata } } // Priority 3: From CompletionRequest if req != nil && req.Metadata != nil { return req.Metadata } return nil } // parseCompletionRequestData extracts CompletionRequest from the request // Data can be passed via: // 1. Request body (JSON payload) - Priority // 2. Query parameters func parseCompletionRequestData(c *gin.Context) (*CompletionRequest, error) { var req CompletionRequest // Try to parse from request body first if c.Request.Body != nil { body, err := io.ReadAll(c.Request.Body) if err != nil { return nil, fmt.Errorf("failed to read request body: %w", err) } // Restore body for further processing c.Request.Body = io.NopCloser(bytes.NewBuffer(body)) // If body is not empty, try to parse it if len(body) > 0 { if err := json.Unmarshal(body, &req); err != nil { return nil, fmt.Errorf("failed to parse completion request from body: %w", err) } // If we got valid data from body, validate and return // Model is optional if assistant_id can be extracted later if len(req.Messages) > 0 { return &req, nil } } } // Fallback: Try to parse from query parameters // Model is optional (can be extracted from assistant_id) model := c.Query("model") req.Model = model // Messages (required, must be JSON string in query) messagesJSON := c.Query("messages") if messagesJSON == "" { return nil, fmt.Errorf("messages field is required") } var messages []Message if err := json.Unmarshal([]byte(messagesJSON), &messages); err != nil { return nil, fmt.Errorf("failed to parse messages from query: %w", err) } if len(messages) == 0 { return nil, fmt.Errorf("messages field must not be empty") } req.Messages = messages // Optional fields from query if tempStr := c.Query("temperature"); tempStr != "" { var temp float64 if _, err := fmt.Sscanf(tempStr, "%f", &temp); err == nil { req.Temperature = &temp } } if maxTokensStr := c.Query("max_tokens"); maxTokensStr != "" { var maxTokens int if _, err := fmt.Sscanf(maxTokensStr, "%d", &maxTokens); err == nil { req.MaxTokens = &maxTokens } } if maxCompletionTokensStr := c.Query("max_completion_tokens"); maxCompletionTokensStr != "" { var maxCompletionTokens int if _, err := fmt.Sscanf(maxCompletionTokensStr, "%d", &maxCompletionTokens); err == nil { req.MaxCompletionTokens = &maxCompletionTokens } } if streamStr := c.Query("stream"); streamStr != "" { stream := streamStr == "true" || streamStr == "1" req.Stream = &stream } // Audio config from query (JSON string) if audioJSON := c.Query("audio"); audioJSON != "" { var audio AudioConfig if err := json.Unmarshal([]byte(audioJSON), &audio); err == nil { req.Audio = &audio } } // Stream options from query (JSON string) if streamOptionsJSON := c.Query("stream_options"); streamOptionsJSON != "" { var streamOptions StreamOptions if err := json.Unmarshal([]byte(streamOptionsJSON), &streamOptions); err == nil { req.StreamOptions = &streamOptions } } // Metadata from query (JSON string) if metadataJSON := c.Query("metadata"); metadataJSON != "" { var metadata map[string]interface{} if err := json.Unmarshal([]byte(metadataJSON), &metadata); err == nil { req.Metadata = metadata } } return &req, nil } ================================================ FILE: agent/context/openapi_test.go ================================================ package context_test import ( "bytes" "encoding/json" "io" "net/http/httptest" "testing" "github.com/gin-gonic/gin" "github.com/yaoapp/gou/store" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // parseCompletionRequestData is a helper function for tests to parse completion request data func parseCompletionRequestData(c *gin.Context) (*context.CompletionRequest, error) { var req context.CompletionRequest if c.Request.Body != nil { body, err := io.ReadAll(c.Request.Body) if err != nil { return nil, err } c.Request.Body = io.NopCloser(bytes.NewBuffer(body)) if len(body) > 0 { if err := json.Unmarshal(body, &req); err != nil { return nil, err } if len(req.Messages) > 0 { return &req, nil } } } return &req, nil } func TestGetMessages_FromBody(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) messages := []context.Message{ { Role: context.RoleUser, Content: "Hello, world!", }, { Role: context.RoleAssistant, Content: "Hi there!", }, } requestBody := map[string]interface{}{ "messages": messages, "model": "gpt-4", } bodyBytes, _ := json.Marshal(requestBody) req := httptest.NewRequest("POST", "/chat/completions", bytes.NewBuffer(bodyBytes)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req // Parse request first completionReq, _ := parseCompletionRequestData(c) result, err := context.GetMessages(c, completionReq) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(result) != 2 { t.Errorf("Expected 2 messages, got %d", len(result)) } if result[0].Role != context.RoleUser { t.Errorf("Expected first message role to be %s, got %s", context.RoleUser, result[0].Role) } } func TestGetMessages_FromQuery(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) messages := []context.Message{ { Role: context.RoleUser, Content: "Test message", }, } messagesJSON, _ := json.Marshal(messages) req := httptest.NewRequest("GET", "/chat/completions", nil) q := req.URL.Query() q.Add("messages", string(messagesJSON)) req.URL.RawQuery = q.Encode() w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req result, err := context.GetMessages(c, nil) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(result) != 1 { t.Errorf("Expected 1 message, got %d", len(result)) } } func TestGetMessages_EmptyMessages(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) requestBody := map[string]interface{}{ "messages": []context.Message{}, "model": "gpt-4", } bodyBytes, _ := json.Marshal(requestBody) req := httptest.NewRequest("POST", "/chat/completions", bytes.NewBuffer(bodyBytes)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq, _ := parseCompletionRequestData(c) _, err := context.GetMessages(c, completionReq) if err == nil { t.Error("Expected error for empty messages") } } func TestGetChatID_FromQuery(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) cache, err := store.Get("__yao.agent.cache") if err != nil { t.Fatalf("Failed to get cache: %v", err) } expectedChatID := "test-chat-123" req := httptest.NewRequest("GET", "/chat/completions?chat_id="+expectedChatID, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req chatID, err := context.GetChatID(c, cache, nil) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } if chatID != expectedChatID { t.Errorf("Expected chat ID %s, got %s", expectedChatID, chatID) } } func TestGetChatID_FromHeader(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) cache, err := store.Get("__yao.agent.cache") if err != nil { t.Fatalf("Failed to get cache: %v", err) } expectedChatID := "header-chat-456" req := httptest.NewRequest("GET", "/chat/completions", nil) req.Header.Set("X-Yao-Chat", expectedChatID) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req chatID, err := context.GetChatID(c, cache, nil) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } if chatID != expectedChatID { t.Errorf("Expected chat ID %s, got %s", expectedChatID, chatID) } } func TestGetChatID_FromMetadata(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) cache, err := store.Get("__yao.agent.cache") if err != nil { t.Fatalf("Failed to get cache: %v", err) } expectedChatID := "metadata-chat-789" requestBody := map[string]interface{}{ "model": "gpt-4", "messages": []map[string]interface{}{ {"role": "user", "content": "Test"}, }, "metadata": map[string]interface{}{ "chat_id": expectedChatID, }, } bodyBytes, _ := json.Marshal(requestBody) req := httptest.NewRequest("POST", "/chat/completions", bytes.NewBuffer(bodyBytes)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq, _ := parseCompletionRequestData(c) chatID, err := context.GetChatID(c, cache, completionReq) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } if chatID != expectedChatID { t.Errorf("Expected chat ID %s, got %s", expectedChatID, chatID) } } func TestGetChatID_FromMessages(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) cache, err := store.Get("__yao.agent.cache") if err != nil { t.Fatalf("Failed to get cache: %v", err) } cache.Clear() // First request with one user message messages1 := []context.Message{ { Role: context.RoleUser, Content: "First message", }, } requestBody1 := map[string]interface{}{ "model": "gpt-4", "messages": messages1, } bodyBytes1, _ := json.Marshal(requestBody1) req := httptest.NewRequest("POST", "/chat/completions", bytes.NewBuffer(bodyBytes1)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq1, _ := parseCompletionRequestData(c) chatID1, err := context.GetChatID(c, cache, completionReq1) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } if chatID1 == "" { t.Error("Expected non-empty chat ID") } // Second request with two user messages (continuation) messages2 := []context.Message{ { Role: context.RoleUser, Content: "First message", }, { Role: context.RoleUser, Content: "Second message", }, } requestBody2 := map[string]interface{}{ "model": "gpt-4", "messages": messages2, } bodyBytes2, _ := json.Marshal(requestBody2) req2 := httptest.NewRequest("POST", "/chat/completions", bytes.NewBuffer(bodyBytes2)) req2.Header.Set("Content-Type", "application/json") w2 := httptest.NewRecorder() c2, _ := gin.CreateTestContext(w2) c2.Request = req2 completionReq2, _ := parseCompletionRequestData(c2) chatID2, err := context.GetChatID(c2, cache, completionReq2) if err != nil { t.Fatalf("Failed to get chat ID second time: %v", err) } // Should get same chat ID (continuation of conversation) if chatID1 != chatID2 { t.Errorf("Expected same chat ID for continuation, got %s and %s", chatID1, chatID2) } } func TestGetChatID_Priority(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) cache, err := store.Get("__yao.agent.cache") if err != nil { t.Fatalf("Failed to get cache: %v", err) } queryChatID := "query-chat-id" headerChatID := "header-chat-id" metadataChatID := "metadata-chat-id" messages := []context.Message{ { Role: context.RoleUser, Content: "This should not be used", }, } requestBody := map[string]interface{}{ "model": "gpt-4", "messages": messages, "metadata": map[string]interface{}{ "chat_id": metadataChatID, }, } bodyBytes, _ := json.Marshal(requestBody) // Test priority: query > header > metadata > messages req := httptest.NewRequest("POST", "/chat/completions?chat_id="+queryChatID, bytes.NewBuffer(bodyBytes)) req.Header.Set("Content-Type", "application/json") req.Header.Set("X-Yao-Chat", headerChatID) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq, _ := parseCompletionRequestData(c) chatID, err := context.GetChatID(c, cache, completionReq) if err != nil { t.Fatalf("Failed to get chat ID: %v", err) } if chatID != queryChatID { t.Errorf("Expected query parameter to take priority, got %s instead of %s", chatID, queryChatID) } } func TestGetLocale_FromQuery(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions?locale=zh-CN", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req locale := context.GetLocale(c, nil) if locale != "zh-cn" { t.Errorf("Expected locale 'zh-cn', got '%s'", locale) } } func TestGetLocale_FromHeader(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions", nil) req.Header.Set("Accept-Language", "en-US,en;q=0.9,zh;q=0.8") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req locale := context.GetLocale(c, nil) if locale != "en-us" { t.Errorf("Expected locale 'en-us', got '%s'", locale) } } func TestGetLocale_FromMetadata(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("POST", "/chat/completions", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq := &context.CompletionRequest{ Metadata: map[string]interface{}{ "locale": "ja-JP", }, } locale := context.GetLocale(c, completionReq) if locale != "ja-jp" { t.Errorf("Expected locale 'ja-jp' from metadata, got '%s'", locale) } } func TestGetLocale_Priority(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions?locale=fr-FR", nil) req.Header.Set("Accept-Language", "en-US") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq := &context.CompletionRequest{ Metadata: map[string]interface{}{ "locale": "de-DE", }, } locale := context.GetLocale(c, completionReq) if locale != "fr-fr" { t.Errorf("Expected query parameter to take priority, got '%s'", locale) } } func TestGetTheme_FromQuery(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions?theme=dark", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req theme := context.GetTheme(c, nil) if theme != "dark" { t.Errorf("Expected theme 'dark', got '%s'", theme) } } func TestGetTheme_FromHeader(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions", nil) req.Header.Set("X-Yao-Theme", "light") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req theme := context.GetTheme(c, nil) if theme != "light" { t.Errorf("Expected theme 'light', got '%s'", theme) } } func TestGetTheme_FromMetadata(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("POST", "/chat/completions", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq := &context.CompletionRequest{ Metadata: map[string]interface{}{ "theme": "auto", }, } theme := context.GetTheme(c, completionReq) if theme != "auto" { t.Errorf("Expected theme 'auto' from metadata, got '%s'", theme) } } func TestGetReferer_FromMetadata(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("POST", "/chat/completions", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq := &context.CompletionRequest{ Metadata: map[string]interface{}{ "referer": "tool", }, } referer := context.GetReferer(c, completionReq) if referer != context.RefererTool { t.Errorf("Expected referer 'tool' from metadata, got '%s'", referer) } } func TestGetAccept_FromQuery(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions?accept=cui-web", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req accept := context.GetAccept(c, nil) if accept != context.AcceptWebCUI { t.Errorf("Expected accept 'cui-web' from query, got '%s'", accept) } } func TestGetAccept_FromHeader(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions", nil) req.Header.Set("X-Yao-Accept", "cui-desktop") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req accept := context.GetAccept(c, nil) if accept != context.AcceptDesktopCUI { t.Errorf("Expected accept 'cui-desktop' from header, got '%s'", accept) } } func TestGetAccept_FromMetadata(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("POST", "/chat/completions", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq := &context.CompletionRequest{ Metadata: map[string]interface{}{ "accept": "cui-native", }, } accept := context.GetAccept(c, completionReq) if accept != context.AccepNativeCUI { t.Errorf("Expected accept 'cui-native' from metadata, got '%s'", accept) } } func TestGetAccept_Default(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req accept := context.GetAccept(c, nil) if accept != context.AcceptStandard { t.Errorf("Expected default accept 'standard', got '%s'", accept) } } func TestGetAccept_Priority(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions?accept=cui-web", nil) req.Header.Set("X-Yao-Accept", "cui-desktop") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq := &context.CompletionRequest{ Metadata: map[string]interface{}{ "accept": "cui-native", }, } accept := context.GetAccept(c, completionReq) if accept != context.AcceptWebCUI { t.Errorf("Expected query parameter to take priority, got '%s'", accept) } } func TestGetAssistantID_FromModel(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("POST", "/chat/completions", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq := &context.CompletionRequest{ Model: "gpt-4-turbo-yao_myassistant", } assistantID, err := context.GetAssistantID(c, completionReq) if err != nil { t.Fatalf("Failed to get assistant ID: %v", err) } if assistantID != "myassistant" { t.Errorf("Expected assistant ID 'myassistant', got '%s'", assistantID) } } func TestGetAssistantID_Priority(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions?assistant_id=from_query", nil) req.Header.Set("X-Yao-Assistant", "from_header") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq := &context.CompletionRequest{ Model: "gpt-4-yao_from_model", } assistantID, err := context.GetAssistantID(c, completionReq) if err != nil { t.Fatalf("Failed to get assistant ID: %v", err) } if assistantID != "from_query" { t.Errorf("Expected query parameter to take priority, got '%s'", assistantID) } } func TestGetRoute_FromQuery(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions?route=/dashboard/home", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req route := context.GetRoute(c, nil) if route != "/dashboard/home" { t.Errorf("Expected route '/dashboard/home', got '%s'", route) } } func TestGetRoute_FromHeader(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions", nil) req.Header.Set("X-Yao-Route", "/settings/profile") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req route := context.GetRoute(c, nil) if route != "/settings/profile" { t.Errorf("Expected route '/settings/profile', got '%s'", route) } } func TestGetRoute_FromPayload(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("POST", "/chat/completions", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq := &context.CompletionRequest{ Route: "/admin/users", } route := context.GetRoute(c, completionReq) if route != "/admin/users" { t.Errorf("Expected route '/admin/users' from payload, got '%s'", route) } } func TestGetRoute_Priority(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions?route=/from/query", nil) req.Header.Set("X-Yao-Route", "/from/header") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq := &context.CompletionRequest{ Route: "/from/payload", } route := context.GetRoute(c, completionReq) if route != "/from/query" { t.Errorf("Expected query parameter to take priority, got '%s'", route) } } func TestGetMetadata_FromQuery(t *testing.T) { gin.SetMode(gin.TestMode) data := map[string]interface{}{ "key1": "value1", "key2": float64(123), } dataJSON, _ := json.Marshal(data) req := httptest.NewRequest("GET", "/chat/completions?metadata="+string(dataJSON), nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req result := context.GetMetadata(c, nil) if result == nil { t.Fatal("Expected data to be returned") } if result["key1"] != "value1" { t.Errorf("Expected key1='value1', got '%v'", result["key1"]) } if result["key2"] != float64(123) { t.Errorf("Expected key2=123, got '%v'", result["key2"]) } } func TestGetMetadata_FromHeader_Base64(t *testing.T) { gin.SetMode(gin.TestMode) dataBase64 := "eyJ1c2VyX2lkIjo0NTYsImFjdGlvbiI6ImNyZWF0ZSJ9" // base64 of {"user_id":456,"action":"create"} req := httptest.NewRequest("GET", "/chat/completions", nil) req.Header.Set("X-Yao-Metadata", dataBase64) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req result := context.GetMetadata(c, nil) if result == nil { t.Fatal("Expected data to be returned") } if result["action"] != "create" { t.Errorf("Expected action='create', got '%v'", result["action"]) } if result["user_id"] != float64(456) { t.Errorf("Expected user_id=456, got '%v'", result["user_id"]) } } func TestGetMetadata_FromPayload(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("POST", "/chat/completions", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req data := map[string]interface{}{ "page": float64(1), "limit": float64(10), } completionReq := &context.CompletionRequest{ Metadata: data, } result := context.GetMetadata(c, completionReq) if result == nil { t.Fatal("Expected data to be returned") } if result["page"] != float64(1) { t.Errorf("Expected page=1, got '%v'", result["page"]) } if result["limit"] != float64(10) { t.Errorf("Expected limit=10, got '%v'", result["limit"]) } } func TestGetMetadata_Priority(t *testing.T) { gin.SetMode(gin.TestMode) queryData := map[string]interface{}{ "source": "query", } queryDataJSON, _ := json.Marshal(queryData) headerDataBase64 := "eyJzb3VyY2UiOiJoZWFkZXIifQ==" // base64 of {"source":"header"} req := httptest.NewRequest("GET", "/chat/completions?metadata="+string(queryDataJSON), nil) req.Header.Set("X-Yao-Metadata", headerDataBase64) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req payloadData := map[string]interface{}{ "source": "payload", } completionReq := &context.CompletionRequest{ Metadata: payloadData, } result := context.GetMetadata(c, completionReq) if result == nil { t.Fatal("Expected data to be returned") } if result["source"] != "query" { t.Errorf("Expected query parameter to take priority, got '%v'", result["source"]) } } func TestGetMetadata_EmptyData(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req result := context.GetMetadata(c, nil) if result != nil { t.Errorf("Expected nil data, got '%v'", result) } } func TestGetCompletionRequest_WriterInitialized(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) cache, err := store.Get("__yao.agent.cache") if err != nil { t.Fatalf("Failed to get cache: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Test message", }, } requestBody := map[string]interface{}{ "model": "gpt-4-yao_test", "messages": messages, } bodyBytes, _ := json.Marshal(requestBody) req := httptest.NewRequest("POST", "/chat/completions", bytes.NewBuffer(bodyBytes)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq, ctx, opts, err := context.GetCompletionRequest(c, cache) if err != nil { t.Fatalf("Failed to get completion request: %v", err) } defer ctx.Release() // Check that Writer is initialized if ctx.Writer == nil { t.Error("Expected ctx.Writer to be initialized, got nil") } // Check that Writer is the same as gin context writer if ctx.Writer != c.Writer { t.Error("Expected ctx.Writer to be the same as gin context writer") } // Check that Options is initialized if opts == nil { t.Error("Expected opts to be initialized, got nil") } // Check other fields if completionReq.Model != "gpt-4-yao_test" { t.Errorf("Expected model 'gpt-4-yao_test', got '%s'", completionReq.Model) } if ctx.AssistantID != "test" { t.Errorf("Expected assistant ID 'test', got '%s'", ctx.AssistantID) } // Check that ChatID was generated (fallback) if ctx.ChatID == "" { t.Error("Expected ChatID to be generated, got empty string") } } func TestGetCompletionRequest_ChatIDFallback(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) cache, err := store.Get("__yao.agent.cache") if err != nil { t.Fatalf("Failed to get cache: %v", err) } // Request without explicit chat_id should generate one messages := []context.Message{ { Role: context.RoleUser, Content: "Test message", }, } requestBody := map[string]interface{}{ "model": "gpt-4-yao_assistant1", "messages": messages, } bodyBytes, _ := json.Marshal(requestBody) req := httptest.NewRequest("POST", "/chat/completions", bytes.NewBuffer(bodyBytes)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req _, ctx, opts, err := context.GetCompletionRequest(c, cache) if err != nil { t.Fatalf("Failed to get completion request: %v", err) } defer ctx.Release() // Check that Options is initialized if opts == nil { t.Error("Expected opts to be initialized, got nil") } // ChatID should be generated (not empty) if ctx.ChatID == "" { t.Error("Expected ChatID to be generated via fallback, got empty string") } // ChatID should be a valid NanoID format (16 characters) if len(ctx.ChatID) < 8 { t.Errorf("Expected ChatID to be at least 8 characters, got %d", len(ctx.ChatID)) } } func TestGetSkip_FromBody(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("POST", "/chat/completions", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq := &context.CompletionRequest{ Skip: &context.Skip{ History: true, Trace: false, }, } skip := context.GetSkip(c, completionReq) if skip == nil { t.Fatal("Expected skip to be returned") } if !skip.History { t.Error("Expected skip.History to be true") } if skip.Trace { t.Error("Expected skip.Trace to be false") } } func TestGetSkip_FromQueryParams(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions?skip_history=true&skip_trace=false", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req skip := context.GetSkip(c, nil) if skip == nil { t.Fatal("Expected skip to be returned") } if !skip.History { t.Error("Expected skip.History to be true from query param") } if skip.Trace { t.Error("Expected skip.Trace to be false") } } func TestGetSkip_FromQueryParams_ShortForm(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions?skip_history=1&skip_trace=1", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req skip := context.GetSkip(c, nil) if skip == nil { t.Fatal("Expected skip to be returned") } if !skip.History { t.Error("Expected skip.History to be true from query param (1)") } if !skip.Trace { t.Error("Expected skip.Trace to be true from query param (1)") } } func TestGetSkip_Priority(t *testing.T) { gin.SetMode(gin.TestMode) // Body should take priority over query req := httptest.NewRequest("POST", "/chat/completions?skip_history=false&skip_trace=false", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq := &context.CompletionRequest{ Skip: &context.Skip{ History: true, Trace: true, }, } skip := context.GetSkip(c, completionReq) if skip == nil { t.Fatal("Expected skip to be returned") } // Body should take priority if !skip.History { t.Error("Expected body parameter to take priority, skip.History should be true") } if !skip.Trace { t.Error("Expected body parameter to take priority, skip.Trace should be true") } } func TestGetSkip_Nil(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req skip := context.GetSkip(c, nil) if skip != nil { t.Errorf("Expected skip to be nil, got %v", skip) } } func TestGetSkip_OnlyHistorySet(t *testing.T) { gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions?skip_history=true", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req skip := context.GetSkip(c, nil) if skip == nil { t.Fatal("Expected skip to be returned") } if !skip.History { t.Error("Expected skip.History to be true") } if skip.Trace { t.Error("Expected skip.Trace to be false (default)") } } func TestGetSkip_FromBodyViaParseRequest(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) // Test parsing Skip from full request body messages := []context.Message{ { Role: context.RoleUser, Content: "Generate a title for this chat", }, } requestBody := map[string]interface{}{ "model": "workers.system.title-yao_test", "messages": messages, "skip": map[string]interface{}{ "history": true, "trace": false, }, } bodyBytes, _ := json.Marshal(requestBody) req := httptest.NewRequest("POST", "/chat/completions", bytes.NewBuffer(bodyBytes)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req // Parse the request completionReq, err := parseCompletionRequestData(c) if err != nil { t.Fatalf("Failed to parse completion request: %v", err) } // Verify Skip was parsed correctly if completionReq.Skip == nil { t.Fatal("Expected Skip to be parsed from body, got nil") } if !completionReq.Skip.History { t.Error("Expected Skip.History to be true from body") } if completionReq.Skip.Trace { t.Error("Expected Skip.Trace to be false from body") } // Now test GetSkip function with the parsed request skip := context.GetSkip(c, completionReq) if skip == nil { t.Fatal("Expected GetSkip to return skip configuration") } if !skip.History { t.Error("Expected GetSkip to return History=true") } if skip.Trace { t.Error("Expected GetSkip to return Trace=false") } } func TestGetMode_FromQuery(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions?mode=task", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req mode := context.GetMode(c, nil) if mode != "task" { t.Errorf("Expected mode 'task' from query, got '%s'", mode) } } func TestGetMode_FromHeader(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions", nil) req.Header.Set("X-Yao-Mode", "chat") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req mode := context.GetMode(c, nil) if mode != "chat" { t.Errorf("Expected mode 'chat' from header, got '%s'", mode) } } func TestGetMode_FromMetadata(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq := &context.CompletionRequest{ Metadata: map[string]interface{}{ "mode": "task", }, } mode := context.GetMode(c, completionReq) if mode != "task" { t.Errorf("Expected mode 'task' from metadata, got '%s'", mode) } } func TestGetMode_Priority(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) // Query has highest priority req := httptest.NewRequest("GET", "/chat/completions?mode=query_mode", nil) req.Header.Set("X-Yao-Mode", "header_mode") w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req completionReq := &context.CompletionRequest{ Metadata: map[string]interface{}{ "mode": "metadata_mode", }, } mode := context.GetMode(c, completionReq) if mode != "query_mode" { t.Errorf("Expected mode 'query_mode' (query has priority), got '%s'", mode) } } func TestGetMode_Empty(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() gin.SetMode(gin.TestMode) req := httptest.NewRequest("GET", "/chat/completions", nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req mode := context.GetMode(c, nil) if mode != "" { t.Errorf("Expected empty mode, got '%s'", mode) } } ================================================ FILE: agent/context/options.go ================================================ package context // ToMap converts Options struct to map for JSON serialization func (opts *Options) ToMap() map[string]interface{} { if opts == nil { return nil } result := make(map[string]interface{}) // Add configurable fields (with json tags) if opts.Connector != "" { result["connector"] = opts.Connector } if opts.Mode != "" { result["mode"] = opts.Mode } if opts.Search != nil { result["search"] = opts.Search } if opts.Skip != nil { result["skip"] = opts.Skip } // Only add DisableGlobalPrompts if true (avoid false values in map) if opts.DisableGlobalPrompts { result["disable_global_prompts"] = opts.DisableGlobalPrompts } if opts.Metadata != nil { result["metadata"] = opts.Metadata } // Note: Runtime fields (Context, Writer) are not serialized (json:"-") // They should not be included in the map return result } // OptionsFromMap creates Options struct from map (e.g., from JS Hook) func OptionsFromMap(m map[string]interface{}) *Options { if m == nil { return &Options{} } opts := &Options{} // Extract configurable fields if connector, ok := m["connector"].(string); ok { opts.Connector = connector } if mode, ok := m["mode"].(string); ok { opts.Mode = mode } // Search supports: bool | SearchIntent | map[string]any | nil if search := m["search"]; search != nil { opts.Search = search } if skipMap, ok := m["skip"].(map[string]interface{}); ok { skip := &Skip{} if history, ok := skipMap["history"].(bool); ok { skip.History = history } if trace, ok := skipMap["trace"].(bool); ok { skip.Trace = trace } if output, ok := skipMap["output"].(bool); ok { skip.Output = output } opts.Skip = skip } if disableGlobalPrompts, ok := m["disable_global_prompts"].(bool); ok { opts.DisableGlobalPrompts = disableGlobalPrompts } if metadata, ok := m["metadata"].(map[string]interface{}); ok { opts.Metadata = metadata } // Note: Context and Writer are runtime fields, not restored from map // They should be set by the caller if needed return opts } ================================================ FILE: agent/context/output.go ================================================ package context import ( "time" "github.com/yaoapp/gou/llm" "github.com/yaoapp/yao/agent/output" "github.com/yaoapp/yao/agent/output/message" ) // Send sends a message via the output module // Automatically manages BlockID, ThreadID, lifecycle events, and metadata for delta operations // - For delta operations: inherits BlockID and ThreadID from original message, increments chunk count // - For new messages: auto-sets ThreadID from Stack, sends message_start event // - Sends block_start event when a new BlockID is first encountered // - Records metadata for all sent messages to enable delta inheritance func (ctx *Context) Send(msg *message.Message) error { // Call OnMessage callback if provided (for ctx.agent.Call with onChunk) if ctx.Stack != nil && ctx.Stack.Options != nil && ctx.Stack.Options.OnMessage != nil { if ret := ctx.Stack.Options.OnMessage(msg); ret != 0 { return nil // Callback requested stop } } out, err := ctx.getOutput() if err != nil { return err } // Skip lifecycle events for event-type messages (prevent recursion) isEventMessage := msg.Type == message.TypeEvent // === Handle message_start event: record metadata for future delta chunks === if isEventMessage && msg.Props != nil { if event, ok := msg.Props["event"].(string); ok && event == message.EventMessageStart { if data, ok := msg.Props["data"].(message.EventMessageStartData); ok { // Record metadata from message_start event if data.MessageID != "" && ctx.messageMetadata != nil { ctx.messageMetadata.setMessage(data.MessageID, &MessageMetadata{ MessageID: data.MessageID, ThreadID: data.ThreadID, Type: data.Type, StartTime: time.Now(), ChunkCount: 0, // Will be incremented by delta chunks }) } } } } // === Delta operations: Auto-inherit and update metadata === if msg.Delta && msg.MessageID != "" && ctx.messageMetadata != nil { if metadata := ctx.getMessageMetadata(msg.MessageID); metadata != nil { // Inherit BlockID if not specified if msg.BlockID == "" { msg.BlockID = metadata.BlockID } // Inherit ThreadID if not specified if msg.ThreadID == "" { msg.ThreadID = metadata.ThreadID } // Increment chunk count for this message metadata.ChunkCount++ // Update Buffer content for streaming messages (for storage) if ctx.Buffer != nil && msg.Props != nil { if content, ok := msg.Props["content"].(string); ok { ctx.Buffer.AppendMessageContent(msg.MessageID, content) } } } } // === Auto-generate ChunkID (always) === if msg.ChunkID == "" && !isEventMessage { if ctx.IDGenerator != nil { msg.ChunkID = ctx.IDGenerator.GenerateChunkID() } else { msg.ChunkID = message.GenerateNanoID() } } // === Non-delta operations: New message logic === if !msg.Delta && !isEventMessage { // Auto-set ThreadID for non-root Stack (nested agent calls) if msg.ThreadID == "" && ctx.Stack != nil && !ctx.Stack.IsRoot() { msg.ThreadID = ctx.Stack.ID } // BlockID is NOT auto-generated by default (only manually specified in special cases) // Example: Send a web card after LLM output, group them in the same Block // Developers can specify via ctx.Send(message, blockId) or message.block_id // === Send block_start event if this is a new block === if msg.BlockID != "" && ctx.messageMetadata != nil { if ctx.messageMetadata.getBlock(msg.BlockID) == nil { // New block, send block_start event blockStartData := message.EventBlockStartData{ BlockID: msg.BlockID, Type: "mixed", // Default type, can be enhanced later Timestamp: time.Now().UnixMilli(), } blockStartEvent := output.NewEventMessage(message.EventBlockStart, "Block started", blockStartData) if err := ctx.sendRaw(blockStartEvent); err != nil { return err } // Record block metadata ctx.messageMetadata.setBlock(msg.BlockID, &BlockMetadata{ BlockID: msg.BlockID, Type: "mixed", StartTime: time.Now(), MessageCount: 0, }) } // Increment message count for this block ctx.messageMetadata.updateBlock(msg.BlockID, func(block *BlockMetadata) { block.MessageCount++ }) } // === Generate MessageID if not provided === if msg.MessageID == "" { if ctx.IDGenerator != nil { msg.MessageID = ctx.IDGenerator.GenerateMessageID() } else { msg.MessageID = message.GenerateNanoID() // Use NanoID generator } } // === Send message_start event === messageStartData := message.EventMessageStartData{ MessageID: msg.MessageID, Type: msg.Type, Timestamp: time.Now().UnixMilli(), ThreadID: msg.ThreadID, // Include ThreadID for concurrent stream identification } messageStartEvent := output.NewEventMessage(message.EventMessageStart, "Message started", messageStartData) if err := ctx.sendRaw(messageStartEvent); err != nil { return err } // === Record message metadata with start time === if ctx.messageMetadata != nil { ctx.messageMetadata.setMessage(msg.MessageID, &MessageMetadata{ MessageID: msg.MessageID, BlockID: msg.BlockID, ThreadID: msg.ThreadID, Type: msg.Type, StartTime: time.Now(), ChunkCount: 1, // Initial chunk }) } } // === Actually send the message === if err := out.Send(msg); err != nil { return err } // === Buffer message for batch saving (non-delta, non-event messages only) === // Delta messages are streaming chunks; only final content should be saved // Event messages are transient lifecycle signals, not stored // Skip if History is disabled in options if !msg.Delta && !isEventMessage && ctx.Buffer != nil && !ctx.shouldSkipHistory() { assistantID := "" if ctx.Stack != nil { assistantID = ctx.Stack.AssistantID } ctx.Buffer.AddAssistantMessage( msg.MessageID, // Use the same MessageID as sent to client msg.Type, msg.Props, msg.BlockID, msg.ThreadID, assistantID, nil, // metadata can be added if needed ) } // === Auto-send message_end for non-delta messages (complete messages) === if !msg.Delta && !isEventMessage && msg.MessageID != "" && ctx.messageMetadata != nil { metadata := ctx.messageMetadata.getMessage(msg.MessageID) if metadata != nil { // Calculate duration durationMs := time.Since(metadata.StartTime).Milliseconds() // Extract content for the extra field var content interface{} if msg.Props != nil { if c, ok := msg.Props["content"]; ok { content = c } } // Build message_end event data endData := message.EventMessageEndData{ MessageID: msg.MessageID, Type: msg.Type, Timestamp: time.Now().UnixMilli(), ThreadID: metadata.ThreadID, // Include ThreadID for concurrent stream identification DurationMs: durationMs, ChunkCount: metadata.ChunkCount, Status: "completed", } // Add content to extra if available if content != nil { endData.Extra = map[string]interface{}{ "content": content, } } // Send message_end event messageEndEvent := output.NewEventMessage(message.EventMessageEnd, "Message completed", endData) ctx.sendRaw(messageEndEvent) } } return nil } // SendStream sends a streaming message that can be appended to later // Unlike Send(), this does NOT automatically send message_end event // Use ctx.Append() to add content, then ctx.End() to finalize // Returns the message ID for use with Append/End func (ctx *Context) SendStream(msg *message.Message) (string, error) { out, err := ctx.getOutput() if err != nil { return "", err } // Skip lifecycle events for event-type messages isEventMessage := msg.Type == message.TypeEvent if isEventMessage { // Event messages should use Send(), not SendStream() return "", ctx.Send(msg) } // === Auto-generate ChunkID === if msg.ChunkID == "" { if ctx.IDGenerator != nil { msg.ChunkID = ctx.IDGenerator.GenerateChunkID() } else { msg.ChunkID = message.GenerateNanoID() } } // === Auto-set ThreadID for non-root Stack === if msg.ThreadID == "" && ctx.Stack != nil && !ctx.Stack.IsRoot() { msg.ThreadID = ctx.Stack.ID } // === Handle BlockID and block_start event === if msg.BlockID != "" && ctx.messageMetadata != nil { if ctx.messageMetadata.getBlock(msg.BlockID) == nil { blockStartData := message.EventBlockStartData{ BlockID: msg.BlockID, Type: "mixed", Timestamp: time.Now().UnixMilli(), } blockStartEvent := output.NewEventMessage(message.EventBlockStart, "Block started", blockStartData) if err := ctx.sendRaw(blockStartEvent); err != nil { return "", err } ctx.messageMetadata.setBlock(msg.BlockID, &BlockMetadata{ BlockID: msg.BlockID, Type: "mixed", StartTime: time.Now(), MessageCount: 0, }) } ctx.messageMetadata.updateBlock(msg.BlockID, func(block *BlockMetadata) { block.MessageCount++ }) } // === Generate MessageID if not provided === if msg.MessageID == "" { if ctx.IDGenerator != nil { msg.MessageID = ctx.IDGenerator.GenerateMessageID() } else { msg.MessageID = message.GenerateNanoID() } } // === Send message_start event === messageStartData := message.EventMessageStartData{ MessageID: msg.MessageID, Type: msg.Type, Timestamp: time.Now().UnixMilli(), ThreadID: msg.ThreadID, } messageStartEvent := output.NewEventMessage(message.EventMessageStart, "Message started", messageStartData) if err := ctx.sendRaw(messageStartEvent); err != nil { return "", err } // === Record message metadata === if ctx.messageMetadata != nil { ctx.messageMetadata.setMessage(msg.MessageID, &MessageMetadata{ MessageID: msg.MessageID, BlockID: msg.BlockID, ThreadID: msg.ThreadID, Type: msg.Type, StartTime: time.Now(), ChunkCount: 1, }) } // === Actually send the message === if err := out.Send(msg); err != nil { return "", err } // === Buffer streaming message (will be completed by End()) === if ctx.Buffer != nil && !ctx.shouldSkipHistory() { assistantID := "" if ctx.Stack != nil { assistantID = ctx.Stack.AssistantID } ctx.Buffer.AddStreamingMessage( msg.MessageID, msg.Type, msg.Props, msg.BlockID, msg.ThreadID, assistantID, nil, ) } // NOTE: No message_end event here - will be sent by End() return msg.MessageID, nil } // End finalizes a streaming message started with SendStream // Optionally appends final content before sending message_end event // This also saves the complete message to the buffer for storage func (ctx *Context) End(messageID string, finalContent ...string) error { if messageID == "" { return nil } // Append final content if provided if len(finalContent) > 0 && finalContent[0] != "" { // Create a delta message for the final content deltaMsg := &message.Message{ MessageID: messageID, Type: message.TypeText, Delta: true, DeltaAction: message.DeltaAppend, Props: map[string]interface{}{ "content": finalContent[0], }, } if err := ctx.Send(deltaMsg); err != nil { return err } } // Get complete content from buffer var completeContent string if ctx.Buffer != nil { completeContent, _ = ctx.Buffer.CompleteStreamingMessage(messageID) } // Get metadata for duration calculation var durationMs int64 var threadID string var chunkCount int var msgType string = message.TypeText if ctx.messageMetadata != nil { if metadata := ctx.messageMetadata.getMessage(messageID); metadata != nil { durationMs = time.Since(metadata.StartTime).Milliseconds() threadID = metadata.ThreadID chunkCount = metadata.ChunkCount msgType = metadata.Type } } // Build message_end event data endData := message.EventMessageEndData{ MessageID: messageID, Type: msgType, Timestamp: time.Now().UnixMilli(), ThreadID: threadID, DurationMs: durationMs, ChunkCount: chunkCount, Status: "completed", } // Add complete content to extra if completeContent != "" { endData.Extra = map[string]interface{}{ "content": completeContent, } } // Send message_end event messageEndEvent := output.NewEventMessage(message.EventMessageEnd, "Message completed", endData) return ctx.sendRaw(messageEndEvent) } // EndMessage sends a message_end event for a completed message // Note: For non-delta messages, message_end is automatically sent by Send() // This method is primarily for delta streaming scenarios: // - After all delta chunks are sent for a message, call EndMessage() to finalize it // - For LLM streaming, this is typically called after receiving ChunkMessageEnd func (ctx *Context) EndMessage(messageID string, content interface{}) error { if messageID == "" || ctx.messageMetadata == nil { return nil } metadata := ctx.messageMetadata.getMessage(messageID) if metadata == nil { return nil // Message not found, skip } // Calculate duration durationMs := time.Since(metadata.StartTime).Milliseconds() // Build message_end event data endData := message.EventMessageEndData{ MessageID: messageID, Type: metadata.Type, Timestamp: time.Now().UnixMilli(), ThreadID: metadata.ThreadID, // Include ThreadID for concurrent stream identification DurationMs: durationMs, ChunkCount: metadata.ChunkCount, Status: "completed", } // Add content to extra if provided if content != nil { endData.Extra = map[string]interface{}{ "content": content, } } // Send message_end event messageEndEvent := output.NewEventMessage(message.EventMessageEnd, "Message completed", endData) return ctx.sendRaw(messageEndEvent) } // EndBlock sends a block_end event for a completed block // This should be called explicitly when all messages in a block are complete func (ctx *Context) EndBlock(blockID string) error { if blockID == "" || ctx.messageMetadata == nil { return nil } blockMetadata := ctx.messageMetadata.getBlock(blockID) if blockMetadata == nil { return nil // Block not found, skip } // Calculate duration durationMs := time.Since(blockMetadata.StartTime).Milliseconds() // Build block_end event data endData := message.EventBlockEndData{ BlockID: blockID, Type: blockMetadata.Type, Timestamp: time.Now().UnixMilli(), DurationMs: durationMs, MessageCount: blockMetadata.MessageCount, Status: "completed", } // Send block_end event blockEndEvent := output.NewEventMessage(message.EventBlockEnd, "Block completed", endData) return ctx.sendRaw(blockEndEvent) } // SendGroup sends a group of messages via the output module // Deprecated: This method is deprecated and will be removed in future versions func (ctx *Context) SendGroup(group *message.Group) error { output, err := ctx.getOutput() if err != nil { return err } return output.SendGroup(group) } // Flush flushes the output writer func (ctx *Context) Flush() error { output, err := ctx.getOutput() if err != nil { return err } return output.Flush() } // CloseOutput closes the output writer func (ctx *Context) CloseOutput() error { output, err := ctx.getOutput() if err != nil { return err } return output.Close() } // sendRaw sends a message directly without triggering lifecycle events // Used internally to send event messages without recursion func (ctx *Context) sendRaw(msg *message.Message) error { out, err := ctx.getOutput() if err != nil { return err } return out.Send(msg) } // getWriter gets the effective Writer for the current context // Priority: Skip.Output > Stack.Options.Writer > ctx.Writer // Note: The Writer returned is always a SafeWriter (wrapped at context creation) // to ensure thread-safe concurrent writes for SSE streaming. func (ctx *Context) getWriter() Writer { // Check if output is explicitly skipped (for internal A2A calls) if ctx.Stack != nil && ctx.Stack.Options != nil && ctx.Stack.Options.Skip != nil && ctx.Stack.Options.Skip.Output { return nil // Explicitly disable output } // Check if current Stack has a Writer override if ctx.Stack != nil && ctx.Stack.Options != nil && ctx.Stack.Options.Writer != nil { return ctx.Stack.Options.Writer } return ctx.Writer } // getOutput gets the output writer for the context func (ctx *Context) getOutput() (*output.Output, error) { // Check if current Stack has cached output if ctx.Stack != nil && ctx.Stack.output != nil { return ctx.Stack.output, nil } // Ensure Writer is wrapped in SafeWriter for concurrent-safe SSE writes // This is essential for ctx.agent.All where multiple sub-agents // write to the same SSE stream concurrently. // We wrap once at the context level so all forked contexts share the same SafeWriter. writer := ctx.getWriter() if writer != nil { // Check if it's already a SafeWriter if _, ok := writer.(*output.SafeWriter); !ok { // Wrap in SafeWriter with context for automatic cleanup on client disconnect // This prevents goroutine leaks in enterprise applications var safeWriter *output.SafeWriter if ctx.Context != nil { // Use request context to detect client disconnection safeWriter = output.NewSafeWriterWithContext(ctx.Context, writer) } else { // Fallback to basic SafeWriter if no context available safeWriter = output.NewSafeWriter(writer) } ctx.Writer = safeWriter writer = safeWriter } } trace, _ := ctx.Trace() var options message.Options = message.Options{ BaseURL: "/", Writer: writer, Trace: trace, Locale: ctx.Locale, Accept: string(ctx.Accept), } if ctx.Capabilities != nil { caps := llm.Capabilities(*ctx.Capabilities) options.Capabilities = &caps } out, err := output.NewOutput(options) if err != nil { return nil, err } // Cache to current Stack (each Stack has its own output with its own Writer) if ctx.Stack != nil { ctx.Stack.output = out } return out, nil } // CloseSafeWriter closes the SafeWriter if one was created // This should be called at the end of the root request to flush any pending writes // and stop the background goroutine. func (ctx *Context) CloseSafeWriter() { if ctx.Writer == nil { return } if sw, ok := ctx.Writer.(*output.SafeWriter); ok { sw.Close() } } ================================================ FILE: agent/context/stack.go ================================================ package context import ( "fmt" "time" "github.com/google/uuid" "github.com/yaoapp/yao/trace" ) // NewStack creates a new root stack with the given trace ID and assistant ID func NewStack(traceID, assistantID, referer string, opts *Options) *Stack { if traceID == "" { traceID = uuid.New().String() } stackID := uuid.New().String() now := time.Now().UnixMilli() return &Stack{ ID: stackID, TraceID: traceID, AssistantID: assistantID, Referer: referer, Depth: 0, ParentID: "", Path: []string{stackID}, Options: opts, CreatedAt: now, Status: StackStatusRunning, } } // NewChildStack creates a child stack from the current stack func (s *Stack) NewChildStack(assistantID, referer string, opts *Options) *Stack { stackID := uuid.New().String() now := time.Now().UnixMilli() // Build path by appending current stack's path with new ID path := make([]string, len(s.Path)+1) copy(path, s.Path) path[len(s.Path)] = stackID return &Stack{ ID: stackID, TraceID: s.TraceID, // Inherit trace ID AssistantID: assistantID, Referer: referer, Depth: s.Depth + 1, ParentID: s.ID, Path: path, Options: opts, CreatedAt: now, Status: StackStatusRunning, } } // NewChildStackFromForkParent creates a child stack from ForkParentInfo // This is used by forked contexts (ctx.agent.Call) to create a child stack // without sharing the actual Stack reference (which would cause race conditions) func NewChildStackFromForkParent(parent *ForkParentInfo, assistantID, referer string, opts *Options) *Stack { stackID := uuid.New().String() now := time.Now().UnixMilli() // Build path by appending parent's path with new ID path := make([]string, len(parent.Path)+1) copy(path, parent.Path) path[len(parent.Path)] = stackID return &Stack{ ID: stackID, TraceID: parent.TraceID, // Inherit trace ID from parent AssistantID: assistantID, Referer: referer, Depth: parent.Depth + 1, ParentID: parent.StackID, // Use parent's stack ID Path: path, Options: opts, CreatedAt: now, Status: StackStatusRunning, } } // Complete marks the stack as completed and calculates duration func (s *Stack) Complete() { now := time.Now().UnixMilli() s.CompletedAt = &now s.Status = StackStatusCompleted duration := now - s.CreatedAt s.DurationMs = &duration } // Fail marks the stack as failed with an error message func (s *Stack) Fail(err error) { now := time.Now().UnixMilli() s.CompletedAt = &now s.Status = StackStatusFailed if err != nil { s.Error = err.Error() } duration := now - s.CreatedAt s.DurationMs = &duration } // Timeout marks the stack as timeout func (s *Stack) Timeout() { now := time.Now().UnixMilli() s.CompletedAt = &now s.Status = StackStatusTimeout duration := now - s.CreatedAt s.DurationMs = &duration } // IsRoot returns true if this is a root stack (no parent) func (s *Stack) IsRoot() bool { return s.ParentID == "" } // IsCompleted returns true if the stack has completed (success, failed, or timeout) func (s *Stack) IsCompleted() bool { return s.Status == StackStatusCompleted || s.Status == StackStatusFailed || s.Status == StackStatusTimeout } // IsRunning returns true if the stack is currently running func (s *Stack) IsRunning() bool { return s.Status == StackStatusRunning } // GetPathString returns the path as a string (e.g., "root_id -> parent_id -> current_id") func (s *Stack) GetPathString() string { if len(s.Path) == 0 { return s.ID } result := s.Path[0] for i := 1; i < len(s.Path); i++ { result += " -> " + s.Path[i] } return result } // String returns a string representation of the stack for debugging func (s *Stack) String() string { status := s.Status if s.IsCompleted() && s.DurationMs != nil { status = fmt.Sprintf("%s (%dms)", s.Status, *s.DurationMs) } return fmt.Sprintf("Stack[ID=%s, TraceID=%s, Assistant=%s, Depth=%d, Status=%s]", s.ID[:8], s.TraceID[:8], s.AssistantID, s.Depth, status) } // Clone creates a deep copy of the stack func (s *Stack) Clone() *Stack { clone := &Stack{ ID: s.ID, TraceID: s.TraceID, AssistantID: s.AssistantID, Referer: s.Referer, Depth: s.Depth, ParentID: s.ParentID, Path: make([]string, len(s.Path)), Options: s.Options, // Shallow copy of Options pointer CreatedAt: s.CreatedAt, Status: s.Status, Error: s.Error, } copy(clone.Path, s.Path) if s.CompletedAt != nil { completedAt := *s.CompletedAt clone.CompletedAt = &completedAt } if s.DurationMs != nil { durationMs := *s.DurationMs clone.DurationMs = &durationMs } return clone } // EnterStack initializes or creates a child stack and returns it along with trace ID and completion function // This is a helper function to manage stack context for nested calls // The stack will be automatically saved to ctx.Stacks for trace logging // // Returns: // - *Stack: current stack // - string: trace ID (generated for root, inherited for children) // - func(): completion function to be deferred // // Usage: // // stack, traceID, done := context.EnterStack(ctx, assistantID, opts) // defer done() // // ... your code here ... func EnterStack(ctx *Context, assistantID string, opts *Options) (*Stack, string, func()) { var stack *Stack var parentStack *Stack var traceID string // Get referer from ctx (request source) referer := ctx.Referer // Initialize Stacks map if not exists if ctx.Stacks == nil { ctx.Stacks = make(map[string]*Stack) } if ctx.Stack == nil { // Check if this is a forked context with parent stack info if ctx.ForkParent != nil { // Create child stack using ForkParent info // This is for forked contexts (ctx.agent.Call) to have proper ThreadID traceID = ctx.ForkParent.TraceID stack = NewChildStackFromForkParent(ctx.ForkParent, assistantID, referer, opts) ctx.Stack = stack } else { // Create root stack for this assistant call (entry point) // Generate a new trace ID for root traceID = trace.GenTraceID() stack = NewStack(traceID, assistantID, referer, opts) ctx.Stack = stack } } else { // Create child stack for nested agent call (delegate) // Inherit trace ID from parent parentStack = ctx.Stack traceID = parentStack.TraceID stack = ctx.Stack.NewChildStack(assistantID, referer, opts) ctx.Stack = stack } // Mark stack as running (in case it was pending) if stack.Status == StackStatusPending { stack.Status = StackStatusRunning } // Save stack to collection for trace logging ctx.Stacks[stack.ID] = stack // Return completion function done := func() { // Mark as completed if no panic occurred if !stack.IsCompleted() { stack.Complete() } // Restore parent stack if parentStack != nil { ctx.Stack = parentStack } } return stack, traceID, done } // GetAllStacks returns all stacks collected during the request // This is useful for trace logging after the request completes func (ctx *Context) GetAllStacks() []*Stack { if ctx.Stacks == nil { return nil } stacks := make([]*Stack, 0, len(ctx.Stacks)) for _, s := range ctx.Stacks { stacks = append(stacks, s) } return stacks } // GetStackByID returns a specific stack by its ID // This is useful for querying stack information during request processing func (ctx *Context) GetStackByID(id string) *Stack { if ctx.Stacks == nil { return nil } return ctx.Stacks[id] } // GetStacksByTraceID returns all stacks with the given trace ID // This is useful for getting the complete call tree for a trace func (ctx *Context) GetStacksByTraceID(traceID string) []*Stack { if ctx.Stacks == nil { return nil } stacks := make([]*Stack, 0) for _, s := range ctx.Stacks { if s.TraceID == traceID { stacks = append(stacks, s) } } return stacks } // GetRootStack returns the root stack (depth = 0) of current trace func (ctx *Context) GetRootStack() *Stack { if ctx.Stacks == nil { return nil } for _, s := range ctx.Stacks { if s.IsRoot() { return s } } return nil } ================================================ FILE: agent/context/stack_test.go ================================================ package context_test import ( stdContext "context" "testing" "time" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func TestNewStack(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() traceID := "12345678" assistantID := "test-assistant" referer := context.RefererAPI opts := &context.Options{} stack := context.NewStack(traceID, assistantID, referer, opts) if stack == nil { t.Fatal("Expected stack to be created, got nil") } if stack.TraceID != traceID { t.Errorf("Expected TraceID '%s', got '%s'", traceID, stack.TraceID) } if stack.AssistantID != assistantID { t.Errorf("Expected AssistantID '%s', got '%s'", assistantID, stack.AssistantID) } if stack.Referer != referer { t.Errorf("Expected Referer '%s', got '%s'", referer, stack.Referer) } if stack.Depth != 0 { t.Errorf("Expected Depth 0, got %d", stack.Depth) } if stack.ParentID != "" { t.Errorf("Expected empty ParentID, got '%s'", stack.ParentID) } if !stack.IsRoot() { t.Error("Expected stack to be root") } if stack.Status != context.StackStatusRunning { t.Errorf("Expected Status '%s', got '%s'", context.StackStatusRunning, stack.Status) } } func TestNewStack_GenerateTraceID(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Empty traceID should generate a UUID stack := context.NewStack("", "test-assistant", context.RefererAPI, &context.Options{}) if stack.TraceID == "" { t.Error("Expected TraceID to be generated, got empty string") } // Should be a valid UUID (36 characters with dashes) if len(stack.TraceID) < 8 { t.Errorf("Expected TraceID to be at least 8 characters, got %d", len(stack.TraceID)) } } func TestNewChildStack(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create parent stack parentStack := context.NewStack("12345678", "parent-assistant", context.RefererAPI, &context.Options{}) // Create child stack childStack := parentStack.NewChildStack("child-assistant", context.RefererAgent, &context.Options{}) if childStack == nil { t.Fatal("Expected child stack to be created, got nil") } // Child should inherit TraceID if childStack.TraceID != parentStack.TraceID { t.Errorf("Expected child TraceID '%s', got '%s'", parentStack.TraceID, childStack.TraceID) } // Child should have parent ID if childStack.ParentID != parentStack.ID { t.Errorf("Expected ParentID '%s', got '%s'", parentStack.ID, childStack.ParentID) } // Child should have incremented depth if childStack.Depth != parentStack.Depth+1 { t.Errorf("Expected Depth %d, got %d", parentStack.Depth+1, childStack.Depth) } // Child should not be root if childStack.IsRoot() { t.Error("Expected child stack not to be root") } // Path should include both parent and child if len(childStack.Path) != 2 { t.Errorf("Expected Path length 2, got %d", len(childStack.Path)) } if childStack.Path[0] != parentStack.ID { t.Errorf("Expected first path element '%s', got '%s'", parentStack.ID, childStack.Path[0]) } if childStack.Path[1] != childStack.ID { t.Errorf("Expected second path element '%s', got '%s'", childStack.ID, childStack.Path[1]) } } func TestStackComplete(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() stack := context.NewStack("12345678", "test-assistant", context.RefererAPI, &context.Options{}) // Wait a bit to have measurable duration time.Sleep(10 * time.Millisecond) stack.Complete() if stack.Status != context.StackStatusCompleted { t.Errorf("Expected Status '%s', got '%s'", context.StackStatusCompleted, stack.Status) } if stack.CompletedAt == nil { t.Error("Expected CompletedAt to be set, got nil") } if stack.DurationMs == nil { t.Error("Expected DurationMs to be set, got nil") } if *stack.DurationMs < 10 { t.Errorf("Expected DurationMs to be at least 10ms, got %d", *stack.DurationMs) } if !stack.IsCompleted() { t.Error("Expected stack to be completed") } if stack.IsRunning() { t.Error("Expected stack not to be running") } } func TestStackFail(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() stack := context.NewStack("12345678", "test-assistant", context.RefererAPI, &context.Options{}) testError := "test error message" stack.Fail(nil) stack.Error = testError if stack.Status != context.StackStatusFailed { t.Errorf("Expected Status '%s', got '%s'", context.StackStatusFailed, stack.Status) } if stack.Error != testError { t.Errorf("Expected Error '%s', got '%s'", testError, stack.Error) } if !stack.IsCompleted() { t.Error("Expected failed stack to be completed") } } func TestStackTimeout(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() stack := context.NewStack("12345678", "test-assistant", context.RefererAPI, &context.Options{}) stack.Timeout() if stack.Status != context.StackStatusTimeout { t.Errorf("Expected Status '%s', got '%s'", context.StackStatusTimeout, stack.Status) } if !stack.IsCompleted() { t.Error("Expected timeout stack to be completed") } } func TestEnterStack_RootCreation(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := context.New(stdContext.Background(), nil, "test-chat-id") ctx.Referer = context.RefererAPI stack, traceID, done := context.EnterStack(ctx, "test-assistant", &context.Options{}) defer done() if stack == nil { t.Fatal("Expected stack to be created, got nil") } if traceID == "" { t.Error("Expected traceID to be generated, got empty string") } // TraceID should be at least 8 digits (from trace.GenTraceID) if len(traceID) < 8 { t.Errorf("Expected traceID length at least 8, got %d", len(traceID)) } if stack.TraceID != traceID { t.Errorf("Expected stack TraceID '%s', got '%s'", traceID, stack.TraceID) } if ctx.Stack != stack { t.Error("Expected ctx.Stack to be set to created stack") } if ctx.Stacks == nil { t.Fatal("Expected ctx.Stacks to be initialized, got nil") } if ctx.Stacks[stack.ID] != stack { t.Error("Expected stack to be saved in ctx.Stacks") } if !stack.IsRoot() { t.Error("Expected stack to be root") } } func TestEnterStack_ChildCreation(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := context.New(stdContext.Background(), nil, "test-chat-id") ctx.Referer = context.RefererAPI // Create parent parentStack, parentTraceID, parentDone := context.EnterStack(ctx, "parent-assistant", &context.Options{}) defer parentDone() if parentStack == nil { t.Fatal("Expected parent stack to be created, got nil") } // Create child childStack, childTraceID, childDone := context.EnterStack(ctx, "child-assistant", &context.Options{}) defer childDone() if childStack == nil { t.Fatal("Expected child stack to be created, got nil") } // Child should inherit trace ID if childTraceID != parentTraceID { t.Errorf("Expected child traceID '%s', got '%s'", parentTraceID, childTraceID) } // Child should have parent ID if childStack.ParentID != parentStack.ID { t.Errorf("Expected child ParentID '%s', got '%s'", parentStack.ID, childStack.ParentID) } // Both should be saved in ctx.Stacks if len(ctx.Stacks) != 2 { t.Errorf("Expected 2 stacks in ctx.Stacks, got %d", len(ctx.Stacks)) } // Current stack should be child if ctx.Stack != childStack { t.Error("Expected ctx.Stack to be child stack") } } func TestEnterStack_DoneCallback(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := context.New(stdContext.Background(), nil, "test-chat-id") ctx.Referer = context.RefererAPI // Create parent parentStack, _, parentDone := context.EnterStack(ctx, "parent-assistant", &context.Options{}) // Create child childStack, _, childDone := context.EnterStack(ctx, "child-assistant", &context.Options{}) // Child should be current if ctx.Stack != childStack { t.Error("Expected ctx.Stack to be child stack before done") } // Call child done childDone() // Parent should be restored if ctx.Stack != parentStack { t.Error("Expected ctx.Stack to be restored to parent stack after child done") } // Child should be completed if !childStack.IsCompleted() { t.Error("Expected child stack to be completed after done") } // Call parent done parentDone() // Parent should be completed if !parentStack.IsCompleted() { t.Error("Expected parent stack to be completed after done") } } func TestContextGetAllStacks(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := context.New(stdContext.Background(), nil, "test-chat-id") ctx.Referer = context.RefererAPI // Create multiple stacks _, _, done1 := context.EnterStack(ctx, "assistant1", &context.Options{}) defer done1() _, _, done2 := context.EnterStack(ctx, "assistant2", &context.Options{}) defer done2() _, _, done3 := context.EnterStack(ctx, "assistant3", &context.Options{}) defer done3() // Get all stacks allStacks := ctx.GetAllStacks() if len(allStacks) != 3 { t.Errorf("Expected 3 stacks, got %d", len(allStacks)) } } func TestContextGetStackByID(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := context.New(stdContext.Background(), nil, "test-chat-id") ctx.Referer = context.RefererAPI stack, _, done := context.EnterStack(ctx, "test-assistant", &context.Options{}) defer done() // Get stack by ID found := ctx.GetStackByID(stack.ID) if found == nil { t.Fatal("Expected to find stack, got nil") } if found.ID != stack.ID { t.Errorf("Expected stack ID '%s', got '%s'", stack.ID, found.ID) } // Try to get non-existent stack notFound := ctx.GetStackByID("non-existent-id") if notFound != nil { t.Error("Expected nil for non-existent stack ID") } } func TestContextGetStacksByTraceID(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := context.New(stdContext.Background(), nil, "test-chat-id") ctx.Referer = context.RefererAPI // Create parent and child (same trace ID) _, traceID, done1 := context.EnterStack(ctx, "parent-assistant", &context.Options{}) defer done1() _, _, done2 := context.EnterStack(ctx, "child-assistant", &context.Options{}) defer done2() // Get stacks by trace ID stacks := ctx.GetStacksByTraceID(traceID) if len(stacks) != 2 { t.Errorf("Expected 2 stacks with trace ID '%s', got %d", traceID, len(stacks)) } // All should have same trace ID for _, s := range stacks { if s.TraceID != traceID { t.Errorf("Expected TraceID '%s', got '%s'", traceID, s.TraceID) } } } func TestContextGetRootStack(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := context.New(stdContext.Background(), nil, "test-chat-id") ctx.Referer = context.RefererAPI // Create parent parentStack, _, done1 := context.EnterStack(ctx, "parent-assistant", &context.Options{}) defer done1() // Create child _, _, done2 := context.EnterStack(ctx, "child-assistant", &context.Options{}) defer done2() // Get root stack rootStack := ctx.GetRootStack() if rootStack == nil { t.Fatal("Expected to find root stack, got nil") } if rootStack.ID != parentStack.ID { t.Errorf("Expected root stack ID '%s', got '%s'", parentStack.ID, rootStack.ID) } if !rootStack.IsRoot() { t.Error("Expected returned stack to be root") } } func TestStackClone(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() original := context.NewStack("12345678", "test-assistant", context.RefererAPI, &context.Options{}) original.Complete() clone := original.Clone() if clone == nil { t.Fatal("Expected clone to be created, got nil") } // Check all fields are copied if clone.ID != original.ID { t.Error("ID not cloned correctly") } if clone.TraceID != original.TraceID { t.Error("TraceID not cloned correctly") } if clone.AssistantID != original.AssistantID { t.Error("AssistantID not cloned correctly") } if clone.Status != original.Status { t.Error("Status not cloned correctly") } // Check deep copy of Path if len(clone.Path) != len(original.Path) { t.Error("Path length not cloned correctly") } // Modify clone's path shouldn't affect original if len(clone.Path) > 0 { clone.Path[0] = "modified" if original.Path[0] == "modified" { t.Error("Path is not deeply copied") } } } ================================================ FILE: agent/context/types.go ================================================ package context import ( "context" "sync" "time" "github.com/yaoapp/gou/llm" "github.com/yaoapp/gou/store" "github.com/yaoapp/yao/agent/memory" "github.com/yaoapp/yao/agent/output" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/openapi/oauth/types" infraV2 "github.com/yaoapp/yao/sandbox/v2" "github.com/yaoapp/yao/tai/workspace" traceTypes "github.com/yaoapp/yao/trace/types" ) // Accept the accept of the request, it will be used to identify the accept of the request. type Accept string // Referer the referer of the request, it will be used to identify the referer of the request. type Referer string // Client represents the client information from HTTP request type Client struct { Type string `json:"type,omitempty"` // Client type: web, android, ios, windows, macos, linux, agent, jssdk UserAgent string `json:"user_agent,omitempty"` // Original User-Agent header IP string `json:"ip,omitempty"` // Client IP address } const ( // AcceptStandard standard response format compatible with OpenAI API and general chat UIs (default) AcceptStandard = "standard" // AcceptWebCUI web-based CUI format with action request support for Yao Chat User Interface AcceptWebCUI = "cui-web" // AccepNativeCUI native mobile/tablet CUI format with action request support AccepNativeCUI = "cui-native" // AcceptDesktopCUI desktop CUI format with action request support AcceptDesktopCUI = "cui-desktop" ) // ValidAccepts is the map of valid accept types var ValidAccepts = map[string]bool{ AcceptStandard: true, AcceptWebCUI: true, AccepNativeCUI: true, AcceptDesktopCUI: true, } const ( // RefererAPI request from HTTP API endpoint RefererAPI = "api" // RefererProcess request from Yao Process call RefererProcess = "process" // RefererMCP request from MCP (Model Context Protocol) server RefererMCP = "mcp" // RefererJSSDK request from JavaScript SDK RefererJSSDK = "jssdk" // RefererAgent request from agent-to-agent delegate call (same context, saves history) RefererAgent = "agent" // RefererAgentFork request from agent-to-agent fork call (ctx.agent.Call/All/Any/Race, skips history) RefererAgentFork = "agent_fork" // RefererTool request from tool/function execution RefererTool = "tool" // RefererHook request from hook trigger (on_message, on_error, etc.) RefererHook = "hook" // RefererSchedule request from scheduled task or cron job RefererSchedule = "schedule" // RefererScript request from custom script execution RefererScript = "script" // RefererInternal request from internal system call RefererInternal = "internal" ) // ValidReferers is the map of valid referer types var ValidReferers = map[string]bool{ RefererAPI: true, RefererProcess: true, RefererMCP: true, RefererJSSDK: true, RefererAgent: true, RefererAgentFork: true, RefererTool: true, RefererHook: true, RefererSchedule: true, RefererScript: true, RefererInternal: true, } const ( // StackStatusPending stack is created but not started yet StackStatusPending = "pending" // StackStatusRunning stack is currently executing StackStatusRunning = "running" // StackStatusCompleted stack completed successfully StackStatusCompleted = "completed" // StackStatusFailed stack failed with error StackStatusFailed = "failed" // StackStatusTimeout stack execution timeout StackStatusTimeout = "timeout" ) // ValidStackStatus is the map of valid stack status types var ValidStackStatus = map[string]bool{ StackStatusPending: true, StackStatusRunning: true, StackStatusCompleted: true, StackStatusFailed: true, StackStatusTimeout: true, } // Interrupt Types and Constants // =============================== // InterruptType represents the type of interrupt type InterruptType string const ( // InterruptGraceful waits for current step to complete before handling interrupt InterruptGraceful InterruptType = "graceful" // InterruptForce immediately cancels current operation and handles interrupt InterruptForce InterruptType = "force" ) // InterruptAction represents the action to take after interrupt is handled type InterruptAction string const ( // InterruptActionContinue appends new messages and continues execution InterruptActionContinue InterruptAction = "continue" // InterruptActionRestart restarts execution with only new messages InterruptActionRestart InterruptAction = "restart" // InterruptActionAbort terminates the request InterruptActionAbort InterruptAction = "abort" ) // InterruptSignal represents an interrupt signal with new messages from user type InterruptSignal struct { Type InterruptType `json:"type"` // Interrupt type: graceful or force Messages []Message `json:"messages"` // User's new messages (can be multiple) Timestamp int64 `json:"timestamp"` // Interrupt timestamp in milliseconds Metadata map[string]interface{} `json:"metadata,omitempty"` // Additional metadata } // InterruptHandler is the function signature for handling interrupts // This handler is registered in the InterruptController and called when interrupt signal is received // Parameters: // - ctx: The context being interrupted // - signal: The interrupt signal (contains Type and Messages) // // Returns: // - error: Error if interrupt handling failed type InterruptHandler func(ctx *Context, signal *InterruptSignal) error // InterruptController manages interrupt handling for a context // All interrupt-related fields are encapsulated in this type type InterruptController struct { queue chan *InterruptSignal `json:"-"` // Queue to receive interrupt signals current *InterruptSignal `json:"-"` // Current interrupt being processed pending []*InterruptSignal `json:"-"` // Pending interrupts in queue mutex sync.RWMutex `json:"-"` // Protects current and pending ctx context.Context `json:"-"` // Interrupt control context (independent from HTTP context) cancel context.CancelFunc `json:"-"` // Cancel function for force interrupt listenerStarted bool `json:"-"` // Whether listener goroutine is started handler InterruptHandler `json:"-"` // Handler to process interrupt signals contextID string `json:"-"` // Context ID to retrieve the parent context } // AssistantInfo represents the assistant information structure type AssistantInfo struct { ID string `json:"assistant_id"` // Assistant ID Type string `json:"type,omitempty"` // Assistant Type, default is assistant Name string `json:"name,omitempty"` // Assistant Name Avatar string `json:"avatar,omitempty"` // Assistant Avatar Description string `json:"description,omitempty"` // Assistant Description } // Skip configuration for what to skip in this request type Skip struct { History bool `json:"history"` // Skip saving chat history (for internal calls like title/prompt generation) Trace bool `json:"trace"` // Skip trace logging Output bool `json:"output"` // Skip output to client (for internal A2A calls that only need response data) Keyword bool `json:"keyword"` // Skip keyword extraction for web search (use raw query directly) Search bool `json:"search"` // Skip auto search (for internal calls like needsearch intent detection) ContentParsing bool `json:"content_parsing"` // Skip content parsing (vision, PDF, docx, etc.), convert files to raw text directly } // MessageMetadata stores metadata for sent messages // Used to inherit BlockID and ThreadID in delta operations type MessageMetadata struct { MessageID string // Message ID BlockID string // Block ID ThreadID string // Thread ID Type string // Message type (text, thinking, etc.) StartTime time.Time // Message start time (for calculating duration) ChunkCount int // Number of chunks sent for this message } // BlockMetadata stores metadata for output blocks type BlockMetadata struct { BlockID string // Block ID Type string // Block type (llm, mcp, agent, etc.) StartTime time.Time // Block start time MessageCount int // Number of messages in this block } // Context the context type Context struct { // Context context.Context // External ID string `json:"id"` // Context ID for external interrupt identification Memory *memory.Memory `json:"-"` // Agent memory with four spaces: User, Team, Chat, Context Cache store.Store `json:"-"` // Cache store, it will be used to store the message cache, default is "__yao.agent.cache" Stack *Stack `json:"-"` // Stack, current active stack of the request Stacks map[string]*Stack `json:"-"` // Stacks, all stacks in this request (for trace logging) Writer Writer `json:"-"` // Writer, it will be used to write response data to the client IDGenerator *message.IDGenerator `json:"-"` // ID generator for this context (chunk, message, block, thread IDs) Logger *RequestLogger `json:"-"` // Request-scoped async logger // ForkParent stores parent stack info for forked contexts (set by Fork()) // This allows EnterStack to create a child stack instead of root stack // without sharing the actual Stack reference (which would cause race conditions) ForkParent *ForkParentInfo `json:"-"` // Chat buffer for batch saving messages and resume steps Buffer *ChatBuffer `json:"-"` // Chat buffer for batch saving at end of Stream() // Internal trace traceTypes.Manager `json:"-"` // Trace manager, lazy initialized on first access messageMetadata *messageMetadataStore `json:"-"` // Thread-safe message metadata store for delta operations sandboxExecutor SandboxExecutor `json:"-"` // Sandbox executor for hooks (set by assistant when sandbox is configured) computer infraV2.Computer `json:"-"` // V2 sandbox computer (set by assistant when V2 sandbox is configured) workspace workspace.FS `json:"-"` // V2 workspace FS (derived from computer.Workplace()) // Model capabilities (set by assistant, used by output adapters) Capabilities *llm.Capabilities `json:"-"` // Model capabilities for the current connector // Interrupt control (all interrupt-related logic is encapsulated in InterruptController) Interrupt *InterruptController `json:"-"` // Interrupt controller for handling user interrupts during streaming // Authorized information Authorized *types.AuthorizedInfo `json:"authorized,omitempty"` // Authorized information ChatID string `json:"chat_id,omitempty"` // Chat ID, use to select chat AssistantID string `json:"assistant_id,omitempty"` // Assistant ID, use to select assistant // Locale information Locale string `json:"locale,omitempty"` // Locale Theme string `json:"theme,omitempty"` // Theme // Request information Client Client `json:"client,omitempty"` // Client information from HTTP request Referer string `json:"referer,omitempty"` // Request source: api, process, mcp, jssdk, agent, tool, hook, schedule, script, internal Accept Accept `json:"accept,omitempty"` // Response format: standard, cui-web, cui-native, cui-desktop // CUI Context information Route string `json:"route,omitempty"` // The route of the request, it will be used to identify the route of the request Metadata map[string]interface{} `json:"metadata,omitempty"` // The metadata of the request, it will be used to pass data to the page } // SearchIntent represents the result of search intent detection // Used by Create hook to specify fine-grained search behavior type SearchIntent struct { NeedSearch bool `json:"need_search"` // Whether search is needed SearchTypes []string `json:"search_types,omitempty"` // Types of search to perform: "web", "kb", "db" Confidence float64 `json:"confidence,omitempty"` // Confidence level (0-1) Reason string `json:"reason,omitempty"` // Reason for the decision } // Options represents the options for the context type Options struct { // Original context, override the default context Context context.Context `json:"-"` // Context, it will be used to pass the context to the call // Writer, use to write response data to the client (override the default writer) Writer Writer `json:"writer,omitempty"` // Writer, use to write response data to the client // Skip configuration (history, trace, etc.), nil means don't skip anything Skip *Skip `json:"skip,omitempty"` // Skip configuration (history, trace, etc.), nil means don't skip anything // Connector, use to select the connector of the LLM Model, Default is Assistant.Connector Connector string `json:"connector,omitempty"` // Connector, use to select the connector of the LLM Model, Default is Assistant.Connector // Disable global prompts, default is false DisableGlobalPrompts bool `json:"disable_global_prompts,omitempty"` // Temporarily disable global prompts for this request // Search controls search behavior, supports multiple types: // - bool: true = enable all search types, false = disable all search // - SearchIntent: fine-grained control with specific types, confidence, etc. // - nil: use default behavior (determined by __yao.needsearch agent) Search any `json:"search,omitempty"` // Search mode: bool | SearchIntent | nil // Agent mode, use to select the mode of the request, default is "chat" Mode string `json:"mode,omitempty"` // Agent mode, use to select the mode of the request, default is "chat" // Uses configuration, allow hook to override wrapper configurations for vision, audio, search, and fetch Uses *Uses `json:"uses,omitempty"` // Uses configuration, allow hook to override wrapper configurations for vision, audio, search, and fetch // Metadata for passing custom data to hooks (e.g., scenario selection) Metadata map[string]any `json:"metadata,omitempty"` // Custom metadata passed to Create/Next hooks // HistorySize controls the max number of history messages loaded for LLM context. // Priority: HistorySize > StoreSetting.MaxSize > default (20) // 0 means use StoreSetting or default. HistorySize int `json:"history_size,omitempty"` // OnMessage is called for each message sent via ctx.Send() // Used by ctx.agent.Call with onChunk callback to receive SSE messages // Returns: 0 = continue, non-zero = stop OnMessage OnMessageFunc `json:"-"` } // ForceA2A sets the options for Agent-to-Agent (A2A) calls. // For A2A calls: // - Output is NOT skipped - sub-agents output normally with ThreadID // - History IS skipped - A2A messages should not be saved to chat history // If Skip is nil, it creates a new Skip instance. func (opts *Options) ForceA2A() { if opts.Skip == nil { opts.Skip = &Skip{} } opts.Skip.History = true // Note: skip.output is NOT set - sub-agents output normally with ThreadID } // OnMessageFunc is a callback function for receiving output messages // Called for each message sent via ctx.Send() - same as SSE messages to client // Returns: 0 = continue, non-zero = stop sending type OnMessageFunc func(msg *message.Message) int // ForkParentInfo stores parent stack information for forked contexts // This is used by EnterStack to create a child stack with proper inheritance // without sharing the actual Stack reference (which would cause race conditions in parallel calls) type ForkParentInfo struct { StackID string // Parent stack ID (used as ParentID for child stack) TraceID string // Parent trace ID (inherited by child stack) Depth int // Parent depth (child depth = parent depth + 1) Path []string // Parent path (child path = parent path + child ID) } // Stack represents the call stack node for tracing agent-to-agent calls // Uses a flat structure to avoid circular references and memory overhead type Stack struct { // Identity ID string `json:"id"` // Unique stack node ID, used to identify this specific call TraceID string `json:"trace_id"` // Shared trace ID for entire call tree, inherited from root // Options Options *Options `json:"options,omitempty"` // Options for the call // Call context AssistantID string `json:"assistant_id"` // Assistant handling this call Referer string `json:"referer,omitempty"` // Call source: api, agent, tool, process, etc. Depth int `json:"depth"` // Call depth in the tree (0=root) // Relationships ParentID string `json:"parent_id,omitempty"` // Parent stack ID (empty for root call) Path []string `json:"path"` // Full path from root: [root_id, parent_id, ..., this_id] // Tracking CreatedAt int64 `json:"created_at"` // Unix timestamp in milliseconds CompletedAt *int64 `json:"completed_at,omitempty"` // Unix timestamp when completed (nil if ongoing) Status string `json:"status"` // Status: pending, running, completed, failed, timeout Error string `json:"error,omitempty"` // Error message if failed // Metrics DurationMs *int64 `json:"duration_ms,omitempty"` // Duration in milliseconds (calculated when completed) // Runtime cache (not serialized) output *output.Output `json:"-"` // Cached output instance for this stack } // Response the response // 100% compatible with the OpenAI API type Response struct { RequestID string `json:"request_id"` // Request ID for the response ContextID string `json:"context_id"` // Context ID for the response TraceID string `json:"trace_id"` // Trace ID for the response ChatID string `json:"chat_id"` // Chat ID for the response AssistantID string `json:"assistant_id"` // Assistant ID for the response Create *HookCreateResponse `json:"create,omitempty"` // Create response from the create hook Next interface{} `json:"next,omitempty"` // Next response from the next hook Completion *CompletionResponse `json:"completion,omitempty"` // Completion response from the completion hook Tools []ToolCallResponse `json:"tools,omitempty"` // Tool call results (if any tools were executed) } // HookCreateResponse the response of the create hook type HookCreateResponse struct { // Messages to be sent to the assistant Messages []Message `json:"messages,omitempty"` // Audio configuration (for models that support audio output) Audio *AudioConfig `json:"audio,omitempty"` // Generation parameters Temperature *float64 `json:"temperature,omitempty"` MaxTokens *int `json:"max_tokens,omitempty"` MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` // MCP configuration - allow hook to add/override MCP servers for this request MCPServers []MCPServerConfig `json:"mcp_servers,omitempty"` // Prompt configuration PromptPreset string `json:"prompt_preset,omitempty"` // Select prompt preset (e.g., "chat.friendly", "task.analysis") DisableGlobalPrompts *bool `json:"disable_global_prompts,omitempty"` // Temporarily disable global prompts for this request // Context adjustments - allow hook to modify context fields Connector string `json:"connector,omitempty"` // Override connector (call-level) Locale string `json:"locale,omitempty"` // Override locale (session-level) Theme string `json:"theme,omitempty"` // Override theme (session-level) Route string `json:"route,omitempty"` // Override route (session-level) Metadata map[string]interface{} `json:"metadata,omitempty"` // Override or merge metadata (session-level) // Uses configuration - allow hook to override wrapper configurations Uses *Uses `json:"uses,omitempty"` // Override wrapper configurations for vision, audio, search, and fetch // ForceUses controls whether to force using Uses tools even when model has native capabilities ForceUses *bool `json:"force_uses,omitempty"` // Force using Uses tools regardless of model capabilities // Search controls search behavior, supports multiple types: // - bool: true = enable all search types, false = disable all search // - SearchIntent: fine-grained control with specific types, confidence, etc. // - nil: use default behavior (determined by __yao.needsearch agent) Search any `json:"search,omitempty"` // Search mode: bool | SearchIntent | nil // Delegate: if provided, delegate to another agent immediately (skip LLM call) // This allows Create hook to route to sub-agents before any LLM processing Delegate *DelegateConfig `json:"delegate,omitempty"` } // NextHookPayload payload for the next hook type NextHookPayload struct { Messages []Message `json:"messages,omitempty"` // Messages to be sent to the assistant Completion *CompletionResponse `json:"completion,omitempty"` // Completion response from the completion hook Tools []ToolCallResponse `json:"tools,omitempty"` // Tools results from the assistant Error string `json:"error,omitempty"` // Error message if failed } // ToolCallResponse the response of a tool call type ToolCallResponse struct { ToolCallID string `json:"toolcall_id"` Server string `json:"server"` Tool string `json:"tool"` Arguments interface{} `json:"arguments,omitempty"` Result interface{} `json:"result,omitempty"` Error string `json:"error,omitempty"` } // NextHookResponse represents the response from Next hook type NextHookResponse struct { // Delegate: if provided, delegate to another agent (recursive call) Delegate *DelegateConfig `json:"delegate,omitempty"` // Data: custom response data to return to user // If both Delegate and Data are nil, use standard CompletionResponse Data interface{} `json:"data,omitempty"` // Metadata: for debugging and logging Metadata map[string]interface{} `json:"metadata,omitempty"` } // DelegateConfig configuration for delegating to another agent type DelegateConfig struct { AgentID string `json:"agent_id"` // Required: target agent ID Messages []Message `json:"messages"` // Messages to send to target agent Options map[string]interface{} `json:"options,omitempty"` // Optional: call-level options for delegation } // NextAction defines the action determined by Next hook response type NextAction string const ( // NextActionReturn returns data to user (standard or custom) NextActionReturn NextAction = "return" // NextActionDelegate delegates to another agent NextActionDelegate NextAction = "delegate" ) // Action returns the determined action based on NextHookResponse fields func (n *NextHookResponse) Action() NextAction { if n.Delegate != nil { return NextActionDelegate } return NextActionReturn } // ResponseHookNext the response of the next hook type ResponseHookNext interface{} // ResponseHookMCP the response of the mcp hook type ResponseHookMCP struct{} // ResponseHookFailback the response of the failback hook type ResponseHookFailback struct{} // HookInterruptedResponse the response of the interrupted hook type HookInterruptedResponse struct { // Action to take after interrupt is handled Action InterruptAction `json:"action"` // continue, restart, or abort // Messages to use for next execution (if action is continue or restart) Messages []Message `json:"messages,omitempty"` // Context adjustments - allow hook to modify context fields AssistantID string `json:"assistant_id,omitempty"` // Override assistant ID Connector string `json:"connector,omitempty"` // Override connector Locale string `json:"locale,omitempty"` // Override locale Theme string `json:"theme,omitempty"` // Override theme Route string `json:"route,omitempty"` // Override route Metadata map[string]interface{} `json:"metadata,omitempty"` // Override or merge metadata // Notice to send to client Notice string `json:"notice,omitempty"` // Message to display to user (e.g., "Processing your new question...") } // Message Structure ( OpenAI Chat Completion Input Message Structure, https://platform.openai.com/docs/api-reference/chat/create#chat/create-messages ) // =============================== // MessageRole represents the role of a message author type MessageRole string // Message role constants const ( RoleDeveloper MessageRole = "developer" // Developer-provided instructions (o1 models and newer) RoleSystem MessageRole = "system" // System instructions RoleUser MessageRole = "user" // User messages RoleAssistant MessageRole = "assistant" // Assistant responses RoleTool MessageRole = "tool" // Tool responses ) // Message represents a message in the conversation, compatible with OpenAI's chat completion API // Supports message types: developer, system, user, assistant, and tool type Message struct { // Common fields for all message types Role MessageRole `json:"role"` // Required: message author role Content interface{} `json:"content,omitempty"` // string or array of ContentPart; Required for most types, optional for assistant with tool_calls Name *string `json:"name,omitempty"` // Optional: participant name to differentiate between participants of the same role // Tool message specific fields ToolCallID *string `json:"tool_call_id,omitempty"` // Required for tool messages: tool call that this message is responding to // Assistant message specific fields ToolCalls []ToolCall `json:"tool_calls,omitempty"` // Optional for assistant: tool calls generated by the model Refusal *string `json:"refusal,omitempty"` // Optional for assistant: refusal message (null when not refusing) } // ContentPartType represents the type of content part type ContentPartType string // Content part type constants const ( ContentText ContentPartType = "text" // Text content ContentImageURL ContentPartType = "image_url" // Image URL content (Vision) ContentInputAudio ContentPartType = "input_audio" // Input audio content (Audio) ContentFile ContentPartType = "file" // File attachment (documents, etc.) ContentData ContentPartType = "data" // Generic data content (base64, binary, etc.) ) // ContentPart represents a part of the message content (for multimodal messages) // Used when Content is an array instead of a simple string type ContentPart struct { Type ContentPartType `json:"type"` // Required: content part type Text string `json:"text,omitempty"` // For type="text": the text content ImageURL *ImageURL `json:"image_url,omitempty"` // For type="image_url": the image URL InputAudio *InputAudio `json:"input_audio,omitempty"` // For type="input_audio": the input audio data File *FileAttachment `json:"file,omitempty"` // For type="file": file attachment Data *DataContent `json:"data,omitempty"` // For type="data": generic data content } // ImageDetailLevel represents the detail level for image processing type ImageDetailLevel string // Image detail level constants const ( DetailAuto ImageDetailLevel = "auto" // Let the model decide DetailLow ImageDetailLevel = "low" // Low detail (faster, cheaper) DetailHigh ImageDetailLevel = "high" // High detail (slower, more expensive) ) // ImageURL represents an image URL in the message content type ImageURL struct { URL string `json:"url"` // Required: URL of the image or base64 encoded image data Detail ImageDetailLevel `json:"detail,omitempty"` // Optional: how the model processes the image } // InputAudio represents input audio data in the message content type InputAudio struct { Data string `json:"data"` // Required: Base64 encoded audio data Format string `json:"format"` // Required: Audio format (e.g., "wav", "mp3") } // FileAttachment represents a file attachment in the message content // Compatible with frontend InputArea format: { type: 'file', file: { url, filename } } type FileAttachment struct { URL string `json:"url"` // Required: URL of the file (http:// or __uploader://fileid wrapper) Filename string `json:"filename,omitempty"` // Optional: original filename } // DataSourceType represents the type of data source type DataSourceType string // Data source type constants const ( DataSourceModel DataSourceType = "model" // Data model DataSourceKBCollection DataSourceType = "kb_collection" // Knowledge base collection DataSourceKBDocument DataSourceType = "kb_document" // Knowledge base document/file DataSourceTable DataSourceType = "table" // Database table DataSourceAPI DataSourceType = "api" // API endpoint DataSourceMCPResource DataSourceType = "mcp_resource" // MCP (Model Context Protocol) resource ) // DataSource represents a single data source reference type DataSource struct { Type DataSourceType `json:"type"` // Required: type of data source Name string `json:"name"` // Required: name/identifier of the data source ID string `json:"id,omitempty"` // Optional: specific ID (e.g., document ID, record ID) Filters map[string]interface{} `json:"filters,omitempty"` // Optional: filters to apply Metadata map[string]interface{} `json:"metadata,omitempty"` // Optional: additional metadata } // DataContent represents data source references in the message // Used to reference data models, knowledge base collections, KB documents, etc. type DataContent struct { Sources []DataSource `json:"sources"` // Required: array of data source references } // ToolCallType represents the type of tool call type ToolCallType string // Tool call type constants const ( ToolTypeFunction ToolCallType = "function" // Function call ) // ToolCall represents a tool call generated by the model (for assistant messages) type ToolCall struct { ID string `json:"id"` // Required: unique identifier for the tool call Type ToolCallType `json:"type"` // Required: type of tool call, currently only "function" Function Function `json:"function"` // Required: function call details } // Function represents a function call with name and arguments type Function struct { Name string `json:"name"` // Required: name of the function to call Arguments string `json:"arguments,omitempty"` // Optional: arguments to pass to the function, as a JSON string } // Completion Request Structure ( OpenAI Chat Completion Request, https://platform.openai.com/docs/api-reference/chat/create ) // =============================== // CompletionRequest represents a chat completion request compatible with OpenAI's API type CompletionRequest struct { // Required fields Model string `json:"model"` // Required: ID of the model to use Messages []Message `json:"messages"` // Required: list of messages comprising the conversation so far // Audio configuration (for models that support audio output) Audio *AudioConfig `json:"audio,omitempty"` // Optional: audio output configuration // Generation parameters Temperature *float64 `json:"temperature,omitempty"` // Optional: sampling temperature (0-2), defaults to 1 MaxTokens *int `json:"max_tokens,omitempty"` // Optional: maximum number of tokens to generate (deprecated, use max_completion_tokens) MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` // Optional: maximum number of tokens that can be generated in the completion // Streaming configuration Stream *bool `json:"stream,omitempty"` // Optional: if true, stream partial message deltas StreamOptions *StreamOptions `json:"stream_options,omitempty"` // Optional: options for streaming response // CUI Context information Route string `json:"route,omitempty"` // Optional: route of the request for CUI context Metadata map[string]interface{} `json:"metadata,omitempty"` // Optional: metadata to pass to the page for CUI context Skip *Skip `json:"skip,omitempty"` // Optional: skip configuration (history, trace, etc.) } // AudioConfig represents the audio output configuration for models that support audio type AudioConfig struct { Voice string `json:"voice"` // Required: voice to use for audio output (e.g., "alloy", "echo", "fable", "onyx", "nova", "shimmer") Format string `json:"format"` // Required: audio output format (e.g., "wav", "mp3", "flac", "opus", "pcm16") } // StreamOptions represents options for streaming responses type StreamOptions struct { IncludeUsage bool `json:"include_usage,omitempty"` // If true, include usage statistics in the final chunk } // MCPServerConfig represents an MCP server configuration // This mirrors agent/store/types.MCPServerConfig to avoid import cycles type MCPServerConfig struct { ServerID string `json:"server_id"` // MCP server ID (required) Tools []string `json:"tools,omitempty"` // Tool name filter (empty = all tools) Resources []string `json:"resources,omitempty"` // Resource URI filter (empty = all resources) } ================================================ FILE: agent/context/types_llm.go ================================================ package context import ( "github.com/yaoapp/gou/llm" "github.com/yaoapp/yao/agent/output/message" ) // Uses represents the wrapper configurations for assistant // Used to specify which assistant or MCP server to use for vision, audio, search, and fetch operations type Uses struct { Vision string `json:"vision,omitempty"` // Vision processing tool. Format: "agent" or "mcp:server_id" Audio string `json:"audio,omitempty"` // Audio processing tool. Format: "agent" or "mcp:server_id" Search string `json:"search,omitempty"` // Search tool. Format: "builtin", "disabled", "", "mcp:." Fetch string `json:"fetch,omitempty"` // Fetch/retrieval tool. Format: "agent" or "mcp:server_id" // Search-related processing tools (NLP) Web string `json:"web,omitempty"` // Web search handler: "builtin", "", "mcp:." Keyword string `json:"keyword,omitempty"` // Keyword extraction: "builtin", "", "mcp:." QueryDSL string `json:"querydsl,omitempty"` // QueryDSL generation: "builtin", "", "mcp:." Rerank string `json:"rerank,omitempty"` // Result reranking: "builtin", "", "mcp:." } // VisionFormat specifies the vision input format type VisionFormat string // Vision format constants define how image inputs are processed const ( // VisionFormatNone indicates no vision support VisionFormatNone VisionFormat = "" // VisionFormatOpenAI indicates OpenAI format (image_url with URL) VisionFormatOpenAI VisionFormat = "openai" // VisionFormatClaude indicates Claude/Anthropic format (image with base64) VisionFormatClaude VisionFormat = "claude" // VisionFormatBase64 forces base64 conversion (alias for claude) VisionFormatBase64 VisionFormat = "base64" // VisionFormatDefault enables auto-detection of format VisionFormatDefault VisionFormat = "default" ) // GetVisionSupport returns whether vision is supported and the format func GetVisionSupport(cap *llm.Capabilities) (bool, VisionFormat) { if cap == nil || cap.Vision == nil { return false, VisionFormatNone } switch v := cap.Vision.(type) { case bool: // Legacy bool format return v, VisionFormatDefault case string: // String format if v == "" || v == string(VisionFormatNone) { return false, VisionFormatNone } return true, VisionFormat(v) case VisionFormat: // Direct VisionFormat type if v == VisionFormatNone || v == "" { return false, VisionFormatNone } return true, v default: return false, VisionFormatNone } } // CompletionOptions the completion request options // These options are extracted from HookCreateResponse and Context, then passed to the LLM connector // Compatible with OpenAI Chat Completion API: https://platform.openai.com/docs/api-reference/chat/create type CompletionOptions struct { // Model capabilities (used by LLM to select appropriate provider) // nil means capabilities are not specified/checked Capabilities *llm.Capabilities `json:"capabilities,omitempty"` // User-specified tools for vision, audio, search, and fetch processing Uses *Uses `json:"uses,omitempty"` // ForceUses controls whether to force using Uses tools even when model has native capabilities // When true: Always use tools specified in Uses, ignore model's native multimodal capabilities // When false (default): Use model's native capabilities if available, fallback to Uses tools // This is useful when you want consistent behavior across different models or prefer specific tools ForceUses bool `json:"force_uses,omitempty"` // Audio configuration (for models that support audio output) Audio *AudioConfig `json:"audio,omitempty"` // Generation parameters Temperature *float64 `json:"temperature,omitempty"` // Sampling temperature (0-2), defaults to 1 MaxTokens *int `json:"max_tokens,omitempty"` // Maximum tokens to generate (deprecated, use MaxCompletionTokens) MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` // Maximum tokens in completion TopP *float64 `json:"top_p,omitempty"` // Nucleus sampling parameter (0-1), alternative to temperature N *int `json:"n,omitempty"` // Number of chat completion choices to generate // Control parameters Stop interface{} `json:"stop,omitempty"` // Up to 4 sequences where the API will stop generating (string or []string) PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Presence penalty (-2.0 to 2.0) FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Frequency penalty (-2.0 to 2.0) LogitBias map[string]float64 `json:"logit_bias,omitempty"` // Modify likelihood of specified tokens appearing // User and response format User string `json:"user,omitempty"` // Unique identifier representing end-user ResponseFormat *ResponseFormat `json:"response_format,omitempty"` // Format of the model's output Seed *int `json:"seed,omitempty"` // Seed for deterministic sampling // Tool calling Tools []map[string]interface{} `json:"tools,omitempty"` // List of tools the model may call ToolChoice interface{} `json:"tool_choice,omitempty"` // Controls which tool is called ("none", "auto", "required", or specific tool) // Streaming configuration Stream *bool `json:"stream,omitempty"` // If true, stream partial message deltas StreamOptions *StreamOptions `json:"stream_options,omitempty"` // Options for streaming response // Reasoning configuration (for reasoning models like o1, GPT-5) ReasoningEffort *string `json:"reasoning_effort,omitempty"` // Reasoning effort level: "low", "medium", "high" (o1 and GPT-5 only) // CUI Context information (from Context) Route string `json:"route,omitempty"` // Route of the request for CUI context Metadata map[string]interface{} `json:"metadata,omitempty"` // Metadata to pass to the page for CUI context } // CompletionResponse represents the unified LLM completion response // This is Yao's internal representation that works with multiple LLM providers (OpenAI, Claude, DeepSeek, etc.) type CompletionResponse struct { // Response metadata ID string `json:"id"` // Unique identifier for the completion Object string `json:"object"` // Object type (e.g., "chat.completion") Created int64 `json:"created"` // Unix timestamp of creation Model string `json:"model"` // Model used for completion // Response message (similar to OpenAI's message structure) Role string `json:"role"` // Role of the response, typically "assistant" Content interface{} `json:"content,omitempty"` // string (text) or []ContentPart (multimodal: text, image, audio) // Tool calls (when model calls functions/tools) ToolCalls []ToolCall `json:"tool_calls,omitempty"` // Tool calls made by the model // Refusal (when model refuses to respond due to policy) Refusal string `json:"refusal,omitempty"` // Refusal message if model refused to answer // Reasoning content (for reasoning models like o1, DeepSeek R1) ReasoningContent string `json:"reasoning_content,omitempty"` // Thinking/reasoning process // Completion metadata FinishReason string `json:"finish_reason"` // Why generation stopped (stop, length, tool_calls, content_filter, etc.) // Usage statistics Usage *message.UsageInfo `json:"usage,omitempty"` // Token usage statistics // Additional metadata SystemFingerprint string `json:"system_fingerprint,omitempty"` // System fingerprint for reproducibility Metadata map[string]interface{} `json:"metadata,omitempty"` // Additional provider-specific metadata // Raw response data (for debugging and special cases) Raw interface{} `json:"raw,omitempty"` // Original raw response from the LLM provider } // FinishReason constants - why the model stopped generating tokens const ( FinishReasonStop = "stop" // Natural stop point or provided stop sequence reached FinishReasonLength = "length" // Max tokens limit reached FinishReasonToolCalls = "tool_calls" // Model called a tool FinishReasonContentFilter = "content_filter" // Content filtered due to safety FinishReasonFunctionCall = "function_call" // Model called a function (deprecated, use tool_calls) ) // ResponseFormat specifies the format of the model's output // Reference: https://platform.openai.com/docs/api-reference/chat/create#chat_create-response_format type ResponseFormat struct { Type ResponseFormatType `json:"type"` // Required: type of response format JSONSchema *JSONSchema `json:"json_schema,omitempty"` // Optional: for type="json_schema", defines the schema } // ResponseFormatType represents the type of response format type ResponseFormatType string // Response format type constants const ( ResponseFormatText ResponseFormatType = "text" // Default text format ResponseFormatJSON ResponseFormatType = "json_object" // JSON object format (no schema) ResponseFormatJSONSchema ResponseFormatType = "json_schema" // JSON with strict schema validation ) // JSONSchema defines a JSON schema for structured output // Used when ResponseFormat.Type is "json_schema" type JSONSchema struct { Name string `json:"name"` // Required: name of the schema Description string `json:"description,omitempty"` // Optional: description of the schema Schema interface{} `json:"schema"` // Required: JSON schema (*jsonschema.Schema or map[string]interface{}) Strict *bool `json:"strict,omitempty"` // Optional: whether to enforce strict schema validation (default: true) } ================================================ FILE: agent/context/utils.go ================================================ package context // getValidatedValue gets value from query, header, or default, and validates it func getValidatedValue(queryValue, headerValue, defaultValue string, validator func(string) string) string { if queryValue != "" { return validator(queryValue) } if headerValue != "" { return validator(headerValue) } return defaultValue } // getValidatedAccept gets Accept from query, header, or parse from client type func getValidatedAccept(queryValue, headerValue, clientType string) Accept { if queryValue != "" { return validateAccept(queryValue) } if headerValue != "" { return validateAccept(headerValue) } return parseAccept(clientType) } // validateReferer validates and returns a valid Referer, returns RefererAPI if invalid func validateReferer(referer string) string { if ValidReferers[referer] { return referer } return RefererAPI } // validateAccept validates and returns a valid Accept type, returns AcceptStandard if invalid func validateAccept(accept string) Accept { if ValidAccepts[accept] { return Accept(accept) } return AcceptStandard } // parseAccept determines the accept type based on client type func parseAccept(clientType string) Accept { switch clientType { case "web": return AcceptWebCUI case "android", "ios": return AccepNativeCUI case "windows", "macos", "linux": return AcceptDesktopCUI default: return AcceptStandard } } ================================================ FILE: agent/docs/configuration.md ================================================ # Assistant Configuration ## Directory Structure ``` assistants/ └── / ├── package.yao # Required: Configuration ├── prompts.yml # Optional: Default prompts ├── prompts/ # Optional: Prompt presets │ ├── chat.yml │ └── task.yml ├── locales/ # Optional: Translations │ ├── en-us.yml │ └── zh-cn.yml ├── src/ # Optional: Hook scripts │ └── index.ts └── mcps/ # Optional: MCP servers └── tools.mcp.yao ``` ## package.yao ### Basic Fields ```json { "name": "{{ name }}", "type": "assistant", "avatar": "/assets/avatar.png", "description": "{{ description }}", "connector": "gpt-4o", "tags": ["Category1", "Category2"], "sort": 1 } ``` | Field | Type | Description | | ------------- | -------- | ------------------------------------ | | `name` | string | Display name (supports i18n `{{ }}`) | | `type` | string | Type: `assistant` (default) | | `avatar` | string | Avatar image path | | `description` | string | Description (supports i18n) | | `connector` | string | LLM connector ID | | `tags` | string[] | Categorization tags | | `sort` | number | Display order | ### Connector Options ```json { "connector": "gpt-4o", "connector_options": { "optional": true, "connectors": ["gpt-4o", "gpt-4o-mini", "claude-3"], "filters": ["tool_calls", "vision"] } } ``` | Field | Type | Description | | ------------ | -------- | -------------------------------------------- | | `optional` | boolean | Allow user to select connector | | `connectors` | string[] | Available connectors (empty = all) | | `filters` | string[] | Required capabilities: `vision`, `audio`, `tool_calls`, `reasoning` | ### Generation Options ```json { "options": { "temperature": 0.7, "max_tokens": 4096 } } ``` ### Placeholder (UI Hints) ```json { "placeholder": { "title": "{{ chat.title }}", "description": "{{ chat.description }}", "prompts": [ "{{ chat.prompts.0 }}", "{{ chat.prompts.1 }}" ] } } ``` ### Visibility & Access ```json { "public": true, "share": "team", "readonly": true, "built_in": true, "mentionable": true, "automated": false } ``` | Field | Type | Description | | ------------- | ------- | --------------------------------- | | `public` | boolean | Visible to all users | | `share` | string | Sharing scope: `private`, `team` | | `readonly` | boolean | Prevent user modifications | | `built_in` | boolean | System-managed assistant | | `mentionable` | boolean | Can be @mentioned in chat | | `automated` | boolean | Can be triggered automatically | ### Modes ```json { "modes": ["chat", "task"], "default_mode": "task" } ``` ### MCP Servers ```json { "mcp": { "servers": [ "server-id", { "server_id": "tools", "tools": ["tool1", "tool2"] }, { "server_id": "resources", "resources": ["uri://pattern"] } ] } } ``` ### Knowledge Base ```json { "kb": { "collections": ["collection-id-1", "collection-id-2"] } } ``` ### Database Models ```json { "db": { "models": ["model.name", "another.model"] } } ``` ### Uses (Wrapper Tools) ```json { "uses": { "vision": "vision-agent", "audio": "audio-agent", "search": "disabled", "fetch": "mcp:fetcher" } } ``` | Field | Description | | -------- | -------------------------------------------------- | | `vision` | Vision processing: `` or `mcp:` | | `audio` | Audio processing: `` or `mcp:` | | `search` | Search: `disabled`, ``, or `mcp:`| | `fetch` | HTTP fetching: `` or `mcp:` | ### Search Configuration ```json { "search": { "web": { "provider": "tavily", "max_results": 10 }, "kb": { "threshold": 0.7, "graph": true }, "db": { "max_results": 20 }, "citation": { "format": "[{index}]", "auto_inject_prompt": true } } } ``` ## Environment Variables Use `$ENV.VAR_NAME` for sensitive values: ```json { "connector": "$ENV.LLM_CONNECTOR" } ``` ## Complete Example ```json { "name": "{{ name }}", "type": "assistant", "avatar": "/assets/assistant.png", "connector": "gpt-4o", "connector_options": { "optional": true, "connectors": ["gpt-4o", "gpt-4o-mini"], "filters": ["tool_calls"] }, "mcp": { "servers": [{ "server_id": "tools", "tools": ["search", "calculate"] }] }, "description": "{{ description }}", "options": { "temperature": 0.7 }, "public": true, "placeholder": { "title": "{{ chat.title }}", "description": "{{ chat.description }}", "prompts": ["{{ chat.prompts.0 }}", "{{ chat.prompts.1 }}"] }, "tags": ["Productivity"], "modes": ["chat", "task"], "default_mode": "chat", "sort": 1, "readonly": true, "mentionable": true } ``` ================================================ FILE: agent/docs/context-api.md ================================================ # Context API The `ctx` object provides access to messaging, memory, tracing, and MCP operations. ## Properties ```typescript interface Context { chat_id: string; // Chat session ID assistant_id: string; // Assistant ID locale: string; // User locale (e.g., "en-us") theme: string; // UI theme route: string; // Request route referer: string; // Request source metadata: Record; // Custom metadata authorized: Record; // Auth info memory: Memory; // Memory namespaces trace: Trace; // Tracing API mcp: MCP; // MCP operations search: Search; // Search API agent: Agent; // Agent-to-Agent calls (A2A) llm: LLM; // Direct LLM calls sandbox?: Sandbox; // Sandbox operations (optional) } ``` ## Messaging ### Send Complete Message ```typescript ctx.Send({ type: "text", props: { content: "Hello!" } }); ctx.Send("Hello!"); // Shorthand for text ``` ### Streaming Messages ```typescript const msgId = ctx.SendStream("Starting..."); ctx.Append(msgId, " processing..."); ctx.Append(msgId, " done!"); ctx.End(msgId); ``` ### Update Streaming Message ```typescript const msgId = ctx.SendStream({ type: "loading", props: { message: "Loading..." } }); // ... do work ... ctx.Replace(msgId, { type: "text", props: { content: "Complete!" } }); ctx.End(msgId); ``` ### Merge Data ```typescript const msgId = ctx.SendStream({ type: "status", props: { progress: 0 } }); ctx.Merge(msgId, { progress: 50 }, "props"); ctx.Merge(msgId, { progress: 100, status: "done" }, "props"); ctx.End(msgId); ``` ### Set Field ```typescript const msgId = ctx.SendStream({ type: "result", props: {} }); ctx.Set(msgId, "success", "props.status"); ctx.Set(msgId, { count: 10 }, "props.data"); ctx.End(msgId); ``` ### Block Grouping ```typescript const blockId = ctx.BlockID(); ctx.Send("Step 1", blockId); ctx.Send("Step 2", blockId); ctx.Send("Step 3", blockId); ctx.EndBlock(blockId); ``` ### ID Generators ```typescript const msgId = ctx.MessageID(); // "M1", "M2", ... const blockId = ctx.BlockID(); // "B1", "B2", ... const threadId = ctx.ThreadID(); // "T1", "T2", ... ``` ## Memory Four-level hierarchical memory system: | Namespace | Scope | Persistence | | -------------------- | ------------ | ----------- | | `ctx.memory.user` | Per user | Persistent | | `ctx.memory.team` | Per team | Persistent | | `ctx.memory.chat` | Per chat | Persistent | | `ctx.memory.context` | Per request | Temporary | ### Basic Operations ```typescript // Get/Set ctx.memory.user.Set("theme", "dark"); const theme = ctx.memory.user.Get("theme"); // With TTL (seconds) ctx.memory.context.Set("temp", data, 300); // Check/Delete if (ctx.memory.chat.Has("topic")) { ctx.memory.chat.Del("topic"); } // Get and delete atomically const token = ctx.memory.context.GetDel("one_time_token"); // Collection operations const keys = ctx.memory.user.Keys(); const count = ctx.memory.chat.Len(); ctx.memory.context.Clear(); ``` ### Counters ```typescript const views = ctx.memory.user.Incr("page_views"); const credits = ctx.memory.user.Decr("credits", 5); ``` ### Lists ```typescript ctx.memory.chat.Push("history", [msg1, msg2]); const last = ctx.memory.chat.Pop("queue"); const items = ctx.memory.chat.Pull("queue", 5); const all = ctx.memory.chat.PullAll("queue"); ``` ### Sets ```typescript ctx.memory.user.AddToSet("visited", ["/home", "/about"]); ``` ### Array Access ```typescript const len = ctx.memory.chat.ArrayLen("messages"); const first = ctx.memory.chat.ArrayGet("messages", 0); const last = ctx.memory.chat.ArrayGet("messages", -1); ctx.memory.chat.ArraySet("messages", 0, newMsg); const slice = ctx.memory.chat.ArraySlice("messages", -10, -1); const page = ctx.memory.chat.ArrayPage("messages", 1, 20); const all = ctx.memory.chat.ArrayAll("messages"); ``` ## Trace ### Create Nodes ```typescript const node = ctx.trace.Add( { query: "input data" }, { label: "Processing", type: "process", icon: "play", description: "Processing user request" } ); ``` ### Logging ```typescript ctx.trace.Info("Starting process"); ctx.trace.Debug("Variable: " + value); ctx.trace.Warn("Deprecated feature"); ctx.trace.Error("Operation failed"); // Or on node node.Info("Step completed"); ``` ### Node Lifecycle ```typescript node.SetOutput({ result: data }); node.SetMetadata("duration", 1500); node.Complete({ status: "done" }); // or node.Fail("Error message"); ``` ### Parallel Nodes ```typescript const nodes = ctx.trace.Parallel([ { input: { url: "api1" }, option: { label: "API 1" } }, { input: { url: "api2" }, option: { label: "API 2" } } ]); ``` ### Child Nodes ```typescript const parent = ctx.trace.Add({}, { label: "Parent" }); const child = parent.Add({}, { label: "Child" }); ``` ## MCP ### Tools ```typescript // List tools const tools = ctx.mcp.ListTools("server-id"); // Call single tool - returns parsed result directly const result = ctx.mcp.CallTool("server-id", "tool-name", { arg: "value" }); console.log(result.field); // Direct access to parsed data // Call multiple sequentially - returns array of parsed results const results = ctx.mcp.CallTools("server-id", [ { name: "tool1", arguments: { a: 1 } }, { name: "tool2", arguments: { b: 2 } } ]); results.forEach(r => console.log(r)); // Call multiple in parallel - returns array of parsed results const results = ctx.mcp.CallToolsParallel("server-id", [ { name: "tool1", arguments: {} }, { name: "tool2", arguments: {} } ]); results.forEach(r => console.log(r)); ``` ### Cross-Server Tool Calls ```typescript // Call tools across multiple MCP servers (like Promise.all) const results = ctx.mcp.All([ { mcp: "server1", tool: "search", arguments: { q: "query" } }, { mcp: "server2", tool: "fetch", arguments: { id: 123 } } ]); // First success wins (like Promise.any) const results = ctx.mcp.Any([ { mcp: "primary", tool: "search", arguments: { q: "query" } }, { mcp: "backup", tool: "search", arguments: { q: "query" } } ]); // First complete wins (like Promise.race) const results = ctx.mcp.Race([ { mcp: "region-us", tool: "ping", arguments: {} }, { mcp: "region-eu", tool: "ping", arguments: {} } ]); // Result structure interface MCPToolResult { mcp: string; // Server ID tool: string; // Tool name result?: any; // Parsed result content error?: string; // Error if failed } ``` ### Resources ```typescript const resources = ctx.mcp.ListResources("server-id"); const data = ctx.mcp.ReadResource("server-id", "resource://uri"); ``` ### Prompts ```typescript const prompts = ctx.mcp.ListPrompts("server-id"); const prompt = ctx.mcp.GetPrompt("server-id", "prompt-name", { arg: "value" }); ``` ## Search ### Single Search ```typescript // Web search const webResult = ctx.search.Web("query", { limit: 10, sites: ["example.com"], time_range: "week" }); // Knowledge base const kbResult = ctx.search.KB("query", { collections: ["docs"], threshold: 0.7, graph: true }); // Database const dbResult = ctx.search.DB("query", { models: ["model.name"], wheres: [{ column: "status", value: "active" }], limit: 20 }); ``` ### Parallel Search ```typescript // Wait for all const results = ctx.search.All([ { type: "web", query: "topic" }, { type: "kb", query: "topic", collections: ["docs"] } ]); // First success const results = ctx.search.Any([ { type: "web", query: "topic" }, { type: "kb", query: "topic" } ]); // First complete const results = ctx.search.Race([ { type: "web", query: "topic" }, { type: "kb", query: "topic" } ]); ``` ### Result Structure ```typescript interface SearchResult { type: "web" | "kb" | "db"; query: string; source: "hook" | "auto" | "user"; items: { citation_id: string; title: string; url: string; content: string; score: number; }[]; error?: string; } ``` ## Agent API The `ctx.agent` object provides methods to call other agents from within hooks, enabling agent-to-agent communication (A2A). ### Single Agent Call ```typescript // Basic call const result = ctx.agent.Call("assistant-id", messages); // With options and callback const result = ctx.agent.Call("assistant-id", messages, { connector: "gpt-4o", mode: "chat", metadata: { source: "hook" }, skip: { history: false, trace: false, output: false }, onChunk: (msg) => { console.log("Received:", msg.type, msg.props); return 0; // 0 = continue, non-zero = stop } }); ``` ### Agent Options ```typescript interface AgentCallOptions { connector?: string; // Override LLM connector mode?: string; // Agent mode ("chat", "task") metadata?: Record; // Custom metadata passed to hooks skip?: { history?: boolean; // Skip loading chat history trace?: boolean; // Skip trace recording output?: boolean; // Skip output to client keyword?: boolean; // Skip keyword extraction search?: boolean; // Skip search content_parsing?: boolean; // Skip content parsing }; onChunk?: (msg: Message) => number; // Callback (0=continue, non-zero=stop) } ``` ### Parallel Agent Calls ```typescript // Wait for all agents to complete (like Promise.all) const results = ctx.agent.All([ { agent: "agent-1", messages: [...] }, { agent: "agent-2", messages: [...] } ]); // Return first successful result (like Promise.any) const results = ctx.agent.Any([ { agent: "agent-1", messages: [...] }, { agent: "agent-2", messages: [...] } ]); // Return first completed result (like Promise.race) const results = ctx.agent.Race([ { agent: "agent-1", messages: [...] }, { agent: "agent-2", messages: [...] } ]); // With global callback for all responses const results = ctx.agent.All([ { agent: "agent-1", messages: [...] }, { agent: "agent-2", messages: [...] } ], { onChunk: (agentId, index, msg) => { console.log(`Agent ${agentId} [${index}]:`, msg.type); return 0; } }); ``` ### Result Structure ```typescript interface AgentResult { agent_id: string; response?: Response; content?: string; error?: string; } ``` ### Message Object (onChunk callback) ```typescript interface Message { type: string; // "text", "thinking", "tool_call", "error" props?: Record; // e.g., { content: "Hello" } chunk_id?: string; // C1, C2, ... message_id?: string; // M1, M2, ... delta?: boolean; // Incremental update flag } ``` ## Sandbox API The `ctx.sandbox` object provides access to sandbox operations when the assistant is configured with a sandbox executor (e.g., Claude CLI). Only available when `sandbox` is configured in `package.yao`. ### Properties ```typescript ctx.sandbox.workdir // Workspace directory path (e.g., "/workspace") ``` ### File Operations ```typescript // Read file const content = ctx.sandbox.ReadFile("config.json"); // Write file ctx.sandbox.WriteFile("output.txt", "Hello World"); // List directory const files = ctx.sandbox.ListDir("src"); files.forEach(f => console.log(f.name, f.is_dir, f.size)); ``` ### Command Execution ```typescript // Execute command (returns stdout) const output = ctx.sandbox.Exec(["npm", "test"]); // Handle errors try { ctx.sandbox.Exec(["git", "commit", "-m", "fix"]); } catch (e) { console.error("Command failed:", e.message); } ``` ### FileInfo Structure ```typescript interface FileInfo { name: string; // File/directory name size: number; // Size in bytes is_dir: boolean; // True if directory } ``` ### Use Cases ```typescript // Prepare workspace before execution function Create(ctx, messages) { if (ctx.sandbox) { ctx.sandbox.WriteFile("config.json", JSON.stringify({ debug: true })); } return { messages }; } // Post-process results function Next(ctx, payload) { if (ctx.sandbox && !payload.error) { const files = ctx.sandbox.ListDir("output"); return { data: { generated: files.map(f => f.name) } }; } return null; } ``` ## LLM API The `ctx.llm` object provides direct access to LLM connectors for streaming completions. ### Single LLM Call ```typescript // Basic streaming call const result = ctx.llm.Stream("gpt-4o", [ { role: "user", content: "Hello" } ]); // With options and callback const result = ctx.llm.Stream("gpt-4o", messages, { temperature: 0.7, max_tokens: 2000, onChunk: (msg) => { console.log("Chunk:", msg.props?.content); return 0; } }); ``` ### Parallel LLM Calls ```typescript // Wait for all LLM calls (like Promise.all) const results = ctx.llm.All([ { connector: "gpt-4o", messages: [...] }, { connector: "claude-3", messages: [...] } ]); // Return first successful result (like Promise.any) const results = ctx.llm.Any([ { connector: "gpt-4o", messages: [...] }, { connector: "claude-3", messages: [...] } ]); // Return first completed result (like Promise.race) const results = ctx.llm.Race([ { connector: "gpt-4o", messages: [...] }, { connector: "claude-3", messages: [...] } ]); // With global callback const results = ctx.llm.All([ { connector: "gpt-4o", messages: [...] }, { connector: "claude-3", messages: [...] } ], { onChunk: (connectorId, index, msg) => { console.log(`LLM ${connectorId} [${index}]:`, msg.type); return 0; } }); ``` ### LLM Options ```typescript interface LlmOptions { temperature?: number; max_tokens?: number; max_completion_tokens?: number; top_p?: number; presence_penalty?: number; frequency_penalty?: number; stop?: string | string[]; user?: string; seed?: number; tools?: object[]; tool_choice?: string | object; response_format?: { type: string; json_schema?: object }; reasoning_effort?: string; onChunk?: (msg: Message) => number; } ``` ### Result Structure ```typescript interface LlmResult { connector: string; response?: CompletionResponse; content?: string; error?: string; } ================================================ FILE: agent/docs/hooks.md ================================================ # Hooks Hooks allow you to customize agent behavior at key points in the execution lifecycle. ## Lifecycle ``` User Input → Create Hook → LLM Call → Tool Execution → Next Hook → Response ``` ## Create Hook Called before LLM call. Use to preprocess messages, configure request, or delegate. ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { // Return null for default behavior return null; // Or return configuration return { messages, // Modified messages temperature: 0.7, // Override temperature max_tokens: 2000, // Override max tokens connector: "gpt-4o-mini", // Override connector prompt_preset: "task", // Select prompt preset disable_global_prompts: true,// Skip global prompts mcp_servers: [ // Add MCP servers { server_id: "tools", tools: ["search"] } ], uses: { // Override wrapper tools vision: "vision-agent", search: "disabled" }, force_uses: true, // Force use wrapper tools locale: "zh-cn", // Override locale metadata: { key: "value" }, // Pass data to context }; } ``` ### Delegation (Skip LLM) Route to another agent immediately: ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { if (shouldDelegate(messages)) { return { delegate: { agent_id: "specialist.agent", messages: messages, options: { metadata: { source: "main" } } } }; } return { messages }; } ``` ## Next Hook Called after LLM response and tool execution. Use to post-process or delegate. ```typescript function Next(ctx: agent.Context, payload: agent.Payload): agent.Next { const { messages, completion, tools, error } = payload; // Handle errors if (error) { return { data: { status: "error", message: error } }; } // Process tool results if (tools?.length > 0) { const results = tools.map(t => t.result); return { data: { status: "success", results } }; } // Delegate based on response if (completion?.content?.includes("transfer")) { return { delegate: { agent_id: "transfer.agent", messages: payload.messages } }; } // Return null for standard response return null; } ``` ### Payload Structure ```typescript interface Payload { messages: Message[]; // Messages sent to LLM completion?: { content: string; // LLM text response tool_calls?: ToolCall[]; // Tool calls from LLM usage?: UsageInfo; // Token usage }; tools?: ToolCallResponse[]; // Tool execution results error?: string; // Error message } interface ToolCallResponse { toolcall_id: string; server: string; // MCP server ID tool: string; // Tool name arguments?: any; // Tool arguments result?: any; // Tool result error?: string; // Tool error } ``` ### Return Values ```typescript interface NextResponse { delegate?: { // Route to another agent agent_id: string; messages: Message[]; options?: Record; }; data?: any; // Custom response data metadata?: Record;// Debug metadata } ``` ## Sending Messages Use `ctx` to send messages to the client: ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { // Send complete message ctx.Send({ type: "text", props: { content: "Processing..." } }); // Streaming message const msgId = ctx.SendStream("Starting..."); ctx.Append(msgId, " step 1..."); ctx.Append(msgId, " step 2..."); ctx.End(msgId); return { messages }; } ``` ## Memory Share data between hooks: ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { // Store in request-scoped memory ctx.memory.context.Set("start_time", Date.now()); ctx.memory.context.Set("query", messages[0]?.content); return { messages }; } function Next(ctx: agent.Context, payload: agent.Payload): agent.Next { // Retrieve data const startTime = ctx.memory.context.Get("start_time"); const duration = Date.now() - startTime; return { data: { duration_ms: duration } }; } ``` ## Tracing Add trace nodes for debugging and UI: ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { const node = ctx.trace.Add( { query: messages[0]?.content }, { label: "Preprocessing", type: "process", icon: "play" } ); node.Info("Starting analysis"); // ... processing ... node.Complete({ status: "done" }); return { messages }; } ``` ## Error Handling ```typescript function Next(ctx: agent.Context, payload: agent.Payload): agent.Next { try { if (payload.error) { ctx.trace.Error(payload.error); return { data: { status: "error", message: "Something went wrong" } }; } // ... normal processing } catch (e) { ctx.trace.Error(e.message); return { data: { status: "error", message: e.message } }; } } ``` ## Multi-Agent Orchestration ```typescript // Main agent delegates based on intent function Next(ctx: agent.Context, payload: agent.Payload): agent.Next { const { tools } = payload; // Route based on tool result const intent = tools?.[0]?.result?.intent; const agentMap = { "search": "search.agent", "calculate": "calc.agent", "translate": "translate.agent" }; if (intent && agentMap[intent]) { return { delegate: { agent_id: agentMap[intent], messages: payload.messages } }; } return null; } ``` ## Complete Example ```typescript import { agent } from "@yao/runtime"; function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { const query = messages[messages.length - 1]?.content || ""; // Store for Next hook ctx.memory.context.Set("query", query); ctx.memory.context.Set("start", Date.now()); // Add trace ctx.trace.Add({ query }, { label: "Create", type: "hook" }); // Check if needs special handling if (query.toLowerCase().includes("urgent")) { return { messages, temperature: 0, prompt_preset: "task" }; } return { messages }; } function Next(ctx: agent.Context, payload: agent.Payload): agent.Next { const { completion, tools, error } = payload; const start = ctx.memory.context.Get("start"); const duration = Date.now() - start; ctx.trace.Add( { duration }, { label: "Next", type: "hook" } ).Complete(); if (error) { return { data: { status: "error", error } }; } if (tools?.length > 0) { return { data: { status: "success", response: completion?.content, tools: tools.map(t => ({ name: t.tool, result: t.result })), duration_ms: duration } }; } return null; } ``` ================================================ FILE: agent/docs/i18n.md ================================================ # Internationalization (i18n) ## Locale Files Create `locales/` directory in the assistant: ``` assistants/my-assistant/ └── locales/ ├── en-us.yml ├── zh-cn.yml └── ja.yml ``` ## Locale File Format ```yaml # locales/en-us.yml name: My Assistant description: A helpful AI assistant chat: title: New Chat description: How can I help you today? prompts: - What can you do? - Help me with a task - Tell me about yourself messages: welcome: Welcome back! error: Something went wrong processing: Processing your request... ``` ```yaml # locales/zh-cn.yml name: 我的助手 description: 一个有帮助的AI助手 chat: title: 新对话 description: 今天我能帮您什么? prompts: - 你能做什么? - 帮我完成一个任务 - 介绍一下你自己 messages: welcome: 欢迎回来! error: 出了点问题 processing: 正在处理您的请求... ``` ## Using Translations ### In package.yao Use `{{ key }}` syntax: ```json { "name": "{{ name }}", "description": "{{ description }}", "placeholder": { "title": "{{ chat.title }}", "description": "{{ chat.description }}", "prompts": [ "{{ chat.prompts.0 }}", "{{ chat.prompts.1 }}", "{{ chat.prompts.2 }}" ] } } ``` ### In Prompts ```yaml - role: system content: | You are {{ name }}. {{ description }} Respond in the user's language. ``` ## Locale Detection The system detects locale from: 1. Request header `Accept-Language` 2. User preference (stored in memory) 3. Default: `en-us` ### Override in Hook ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { // Get user preference const userLocale = ctx.memory.user.Get("preferred_locale"); return { messages, locale: userLocale || "en-us" }; } ``` ## Global Translations Define global translations in `agent/locales/`: ``` agent/ └── locales/ ├── en-us.yml └── zh-cn.yml ``` These are available to all assistants via the `__global__` namespace. ## Accessing Translations in Hooks ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { const locale = ctx.locale; // e.g., "en-us" // Use locale for custom logic if (locale.startsWith("zh")) { return { messages, prompt_preset: "chinese" }; } return { messages }; } ``` ## Nested Keys Access nested values with dot notation: ```yaml # locales/en-us.yml errors: validation: required: This field is required invalid: Invalid value network: timeout: Connection timed out ``` ```json { "placeholder": { "title": "{{ errors.validation.required }}" } } ``` ## Fallback Behavior If a translation key is not found: 1. Try the requested locale 2. Fall back to `en-us` 3. Return the key itself if not found ## Dynamic Locale Content ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { // Add locale-specific system message const localeGreeting = { "en-us": "Hello! How can I help you?", "zh-cn": "你好!有什么可以帮您的?", "ja": "こんにちは!何かお手伝いできますか?" }; const greeting = localeGreeting[ctx.locale] || localeGreeting["en-us"]; return { messages: [ { role: "system", content: `Greeting: ${greeting}` }, ...messages ] }; } ``` ## Best Practices 1. **Keep keys consistent** - Use the same keys across all locale files 2. **Use nested structure** - Organize related translations together 3. **Provide fallbacks** - Always have `en-us` as the base locale 4. **Test all locales** - Verify translations render correctly 5. **Use context variables** - Combine with `$CTX.locale` in prompts ================================================ FILE: agent/docs/iframe.md ================================================ # Iframe Integration Agent Pages can be embedded in CUI via `/web/` routes. This document covers the iframe communication mechanism between embedded pages and the CUI host. ## Route Mapping Pages are accessible via: ``` /web// ``` Example: | Page File | URL | | -------------------------- | --------------------------------- | | `pages/index/index.html` | `/web/my-assistant/index` | | `pages/result/index.html` | `/web/my-assistant/result` | | `pages/report/detail.html` | `/web/my-assistant/report/detail` | ## URL Parameters CUI automatically injects context via URL parameters: | Parameter | Value | Description | | ---------- | ---------------------- | ------------- | | `__theme` | `light` / `dark` | Current theme | | `__locale` | `en-us`, `zh-cn`, etc. | User locale | > **Note**: Authentication uses secure HTTP-only cookies, so `__token` parameter is not needed. **Usage in page URL:** ``` /web/my-assistant/result?theme=__theme&locale=__locale ``` CUI replaces `__theme`, `__locale` with actual values before loading. ## Message Communication ### Receiving Setup Message When the iframe loads, CUI sends a `setup` message: ```typescript // In your page script window.addEventListener("message", (e) => { if (e.data.type === "setup") { const { theme, locale } = e.data.message; // Apply theme, set locale document.documentElement.setAttribute("data-theme", theme); } }); ``` ### Sending Actions to CUI Pages can trigger CUI actions via `postMessage` using the unified Action system: ```typescript // Send action to parent CUI window.parent.postMessage( { type: "action", message: { name: "notify.success", payload: { message: "Operation completed" }, }, }, window.location.origin ); ``` ### Action Types #### Navigate | Action | Description | Payload | | --------------- | ------------------------------- | ------------------------------------------- | | `navigate` | Open page in sidebar or new tab | `{ route, title?, icon?, query?, target? }` | | `navigate.back` | Navigate back in history | - | **Navigate Payload:** | Field | Type | Required | Description | | -------- | ------------------------ | -------- | ----------------------------------------------- | | `route` | `string` | ✅ | Target route (`$dashboard/xxx`, `/xxx`, or URL) | | `title` | `string` | - | Page title (shows title bar with back button) | | `icon` | `string` | - | Tab icon (e.g., `material-folder`) | | `query` | `Record` | - | Query parameters | | `target` | `'_self'` \| `'_blank'` | - | `_self` (sidebar) or `_blank` (new window) | #### Notify | Action | Description | Payload | | ---------------- | ------------------------- | ------------------------------------------ | | `notify.success` | Show success notification | `{ message, duration?, icon?, closable? }` | | `notify.error` | Show error notification | `{ message, duration?, icon?, closable? }` | | `notify.warning` | Show warning notification | `{ message, duration?, icon?, closable? }` | | `notify.info` | Show info notification | `{ message, duration?, icon?, closable? }` | #### App | Action | Description | | ----------------- | ------------------------ | | `app.menu.reload` | Refresh application menu | #### Modal | Action | Description | | ------------- | ----------------- | | `modal.open` | Open modal dialog | | `modal.close` | Close modal | #### Table | Action | Description | | --------------- | -------------------- | | `table.search` | Trigger table search | | `table.refresh` | Refresh table data | | `table.save` | Save table row | | `table.delete` | Delete table row(s) | #### Form | Action | Description | | ----------------- | --------------------- | | `form.find` | Load form data by ID | | `form.submit` | Submit form | | `form.reset` | Reset form | | `form.setFields` | Set form field values | | `form.fullscreen` | Toggle fullscreen | #### MCP (Client-side) | Action | Description | | ------------------- | ------------------ | | `mcp.tool.call` | Execute MCP tool | | `mcp.resource.read` | Read MCP resource | | `mcp.resource.list` | List MCP resources | | `mcp.prompt.get` | Get MCP prompt | | `mcp.prompt.list` | List MCP prompts | #### Event | Action | Description | | ------------ | ----------------- | | `event.emit` | Emit custom event | #### Confirm | Action | Description | | --------- | ------------------------ | | `confirm` | Show confirmation dialog | ### Receiving Events from CUI CUI can send messages to iframe via `web/sendMessage` event: ```typescript // In your page script window.addEventListener("message", (e) => { const { type, message } = e.data; switch (type) { case "setup": // Initial setup with theme, locale break; case "refresh": // CUI requests page refresh location.reload(); break; case "data": // CUI sends data update handleDataUpdate(message); break; } }); ``` ## Complete Example ### Page HTML (pages/result/index.html) ```html Result Page
``` ### Page Script (pages/result/result.ts) ```typescript import { $Backend, Component, EventData } from "@yao/sui"; const self = this as Component; // Helper: Send action to CUI parent const sendAction = (name: string, payload?: any) => { try { window.parent.postMessage( { type: "action", message: { name, payload } }, window.location.origin ); } catch (err) { console.error("Failed to send action to parent:", err); } }; // Initialize message listener function init() { window.addEventListener("message", (e) => { if (e.origin !== window.location.origin) return; const { type, message } = e.data; switch (type) { case "setup": // Apply theme, locale from CUI document.documentElement.setAttribute("data-theme", message.theme); break; case "update": // Handle data updates from CUI console.log("Received update:", message); break; } }); // Make helper available globally (window as any).sendAction = sendAction; } init(); // Event handler: Show success notification self.HandleSuccess = (event: Event, data: EventData) => { sendAction("notify.success", { message: data.message || "Success!" }); }; // Event handler: Navigate to page self.HandleNavigate = (event: Event, data: EventData) => { sendAction("navigate", { route: data.path, title: data.title, }); }; // Event handler: Close sidebar self.HandleClose = () => { sendAction("event.emit", { key: "app/closeSidebar", value: {} }); }; // Event handler: Call backend and display result self.HandleQuery = async (event: Event, data: EventData) => { try { const result = await $Backend().Call("Query", data.id); console.log(result); } catch (error: any) { sendAction("notify.error", { message: error.message }); } }; ``` ## Triggering from Hooks Open page in sidebar from agent hooks: ```typescript function Next(ctx: agent.Context, payload: agent.Payload): agent.Next { // Open result page in sidebar ctx.Send({ type: "action", props: { name: "navigate", payload: { route: `/agents/my-assistant/result`, title: "Results", query: { id: resultId }, }, }, }); return null; } ``` See [Pages](pages.md) for more details on triggering pages from hooks. ## Security Notes 1. **Same-origin only**: Messages are only processed from same-origin iframes 2. **Secure cookies**: Authentication uses HTTP-only cookies, no token in URL 3. **Validate messages**: Always validate message structure before processing ================================================ FILE: agent/docs/mcp.md ================================================ # MCP Integration Model Context Protocol (MCP) enables tool integration with external services. ## Directory Structure Assistants can define their own namespaced MCP servers in the `mcps/` directory: ``` assistants/ └── my-assistant/ ├── package.yao └── mcps/ ├── tools.mcp.yao # → agents.my-assistant.tools ├── calculator.mcp.yao # → agents.my-assistant.calculator └── mapping/ └── tools/ └── schemes/ ├── search.in.yao └── search.out.yao ``` MCP servers are automatically loaded with `agents..` prefix. ## Defining MCP Servers Create `mcps/tools.mcp.yao` in the assistant directory: ```json { "label": "Tools", "description": "Custom tools for the assistant", "transport": "process", "tools": { "search": "scripts.tools.Search", "create": "models.data.Create" } } ``` ### Transport Types **Process (Yao Internal)** Map Yao Processes directly to MCP tools: ```json { "transport": "process", "tools": { "search": "models.data.Paginate", "create": "models.data.Create" }, "resources": { "detail": "models.data.Find" } } ``` **STDIO (Local Server)** ```json { "transport": "stdio", "command": "python", "arguments": ["mcp_server.py"], "env": { "API_KEY": "$ENV.API_KEY" } } ``` **HTTP (REST API)** ```json { "transport": "http", "url": "https://mcp.example.com/api", "authorization_token": "$ENV.TOKEN" } ``` **SSE (Server-Sent Events)** ```json { "transport": "sse", "url": "https://mcp.example.com/events", "authorization_token": "$ENV.TOKEN" } ``` ## Configuring in package.yao ### All Tools ```json { "mcp": { "servers": ["tools"] } } ``` ### Specific Tools ```json { "mcp": { "servers": [{ "server_id": "tools", "tools": ["search", "calculate"] }] } } ``` ### With Resources ```json { "mcp": { "servers": [ { "server_id": "data", "tools": ["query"], "resources": ["data://users/*"] } ] } } ``` ## Dynamic Configuration in Hooks ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { return { messages, mcp_servers: [ { server_id: "tools", tools: ["search"] }, { server_id: "data", resources: ["data://reports"] }, ], }; } ``` ## Using MCP in Hooks ### List Available Tools ```typescript const tools = ctx.mcp.ListTools("server-id"); // { tools: [{ name: "search", description: "...", inputSchema: {...} }] } ``` ### Call Tool ```typescript // Returns parsed result directly - no wrapper object const result = ctx.mcp.CallTool("server-id", "search", { query: "example", limit: 10, }); console.log(result.items); // Direct access to parsed data ``` ### Batch Tool Calls ```typescript // Sequential - returns array of parsed results const results = ctx.mcp.CallTools("server-id", [ { name: "step1", arguments: { input: "a" } }, { name: "step2", arguments: { input: "b" } }, ]); results.forEach(r => console.log(r)); // Parallel - returns array of parsed results const results = ctx.mcp.CallToolsParallel("server-id", [ { name: "api1", arguments: {} }, { name: "api2", arguments: {} }, ]); results.forEach(r => console.log(r)); ``` ### Read Resources ```typescript const resources = ctx.mcp.ListResources("server-id"); const data = ctx.mcp.ReadResource("server-id", "data://users/123"); ``` ### Get Prompts ```typescript const prompts = ctx.mcp.ListPrompts("server-id"); const prompt = ctx.mcp.GetPrompt("server-id", "system", { role: "helper" }); ``` ### Cross-Server Tool Calls Call tools across multiple MCP servers concurrently: ```typescript // Wait for all (like Promise.all) const results = ctx.mcp.All([ { mcp: "server1", tool: "search", arguments: { q: "query" } }, { mcp: "server2", tool: "analyze", arguments: { data: "input" } } ]); // First success (like Promise.any) - good for fallback const results = ctx.mcp.Any([ { mcp: "primary", tool: "fetch", arguments: { id: 1 } }, { mcp: "backup", tool: "fetch", arguments: { id: 1 } } ]); // First complete (like Promise.race) - good for latency const results = ctx.mcp.Race([ { mcp: "region-us", tool: "ping", arguments: {} }, { mcp: "region-eu", tool: "ping", arguments: {} } ]); // Access results results.forEach(r => { if (r.error) { console.log(`${r.mcp}/${r.tool} failed: ${r.error}`); } else { console.log(`${r.mcp}/${r.tool} result:`, r.result); } }); ``` ## Tool Schema Mapping Define input schemas for process transport tools: ``` mcps/ └── mapping/ └── / └── schemes/ ├── search.in.yao # Input schema └── search.out.yao # Output schema (optional) ``` **mapping/tools/schemes/search.in.yao** ```json { "type": "object", "description": "Search data", "properties": { "keyword": { "type": "string" }, "page": { "type": "integer" } }, "x-process-args": [":arguments"] } ``` The `x-process-args` maps MCP arguments to Yao Process parameters: - `":arguments"` - Pass entire arguments object - `"$args.field"` - Extract specific field ### Schema with Nested Objects ```json { "type": "object", "description": "Extract structured data from input", "properties": { "intent": { "type": "string", "enum": ["query", "create", "update"], "description": "Operation intent" }, "items": { "type": "array", "items": { "type": "object", "properties": { "name": { "type": "string" }, "value": { "type": "number" } }, "required": ["name", "value"] } } }, "required": ["intent"], "x-process-args": [":arguments"] } ``` ## Using Assistant Models in MCP MCP tools can reference assistant's own models: **mcps/data.mcp.yao** ```json { "label": "Data Tools", "transport": "process", "tools": { "list_orders": "models.agents.my-assistant.order.Paginate", "get_order": "models.agents.my-assistant.order.Find", "create_order": "models.agents.my-assistant.order.Create", "custom_query": "agents.my-assistant.orders.Query" } } ``` See [Models](models.md) for defining assistant models. ## Error Handling ```typescript function Next(ctx: agent.Context, payload: agent.Payload): agent.Next { const { tools } = payload; if (tools) { for (const tool of tools) { if (tool.error) { ctx.trace.Error(`Tool ${tool.tool} failed: ${tool.error}`); // Handle error } else { // Process result console.log(tool.result); } } } return null; } ``` ## Complete Example **mcps/calculator.mcp.yao** ```json { "label": "Calculator", "description": "Math operations", "transport": "process", "tools": { "add": "scripts.math.Add", "multiply": "scripts.math.Multiply" } } ``` **package.yao** ```json { "name": "Math Assistant", "connector": "gpt-4o", "mcp": { "servers": [{ "server_id": "calculator", "tools": ["add", "multiply"] }] } } ``` **src/index.ts** ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { // Check if calculation is needed const query = messages[messages.length - 1]?.content || ""; if (/\d+\s*[\+\-\*\/]\s*\d+/.test(query)) { // Enable calculator return { messages, mcp_servers: [{ server_id: "calculator" }], }; } return { messages }; } function Next(ctx: agent.Context, payload: agent.Payload): agent.Next { const { tools } = payload; if (tools?.length > 0) { const calcResult = tools.find((t) => t.server === "calculator"); if (calcResult?.result) { return { data: { answer: calcResult.result, expression: calcResult.arguments, }, }; } } return null; } ``` ================================================ FILE: agent/docs/models.md ================================================ # Assistant Models Assistants can define their own namespaced data models in the `models/` directory. These models are automatically loaded with the `agents..` prefix and use isolated database tables. ## Directory Structure ``` assistants/ └── my-assistant/ ├── package.yao └── models/ ├── order.mod.yao # → agents.my-assistant.order ├── item.mod.yao # → agents.my-assistant.item └── nested/ └── log.mod.yao # → agents.my-assistant.nested.log ``` ## Model Definition Standard Yao model definition with automatic table prefixing. **models/order.mod.yao** ```json { "name": "Order", "label": "Order Record", "description": "Customer orders", "table": { "name": "order", "comment": "Order records" }, "columns": [ { "name": "id", "type": "ID", "label": "ID", "primary": true }, { "name": "order_no", "type": "string", "label": "Order Number", "length": 100, "nullable": false, "unique": true, "index": true }, { "name": "customer_id", "type": "string", "label": "Customer ID", "length": 255, "nullable": false, "index": true }, { "name": "total_amount", "type": "decimal", "label": "Total Amount", "precision": 15, "scale": 2, "nullable": false }, { "name": "status", "type": "enum", "label": "Status", "option": ["pending", "confirmed", "shipped", "completed", "cancelled"], "default": "pending", "nullable": false, "index": true }, { "name": "metadata", "type": "json", "label": "Metadata", "nullable": true } ], "relations": { "items": { "type": "hasMany", "model": "item", "key": "order_id", "foreign": "id" } }, "indexes": [ { "name": "idx_customer_status", "columns": ["customer_id", "status"], "type": "index" } ], "option": { "timestamps": true, "soft_deletes": true } } ``` ## Table Naming Tables are automatically prefixed with `agents__`: | Assistant ID | Model File | Model ID | Table Name | | ------------ | ---------------------- | ------------------------ | ------------------------ | | `expense` | `models/order.mod.yao` | `agents.expense.order` | `agents_expense_order` | | `tests.demo` | `models/user.mod.yao` | `agents.tests.demo.user` | `agents_tests_demo_user` | ## Using Models ### In Hooks ```typescript import { Process } from "@yao/runtime"; function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { // Query assistant's own model const orders = Process("models.agents.my-assistant.order.Paginate", { wheres: [{ column: "status", value: "pending" }], limit: 10, }); return { messages }; } ``` ### In MCP Tools ```json { "transport": "process", "tools": { "list_orders": "models.agents.my-assistant.order.Paginate", "get_order": "models.agents.my-assistant.order.Find", "create_order": "models.agents.my-assistant.order.Create", "update_order": "models.agents.my-assistant.order.Update" } } ``` ### In Scripts **src/orders.ts** ```typescript import { Process } from "@yao/runtime"; export function ListPending(): any[] { return Process("models.agents.my-assistant.order.Get", { wheres: [{ column: "status", value: "pending" }], orders: [{ column: "created_at", option: "desc" }], }); } export function CreateOrder(data: any): any { return Process("models.agents.my-assistant.order.Create", data); } export function UpdateStatus(id: number, status: string): any { return Process("models.agents.my-assistant.order.Update", id, { status }); } ``` ## Column Types | Type | Description | Options | | ------------ | -------------------------- | ----------------------- | | `ID` | Auto-increment primary key | `primary: true` | | `string` | VARCHAR | `length` (default: 255) | | `text` | TEXT | - | | `integer` | INT | - | | `bigInteger` | BIGINT | - | | `float` | FLOAT | `precision`, `scale` | | `decimal` | DECIMAL | `precision`, `scale` | | `boolean` | BOOLEAN | - | | `date` | DATE | - | | `datetime` | DATETIME | - | | `timestamp` | TIMESTAMP | - | | `json` | JSON/JSONB | - | | `enum` | ENUM | `option: [...]` | ## Column Options | Option | Type | Description | | ----------- | --------- | ----------------- | | `nullable` | `boolean` | Allow NULL values | | `default` | `any` | Default value | | `unique` | `boolean` | Unique constraint | | `index` | `boolean` | Create index | | `primary` | `boolean` | Primary key | | `length` | `integer` | String length | | `precision` | `integer` | Decimal precision | | `scale` | `integer` | Decimal scale | | `comment` | `string` | Column comment | ## Relations ```json { "relations": { "items": { "type": "hasMany", "model": "item", "key": "order_id", "foreign": "id" }, "customer": { "type": "hasOne", "model": "customer", "key": "id", "foreign": "customer_id" } } } ``` | Type | Description | | ---------------- | ----------------------------- | | `hasOne` | One-to-one relationship | | `hasMany` | One-to-many relationship | | `hasOneThrough` | Has one through intermediate | | `hasManyThrough` | Has many through intermediate | ## Model Options ```json { "option": { "timestamps": true, "soft_deletes": true, "permission": true } } ``` | Option | Description | | -------------- | -------------------------------------- | | `timestamps` | Add `created_at`, `updated_at` columns | | `soft_deletes` | Add `deleted_at` for soft delete | | `permission` | Enable permission checks | ## Process Reference Common model processes: | Process | Description | Arguments | | ------------- | ---------------- | --------------------------- | | `Find` | Get by ID | `id`, `query?` | | `Get` | Get records | `query` | | `Paginate` | Paginated list | `query`, `page`, `pagesize` | | `Create` | Create record | `data` | | `Update` | Update record | `id`, `data` | | `Save` | Create or update | `data` | | `Delete` | Delete record | `id` | | `Destroy` | Hard delete | `id` | | `Insert` | Batch insert | `columns`, `rows` | | `UpdateWhere` | Batch update | `query`, `data` | | `DeleteWhere` | Batch delete | `query` | ## Migration Models are automatically migrated when Yao starts. The migration: 1. Creates tables if not exist 2. Adds new columns 3. Creates indexes 4. Does NOT drop columns (safe migration) To force schema sync: ```bash yao migrate --reset # Warning: drops and recreates tables ``` ## Example: Complete Assistant with Models **assistants/inventory/package.yao** ```json { "name": "Inventory Assistant", "connector": "gpt-4o", "mcp": { "servers": [{ "server_id": "inventory" }] } } ``` **assistants/inventory/models/product.mod.yao** ```json { "name": "Product", "table": { "name": "product" }, "columns": [ { "name": "id", "type": "ID", "primary": true }, { "name": "sku", "type": "string", "length": 50, "unique": true }, { "name": "name", "type": "string", "length": 200 }, { "name": "quantity", "type": "integer", "default": 0 }, { "name": "price", "type": "decimal", "precision": 10, "scale": 2 } ], "option": { "timestamps": true } } ``` **assistants/inventory/mcps/inventory.mcp.yao** ```json { "label": "Inventory", "transport": "process", "tools": { "list_products": "models.agents.inventory.product.Paginate", "get_product": "models.agents.inventory.product.Find", "update_stock": "agents.inventory.stock.Update" } } ``` **assistants/inventory/src/stock.ts** ```typescript import { Process } from "@yao/runtime"; export function Update(args: { sku: string; quantity: number }): any { const product = Process("models.agents.inventory.product.Get", { wheres: [{ column: "sku", value: args.sku }], limit: 1, }); if (!product || product.length === 0) { throw new Error(`Product not found: ${args.sku}`); } return Process("models.agents.inventory.product.Update", product[0].id, { quantity: args.quantity, }); } ``` ================================================ FILE: agent/docs/pages.md ================================================ # Agent Pages Agent Pages provide a built-in SUI (Simple User Interface) framework for building web interfaces for AI agents. Pages are automatically loaded from the `/agent/template/` directory for global templates and `/assistants//pages/` for individual assistant pages. ## Directory Structure ``` / ├── agent/ │ └── template/ # Global template directory │ ├── __document.html # Document template │ ├── __data.json # Global data │ ├── __assets/ # Global assets │ │ ├── css/ │ │ ├── js/ │ │ └── images/ │ ├── pages/ # Global pages (login, error, etc.) │ │ └── login/ │ │ └── login.html │ └── __locales/ # Internationalization │ └── assistants/ └── my-assistant/ ├── package.yao └── pages/ # Assistant-specific pages ├── index/ │ ├── index.html │ ├── index.css │ ├── index.ts │ └── index.backend.ts └── __assets/ # Optional assistant assets ``` ## Route Mapping | File Path | Public URL | | ----------------------------------------- | -------------------- | | `/agent/template/pages/login/login.html` | `/agents/login` | | `/assistants/demo/pages/index/index.html` | `/agents/demo/index` | | `/assistants/demo/pages/chat/chat.html` | `/agents/demo/chat` | ## Quick Start ### 1. Create Document Template **`/agent/template/__document.html`**: ```html {{ $global.title }}
{{ __page }}
``` ### 2. Create Global Data **`/agent/template/__data.json`**: ```json { "title": "AI Agent", "version": "1.0.0" } ``` ### 3. Create a Page **`/assistants/my-assistant/pages/index/index.html`**: ```html

{{ title }}

{{ msg.content }}
``` **`/assistants/my-assistant/pages/index/index.json`**: ```json { "title": "Chat", "messages": [] } ``` **`/assistants/my-assistant/pages/index/index.css`**: ```css .page { max-width: 800px; margin: 0 auto; padding: 24px; } .messages { display: flex; flex-direction: column; gap: 12px; } .message.user { align-self: flex-end; background: #007bff; color: white; } .message.assistant { align-self: flex-start; background: #f0f0f0; } ``` ### 4. Add Backend Script **`/assistants/my-assistant/pages/index/index.backend.ts`**: ```typescript function BeforeRender(request: Request): Record { const chatId = request.query.chat_id; return { messages: chatId ? Process("scripts.chat.GetHistory", chatId) : [], user: request.authorized?.user_id, }; } function ApiGetData(request: Request): any { const { id } = request.payload; return Process("models.data.Find", id, {}); } ``` ### 5. Add Frontend Script **`/assistants/my-assistant/pages/index/index.ts`**: Frontend scripts can be written in two styles: **Style 1: Direct Code (Simple Pages)** ```typescript // Runs immediately when script loads document.addEventListener("DOMContentLoaded", () => { const form = document.querySelector("#myForm") as HTMLFormElement; form.addEventListener("submit", async (e) => { e.preventDefault(); // Handle form submission }); }); // Smooth scrolling for navigation document.querySelectorAll('a[href^="#"]').forEach((anchor) => { anchor.addEventListener("click", function (e) { e.preventDefault(); const target = document.querySelector(this.getAttribute("href")); target?.scrollIntoView({ behavior: "smooth" }); }); }); ``` **Style 2: Component Pattern (Interactive Pages)** ```typescript import { $Backend, Component, EventData } from "@yao/sui"; const self = this as Component; // Event handler bound to s:on-click="HandleClick" self.HandleClick = async (event: Event, data: EventData) => { const result = await $Backend().Call("GetData", data.id); console.log(result); }; // Form submission handler self.HandleSubmit = async (event: Event) => { event.preventDefault(); const form = event.target as HTMLFormElement; const formData = new FormData(form); await $Backend().Call("Submit", Object.fromEntries(formData)); }; ``` **Using Backend API:** ```typescript import { $Backend, Yao } from "@yao/sui"; // Call backend method const data = await $Backend().Call("MethodName", arg1, arg2); // Direct API calls const yao = new Yao(); const res = await yao.Get("/api/endpoint", { param: "value" }); await yao.Post("/api/endpoint", { data: "value" }); ``` ### 6. Build and Run ```bash # Build pages yao sui build agent # Or watch for changes yao sui watch agent # Start server yao start ``` Access at: `http://localhost:5099/agents/my-assistant/index` ## Template Syntax ### Data Binding ```html

{{ title }}

{{ user.name }}

{{ description || "No description" }}

``` ### Conditionals ```html
Welcome, {{ user.name }}!
Welcome, Guest!
Please log in
``` ### Loops ```html
  • {{ i + 1 }}. {{ item.name }}
``` ### Events ```html ``` ### Components Pages can use other pages as components: ```html
Content
``` ## Built-in Variables | Variable | Description | | ---------- | ------------------------------------------- | | `$global` | Global data from `__data.json` | | `$query` | URL query parameters | | `$param` | URL path parameters | | `$payload` | POST request body | | `$cookie` | Cookie values | | `$url` | Current URL info | | `$theme` | Current theme | | `$locale` | Current locale | | `$auth` | OAuth authorization info (if authenticated) | ## Page Configuration Create `.config` for page settings: ```json { "title": "Page Title", "guard": "bearer-jwt", "cache": 3600, "data": { "key": "value" } } ``` ## Asset Paths - **Global assets**: `/agents/assets/...` → `/agent/template/__assets/...` - **Assistant assets**: `/agents//assets/...` → `/assistants//pages/__assets/...` ## Build Output ``` /public/agents/ ├── assets/ │ ├── libsui.min.js # SUI frontend SDK │ ├── css/ # Global CSS │ ├── js/ # Global JS │ └── images/ # Global images │ ├── login.sui # Global page ├── login.cfg │ └── my-assistant/ ├── index.sui # Assistant page └── index.cfg ``` ## Authentication Pages default to public access. To require authentication: **`/assistants/my-assistant/pages/dashboard/dashboard.config`**: ```json { "guard": "bearer-jwt" } ``` Available guards: | Guard | Description | | ------------ | --------------------------------- | | `-` | No authentication (default) | | `bearer-jwt` | JWT token in Authorization header | | `cookie-jwt` | JWT token in cookie | | `oauth` | OAuth 2.0 authentication | ## Triggering Pages from Hooks Use `action` messages to open pages in the sidebar during conversation: ```typescript // Navigate to a page in sidebar ctx.Send({ type: "action", props: { name: "navigate", payload: { route: "/agents/my-assistant/result", // Page route title: "Query Results", // Sidebar title query: { id: "123" }, // Passed as $query in page }, }, }); // Open in new tab ctx.Send({ type: "action", props: { name: "navigate", payload: { route: "/agents/my-assistant/detail", target: "_blank", }, }, }); ``` ### Action Reference #### Navigate Open a route in the sidebar or new window. **Payload:** | Field | Type | Required | Description | | -------- | ------------------------ | -------- | ---------------------------------------------------- | | `route` | `string` | ✅ | Target route or URL | | `title` | `string` | - | Page title (shows custom title bar with back button) | | `icon` | `string` | - | Tab icon (e.g., `material-folder`) | | `query` | `Record` | - | Query parameters (passed as `$query` in page) | | `target` | `'_self'` \| `'_blank'` | - | `_self` (sidebar, default) or `_blank` (new window) | **Route Types:** | Prefix | Type | Description | | ----------------- | -------- | ----------------------------------------------- | | `$dashboard/` | CUI Page | Dashboard pages (e.g., `$dashboard/kb` → `/kb`) | | `/` | SUI Page | Custom pages (e.g., `/agents/demo/result`) | | `http://https://` | External | External URL (loaded in iframe) | **Examples:** ```typescript // Open agent page in sidebar with title ctx.Send({ type: "action", props: { name: "navigate", payload: { route: "/agents/my-assistant/result", title: "Query Results", icon: "material-table_chart", query: { id: "123" }, }, }, }); // Open CUI dashboard page ctx.Send({ type: "action", props: { name: "navigate", payload: { route: "$dashboard/users" }, }, }); // Open external URL in new tab ctx.Send({ type: "action", props: { name: "navigate", payload: { route: "https://docs.example.com", target: "_blank", }, }, }); ``` #### Navigate Back Navigate back in history. ```typescript ctx.Send({ type: "action", props: { name: "navigate.back" }, }); ``` #### Notify Show notification messages. **Actions:** | Action | Description | | ---------------- | ----------------------------- | | `notify.success` | Success notification (green) | | `notify.error` | Error notification (red) | | `notify.warning` | Warning notification (yellow) | | `notify.info` | Info notification (blue) | **Payload:** | Field | Type | Required | Description | | ---------- | --------- | -------- | ---------------------------------------------- | | `message` | `string` | ✅ | Notification message | | `duration` | `number` | - | Auto-close seconds (default: 3, 0 = keep open) | | `icon` | `string` | - | Custom icon (overrides default) | | `closable` | `boolean` | - | Show close button (default: false) | **Examples:** ```typescript // Success notification ctx.Send({ type: "action", props: { name: "notify.success", payload: { message: "Data saved successfully!" }, }, }); // Error with custom duration ctx.Send({ type: "action", props: { name: "notify.error", payload: { message: "Operation failed", duration: 5, closable: true, }, }, }); ``` #### App Menu Refresh application menu/navigation. ```typescript ctx.Send({ type: "action", props: { name: "app.menu.reload" }, }); ``` #### All Actions | Category | Action | Description | | -------- | ------------------- | ------------------------------- | | Navigate | `navigate` | Open page in sidebar or new tab | | | `navigate.back` | Navigate back in history | | Notify | `notify.success` | Show success notification | | | `notify.error` | Show error notification | | | `notify.warning` | Show warning notification | | | `notify.info` | Show info notification | | App | `app.menu.reload` | Refresh application menu | | Modal | `modal.open` | Open content in modal dialog | | | `modal.close` | Close modal | | Table | `table.search` | Trigger table search | | | `table.refresh` | Refresh table data | | | `table.save` | Save table row data | | | `table.delete` | Delete table row(s) | | Form | `form.find` | Load form data by ID | | | `form.submit` | Submit form data | | | `form.reset` | Reset form to initial state | | | `form.setFields` | Set form field values | | MCP | `mcp.tool.call` | Execute MCP tool (client-side) | | | `mcp.resource.read` | Read MCP resource | | Event | `event.emit` | Emit custom event | | Confirm | `confirm` | Show confirmation dialog | ## Frontend API ### Backend Calls ```typescript import { $Backend, Yao } from "@yao/sui"; // Call backend method defined in .backend.ts const data = await $Backend().Call("MethodName", arg1, arg2); // Direct API calls const yao = new Yao(); const res = await yao.Get("/api/endpoint", { query: "value" }); await yao.Post("/api/endpoint", { body: "data" }); ``` ### State Management ```typescript import { Component } from "@yao/sui"; const self = this as Component; // Store values (per component instance) self.store.Set("key", value); const value = self.store.Get("key"); ``` ### Parent Communication (Iframe) ```typescript // Helper: Send action to CUI parent const sendAction = (name: string, payload?: any) => { window.parent.postMessage( { type: "action", message: { name, payload } }, window.location.origin ); }; // Usage sendAction("notify.success", { message: "Done!" }); sendAction("navigate", { route: "/agents/my-assistant/detail", title: "Details", }); // Receive messages from parent window.addEventListener("message", (e) => { if (e.origin !== window.location.origin) return; const { type, message } = e.data; if (type === "setup") { document.documentElement.setAttribute("data-theme", message.theme); } }); ``` ## Iframe Communication When pages are embedded in CUI via `/web//`, they can communicate with the host: ### Receiving Context ```javascript window.addEventListener("message", (e) => { if (e.origin !== window.location.origin) return; if (e.data.type === "setup") { const { theme, locale } = e.data.message; // Apply theme, set locale document.documentElement.setAttribute("data-theme", theme); } }); ``` ### Sending Actions ```javascript // Helper function const sendAction = (name, payload) => { window.parent.postMessage( { type: "action", message: { name, payload } }, window.location.origin ); }; // Show notification sendAction("notify.success", { message: "Done!" }); // Navigate to page sendAction("navigate", { route: "/agents/my-assistant/detail", title: "Details", }); ``` See [Iframe Integration](iframe.md) for complete documentation. ## Related Documentation - [Iframe Integration](iframe.md) - CUI iframe communication - [SUI Template Syntax](../../sui/docs/template-syntax.md) - [SUI Data Binding](../../sui/docs/data-binding.md) - [SUI Components](../../sui/docs/components.md) - [SUI Frontend API](../../sui/docs/frontend-api.md) ================================================ FILE: agent/docs/prompts.md ================================================ # Prompts ## Default Prompts Create `prompts.yml` in the assistant directory: ```yaml - role: system content: | You are a helpful assistant. ## Guidelines - Be concise and accurate - Ask clarifying questions when needed - role: system name: context content: | Current date: {{ $CTX.date }} User locale: {{ $CTX.locale }} ``` ### Prompt Structure ```yaml - role: system | user | assistant content: string name: string (optional) ``` ### Context Variables Use `$CTX.*` for runtime context: | Variable | Description | | ---------------- | -------------------------- | | `$CTX.date` | Current date | | `$CTX.time` | Current time | | `$CTX.locale` | User locale (e.g., en-us) | | `$CTX.timezone` | User timezone | | `$CTX.user_id` | Current user ID | | `$CTX.team_id` | Current team ID | | `$CTX.chat_id` | Current chat session ID | ## Prompt Presets Create presets in `prompts/` directory for different scenarios: ``` prompts/ ├── chat.yml # Casual conversation ├── task.yml # Task-oriented └── analysis.yml # Data analysis ``` **prompts/chat.yml** ```yaml - role: system content: | You are a friendly conversational assistant. Be warm and engaging. ``` **prompts/task.yml** ```yaml - role: system content: | You are a task-focused assistant. Be precise and efficient. ``` ### Using Presets Select preset in Create hook: ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { return { messages, prompt_preset: "task", // Use prompts/task.yml }; } ``` Or via mode configuration in `package.yao`: ```json { "modes": ["chat", "task"], "default_mode": "task" } ``` ## Global Prompts Define global prompts in `agent/prompts.yml` (applies to all assistants): ```yaml - role: system content: | # Global Guidelines - Always be helpful and respectful - Follow company policies ``` ### Disabling Global Prompts Per assistant: ```json { "disable_global_prompts": true } ``` Per request (in Create hook): ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { return { messages, disable_global_prompts: true, }; } ``` ## Multi-line Content Use YAML block scalars for long content: ```yaml - role: system content: | # Assistant Role You are an expert in data analysis. ## Capabilities - Statistical analysis - Data visualization - Report generation ## Guidelines 1. Always validate input data 2. Explain your methodology 3. Provide actionable insights ``` ## Dynamic Prompts Inject dynamic content in Create hook: ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { const userPrefs = ctx.memory.user.Get("preferences"); // Add dynamic system message const dynamicPrompt = { role: "system", content: `User preferences: ${JSON.stringify(userPrefs)}`, }; return { messages: [dynamicPrompt, ...messages], }; } ``` ## Prompt Best Practices 1. **Be specific** - Clear instructions produce better results 2. **Use structure** - Headers, lists, and sections improve readability 3. **Set boundaries** - Define what the assistant should and shouldn't do 4. **Include examples** - Show expected input/output formats 5. **Layer prompts** - Use global + assistant + dynamic prompts together ================================================ FILE: agent/docs/search.md ================================================ # Search The agent search system provides automatic search across web, knowledge base (KB), and database (DB). ## Auto Search Flow 1. **Intent Detection** - `__yao.needsearch` agent analyzes user message 2. **Search Execution** - Executes web/kb/db searches based on intent 3. **Context Injection** - Results injected as system message 4. **Citation** - LLM can cite results using `[1]`, `[2]` format ## Configuration ### Global (agent/search.yml) ```yaml web: provider: tavily # tavily, serper, serpapi api_key_env: TAVILY_API_KEY max_results: 10 kb: threshold: 0.7 graph: true db: max_results: 20 keyword: max_keywords: 5 language: en citation: format: "[{index}]" auto_inject_prompt: true ``` ### Per Assistant (package.yao) ```json { "search": { "web": { "max_results": 5 }, "kb": { "threshold": 0.8 }, "citation": { "format": "[{index}]" } }, "kb": { "collections": ["docs", "faq"] }, "db": { "models": ["articles", "products"] } } ``` ## Controlling Search in Hooks ### Disable Search ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { return { messages, search: false }; } ``` ### Enable Specific Types ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { return { messages, search: { need_search: true, search_types: ["kb", "db"], // Only KB and DB confidence: 1.0, reason: "controlled by hook" } }; } ``` ### Disable via Uses ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { return { messages, uses: { search: "disabled" } }; } ``` ## Search API (ctx.search) ### Web Search ```typescript const result = ctx.search.Web("query", { limit: 10, sites: ["example.com", "docs.example.com"], time_range: "week", // day, week, month, year rerank: { top_n: 5 } }); ``` ### Knowledge Base Search ```typescript const result = ctx.search.KB("query", { collections: ["docs", "faq"], threshold: 0.7, limit: 10, graph: true, rerank: { top_n: 5 } }); ``` ### Database Search ```typescript const result = ctx.search.DB("query", { models: ["articles"], wheres: [{ column: "status", value: "published" }], orders: [{ column: "created_at", option: "desc" }], select: ["id", "title", "content"], limit: 20, rerank: { top_n: 10 } }); ``` ### Parallel Search ```typescript // Wait for all const results = ctx.search.All([ { type: "web", query: "topic" }, { type: "kb", query: "topic", collections: ["docs"] }, { type: "db", query: "topic", models: ["articles"] } ]); // First success with results const results = ctx.search.Any([ { type: "web", query: "topic" }, { type: "kb", query: "topic" } ]); // First to complete const results = ctx.search.Race([ { type: "web", query: "topic" }, { type: "kb", query: "topic" } ]); ``` ## Result Structure ```typescript interface SearchResult { type: "web" | "kb" | "db"; query: string; source: "hook" | "auto" | "user"; items: SearchItem[]; error?: string; } interface SearchItem { citation_id: string; // "1", "2", etc. title: string; url: string; content: string; score: number; } ``` ## Custom Search in Hooks ```typescript function Create(ctx: agent.Context, messages: agent.Message[]): agent.Create { const query = messages[messages.length - 1]?.content || ""; // Custom KB search const kbResult = ctx.search.KB(query, { collections: ["internal_docs"], threshold: 0.8 }); if (kbResult.items?.length > 0) { // Format results as context const context = kbResult.items .map((item, i) => `[${i + 1}] ${item.title}\n${item.content}`) .join("\n\n"); // Inject as system message const contextMsg = { role: "system", content: `Reference information:\n${context}` }; return { messages: [contextMsg, ...messages], search: false // Skip auto search }; } return { messages }; } ``` ## Authorization ### KB Collections Collections are filtered by user authorization: ```typescript // Only collections user has access to are searched const result = ctx.search.KB("query", { collections: ["public", "internal", "secret"] // User without "secret" access won't search that collection }); ``` ### DB Models Database queries include permission filters: ```typescript // Auth filters are automatically added // e.g., { column: "__yao_created_by", value: user_id } const result = ctx.search.DB("query", { models: ["user_documents"] }); ``` ## Web Search Providers ### Tavily ```yaml web: provider: tavily api_key_env: TAVILY_API_KEY ``` ### Serper ```yaml web: provider: serper api_key_env: SERPER_API_KEY ``` ### SerpAPI ```yaml web: provider: serpapi api_key_env: SERPAPI_API_KEY ``` ## Citation LLM responses can include citations: ``` Based on the documentation [1], the feature works by... [2] References: [1] Getting Started Guide - https://docs.example.com/start [2] API Reference - https://docs.example.com/api ``` ### Citation Format ```yaml citation: format: "[{index}]" # or "({index})" or "[^{index}]" auto_inject_prompt: true custom_prompt: | When citing sources, use the format [N] where N is the reference number. ``` ================================================ FILE: agent/docs/testing.md ================================================ # Agent Testing A comprehensive testing framework for Yao AI agents with support for standard testing, dynamic (simulator-driven) testing, agent-driven assertions, and CI integration. ## Quick Start ```bash # Test with direct message (auto-detect agent from current directory) cd assistants/my-assistant yao agent test -i "Hello, how are you?" # Test with JSONL file yao agent test -i tests/inputs.jsonl # Generate HTML report yao agent test -i tests/inputs.jsonl -o report.html # Stability analysis (run each test 5 times) yao agent test -i tests/inputs.jsonl --runs 5 ``` ## Input Modes The `-i` flag supports multiple input modes: | Mode | Example | Description | |------|---------|-------------| | Direct message | `-i "Hello"` | Single message test | | JSONL file | `-i tests/inputs.jsonl` | Multiple test cases | | Agent-driven | `-i "agents:tests.generator?count=10"` | Generate tests with agent | | Script test | `-i scripts.expense.setup` | Test handler scripts | | Script-generated | `-i "scripts:tests.gen.Generate"` | Generate tests from script | ## Test Case Format (JSONL) ### Basic Test ```jsonl {"id": "greeting", "input": "Hello", "assert": {"type": "contains", "value": "Hi"}} ``` ### With Conversation History ```jsonl { "id": "multi-turn", "input": [ {"role": "user", "content": "What's 2+2?"}, {"role": "assistant", "content": "4"}, {"role": "user", "content": "Multiply by 3"} ], "assert": {"type": "contains", "value": "12"} } ``` ### With File Attachments ```jsonl { "id": "image-test", "input": { "role": "user", "content": [ {"type": "text", "text": "Describe this image"}, {"type": "image", "source": "file://fixtures/test.jpg"} ] } } ``` ## Assertions ### Static Assertions | Type | Description | Example | |------|-------------|---------| | `equals` | Exact match | `{"type": "equals", "value": {"key": "val"}}` | | `contains` | Output contains value | `{"type": "contains", "value": "keyword"}` | | `not_contains` | Output does not contain | `{"type": "not_contains", "value": "error"}` | | `regex` | Match regex pattern | `{"type": "regex", "value": "\\d+"}` | | `json_path` | Extract and compare | `{"type": "json_path", "path": "$.field", "value": true}` | | `type` | Check output type | `{"type": "type", "value": "object"}` | | `tool_called` | Check tool was called | `{"type": "tool_called", "value": "setup"}` | | `tool_result` | Check tool result | `{"type": "tool_result", "value": {"tool": "setup", "result": {"success": true}}}` | ### Agent-Driven Assertions Use LLM to validate response semantics: ```jsonl { "id": "helpful-response", "input": "How do I reset my password?", "assert": { "type": "agent", "use": "agents:tests.validator-agent", "value": "Response should provide clear step-by-step instructions" } } ``` ### Multiple Assertions All assertions must pass: ```jsonl { "id": "complete-check", "input": "Submit expense", "assert": [ {"type": "contains", "value": "expense"}, {"type": "not_contains", "value": "error"}, {"type": "regex", "value": "(?i)(submitted|created)"} ] } ``` ## Dynamic Mode (Simulator) For testing complex conversation flows with a user simulator: ```jsonl { "id": "order-flow", "input": "I want to order coffee", "simulator": { "use": "tests.simulator-agent", "options": { "metadata": { "persona": "Customer", "goal": "Order a medium latte" } } }, "checkpoints": [ { "id": "greeting", "assert": {"type": "regex", "value": "(?i)(hello|hi)"} }, { "id": "ask-size", "after": ["greeting"], "assert": {"type": "regex", "value": "(?i)size"} }, { "id": "confirm", "after": ["ask-size"], "assert": {"type": "regex", "value": "(?i)confirm"} } ], "max_turns": 10 } ``` Run with: ```bash yao agent test -i tests/dynamic.jsonl --simulator tests.simulator-agent -v ``` ## Script Testing Test agent handler scripts with the `t.assert` API: ```typescript // assistants/my-assistant/src/setup_test.ts import { SystemReady } from "./setup"; export function TestSystemReady(t: TestingT, ctx: Context) { const result = SystemReady(ctx); t.assert.True(result.success, "Should succeed"); t.assert.Equal(result.status, "ready", "Status should be ready"); t.assert.NotNil(result.data, "Data should not be nil"); } export function TestWithAgentAssertion(t: TestingT, ctx: Context) { const response = Process("agents.my-assistant.Stream", ctx, messages); // Static assertion t.assert.Contains(response.content, "confirm"); // Agent-driven assertion t.assert.Agent(response.content, "tests.validator-agent", { criteria: "Response should ask for confirmation" }); } ``` Run with: ```bash yao agent test -i scripts.my-assistant.setup -v ``` ### Available Assertions | Method | Description | |--------|-------------| | `t.assert.True(value, msg)` | Assert value is true | | `t.assert.False(value, msg)` | Assert value is false | | `t.assert.Equal(a, b, msg)` | Assert a equals b | | `t.assert.NotEqual(a, b, msg)` | Assert a not equals b | | `t.assert.Nil(value, msg)` | Assert value is null/undefined | | `t.assert.NotNil(value, msg)` | Assert value is not nil | | `t.assert.Contains(s, sub, msg)` | Assert string contains substr | | `t.assert.Len(arr, n, msg)` | Assert array/string length | | `t.assert.Agent(resp, id, opts)` | Agent-driven assertion | ## Before/After Hooks ### Per-Test Hooks ```jsonl { "id": "with-setup", "input": "Show my data", "before": "env_test.Before", "after": "env_test.After" } ``` ### Global Hooks ```bash yao agent test -i tests/inputs.jsonl --before env_test.BeforeAll --after env_test.AfterAll ``` ### Hook Implementation ```typescript // assistants/my-assistant/src/env_test.ts export function Before(ctx: Context, testCase: TestCase): any { const userId = Process("models.user.Create", { name: "Test User" }); return { userId }; // Passed to After } export function After(ctx: Context, testCase: TestCase, result: TestResult, beforeData: any) { if (beforeData?.userId) { Process("models.user.Delete", beforeData.userId); } } export function BeforeAll(ctx: Context, testCases: TestCase[]): any { Process("models.migrate"); return { initialized: true }; } export function AfterAll(ctx: Context, results: TestResult[], beforeData: any) { const passed = results.filter(r => r.status === "passed").length; console.log(`Tests completed: ${passed}/${results.length} passed`); } ``` ## Custom Context Create a JSON file for custom authorization: ```json { "chat_id": "test-chat-001", "authorized": { "user_id": "test-user-123", "team_id": "test-team-456", "constraints": { "owner_only": true, "extra": { "department": "engineering" } } } } ``` Use with `--ctx`: ```bash yao agent test -i scripts.my-assistant.setup --ctx tests/context.json -v ``` ## Command Line Options | Flag | Description | Default | |------|-------------|---------| | `-i` | Input: JSONL file, message, `agents:xxx`, or `scripts:xxx` | (required) | | `-o` | Output file path | `output-{timestamp}.jsonl` | | `-n` | Agent ID (optional, auto-detected) | auto-detect | | `-a` | Application directory | auto-detect | | `-e` | Environment file | - | | `-c` | Override connector | agent default | | `-u` | Test user ID | `test-user` | | `-t` | Test team ID | `test-team` | | `-r` | Reporter agent ID | built-in | | `-v` | Verbose output | false | | `--ctx` | Path to context JSON file | - | | `--simulator` | Default simulator agent ID | - | | `--before` | Global BeforeAll hook | - | | `--after` | Global AfterAll hook | - | | `--runs` | Runs per test (stability analysis) | 1 | | `--run` | Regex pattern to filter tests | - | | `--timeout` | Timeout per test | 2m | | `--parallel` | Parallel test cases | 1 | | `--fail-fast` | Stop on first failure | false | | `--dry-run` | Generate tests without running | false | ## Output Formats Determined by `-o` file extension: | Extension | Format | Description | |-----------|--------|-------------| | `.jsonl` | JSONL | Streaming (default) | | `.json` | JSON | Complete structured | | `.md` | Markdown | Human-readable | | `.html` | HTML | Interactive web report | ## Stability Analysis Run each test multiple times to measure consistency: ```bash yao agent test -i tests/inputs.jsonl --runs 5 -o stability.json ``` | Pass Rate | Classification | |-----------|----------------| | 100% | Stable | | 80-99% | Mostly Stable | | 50-79% | Unstable | | < 50% | Highly Unstable | ## CI Integration ```bash # Exit code: 0 = all passed, 1 = failures yao agent test -i tests/inputs.jsonl --fail-fast # Run with parallel execution yao agent test -i tests/inputs.jsonl --parallel 4 ``` ### GitHub Actions Example ```yaml - name: Run Agent Tests run: | yao agent test -i assistants/my-assistant/tests/inputs.jsonl \ -u ci-user -t ci-team \ --runs 3 \ -o report.json - name: Run Dynamic Tests run: | yao agent test -i assistants/my-assistant/tests/dynamic.jsonl \ --simulator tests.simulator-agent \ -v - name: Run Script Tests run: | yao agent test -i scripts.my-assistant.setup -v ``` ## Exit Codes | Code | Description | |------|-------------| | 0 | All tests passed | | 1 | Tests failed, configuration error, or runtime error | ================================================ FILE: agent/i18n/builtin.go ================================================ package i18n // init registers built-in global messages func init() { // Initialize __global__ if not exists if Locales["__global__"] == nil { Locales["__global__"] = make(map[string]I18n) } // Built-in English messages Locales["__global__"]["en"] = I18n{ Locale: "en", Messages: map[string]any{ // Assistant: agent.go Stream() function "assistant.agent.stream.label": "{{name}}", "assistant.agent.stream.description": "{{name}} is processing the request", "assistant.agent.stream.history": "Get Chat History", "assistant.agent.stream.capabilities": "Get Connector Capabilities", "assistant.agent.stream.create_hook": "Call Create Hook", "assistant.agent.stream.closing": "Closing output (root call)", "assistant.agent.stream.skipping": "Skipping output close (nested call)", "assistant.agent.stream.close_error": "Failed to close output", "assistant.agent.completion.label": "Agent Completion", "assistant.agent.completion.description": "Final output from {{name}}", // LLM: providers/openai/openai.go Stream() function "llm.openai.stream.label": "LLM %s", "llm.openai.stream.description": "LLM %s is processing the request", "llm.openai.stream.starting": "Starting stream request", "llm.openai.stream.request": "Stream Request", "llm.openai.stream.retry": "Stream request failed, retrying", "llm.openai.stream.api_error": "OpenAI API returned error response", "llm.openai.stream.error": "OpenAI Stream Error", "llm.openai.stream.no_data": "Request body that caused empty response", "llm.openai.stream.no_data_info": "Request details", "llm.openai.post.api_error": "OpenAI API error response", // LLM: handlers/stream.go (general LLM stream handler) "llm.handlers.stream.info": "LLM Stream", "llm.handlers.stream.raw_output": "LLM Raw Output", // Output: adapters/openai/writer.go "output.openai.writer.sending_chunk": "Sending chunk to client", "output.openai.writer.sending_done": "Sending [DONE] to client", "output.openai.writer.adapt_error": "Failed to adapt message", "output.openai.writer.chunk_error": "Failed to send chunk", "output.openai.writer.group_error": "Failed to write message in group", "output.openai.writer.send_error": "Failed to send data to client", "output.openai.writer.marshal_error": "Failed to marshal chunk", "output.openai.writer.done_error": "Failed to send [DONE] to client", // Output: adapters/cui/writer.go "output.cui.writer.sending_chunk": "Sending chunk to client", "output.cui.writer.adapt_error": "Failed to adapt message", "output.cui.writer.chunk_error": "Failed to send chunk", "output.cui.writer.group_error": "Failed to send message group", "output.cui.writer.send_error": "Failed to send data to client", "output.cui.writer.marshal_error": "Failed to marshal chunk", // Output: Stream event messages "output.stream_start": "Assistant is processing", "output.view_trace": "View process", // Common status messages "common.status.processing": "Processing", "common.status.completed": "Completed", "common.status.failed": "Failed", "common.status.retrying": "Retrying", // MCP: context/mcp.go - Resource operations "mcp.list_resources.label": "MCP: List Resources", "mcp.list_resources.description": "List resources from MCP client '%s'", "mcp.read_resource.label": "MCP: Read Resource", "mcp.read_resource.description": "Read resource '%s' from MCP client '%s'", // MCP: context/mcp.go - Tool operations "mcp.list_tools.label": "MCP: List Tools", "mcp.list_tools.description": "List tools from MCP client '%s'", "mcp.call_tool.label": "MCP: Call Tool", "mcp.call_tool.description": "Call tool '%s' from MCP client '%s'", "mcp.call_tools.label": "MCP: Call Tools", "mcp.call_tools.description": "Call %d tools sequentially from MCP client '%s'", "mcp.call_tools_parallel.label": "MCP: Call Tools (Parallel)", "mcp.call_tools_parallel.description": "Call %d tools in parallel from MCP client '%s'", // MCP: context/mcp.go - Prompt operations "mcp.list_prompts.label": "MCP: List Prompts", "mcp.list_prompts.description": "List prompts from MCP client '%s'", "mcp.get_prompt.label": "MCP: Get Prompt", "mcp.get_prompt.description": "Get prompt '%s' from MCP client '%s'", // MCP: context/mcp.go - Sample operations "mcp.list_samples.label": "MCP: List Samples", "mcp.list_samples.description": "List samples for '%s' from MCP client '%s'", "mcp.get_sample.label": "MCP: Get Sample", "mcp.get_sample.description": "Get sample #%d for '%s' from MCP client '%s'", // KB: Chat collection "kb.chat.name": "Chat Knowledge Base", "kb.chat.description": "Auto-created knowledge base collection for chat sessions", // Sandbox: assistant/sandbox.go - Sandbox status messages "sandbox.preparing": "Getting things ready...", "sandbox.ready": "Sandbox ready", "sandbox.working": "Working on your request", "sandbox.completed": "Completed", "sandbox.failed": "Execution failed", "sandbox.starting": "Setting up workspace...", "sandbox.pulling_image": "Preparing environment (first time may take a moment)", "sandbox.waiting_response": "Waiting for AI response...", // Sandbox: claude/executor.go - Tool execution messages "sandbox.tool.read": "Reading file", "sandbox.tool.write": "Writing file", "sandbox.tool.edit": "Editing file", "sandbox.tool.bash": "Running command", "sandbox.tool.glob": "Finding files", "sandbox.tool.grep": "Searching code", "sandbox.tool.ls": "Listing directory", "sandbox.tool.task": "Running subtask", "sandbox.tool.web_search": "Searching web", "sandbox.tool.web_fetch": "Fetching URL", "sandbox.tool.todo_write": "Managing tasks", "sandbox.tool.ask_question": "Asking question", "sandbox.tool.switch_mode": "Switching mode", "sandbox.tool.read_lints": "Checking lints", "sandbox.tool.edit_notebook": "Editing notebook", "sandbox.tool.unknown": "Executing {{name}}", // Content: content/image/image.go - Image processing messages "content.image.analyzing": "Analyzing image", // Content: content/pdf/pdf.go - PDF processing messages "content.pdf.analyzing_page": "Analyzing PDF page %d/%d", // Search: assistant/search.go - Output messages "search.loading": "Searching", "search.success": "Found %d references", "search.success.one": "Found 1 reference", "search.partial": "Found %d references (some sources failed)", "search.failed": "Search failed", "search.no_results": "No references found", // Search Intent: assistant/search.go - Intent detection messages "search.intent.loading": "Checking if references are needed", "search.intent.need_search": "Searching for references", "search.intent.no_search": "No references needed", // Keyword Extraction: assistant/search.go - Keyword extraction messages "search.keyword.loading": "Analyzing conversation", "search.keyword.done": "Analysis complete", // Search: assistant/search.go - Trace labels "search.trace.label": "Search", "search.trace.description": "Search the web and knowledge base for relevant information", "search.trace.web.label": "Web Search", "search.trace.web.description": "Searching the web", "search.trace.kb.label": "KB Search", "search.trace.kb.description": "Searching knowledge base", "search.trace.db.label": "DB Search", "search.trace.db.description": "Searching database", }, } // Built-in Chinese (Simplified) messages Locales["__global__"]["zh-cn"] = I18n{ Locale: "zh-cn", Messages: map[string]any{ // Assistant: agent.go Stream() function "assistant.agent.stream.label": "{{name}}", "assistant.agent.stream.description": "{{name}} 正在处理请求", "assistant.agent.stream.history": "获取聊天历史", "assistant.agent.stream.capabilities": "获取连接器能力", "assistant.agent.stream.create_hook": "调用 Create Hook", "assistant.agent.stream.closing": "关闭输出(根调用)", "assistant.agent.stream.skipping": "跳过输出关闭(嵌套调用)", "assistant.agent.stream.close_error": "关闭输出失败", "assistant.agent.completion.label": "智能体完成", "assistant.agent.completion.description": "{{name}} 最终输出", // LLM: providers/openai/openai.go Stream() function "llm.openai.stream.label": "LLM %s", "llm.openai.stream.description": "LLM %s 正在处理请求", "llm.openai.stream.starting": "开始流式请求", "llm.openai.stream.request": "流式请求", "llm.openai.stream.retry": "流式请求失败,正在重试", "llm.openai.stream.api_error": "OpenAI API 返回错误响应", "llm.openai.stream.error": "OpenAI 流错误", "llm.openai.stream.no_data": "导致空响应的请求体", "llm.openai.stream.no_data_info": "请求详情", "llm.openai.post.api_error": "OpenAI API 错误响应", // LLM: handlers/stream.go (general LLM stream handler) "llm.handlers.stream.info": "LLM 流式输出", "llm.handlers.stream.raw_output": "LLM 原始输出", // Output: adapters/openai/writer.go "output.openai.writer.sending_chunk": "向客户端发送数据块", "output.openai.writer.sending_done": "向客户端发送 [DONE]", "output.openai.writer.adapt_error": "适配消息失败", "output.openai.writer.chunk_error": "发送数据块失败", "output.openai.writer.group_error": "写入消息组中的消息失败", "output.openai.writer.send_error": "发送数据到客户端失败", "output.openai.writer.marshal_error": "序列化数据块失败", "output.openai.writer.done_error": "发送 [DONE] 到客户端失败", // Output: adapters/cui/writer.go "output.cui.writer.sending_chunk": "向客户端发送数据块", "output.cui.writer.adapt_error": "适配消息失败", "output.cui.writer.chunk_error": "发送数据块失败", "output.cui.writer.group_error": "发送消息组失败", "output.cui.writer.send_error": "发送数据到客户端失败", "output.cui.writer.marshal_error": "序列化数据块失败", // Output: Stream event messages "output.stream_start": "智能体正在处理", "output.view_trace": "查看处理详情", // Common status messages "common.status.processing": "处理中", "common.status.completed": "已完成", "common.status.failed": "失败", "common.status.retrying": "重试中", // KB: Chat collection "kb.chat.name": "聊天知识库", "kb.chat.description": "自动为聊天会话创建的知识库集合", // Sandbox: assistant/sandbox.go - Sandbox status messages "sandbox.preparing": "正在准备...", "sandbox.ready": "就绪", "sandbox.working": "正在处理您的请求", "sandbox.completed": "处理完成", "sandbox.failed": "执行失败", "sandbox.starting": "正在启动工作区...", "sandbox.pulling_image": "正在准备运行环境(首次可能需要一点时间)", "sandbox.waiting_response": "等待 AI 响应...", // Sandbox: claude/executor.go - Tool execution messages "sandbox.tool.read": "正在读取文件", "sandbox.tool.write": "正在写入文件", "sandbox.tool.edit": "正在编辑文件", "sandbox.tool.bash": "正在执行命令", "sandbox.tool.glob": "正在查找文件", "sandbox.tool.grep": "正在搜索代码", "sandbox.tool.ls": "正在列出目录", "sandbox.tool.task": "正在执行子任务", "sandbox.tool.web_search": "正在搜索网页", "sandbox.tool.web_fetch": "正在获取网页", "sandbox.tool.todo_write": "正在管理任务", "sandbox.tool.ask_question": "正在询问问题", "sandbox.tool.switch_mode": "正在切换模式", "sandbox.tool.read_lints": "正在检查代码", "sandbox.tool.edit_notebook": "正在编辑笔记本", "sandbox.tool.unknown": "正在执行 {{name}}", // Content: content/image/image.go - Image processing messages "content.image.analyzing": "正在分析图片", // Content: content/pdf/pdf.go - PDF processing messages "content.pdf.analyzing_page": "正在分析 PDF 第 %d/%d 页", // Search: assistant/search.go - Output messages "search.loading": "正在搜索", "search.success": "找到 %d 条参考资料", "search.success.one": "找到 1 条参考资料", "search.partial": "找到 %d 条参考资料(部分来源失败)", "search.failed": "搜索失败", "search.no_results": "未找到相关资料", // Search Intent: assistant/search.go - Intent detection messages "search.intent.loading": "检查是否需要查询资料", "search.intent.need_search": "正在查询相关资料", "search.intent.no_search": "无需查询资料", // Keyword Extraction: assistant/search.go - Keyword extraction messages "search.keyword.loading": "正在分析对话内容", "search.keyword.done": "分析完成", // Search: assistant/search.go - Trace labels "search.trace.label": "搜索", "search.trace.description": "搜索网络和知识库获取相关信息", "search.trace.web.label": "网页搜索", "search.trace.web.description": "搜索网页获取相关信息", "search.trace.kb.label": "知识库搜索", "search.trace.kb.description": "搜索知识库获取相关信息", "search.trace.db.label": "数据库搜索", "search.trace.db.description": "搜索数据库获取相关信息", }, } // Built-in Chinese (short code) - same as zh-cn Locales["__global__"]["zh"] = I18n{ Locale: "zh", Messages: map[string]any{ // Assistant: agent.go Stream() function "assistant.agent.stream.label": "{{name}}", "assistant.agent.stream.description": "{{name}} 正在处理请求", "assistant.agent.stream.history": "获取聊天历史", "assistant.agent.stream.capabilities": "获取连接器能力", "assistant.agent.stream.create_hook": "调用 Create Hook", "assistant.agent.stream.closing": "关闭输出(根调用)", "assistant.agent.stream.skipping": "跳过输出关闭(嵌套调用)", "assistant.agent.stream.close_error": "关闭输出失败", "assistant.agent.completion.label": "智能体完成", "assistant.agent.completion.description": "{{name}} 最终输出", // LLM: providers/openai/openai.go Stream() function "llm.openai.stream.label": "LLM %s", "llm.openai.stream.description": "LLM %s 正在处理请求", "llm.openai.stream.starting": "开始流式请求", "llm.openai.stream.request": "流式请求", "llm.openai.stream.retry": "流式请求失败,正在重试", "llm.openai.stream.api_error": "OpenAI API 返回错误响应", "llm.openai.stream.error": "OpenAI 流错误", "llm.openai.stream.no_data": "导致空响应的请求体", "llm.openai.stream.no_data_info": "请求详情", "llm.openai.post.api_error": "OpenAI API 错误响应", // LLM: handlers/stream.go (general LLM stream handler) "llm.handlers.stream.info": "LLM 流式输出", "llm.handlers.stream.raw_output": "LLM 原始输出", // Output: adapters/openai/writer.go "output.openai.writer.sending_chunk": "向客户端发送数据块", "output.openai.writer.sending_done": "向客户端发送 [DONE]", "output.openai.writer.adapt_error": "适配消息失败", "output.openai.writer.chunk_error": "发送数据块失败", "output.openai.writer.group_error": "写入消息组中的消息失败", "output.openai.writer.send_error": "发送数据到客户端失败", "output.openai.writer.marshal_error": "序列化数据块失败", "output.openai.writer.done_error": "发送 [DONE] 到客户端失败", // Output: adapters/cui/writer.go "output.cui.writer.sending_chunk": "向客户端发送数据块", "output.cui.writer.adapt_error": "适配消息失败", "output.cui.writer.chunk_error": "发送数据块失败", "output.cui.writer.group_error": "发送消息组失败", "output.cui.writer.send_error": "发送数据到客户端失败", "output.cui.writer.marshal_error": "序列化数据块失败", // Output: Stream event messages "output.stream_start": "智能体正在处理", "output.view_trace": "查看处理详情", // Common status messages "common.status.processing": "处理中", "common.status.completed": "已完成", "common.status.failed": "失败", "common.status.retrying": "重试中", // MCP: context/mcp.go - Resource operations "mcp.list_resources.label": "MCP: 列出资源", "mcp.list_resources.description": "从 MCP 客户端 '%s' 列出资源", "mcp.read_resource.label": "MCP: 读取资源", "mcp.read_resource.description": "从 MCP 客户端 '%s' 读取资源 '%s'", // MCP: context/mcp.go - Tool operations "mcp.list_tools.label": "MCP: 列出工具", "mcp.list_tools.description": "从 MCP 客户端 '%s' 列出工具", "mcp.call_tool.label": "MCP: 调用工具", "mcp.call_tool.description": "从 MCP 客户端 '%s' 调用工具 '%s'", "mcp.call_tools.label": "MCP: 调用工具", "mcp.call_tools.description": "从 MCP 客户端 '%s' 顺序调用 %d 个工具", "mcp.call_tools_parallel.label": "MCP: 调用工具(并行)", "mcp.call_tools_parallel.description": "从 MCP 客户端 '%s' 并行调用 %d 个工具", // MCP: context/mcp.go - Prompt operations "mcp.list_prompts.label": "MCP: 列出提示词", "mcp.list_prompts.description": "从 MCP 客户端 '%s' 列出提示词", "mcp.get_prompt.label": "MCP: 获取提示词", "mcp.get_prompt.description": "从 MCP 客户端 '%s' 获取提示词 '%s'", // MCP: context/mcp.go - Sample operations "mcp.list_samples.label": "MCP: 列出示例", "mcp.list_samples.description": "从 MCP 客户端 '%s' 列出 '%s' 的示例", "mcp.get_sample.label": "MCP: 获取示例", "mcp.get_sample.description": "从 MCP 客户端 '%s' 获取 '%s' 的第 %d 个示例", // KB: Chat collection "kb.chat.name": "聊天知识库", "kb.chat.description": "自动为聊天会话创建的知识库集合", // Sandbox: assistant/sandbox.go - Sandbox status messages "sandbox.preparing": "正在准备...", "sandbox.ready": "就绪", "sandbox.working": "正在处理您的请求", "sandbox.completed": "处理完成", "sandbox.failed": "执行失败", "sandbox.starting": "正在启动工作区...", "sandbox.pulling_image": "正在准备运行环境(首次可能需要一点时间)", "sandbox.waiting_response": "等待 AI 响应...", // Sandbox: claude/executor.go - Tool execution messages "sandbox.tool.read": "正在读取文件", "sandbox.tool.write": "正在写入文件", "sandbox.tool.edit": "正在编辑文件", "sandbox.tool.bash": "正在执行命令", "sandbox.tool.glob": "正在查找文件", "sandbox.tool.grep": "正在搜索代码", "sandbox.tool.ls": "正在列出目录", "sandbox.tool.task": "正在执行子任务", "sandbox.tool.web_search": "正在搜索网页", "sandbox.tool.web_fetch": "正在获取网页", "sandbox.tool.todo_write": "正在管理任务", "sandbox.tool.ask_question": "正在询问问题", "sandbox.tool.switch_mode": "正在切换模式", "sandbox.tool.read_lints": "正在检查代码", "sandbox.tool.edit_notebook": "正在编辑笔记本", "sandbox.tool.unknown": "正在执行 {{name}}", // Content: content/image/image.go - Image processing messages "content.image.analyzing": "正在分析图片", // Content: content/pdf/pdf.go - PDF processing messages "content.pdf.analyzing_page": "正在分析 PDF 第 %d/%d 页", // Search: assistant/search.go - Output messages "search.loading": "正在搜索", "search.success": "找到 %d 条参考资料", "search.success.one": "找到 1 条参考资料", "search.partial": "找到 %d 条参考资料(部分来源失败)", "search.failed": "搜索失败", "search.no_results": "未找到相关资料", // Search Intent: assistant/search.go - Intent detection messages "search.intent.loading": "检查是否需要查询资料", "search.intent.need_search": "正在查询相关资料", "search.intent.no_search": "无需查询资料", // Keyword Extraction: assistant/search.go - Keyword extraction messages "search.keyword.loading": "正在分析对话内容", "search.keyword.done": "分析完成", // Search: assistant/search.go - Trace labels "search.trace.label": "搜索", "search.trace.description": "搜索网络和知识库获取相关信息", "search.trace.web.label": "网页搜索", "search.trace.web.description": "搜索网页获取相关信息", "search.trace.kb.label": "知识库搜索", "search.trace.kb.description": "搜索知识库获取相关信息", "search.trace.db.label": "数据库搜索", "search.trace.db.description": "搜索数据库获取相关信息", }, } } ================================================ FILE: agent/i18n/i18n.go ================================================ package i18n import ( "path/filepath" "regexp" "strings" "github.com/yaoapp/gou/application" "github.com/yaoapp/gou/fs" "github.com/yaoapp/kun/maps" ) // Locales the locales var Locales = map[string]map[string]I18n{} // I18n the i18n struct type I18n struct { Locale string `json:"locale,omitempty" yaml:"locale,omitempty"` Messages map[string]any `json:"messages,omitempty" yaml:"messages,omitempty"` } // Map the i18n map type Map map[string]I18n // Parse parse the input func (i18n I18n) Parse(input any) any { if input == nil { return nil } switch in := input.(type) { case string: return i18n.parseString(in) case map[string]any: new := map[string]any{} for key, value := range in { new[key] = i18n.Parse(value) } return new case []any: new := []any{} for _, value := range in { new = append(new, i18n.Parse(value)) } return new case []string: new := []string{} for _, value := range in { if parsed := i18n.Parse(value); parsed != nil { if s, ok := parsed.(string); ok { new = append(new, s) } else { new = append(new, value) } } else { new = append(new, value) } } return new } return input } // parseString parse a string value func (i18n I18n) parseString(in string) string { trimed := strings.TrimSpace(in) // Check if it's a direct message key (no template markers) if !strings.Contains(trimed, "{{") && !strings.Contains(trimed, "}}") { if val, ok := i18n.Messages[trimed]; ok { if s, ok := val.(string); ok { return s } } return in } // Check if it's a full template expression {{...}} (exact match - entire string is one template) hasExp := strings.HasPrefix(trimed, "{{") && strings.HasSuffix(trimed, "}}") if hasExp { // Check if there's only ONE template pattern (no text before/after or multiple templates) re := regexp.MustCompile(`\{\{\s*([^}]+?)\s*\}\}`) matches := re.FindAllString(trimed, -1) // Only treat as full template if there's exactly one match and it equals the trimmed string if len(matches) == 1 && matches[0] == trimed { exp := strings.TrimSpace(strings.TrimPrefix(strings.TrimSuffix(trimed, "}}"), "{{")) if val, ok := i18n.Messages[exp]; ok { if s, ok := val.(string); ok { return s } } return in } } // Handle embedded template variables: "text {{var}} more {{var2}}" if strings.Contains(in, "{{") && strings.Contains(in, "}}") { result := in // Use regex to find all {{...}} patterns re := regexp.MustCompile(`\{\{\s*([^}]+?)\s*\}\}`) matches := re.FindAllStringSubmatch(in, -1) for _, match := range matches { if len(match) >= 2 { fullMatch := match[0] // Full match including {{ }} varName := strings.TrimSpace(match[1]) // Variable name without {{ }} // Try to replace with value from Messages if val, ok := i18n.Messages[varName]; ok { if s, ok := val.(string); ok { result = strings.Replace(result, fullMatch, s, 1) } } } } return result } return in } // GetLocales get the locales from path func GetLocales(path string) (Map, error) { app, err := fs.Get("app") if err != nil { return nil, err } i18ns := Map{} localesdir := filepath.Join(path, "locales") if has, _ := app.Exists(localesdir); has { locales, err := app.ReadDir(localesdir, true) if err != nil { return nil, err } // load locales for _, locale := range locales { localeData, err := app.ReadFile(locale) if err != nil { return nil, err } var messages maps.Map = map[string]any{} err = application.Parse(locale, localeData, &messages) if err != nil { return nil, err } name := strings.ToLower(strings.TrimSuffix(filepath.Base(locale), ".yml")) i18ns[name] = I18n{Locale: name, Messages: messages} } } return i18ns, nil } // Flatten flattens the map of locales by adding short language codes and region codes // e.g., "en-us" will also create "en" and "us" entries // If __global__ locales exist, they are merged (local/user messages override global built-in messages) func (m Map) Flatten() Map { flattened := make(Map) // First, process local messages with Dot() flattening for localeCode, i18n := range m { // Flatten nested messages to dot notation (e.g., {"local": {"key": "value"}} -> {"local.key": "value"}) flattened[localeCode] = I18n{ Locale: localeCode, Messages: maps.MapOf(i18n.Messages).Dot(), } // Add short language codes parts := strings.Split(localeCode, "-") if len(parts) > 1 { // Add short language code (e.g., "en" from "en-us") if _, ok := flattened[parts[0]]; !ok { flattened[parts[0]] = flattened[localeCode] } // Add region code (e.g., "us" from "en-us") if _, ok := flattened[parts[1]]; !ok { flattened[parts[1]] = flattened[localeCode] } } } // Merge with global locales if they exist // Strategy: Start with global (built-in), then override with local (user) globalLocales, hasGlobal := Locales["__global__"] if !hasGlobal { return flattened } for globalLocaleCode, globalI18n := range globalLocales { // Ensure global messages are also flattened (though builtin.go already uses flat keys) globalFlattened := maps.MapOf(globalI18n.Messages).Dot() if localI18n, ok := flattened[globalLocaleCode]; ok { // Both global and local exist: merge with local overriding global mergedMessages := make(map[string]any) // First copy all global messages for k, v := range globalFlattened { mergedMessages[k] = v } // Then override with local messages for k, v := range localI18n.Messages { mergedMessages[k] = v } flattened[globalLocaleCode] = I18n{ Locale: globalLocaleCode, Messages: mergedMessages, } } else { // Only global exists, add it with flattened messages flattened[globalLocaleCode] = I18n{ Locale: globalLocaleCode, Messages: globalFlattened, } } } return flattened } // FlattenWithGlobal is deprecated. Use Flatten() instead, which now automatically merges with global locales. // Kept for backward compatibility. func (m Map) FlattenWithGlobal() Map { return m.Flatten() } // Translate translate the input with recursive variable resolution // Fallback strategy: assistant locale -> assistant short codes -> global locale -> global short codes func Translate(assistantID string, locale string, input any) any { locale = strings.ToLower(strings.TrimSpace(locale)) // Helper function to try translation with a specific i18n object tryTranslate := func(i18n I18n, input any) (any, bool) { result := i18n.Parse(input) // For string input, check if translation was found by comparing with input // For other types, Parse always returns a result (transformed or original) if inputStr, ok := input.(string); ok { if resultStr, ok := result.(string); ok { // Translation found if result is different from input if resultStr != inputStr { return result, true } return input, false } } // For non-string inputs (maps, slices), Parse always processes them return result, true } // Helper function to process recursive templates processTemplates := func(result any, assistantID string, locale string) any { if resultStr, ok := result.(string); ok && strings.Contains(resultStr, "{{") && strings.Contains(resultStr, "}}") { re := regexp.MustCompile(`\{\{\s*([^}]+?)\s*\}\}`) resultStr = re.ReplaceAllStringFunc(resultStr, func(match string) string { varName := strings.TrimSpace(strings.TrimPrefix(strings.TrimSuffix(match, "}}"), "{{")) translated := Translate(assistantID, locale, varName) if translatedStr, ok := translated.(string); ok && translatedStr != varName { return translatedStr } return match }) return resultStr } return result } // Try assistant locale first if i18ns, has := Locales[assistantID]; has { // Try exact locale if i18n, hasLocale := i18ns[locale]; hasLocale { if result, found := tryTranslate(i18n, input); found { return processTemplates(result, assistantID, locale) } } // Try short codes parts := strings.Split(locale, "-") if len(parts) > 1 { if i18n, hasLocale := i18ns[parts[1]]; hasLocale { if result, found := tryTranslate(i18n, input); found { return processTemplates(result, assistantID, locale) } } if i18n, hasLocale := i18ns[parts[0]]; hasLocale { if result, found := tryTranslate(i18n, input); found { return processTemplates(result, assistantID, locale) } } } } // Fallback to global locales if globalI18ns, hasGlobal := Locales["__global__"]; hasGlobal { // Try exact locale if i18n, hasLocale := globalI18ns[locale]; hasLocale { if result, found := tryTranslate(i18n, input); found { return processTemplates(result, assistantID, locale) } } // Try short codes parts := strings.Split(locale, "-") if len(parts) > 1 { if i18n, hasLocale := globalI18ns[parts[1]]; hasLocale { if result, found := tryTranslate(i18n, input); found { return processTemplates(result, assistantID, locale) } } if i18n, hasLocale := globalI18ns[parts[0]]; hasLocale { if result, found := tryTranslate(i18n, input); found { return processTemplates(result, assistantID, locale) } } } } return input } // TranslateGlobal translate the input with global i18n func TranslateGlobal(locale string, input any) any { locale = strings.ToLower(strings.TrimSpace(locale)) i18ns, has := Locales["__global__"] if !has { i18ns = map[string]I18n{} } // Try the exact locale first i18n, has := i18ns[locale] if has { result := i18n.Parse(input) // If the result is the same as input (not translated), try fallback if result != input { return result } } // Fallback logic: for "en-us", try "en" parts := strings.Split(locale, "-") if len(parts) > 1 { // Try the language code (e.g., "en" for "en-us") if fallbackI18n, hasFallback := i18ns[parts[0]]; hasFallback { result := fallbackI18n.Parse(input) if result != input { return result } } // Try the country code (e.g., "us" for "en-us") if fallbackI18n, hasFallback := i18ns[parts[1]]; hasFallback { result := fallbackI18n.Parse(input) if result != input { return result } } } return input } // T is a short alias for TranslateGlobal that returns string // Usage: i18n.T(ctx.Locale, "assistant.agent.stream.label") // Variables in templates like {{variable}} will be recursively resolved from the global language pack func T(locale string, key string) string { result := TranslateGlobal(locale, key) if str, ok := result.(string); ok { return str } return key } // Tr translates with assistantID and returns string // Supports recursive translation of {{variable}} templates // Usage: i18n.Tr(assistantID, locale, "key") func Tr(assistantID string, locale string, key string) string { result := Translate(assistantID, locale, key) if str, ok := result.(string); ok { return str } return key } ================================================ FILE: agent/i18n/i18n_test.go ================================================ package i18n import ( "testing" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // TestParseString tests the parseString method func TestParseString(t *testing.T) { i18n := I18n{ Locale: "en", Messages: map[string]any{ "hello": "Hello", "world": "World", "greeting": "Hello, World!", "description": "This is a test", }, } tests := []struct { name string input string expected string }{ { name: "Template expression with match", input: "{{greeting}}", expected: "Hello, World!", }, { name: "Template expression with spaces", input: "{{ greeting }}", expected: "Hello, World!", }, { name: "Template expression without match", input: "{{notfound}}", expected: "{{notfound}}", }, { name: "Direct message key", input: "hello", expected: "Hello", }, { name: "Direct message key with spaces", input: " world ", expected: "World", }, { name: "Non-existent key", input: "notfound", expected: "notfound", }, { name: "Regular text", input: "Just some text", expected: "Just some text", }, { name: "Empty string", input: "", expected: "", }, // Embedded template tests (new feature) { name: "Embedded single template", input: "Hello {{hello}}", expected: "Hello Hello", }, { name: "Embedded multiple templates", input: "{{hello}} {{world}}!", expected: "Hello World!", }, { name: "Embedded template with spaces", input: "Say {{ hello }} to the {{ world }}", expected: "Say Hello to the World", }, { name: "Embedded template mixed with text", input: "Message: {{greeting}} - {{description}}", expected: "Message: Hello, World! - This is a test", }, { name: "Embedded template not found", input: "Hello {{notfound}} World", expected: "Hello {{notfound}} World", }, { name: "Embedded template partial match", input: "{{hello}} {{notfound}} {{world}}", expected: "Hello {{notfound}} World", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := i18n.parseString(tt.input) if result != tt.expected { t.Errorf("parseString(%q) = %q, want %q", tt.input, result, tt.expected) } }) } } // TestParseStringNonStringValue tests parseString when message value is not a string func TestParseStringNonStringValue(t *testing.T) { i18n := I18n{ Locale: "en", Messages: map[string]any{ "number": 123, "object": map[string]any{"key": "value"}, }, } tests := []struct { name string input string expected string }{ { name: "Template with number value", input: "{{number}}", expected: "{{number}}", }, { name: "Direct key with number value", input: "number", expected: "number", }, { name: "Template with object value", input: "{{object}}", expected: "{{object}}", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := i18n.parseString(tt.input) if result != tt.expected { t.Errorf("parseString(%q) = %q, want %q", tt.input, result, tt.expected) } }) } } // TestParse tests the Parse method with various input types func TestParse(t *testing.T) { i18n := I18n{ Locale: "en", Messages: map[string]any{ "name": "John", "description": "A developer", "title": "Welcome", }, } t.Run("Nil input", func(t *testing.T) { result := i18n.Parse(nil) if result != nil { t.Errorf("Parse(nil) = %v, want nil", result) } }) t.Run("String input", func(t *testing.T) { result := i18n.Parse("{{name}}") if result != "John" { t.Errorf("Parse({{name}}) = %v, want 'John'", result) } }) t.Run("Map input", func(t *testing.T) { input := map[string]any{ "name": "{{name}}", "description": "{{description}}", "age": 30, } result := i18n.Parse(input) if resultMap, ok := result.(map[string]any); ok { if resultMap["name"] != "John" { t.Errorf("Expected name 'John', got %v", resultMap["name"]) } if resultMap["description"] != "A developer" { t.Errorf("Expected description 'A developer', got %v", resultMap["description"]) } if resultMap["age"] != 30 { t.Errorf("Expected age 30, got %v", resultMap["age"]) } } else { t.Errorf("Expected map[string]any, got %T", result) } }) t.Run("Slice of any", func(t *testing.T) { input := []any{"{{name}}", "{{description}}", 123} result := i18n.Parse(input) if resultSlice, ok := result.([]any); ok { if len(resultSlice) != 3 { t.Errorf("Expected 3 elements, got %d", len(resultSlice)) } if resultSlice[0] != "John" { t.Errorf("Expected 'John', got %v", resultSlice[0]) } if resultSlice[1] != "A developer" { t.Errorf("Expected 'A developer', got %v", resultSlice[1]) } if resultSlice[2] != 123 { t.Errorf("Expected 123, got %v", resultSlice[2]) } } else { t.Errorf("Expected []any, got %T", result) } }) t.Run("Slice of strings", func(t *testing.T) { input := []string{"{{name}}", "{{description}}", "plain text"} result := i18n.Parse(input) if resultSlice, ok := result.([]string); ok { if len(resultSlice) != 3 { t.Errorf("Expected 3 elements, got %d", len(resultSlice)) } if resultSlice[0] != "John" { t.Errorf("Expected 'John', got %v", resultSlice[0]) } if resultSlice[1] != "A developer" { t.Errorf("Expected 'A developer', got %v", resultSlice[1]) } if resultSlice[2] != "plain text" { t.Errorf("Expected 'plain text', got %v", resultSlice[2]) } } else { t.Errorf("Expected []string, got %T", result) } }) t.Run("Nested structures", func(t *testing.T) { input := map[string]any{ "user": map[string]any{ "name": "{{name}}", "info": []string{"{{title}}", "{{description}}"}, }, } result := i18n.Parse(input) if resultMap, ok := result.(map[string]any); ok { if userMap, ok := resultMap["user"].(map[string]any); ok { if userMap["name"] != "John" { t.Errorf("Expected nested name 'John', got %v", userMap["name"]) } if infoSlice, ok := userMap["info"].([]any); ok { if infoSlice[0] != "Welcome" { t.Errorf("Expected 'Welcome', got %v", infoSlice[0]) } } } } }) t.Run("Other types pass through", func(t *testing.T) { input := 12345 result := i18n.Parse(input) if result != input { t.Errorf("Expected %v, got %v", input, result) } }) } // TestParseSliceStringWithNilAndNonString tests []string parsing edge cases func TestParseSliceStringWithNilAndNonString(t *testing.T) { i18n := I18n{ Locale: "en", Messages: map[string]any{ "key1": "value1", "key2": 123, // Non-string value "key3": nil, // Nil value }, } t.Run("String slice with fallback", func(t *testing.T) { input := []string{"{{key1}}", "{{key2}}", "{{notfound}}"} result := i18n.Parse(input) if resultSlice, ok := result.([]string); ok { if resultSlice[0] != "value1" { t.Errorf("Expected 'value1', got %v", resultSlice[0]) } // key2 has non-string value, should fallback to original if resultSlice[1] != "{{key2}}" { t.Errorf("Expected '{{key2}}', got %v", resultSlice[1]) } if resultSlice[2] != "{{notfound}}" { t.Errorf("Expected '{{notfound}}', got %v", resultSlice[2]) } } else { t.Errorf("Expected []string, got %T", result) } }) t.Run("String slice with nil parsed result", func(t *testing.T) { // This tests the case where Parse returns nil for a string input := []string{"{{key3}}", "normal"} result := i18n.Parse(input) if resultSlice, ok := result.([]string); ok { // When parsed is nil, should fallback to original if resultSlice[0] != "{{key3}}" { t.Errorf("Expected '{{key3}}' (tests nil parsed branch), got %v", resultSlice[0]) } if resultSlice[1] != "normal" { t.Errorf("Expected 'normal', got %v", resultSlice[1]) } } else { t.Errorf("Expected []string, got %T", result) } }) t.Run("String slice with non-string parsed result from map", func(t *testing.T) { i18nWithMap := I18n{ Locale: "en", Messages: map[string]any{ "map_key": map[string]any{"nested": "value"}, }, } // When Parse returns a non-string type (like a map), should fallback input := []string{"{{map_key}}", "text"} result := i18nWithMap.Parse(input) if resultSlice, ok := result.([]string); ok { // Should fallback to original when parsed is not string if resultSlice[0] != "{{map_key}}" { t.Errorf("Expected '{{map_key}}' (tests non-string parsed branch), got %v", resultSlice[0]) } } else { t.Errorf("Expected []string, got %T", result) } }) } // TestMapFlatten tests the Flatten method func TestMapFlatten(t *testing.T) { i18ns := Map{ "en-us": I18n{ Locale: "en-us", Messages: map[string]any{ "greeting": "Hello", }, }, "zh-cn": I18n{ Locale: "zh-cn", Messages: map[string]any{ "greeting": "你好", }, }, } flattened := i18ns.Flatten() // Should have original keys if _, ok := flattened["en-us"]; !ok { t.Error("Expected 'en-us' key in flattened map") } if _, ok := flattened["zh-cn"]; !ok { t.Error("Expected 'zh-cn' key in flattened map") } // Should have short lang codes if _, ok := flattened["en"]; !ok { t.Error("Expected 'en' short code in flattened map") } if _, ok := flattened["us"]; !ok { t.Error("Expected 'us' region code in flattened map") } if _, ok := flattened["zh"]; !ok { t.Error("Expected 'zh' short code in flattened map") } if _, ok := flattened["cn"]; !ok { t.Error("Expected 'cn' region code in flattened map") } // Verify messages are preserved if msg, ok := flattened["en"].Messages["greeting"].(string); !ok || msg != "Hello" { t.Errorf("Expected 'Hello', got %v", flattened["en"].Messages["greeting"]) } if msg, ok := flattened["zh"].Messages["greeting"].(string); !ok || msg != "你好" { t.Errorf("Expected '你好', got %v", flattened["zh"].Messages["greeting"]) } } // TestMapFlattenWithGlobal tests the FlattenWithGlobal method func TestMapFlattenWithGlobal(t *testing.T) { // Save and restore __global__ originalGlobal := Locales["__global__"] defer func() { Locales["__global__"] = originalGlobal }() // Setup global locales Locales["__global__"] = map[string]I18n{ "en": { Locale: "en", Messages: map[string]any{ "global.key": "Global Value", "common": "Common", }, }, } i18ns := Map{ "en": I18n{ Locale: "en", Messages: map[string]any{ "local.key": "Local Value", "common": "Local Common", // Should override global }, }, } flattened := i18ns.FlattenWithGlobal() if _, ok := flattened["en"]; !ok { t.Fatal("Expected 'en' key in flattened map") } // Should have local key if val, ok := flattened["en"].Messages["local.key"].(string); !ok || val != "Local Value" { t.Errorf("Expected 'Local Value', got %v", flattened["en"].Messages["local.key"]) } // Should have global key if val, ok := flattened["en"].Messages["global.key"].(string); !ok || val != "Global Value" { t.Errorf("Expected 'Global Value', got %v", flattened["en"].Messages["global.key"]) } // Local should override global if val, ok := flattened["en"].Messages["common"].(string); !ok || val != "Local Common" { t.Errorf("Expected 'Local Common', got %v", flattened["en"].Messages["common"]) } } // TestMapFlattenWithGlobalNoGlobal tests FlattenWithGlobal when no global exists func TestMapFlattenWithGlobalNoGlobal(t *testing.T) { // Save and restore __global__ originalGlobal := Locales["__global__"] defer func() { Locales["__global__"] = originalGlobal }() // Make sure no global exists delete(Locales, "__global__") i18ns := Map{ "en": I18n{ Locale: "en", Messages: map[string]any{ "key": "value", }, }, } flattened := i18ns.FlattenWithGlobal() if _, ok := flattened["en"]; !ok { t.Fatal("Expected 'en' key in flattened map") } if val, ok := flattened["en"].Messages["key"].(string); !ok || val != "value" { t.Errorf("Expected 'value', got %v", flattened["en"].Messages["key"]) } } // TestMapFlattenWithGlobalKeyConflict tests FlattenWithGlobal when local keys already exist func TestMapFlattenWithGlobalKeyConflict(t *testing.T) { // Save and restore __global__ originalGlobal := Locales["__global__"] defer func() { Locales["__global__"] = originalGlobal }() // Setup global with keys in flat format (after Dot()) Locales["__global__"] = map[string]I18n{ "en": { Locale: "en", Messages: map[string]any{ "shared.key": "Global Shared", "global.only": "Global Only", "local.key": "Global Local", // Will be overridden }, }, } // Local messages in nested format (will be flattened by Dot()) i18ns := Map{ "en": I18n{ Locale: "en", Messages: map[string]any{ "local": map[string]any{ "key": "Local Value", // After Dot() becomes "local.key", should override global }, "unique": map[string]any{ "key": "Local Unique", }, }, }, } flattened := i18ns.FlattenWithGlobal() if _, ok := flattened["en"]; !ok { t.Fatal("Expected 'en' key in flattened map") } // Local key should exist and NOT be overridden by global if val, ok := flattened["en"].Messages["local.key"].(string); !ok || val != "Local Value" { t.Errorf("Expected 'Local Value' (local should override global), got %v", flattened["en"].Messages["local.key"]) } // Global only key should exist if val, ok := flattened["en"].Messages["global.only"].(string); !ok || val != "Global Only" { t.Errorf("Expected 'Global Only', got %v", flattened["en"].Messages["global.only"]) } // Unique local key should exist if val, ok := flattened["en"].Messages["unique.key"].(string); !ok || val != "Local Unique" { t.Errorf("Expected 'Local Unique', got %v", flattened["en"].Messages["unique.key"]) } // Shared key from global should exist if val, ok := flattened["en"].Messages["shared.key"].(string); !ok || val != "Global Shared" { t.Errorf("Expected 'Global Shared', got %v", flattened["en"].Messages["shared.key"]) } } // TestTranslate tests the Translate function func TestTranslate(t *testing.T) { assistantID := "test-assistant" Locales[assistantID] = map[string]I18n{ "en": { Locale: "en", Messages: map[string]any{ "greeting": "Hello", "name": "John", }, }, "zh-cn": { Locale: "zh-cn", Messages: map[string]any{ "greeting": "你好", "name": "张三", }, }, } defer delete(Locales, assistantID) t.Run("Translate with exact locale match", func(t *testing.T) { result := Translate(assistantID, "en", "{{greeting}}") if result != "Hello" { t.Errorf("Expected 'Hello', got %v", result) } }) t.Run("Translate with locale variant", func(t *testing.T) { result := Translate(assistantID, "zh-CN", "{{greeting}}") if result != "你好" { t.Errorf("Expected '你好', got %v", result) } }) t.Run("Translate with short locale code", func(t *testing.T) { result := Translate(assistantID, "en-us", "{{name}}") if result != "John" { t.Errorf("Expected 'John', got %v", result) } }) t.Run("Translate without locale match", func(t *testing.T) { result := Translate(assistantID, "fr", "{{greeting}}") // Should return original when no locale found if result != "{{greeting}}" { t.Errorf("Expected '{{greeting}}', got %v", result) } }) t.Run("Translate non-existent assistant", func(t *testing.T) { result := Translate("nonexistent", "en", "{{greeting}}") if result != "{{greeting}}" { t.Errorf("Expected '{{greeting}}', got %v", result) } }) t.Run("Translate with fallback to global", func(t *testing.T) { // Save and restore __global__ originalGlobal := Locales["__global__"] defer func() { Locales["__global__"] = originalGlobal }() Locales["__global__"] = map[string]I18n{ "es": { Locale: "es", Messages: map[string]any{ "greeting": "Hola", }, }, } result := Translate(assistantID, "es", "{{greeting}}") if result != "Hola" { t.Errorf("Expected 'Hola', got %v", result) } }) t.Run("Translate complex structure", func(t *testing.T) { input := map[string]any{ "title": "{{greeting}}", "user": "{{name}}", } result := Translate(assistantID, "zh-cn", input) if resultMap, ok := result.(map[string]any); ok { if resultMap["title"] != "你好" { t.Errorf("Expected '你好', got %v", resultMap["title"]) } if resultMap["user"] != "张三" { t.Errorf("Expected '张三', got %v", resultMap["user"]) } } else { t.Errorf("Expected map[string]any, got %T", result) } }) } // TestTranslateGlobal tests the TranslateGlobal function with custom messages func TestTranslateGlobal(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Save existing __global__ and restore after test originalGlobal := make(map[string]I18n) if existing, ok := Locales["__global__"]; ok { for k, v := range existing { originalGlobal[k] = v } } defer func() { Locales["__global__"] = originalGlobal }() // Add custom test messages to existing global (not replacing) if Locales["__global__"] == nil { Locales["__global__"] = make(map[string]I18n) } // Extend existing English messages enMessages := make(map[string]any) if existing, ok := Locales["__global__"]["en"]; ok { for k, v := range existing.Messages { enMessages[k] = v } } enMessages["button.ok"] = "OK" enMessages["button.cancel"] = "Cancel" Locales["__global__"]["en"] = I18n{ Locale: "en", Messages: enMessages, } // Extend existing Chinese messages zhcnMessages := make(map[string]any) if existing, ok := Locales["__global__"]["zh-cn"]; ok { for k, v := range existing.Messages { zhcnMessages[k] = v } } zhcnMessages["button.ok"] = "确定" zhcnMessages["button.cancel"] = "取消" Locales["__global__"]["zh-cn"] = I18n{ Locale: "zh-cn", Messages: zhcnMessages, } // Extend existing Chinese short code messages zhMessages := make(map[string]any) if existing, ok := Locales["__global__"]["zh"]; ok { for k, v := range existing.Messages { zhMessages[k] = v } } zhMessages["button.ok"] = "确定" zhMessages["button.cancel"] = "取消" Locales["__global__"]["zh"] = I18n{ Locale: "zh", Messages: zhMessages, } t.Run("TranslateGlobal with match", func(t *testing.T) { result := TranslateGlobal("en", "{{button.ok}}") if result != "OK" { t.Errorf("Expected 'OK', got %v", result) } }) t.Run("TranslateGlobal with Chinese", func(t *testing.T) { result := TranslateGlobal("zh-cn", "{{button.cancel}}") if result != "取消" { t.Errorf("Expected '取消', got %v", result) } }) t.Run("TranslateGlobal with short code", func(t *testing.T) { result := TranslateGlobal("zh-TW", "{{button.ok}}") if result != "确定" { t.Errorf("Expected '确定', got %v", result) } }) t.Run("TranslateGlobal without match", func(t *testing.T) { result := TranslateGlobal("fr", "{{button.ok}}") if result != "{{button.ok}}" { t.Errorf("Expected '{{button.ok}}', got %v", result) } }) t.Run("TranslateGlobal no global", func(t *testing.T) { // Temporarily remove global temp := Locales["__global__"] delete(Locales, "__global__") result := TranslateGlobal("en", "{{button.ok}}") if result != "{{button.ok}}" { t.Errorf("Expected '{{button.ok}}', got %v", result) } // Restore Locales["__global__"] = temp }) t.Run("TranslateGlobal fallback from en-us to en", func(t *testing.T) { // Create a scenario where en-us has limited messages, but en has more // This simulates the real-world case: en-us locale exists with 3 messages, // but en has 41 messages including llm.handlers.stream.info // Create en-us with only a few messages enUSMessages := map[string]any{ "button.ok": "OK (US)", // en-us specific "app.name": "My App", "app.version": "1.0", } Locales["__global__"]["en-us"] = I18n{ Locale: "en-us", Messages: enUSMessages, } // Test 1: Key exists in en-us - should use en-us result := TranslateGlobal("en-us", "button.ok") if result != "OK (US)" { t.Errorf("Expected 'OK (US)' from en-us, got %v", result) } // Test 2: Key does NOT exist in en-us but exists in en - should fallback to en result = TranslateGlobal("en-us", "button.cancel") if result != "Cancel" { t.Errorf("Expected 'Cancel' (fallback from en-us to en), got %v", result) } // Test 3: Built-in key that exists in en but not in en-us result = TranslateGlobal("en-us", "llm.handlers.stream.info") expected := "LLM Stream" if result != expected { t.Errorf("Expected '%s' (fallback from en-us to en), got %v", expected, result) } // Test 4: Direct key (not template) also should fallback result = TranslateGlobal("en-us", "{{llm.handlers.stream.info}}") if result != expected { t.Errorf("Expected '%s' (fallback from en-us to en with template), got %v", expected, result) } }) } // TestGetLocalesIntegration tests GetLocales with real assistant data func TestGetLocalesIntegration(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Use the real mohe assistant path (relative to app root) assistantPath := "/assistants/mohe" t.Run("Load real locale files", func(t *testing.T) { locales, err := GetLocales(assistantPath) if err != nil { t.Skipf("Skipping: %v", err) return } // Should have at least 2 locales (en-us and zh-cn) if len(locales) < 2 { t.Errorf("Expected at least 2 locales, got %d", len(locales)) } // Check en-us locale if enUS, ok := locales["en-us"]; ok { if enUS.Locale != "en-us" { t.Errorf("Expected locale 'en-us', got %s", enUS.Locale) } // Check some messages if desc, ok := enUS.Messages["description"].(string); ok { if desc == "" { t.Error("Expected non-empty description") } t.Logf("English description: %s", desc) } if chat, ok := enUS.Messages["chat"].(map[string]interface{}); ok { if title, ok := chat["title"].(string); ok { if title != "New Chat" { t.Errorf("Expected 'New Chat', got %s", title) } } } } else { t.Error("Expected 'en-us' locale") } // Check zh-cn locale if zhCN, ok := locales["zh-cn"]; ok { if zhCN.Locale != "zh-cn" { t.Errorf("Expected locale 'zh-cn', got %s", zhCN.Locale) } // Check some messages if desc, ok := zhCN.Messages["description"].(string); ok { if desc == "" { t.Error("Expected non-empty description") } t.Logf("Chinese description: %s", desc) } if chat, ok := zhCN.Messages["chat"].(map[string]interface{}); ok { if title, ok := chat["title"].(string); ok { if title != "新对话" { t.Errorf("Expected '新对话', got %s", title) } } } } else { t.Error("Expected 'zh-cn' locale") } t.Logf("Loaded %d locales successfully", len(locales)) }) t.Run("Flatten loaded locales", func(t *testing.T) { locales, err := GetLocales(assistantPath) if err != nil { t.Skipf("Skipping: %v", err) return } flattened := locales.Flatten() // Should have short codes if _, ok := flattened["en"]; !ok { t.Error("Expected 'en' short code after flatten") } if _, ok := flattened["zh"]; !ok { t.Error("Expected 'zh' short code after flatten") } if _, ok := flattened["us"]; !ok { t.Error("Expected 'us' region code after flatten") } if _, ok := flattened["cn"]; !ok { t.Error("Expected 'cn' region code after flatten") } // Verify flattened messages structure if en, ok := flattened["en"]; ok { if _, ok := en.Messages["chat.title"]; !ok { t.Error("Expected flattened 'chat.title' key") } if _, ok := en.Messages["chat.description"]; !ok { t.Error("Expected flattened 'chat.description' key") } if _, ok := en.Messages["chat.prompts.0"]; !ok { t.Error("Expected flattened 'chat.prompts.0' key") } } t.Logf("Flattened to %d locale codes", len(flattened)) }) } // TestEdgeCases tests various edge cases func TestEdgeCases(t *testing.T) { t.Run("Empty Messages map", func(t *testing.T) { i18n := I18n{ Locale: "en", Messages: map[string]any{}, } result := i18n.Parse("{{key}}") if result != "{{key}}" { t.Errorf("Expected '{{key}}', got %v", result) } }) t.Run("Nil Messages map", func(t *testing.T) { i18n := I18n{ Locale: "en", Messages: nil, } result := i18n.Parse("{{key}}") if result != "{{key}}" { t.Errorf("Expected '{{key}}', got %v", result) } }) t.Run("Empty locale string", func(t *testing.T) { Locales["test"] = map[string]I18n{ "en": { Locale: "en", Messages: map[string]any{"key": "value"}, }, } defer delete(Locales, "test") result := Translate("test", "", "{{key}}") // Should still work with empty string after trim if result != "{{key}}" { t.Logf("Result: %v", result) } }) t.Run("Locale with only spaces", func(t *testing.T) { Locales["test"] = map[string]I18n{ "": { Locale: "", Messages: map[string]any{"key": "value"}, }, } defer delete(Locales, "test") result := Translate("test", " ", "{{key}}") if result != "value" { t.Errorf("Expected 'value', got %v", result) } }) } // TestBuiltinMessages tests the built-in global messages func TestBuiltinMessages(t *testing.T) { // Save and restore __global__ to avoid test interference originalGlobal := make(map[string]I18n) if existing, ok := Locales["__global__"]; ok { for k, v := range existing { originalGlobal[k] = v } } defer func() { Locales["__global__"] = originalGlobal }() t.Run("English built-in messages", func(t *testing.T) { // Test assistant messages // Updated: label now only shows {{name}} without "Assistant" prefix result := TranslateGlobal("en", "{{assistant.agent.stream.label}}") expected := "{{name}}" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } result = TranslateGlobal("en", "{{assistant.agent.stream.description}}") expected = "{{name}} is processing the request" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } result = TranslateGlobal("en", "{{assistant.agent.stream.history}}") expected = "Get Chat History" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } // Test LLM messages (note: LLM uses %s for fmt.Sprintf, not {{name}} for recursive translation) result = TranslateGlobal("en", "{{llm.openai.stream.label}}") expected = "LLM %s" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } result = TranslateGlobal("en", "{{llm.handlers.stream.info}}") expected = "LLM Stream" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } // Test common messages result = TranslateGlobal("en", "{{common.status.processing}}") expected = "Processing" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } }) t.Run("Chinese (zh-cn) built-in messages", func(t *testing.T) { // Test assistant messages // Updated: label now only shows {{name}} without "助手" prefix result := TranslateGlobal("zh-cn", "{{assistant.agent.stream.label}}") expected := "{{name}}" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } result = TranslateGlobal("zh-cn", "{{assistant.agent.stream.description}}") expected = "{{name}} 正在处理请求" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } result = TranslateGlobal("zh-cn", "{{assistant.agent.stream.history}}") expected = "获取聊天历史" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } // Test LLM messages result = TranslateGlobal("zh-cn", "{{llm.handlers.stream.info}}") expected = "LLM 流式输出" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } // Test common messages result = TranslateGlobal("zh-cn", "{{common.status.processing}}") expected = "处理中" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } }) t.Run("Chinese (zh) short code", func(t *testing.T) { // Updated: label now only shows {{name}} without "助手" prefix result := TranslateGlobal("zh", "{{assistant.agent.stream.label}}") expected := "{{name}}" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } result = TranslateGlobal("zh", "{{common.status.processing}}") expected = "处理中" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } }) t.Run("Embedded template with built-in messages", func(t *testing.T) { // English result := TranslateGlobal("en", "Status: {{common.status.processing}}") expected := "Status: Processing" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } // Chinese result = TranslateGlobal("zh-cn", "状态: {{common.status.processing}}") expected = "状态: 处理中" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } }) t.Run("Non-existent key in global", func(t *testing.T) { result := TranslateGlobal("en", "{{unknown.key}}") if result != "{{unknown.key}}" { t.Errorf("Expected '{{unknown.key}}', got '%v'", result) } }) } // TestTAlias tests the T function alias func TestTr(t *testing.T) { // Save original global locales originalGlobal := Locales["__global__"] defer func() { if originalGlobal != nil { Locales["__global__"] = originalGlobal } else { delete(Locales, "__global__") } }() // Setup test locales with nested templates Locales["__global__"] = map[string]I18n{ "en": { Locale: "en", Messages: map[string]any{ "assistant.label": "Assistant {{assistant.name}}", // Use full key path "assistant.name": "AI Helper", "assistant.description": "{{assistant.label}} is processing", "llm.label": "LLM {{model.deepseek}}", // Use full key path "model.deepseek": "DeepSeek", "deeply.nested": "Level1 {{level2}}", "level2": "Level2 {{level3}}", "level3": "Level3 End", "simple.message": "Hello World", }, }, "zh-cn": { Locale: "zh-cn", Messages: map[string]any{ "assistant.label": "助手 {{assistant.name}}", // Use full key path "assistant.name": "智能助手", "assistant.description": "{{assistant.label}} 正在处理", "llm.label": "模型 {{model.deepseek}}", // Use full key path "model.deepseek": "深度求索", "deeply.nested": "第一层 {{level2}}", "level2": "第二层 {{level3}}", "level3": "第三层结束", "simple.message": "你好世界", }, }, } // Setup assistant-specific locale (overrides assistant.name, but inherits assistant.label from global) Locales["test-assistant"] = map[string]I18n{ "en": { Locale: "en", Messages: map[string]any{ "assistant.name": "Custom Assistant", // This will override global when assistant.label is resolved }, }, } defer delete(Locales, "test-assistant") t.Run("Simple translation without variables", func(t *testing.T) { result := Tr("__global__", "en", "simple.message") if result != "Hello World" { t.Errorf("Expected 'Hello World', got '%s'", result) } result = Tr("__global__", "zh-cn", "simple.message") if result != "你好世界" { t.Errorf("Expected '你好世界', got '%s'", result) } }) t.Run("One level nested variable", func(t *testing.T) { // "Assistant {{name}}" -> "Assistant AI Helper" result := Tr("__global__", "en", "assistant.label") if result != "Assistant AI Helper" { t.Errorf("Expected 'Assistant AI Helper', got '%s'", result) } result = Tr("__global__", "zh-cn", "assistant.label") if result != "助手 智能助手" { t.Errorf("Expected '助手 智能助手', got '%s'", result) } }) t.Run("Two levels nested variables", func(t *testing.T) { // "{{assistant.label}} is processing" -> "Assistant AI Helper is processing" result := Tr("__global__", "en", "assistant.description") if result != "Assistant AI Helper is processing" { t.Errorf("Expected 'Assistant AI Helper is processing', got '%s'", result) } result = Tr("__global__", "zh-cn", "assistant.description") if result != "助手 智能助手 正在处理" { t.Errorf("Expected '助手 智能助手 正在处理', got '%s'", result) } }) t.Run("Three levels deeply nested", func(t *testing.T) { // "Level1 {{level2}}" -> "Level1 Level2 {{level3}}" -> "Level1 Level2 Level3 End" result := Tr("__global__", "en", "deeply.nested") if result != "Level1 Level2 Level3 End" { t.Errorf("Expected 'Level1 Level2 Level3 End', got '%s'", result) } result = Tr("__global__", "zh-cn", "deeply.nested") if result != "第一层 第二层 第三层结束" { t.Errorf("Expected '第一层 第二层 第三层结束', got '%s'", result) } }) t.Run("Assistant-specific override", func(t *testing.T) { // When assistant locale exists but doesn't have a key, it WILL fallback to global // This is key-level fallback: try assistant first, then fallback to global result := Tr("test-assistant", "en", "assistant.label") // "Assistant {{assistant.name}}" from global, then {{assistant.name}} -> "Custom Assistant" from assistant if result != "Assistant Custom Assistant" { t.Errorf("Expected 'Assistant Custom Assistant' (fallback to global with assistant override), got '%s'", result) } // assistant has 'en' locale but doesn't have this key, fallback to global result = Tr("test-assistant", "en", "simple.message") if result != "Hello World" { t.Errorf("Expected 'Hello World' (fallback to global), got '%s'", result) } // If assistant locale has the key, it will use assistant's value result = Tr("test-assistant", "en", "assistant.name") if result != "Custom Assistant" { t.Errorf("Expected 'Custom Assistant' (from assistant locale), got '%s'", result) } }) t.Run("Non-existent key returns original", func(t *testing.T) { result := Tr("__global__", "en", "non.existent.key") if result != "non.existent.key" { t.Errorf("Expected 'non.existent.key', got '%s'", result) } }) t.Run("LLM with model variable", func(t *testing.T) { result := Tr("__global__", "en", "llm.label") if result != "LLM DeepSeek" { t.Errorf("Expected 'LLM DeepSeek', got '%s'", result) } result = Tr("__global__", "zh-cn", "llm.label") if result != "模型 深度求索" { t.Errorf("Expected '模型 深度求索', got '%s'", result) } }) } func TestTAlias(t *testing.T) { // Save and restore __global__ to avoid test interference originalGlobal := make(map[string]I18n) if existing, ok := Locales["__global__"]; ok { for k, v := range existing { originalGlobal[k] = v } } defer func() { Locales["__global__"] = originalGlobal }() t.Run("T alias works like TranslateGlobal", func(t *testing.T) { // Test that T and TranslateGlobal return the same results input := "{{assistant.agent.stream.label}}" resultT := T("en", input) resultGlobal := TranslateGlobal("en", input) if resultT != resultGlobal { t.Errorf("T and TranslateGlobal should return same result. T: %v, TranslateGlobal: %v", resultT, resultGlobal) } // Updated: label now only shows {{name}} without "Assistant" prefix expected := "{{name}}" if resultT != expected { t.Errorf("Expected '%s', got '%v'", expected, resultT) } }) t.Run("T alias with Chinese", func(t *testing.T) { result := T("zh-cn", "{{assistant.agent.stream.history}}") expected := "获取聊天历史" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } }) t.Run("T alias with embedded template", func(t *testing.T) { result := T("en", "Status: {{common.status.completed}}") expected := "Status: Completed" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } }) t.Run("T with nested template (template in template value)", func(t *testing.T) { // assistant.agent.stream.label = "{{name}}" (contains {{name}} template) // Updated: label now only shows {{name}} without prefix // This tests if we can get the template string itself result := T("en", "{{assistant.agent.stream.label}}") expected := "{{name}}" if result != expected { t.Errorf("Expected '%s', got '%v'", expected, result) } // Verify Chinese version too resultZh := T("zh-cn", "{{assistant.agent.stream.label}}") expectedZh := "{{name}}" if resultZh != expectedZh { t.Errorf("Expected '%s', got '%v'", expectedZh, resultZh) } }) } ================================================ FILE: agent/llm/adapters/adapter.go ================================================ package adapters import ( "github.com/yaoapp/yao/agent/context" ) // CapabilityAdapter is the interface for capability-specific message and response processing // Each adapter handles one capability dimension (tool calls, vision, audio, reasoning, etc.) type CapabilityAdapter interface { // Name returns the adapter name for debugging Name() string // PreprocessMessages preprocesses messages before sending to LLM // Returns modified messages or error PreprocessMessages(messages []context.Message) ([]context.Message, error) // PreprocessOptions preprocesses completion options before sending to LLM // Returns modified options or error PreprocessOptions(options *context.CompletionOptions) (*context.CompletionOptions, error) // PostprocessResponse postprocesses the LLM response // Returns modified response or error PostprocessResponse(response *context.CompletionResponse) (*context.CompletionResponse, error) // ProcessStreamChunk processes a streaming chunk // Returns modified chunk type and data, or error ProcessStreamChunk(chunkType context.StreamChunkType, data []byte) (context.StreamChunkType, []byte, error) } // BaseAdapter provides default implementations for CapabilityAdapter // Adapters can embed this and override only the methods they need type BaseAdapter struct { name string } // NewBaseAdapter creates a new base adapter func NewBaseAdapter(name string) *BaseAdapter { return &BaseAdapter{name: name} } // Name returns the adapter name func (a *BaseAdapter) Name() string { return a.name } // PreprocessMessages default implementation (no-op) func (a *BaseAdapter) PreprocessMessages(messages []context.Message) ([]context.Message, error) { return messages, nil } // PreprocessOptions default implementation (no-op) func (a *BaseAdapter) PreprocessOptions(options *context.CompletionOptions) (*context.CompletionOptions, error) { return options, nil } // PostprocessResponse default implementation (no-op) func (a *BaseAdapter) PostprocessResponse(response *context.CompletionResponse) (*context.CompletionResponse, error) { return response, nil } // ProcessStreamChunk default implementation (pass through) func (a *BaseAdapter) ProcessStreamChunk(chunkType context.StreamChunkType, data []byte) (context.StreamChunkType, []byte, error) { return chunkType, data, nil } ================================================ FILE: agent/llm/adapters/audio.go ================================================ package adapters import ( "github.com/yaoapp/yao/agent/context" ) // AudioAdapter handles audio capability // If model doesn't support audio, it removes or converts audio content type AudioAdapter struct { *BaseAdapter nativeSupport bool } // NewAudioAdapter creates a new audio adapter func NewAudioAdapter(nativeSupport bool) *AudioAdapter { return &AudioAdapter{ BaseAdapter: NewBaseAdapter("AudioAdapter"), nativeSupport: nativeSupport, } } // PreprocessMessages removes or converts audio content if not supported func (a *AudioAdapter) PreprocessMessages(messages []context.Message) ([]context.Message, error) { if a.nativeSupport { // Native support, no preprocessing needed return messages, nil } // Process messages to remove audio content processed := make([]context.Message, 0, len(messages)) for _, msg := range messages { processedMsg := msg // Handle multimodal content (array of ContentPart) if contentParts, ok := msg.Content.([]context.ContentPart); ok { filteredParts := make([]context.ContentPart, 0) for _, part := range contentParts { // Skip audio content if not supported if part.Type == context.ContentInputAudio { // TODO: Optionally convert to transcription text if available continue } filteredParts = append(filteredParts, part) } // If all parts were filtered out, add placeholder text if len(filteredParts) == 0 { processedMsg.Content = "[Audio content not supported by this model]" } else if len(filteredParts) == 1 && filteredParts[0].Type == context.ContentText { // Single text part, convert to string processedMsg.Content = filteredParts[0].Text } else { processedMsg.Content = filteredParts } } processed = append(processed, processedMsg) } return processed, nil } ================================================ FILE: agent/llm/adapters/reasoning.go ================================================ package adapters import ( "github.com/yaoapp/gou/connector/openai" "github.com/yaoapp/yao/agent/context" ) // ReasoningFormat represents the reasoning content format type ReasoningFormat string const ( ReasoningFormatNone ReasoningFormat = "none" // No reasoning support ReasoningFormatOpenAI ReasoningFormat = "openai-o1" // OpenAI o1 format (hidden reasoning) ReasoningFormatGPT5 ReasoningFormat = "gpt-5" // GPT-5 format (hidden reasoning) ReasoningFormatDeepSeek ReasoningFormat = "deepseek-r1" // DeepSeek R1 format (visible reasoning) ) // ReasoningAdapter handles reasoning content capability // - Manages reasoning_effort parameter (o1, GPT-5) // - Manages temperature parameter constraints (reasoning models typically require temperature=1) // - Extracts reasoning_tokens from usage // - Parses visible reasoning content (DeepSeek R1) type ReasoningAdapter struct { *BaseAdapter format ReasoningFormat supportsEffort bool // Whether the model supports reasoning_effort parameter supportsTemperature bool // Whether the model supports temperature adjustment } // NewReasoningAdapter creates a new reasoning adapter // If cap.TemperatureAdjustable is provided, it overrides the default behavior func NewReasoningAdapter(format ReasoningFormat, cap *openai.Capabilities) *ReasoningAdapter { supportsEffort := false supportsTemperature := true // Set defaults based on reasoning format switch format { case ReasoningFormatOpenAI, ReasoningFormatGPT5: // OpenAI o1 and GPT-5: support reasoning_effort, but NOT temperature adjustment supportsEffort = true supportsTemperature = false case ReasoningFormatDeepSeek: // DeepSeek R1: no reasoning_effort, no temperature adjustment supportsEffort = false supportsTemperature = false case ReasoningFormatNone: // Non-reasoning models: no reasoning_effort, but support temperature supportsEffort = false supportsTemperature = true } // Override with explicit capability if provided if cap != nil { supportsTemperature = cap.TemperatureAdjustable } return &ReasoningAdapter{ BaseAdapter: NewBaseAdapter("ReasoningAdapter"), format: format, supportsEffort: supportsEffort, supportsTemperature: supportsTemperature, } } // PreprocessOptions handles reasoning_effort and temperature parameters func (a *ReasoningAdapter) PreprocessOptions(options *context.CompletionOptions) (*context.CompletionOptions, error) { if options == nil { return options, nil } newOptions := *options modified := false // 1. Handle reasoning_effort parameter if !a.supportsEffort && newOptions.ReasoningEffort != nil { // Model doesn't support reasoning_effort, remove the parameter newOptions.ReasoningEffort = nil modified = true } // 2. Handle temperature parameter if !a.supportsTemperature && newOptions.Temperature != nil { currentTemp := *newOptions.Temperature if currentTemp != 1.0 { // Model doesn't support temperature adjustment, reset to default (1.0) defaultTemp := 1.0 newOptions.Temperature = &defaultTemp modified = true } } if modified { return &newOptions, nil } // No modifications needed return options, nil } // ProcessStreamChunk processes streaming chunks with reasoning content func (a *ReasoningAdapter) ProcessStreamChunk(chunkType context.StreamChunkType, data []byte) (context.StreamChunkType, []byte, error) { if a.format == ReasoningFormatNone { // No reasoning support, pass through return chunkType, data, nil } // TODO: Parse reasoning_content based on format // - OpenAI o1: No visible reasoning in stream (reasoning happens internally) // - GPT-5: No visible reasoning in stream (reasoning happens internally) // - DeepSeek R1: May have ... tags or reasoning_content field return chunkType, data, nil } // PostprocessResponse extracts reasoning content and tokens from the final response func (a *ReasoningAdapter) PostprocessResponse(response *context.CompletionResponse) (*context.CompletionResponse, error) { if a.format == ReasoningFormatNone { // No reasoning support return response, nil } // Reasoning tokens are already extracted in Usage.CompletionTokensDetails.ReasoningTokens // by the OpenAI response parser, no additional processing needed for o1/GPT-5 // TODO: For DeepSeek R1, extract visible reasoning content // - Parse ... tags from content // - Set response.ReasoningContent // - Remove tags from response.Content (keep only final answer) return response, nil } ================================================ FILE: agent/llm/adapters/toolcall.go ================================================ package adapters import ( "github.com/yaoapp/yao/agent/context" ) // ToolCallAdapter handles tool calling capability // If model doesn't support native tool calls, it injects tool instructions into prompts type ToolCallAdapter struct { *BaseAdapter nativeSupport bool } // NewToolCallAdapter creates a new tool call adapter func NewToolCallAdapter(nativeSupport bool) *ToolCallAdapter { return &ToolCallAdapter{ BaseAdapter: NewBaseAdapter("ToolCallAdapter"), nativeSupport: nativeSupport, } } // PreprocessMessages injects tool calling instructions if not natively supported func (a *ToolCallAdapter) PreprocessMessages(messages []context.Message) ([]context.Message, error) { if a.nativeSupport { // Native support, no preprocessing needed return messages, nil } // TODO: Inject tool calling instructions into system prompt // - Generate tool description prompt // - Add to system message or create new system message // - Include tool schemas and usage instructions return messages, nil } // PreprocessOptions removes tool-related options if not natively supported func (a *ToolCallAdapter) PreprocessOptions(options *context.CompletionOptions) (*context.CompletionOptions, error) { if a.nativeSupport { // Native support, keep options as-is return options, nil } if options == nil { return options, nil } // Remove tool parameters for non-native models newOptions := *options newOptions.Tools = nil newOptions.ToolChoice = nil return &newOptions, nil } // PostprocessResponse extracts tool calls from text if not natively supported func (a *ToolCallAdapter) PostprocessResponse(response *context.CompletionResponse) (*context.CompletionResponse, error) { if a.nativeSupport { // Native support, response already has structured tool calls return response, nil } // TODO: Extract tool calls from text response // - Look for JSON blocks or specific patterns // - Parse tool name and arguments // - Add to response.ToolCalls return response, nil } ================================================ FILE: agent/llm/adapters/vision.go ================================================ package adapters import ( "encoding/base64" "fmt" "io" "net/http" "strings" "time" "github.com/yaoapp/yao/agent/context" ) // VisionAdapter handles vision (image) capability // If model doesn't support vision, it removes or converts image content type VisionAdapter struct { *BaseAdapter nativeSupport bool format context.VisionFormat } // NewVisionAdapter creates a new vision adapter func NewVisionAdapter(nativeSupport bool, format context.VisionFormat) *VisionAdapter { return &VisionAdapter{ BaseAdapter: NewBaseAdapter("VisionAdapter"), nativeSupport: nativeSupport, format: format, } } // PreprocessMessages removes or converts image content if not supported func (a *VisionAdapter) PreprocessMessages(messages []context.Message) ([]context.Message, error) { if !a.nativeSupport { // No vision support, remove image content return a.removeImageContent(messages), nil } // Check if we need to convert format needsConversion := a.format == context.VisionFormatClaude || a.format == context.VisionFormatBase64 if !needsConversion { // Native support with OpenAI format or default, no preprocessing needed return messages, nil } // Convert image_url format to Claude base64 format return a.convertToBase64Format(messages) } // removeImageContent removes image content from messages func (a *VisionAdapter) removeImageContent(messages []context.Message) []context.Message { processed := make([]context.Message, 0, len(messages)) for _, msg := range messages { processedMsg := msg // Handle multimodal content (array of map) if contentParts, ok := msg.Content.([]map[string]interface{}); ok { filteredParts := make([]map[string]interface{}, 0) for _, part := range contentParts { partType, _ := part["type"].(string) // Skip image content if partType != "image_url" && partType != "image" { filteredParts = append(filteredParts, part) } } // If all parts were filtered out, add placeholder text if len(filteredParts) == 0 { processedMsg.Content = "[Image content not supported by this model]" } else if len(filteredParts) == 1 { if textVal, ok := filteredParts[0]["text"].(string); ok { processedMsg.Content = textVal } else { processedMsg.Content = filteredParts } } else { processedMsg.Content = filteredParts } } processed = append(processed, processedMsg) } return processed } // convertToBase64Format converts image_url format to Claude base64 format func (a *VisionAdapter) convertToBase64Format(messages []context.Message) ([]context.Message, error) { processed := make([]context.Message, 0, len(messages)) for _, msg := range messages { processedMsg := msg // Handle multimodal content if contentParts, ok := msg.Content.([]map[string]interface{}); ok { convertedParts := make([]map[string]interface{}, 0) for _, part := range contentParts { partType, _ := part["type"].(string) if partType == "image_url" { // Convert to base64 format convertedPart, err := a.convertImageURLToBase64(part) if err != nil { // If conversion fails, skip this image continue } convertedParts = append(convertedParts, convertedPart) } else { // Keep non-image parts as-is convertedParts = append(convertedParts, part) } } processedMsg.Content = convertedParts } processed = append(processed, processedMsg) } return processed, nil } // convertImageURLToBase64 converts OpenAI image_url format to Claude base64 format func (a *VisionAdapter) convertImageURLToBase64(part map[string]interface{}) (map[string]interface{}, error) { // Extract URL from image_url object imageURLObj, ok := part["image_url"].(map[string]interface{}) if !ok { return nil, fmt.Errorf("invalid image_url format") } url, ok := imageURLObj["url"].(string) if !ok || url == "" { return nil, fmt.Errorf("missing or invalid URL in image_url") } // Check if already base64 data URL if strings.HasPrefix(url, "data:") { // Extract media type and base64 data from data URL // Format: data:image/jpeg;base64, parts := strings.SplitN(url, ",", 2) if len(parts) != 2 { return nil, fmt.Errorf("invalid data URL format") } // Extract media type from first part mediaParts := strings.Split(parts[0], ";") mediaType := strings.TrimPrefix(mediaParts[0], "data:") base64Data := parts[1] return map[string]interface{}{ "type": "image", "source": map[string]interface{}{ "type": "base64", "media_type": mediaType, "data": base64Data, }, }, nil } // Download image from URL and convert to base64 base64Data, mediaType, err := a.downloadAndEncodeImage(url) if err != nil { return nil, fmt.Errorf("failed to download image: %w", err) } // Return Claude/Anthropic format return map[string]interface{}{ "type": "image", "source": map[string]interface{}{ "type": "base64", "media_type": mediaType, "data": base64Data, }, }, nil } // downloadAndEncodeImage downloads an image from URL and returns base64 encoded data func (a *VisionAdapter) downloadAndEncodeImage(url string) (string, string, error) { // Create HTTP client with timeout client := &http.Client{ Timeout: 30 * time.Second, } // Download image resp, err := client.Get(url) if err != nil { return "", "", fmt.Errorf("failed to download image: %w", err) } defer resp.Body.Close() if resp.StatusCode != 200 { return "", "", fmt.Errorf("failed to download image: HTTP %d", resp.StatusCode) } // Read image data imageData, err := io.ReadAll(resp.Body) if err != nil { return "", "", fmt.Errorf("failed to read image data: %w", err) } // Detect media type from Content-Type header mediaType := resp.Header.Get("Content-Type") // Normalize media type (remove charset and other parameters) if mediaType != "" { // Split by semicolon to remove parameters like "; charset=utf-8" if idx := strings.Index(mediaType, ";"); idx != -1 { mediaType = strings.TrimSpace(mediaType[:idx]) } } if mediaType == "" { // Fallback to detecting from URL extension or default to jpeg urlLower := strings.ToLower(url) if strings.HasSuffix(urlLower, ".png") { mediaType = "image/png" } else if strings.HasSuffix(urlLower, ".gif") { mediaType = "image/gif" } else if strings.HasSuffix(urlLower, ".webp") { mediaType = "image/webp" } else if strings.Contains(urlLower, ".jpg") || strings.Contains(urlLower, ".jpeg") { mediaType = "image/jpeg" } else { // Default to jpeg mediaType = "image/jpeg" } } // Encode to base64 base64Data := base64.StdEncoding.EncodeToString(imageData) return base64Data, mediaType, nil } ================================================ FILE: agent/llm/capabilities.go ================================================ package llm import ( "github.com/yaoapp/gou/connector" goullm "github.com/yaoapp/gou/llm" ) // GetCapabilities get the capabilities of a connector by connector ID // Reads capabilities from connector's Setting()["capabilities"], with fallback to defaults. func GetCapabilities(connectorID string) *goullm.Capabilities { if connectorID == "" { return getDefaultCapabilities() } conn, err := connector.Select(connectorID) if err != nil { return getDefaultCapabilities() } return GetCapabilitiesFromConn(conn) } // GetCapabilitiesFromConn get the capabilities from a connector instance func GetCapabilitiesFromConn(conn connector.Connector) *goullm.Capabilities { if conn == nil { return getDefaultCapabilities() } settings := conn.Setting() if settings != nil { if caps, ok := settings["capabilities"]; ok { if capabilities, ok := caps.(*goullm.Capabilities); ok { return capabilities } if capabilities, ok := caps.(goullm.Capabilities); ok { return &capabilities } } } return getDefaultCapabilities() } // getDefaultCapabilities returns minimal default capabilities func getDefaultCapabilities() *goullm.Capabilities { return &goullm.Capabilities{ Vision: false, ToolCalls: false, Audio: false, Reasoning: false, Streaming: false, JSON: false, Multimodal: false, TemperatureAdjustable: true, } } // GetCapabilitiesMap get capabilities as map[string]interface{} for API responses func GetCapabilitiesMap(connectorID string) map[string]interface{} { caps := GetCapabilities(connectorID) if caps == nil { return nil } return ToMap(caps) } // ToMap converts Capabilities to map[string]interface{} func ToMap(caps *goullm.Capabilities) map[string]interface{} { if caps == nil { return nil } result := make(map[string]interface{}) if caps.Vision != nil { result["vision"] = caps.Vision } result["audio"] = caps.Audio result["stt"] = caps.STT result["tool_calls"] = caps.ToolCalls result["reasoning"] = caps.Reasoning result["streaming"] = caps.Streaming result["json"] = caps.JSON result["multimodal"] = caps.Multimodal result["temperature_adjustable"] = caps.TemperatureAdjustable return result } ================================================ FILE: agent/llm/interfaces.go ================================================ package llm import ( "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" ) // LLM the LLM interface type LLM interface { Stream(ctx *context.Context, messages []context.Message, options *context.CompletionOptions, handler message.StreamFunc) (*context.CompletionResponse, error) Post(ctx *context.Context, messages []context.Message, options *context.CompletionOptions) (*context.CompletionResponse, error) } ================================================ FILE: agent/llm/jsapi.go ================================================ // Package llm provides the LLM JSAPI implementation package llm import ( "fmt" "github.com/yaoapp/gou/connector" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" ) // JSAPI implements LlmAPI interface for ctx.llm.* methods type JSAPI struct { ctx *agentContext.Context } // Ensure JSAPI implements both interfaces var _ agentContext.LlmAPI = (*JSAPI)(nil) var _ agentContext.LlmAPIWithCallback = (*JSAPI)(nil) // NewJSAPI creates a new JSAPI for the given context func NewJSAPI(ctx *agentContext.Context) *JSAPI { return &JSAPI{ctx: ctx} } // SetJSAPIFactory registers the JSAPI factory with the context package // This should be called during initialization func SetJSAPIFactory() { agentContext.LlmAPIFactory = func(ctx *agentContext.Context) agentContext.LlmAPI { return NewJSAPI(ctx) } } // Stream implements LlmAPI.Stream - calls LLM with streaming output to ctx.Writer func (api *JSAPI) Stream(connectorID string, messages []interface{}, opts map[string]interface{}) interface{} { return api.StreamWithHandler(connectorID, messages, opts, nil) } // StreamWithHandler implements LlmAPIWithCallback.StreamWithHandler - calls LLM with OnMessage handler func (api *JSAPI) StreamWithHandler(connectorID string, messages []interface{}, opts map[string]interface{}, handler agentContext.OnMessageFunc) interface{} { result := &Result{ Connector: connectorID, } // Validate context if api.ctx == nil { result.Error = "context is nil" return result } // Get connector conn, err := connector.Select(connectorID) if err != nil { result.Error = fmt.Sprintf("failed to select connector %s: %v", connectorID, err) return result } // Parse messages to context.Message format ctxMessages, err := parseMessages(messages) if err != nil { result.Error = fmt.Sprintf("failed to parse messages: %v", err) return result } // Build CompletionOptions from opts completionOptions := buildCompletionOptions(conn, opts) // Create LLM instance llmInstance, err := New(conn, completionOptions) if err != nil { result.Error = fmt.Sprintf("failed to create LLM instance: %v", err) return result } // Create stream handler with the provided callback // Note: We pass handler directly to the stream handler instead of setting ctx.Stack.Options.OnMessage // This avoids race conditions in concurrent batch calls where multiple goroutines // would otherwise overwrite the same ctx.Stack.Options.OnMessage streamHandler := createStreamHandlerWithCallback(api.ctx, handler) // Execute LLM stream call response, err := llmInstance.Stream(api.ctx, ctxMessages, completionOptions, streamHandler) if err != nil { result.Error = fmt.Sprintf("LLM stream failed: %v", err) return result } // Set response result.Response = response // Extract text content from response if response != nil { result.Content = extractContent(response) } return result } // createStreamHandlerWithCallback creates a stream handler that uses the provided callback directly // This is used instead of setting ctx.Stack.Options.OnMessage to avoid race conditions // in concurrent batch calls func createStreamHandlerWithCallback(ctx *agentContext.Context, handler agentContext.OnMessageFunc) message.StreamFunc { // Handle nil context if ctx == nil { return func(chunkType message.StreamChunkType, data []byte) int { return 0 // No-op handler when context is nil } } // Stream state for tracking message groups state := &streamState{ ctx: ctx, buffer: []byte{}, handler: handler, // Store the handler directly in state } return func(chunkType message.StreamChunkType, data []byte) int { return state.handleChunk(chunkType, data) } } // parseMessages converts JS message array to context.Message slice func parseMessages(messages []interface{}) ([]agentContext.Message, error) { result := make([]agentContext.Message, 0, len(messages)) for i, msg := range messages { msgMap, ok := msg.(map[string]interface{}) if !ok { return nil, fmt.Errorf("message %d is not an object", i) } ctxMsg := agentContext.Message{} // Required: role if role, ok := msgMap["role"].(string); ok { ctxMsg.Role = agentContext.MessageRole(role) } else { return nil, fmt.Errorf("message %d missing role", i) } // Optional: content (can be string or array for multimodal) if content, ok := msgMap["content"]; ok { ctxMsg.Content = content } // Optional: name if name, ok := msgMap["name"].(string); ok { ctxMsg.Name = &name } // Optional: tool_calls if toolCalls, ok := msgMap["tool_calls"]; ok { if tcArray, ok := toolCalls.([]interface{}); ok { ctxMsg.ToolCalls = parseToolCalls(tcArray) } } // Optional: tool_call_id (for tool response messages) if toolCallID, ok := msgMap["tool_call_id"].(string); ok { ctxMsg.ToolCallID = &toolCallID } result = append(result, ctxMsg) } return result, nil } // parseToolCalls converts JS tool_calls array to context.ToolCall slice func parseToolCalls(toolCalls []interface{}) []agentContext.ToolCall { result := make([]agentContext.ToolCall, 0, len(toolCalls)) for _, tc := range toolCalls { tcMap, ok := tc.(map[string]interface{}) if !ok { continue } toolCall := agentContext.ToolCall{} if id, ok := tcMap["id"].(string); ok { toolCall.ID = id } if typ, ok := tcMap["type"].(string); ok { toolCall.Type = agentContext.ToolCallType(typ) } if fn, ok := tcMap["function"].(map[string]interface{}); ok { toolCall.Function = agentContext.Function{} if name, ok := fn["name"].(string); ok { toolCall.Function.Name = name } if args, ok := fn["arguments"].(string); ok { toolCall.Function.Arguments = args } } result = append(result, toolCall) } return result } // buildCompletionOptions creates CompletionOptions from JS opts map // BuildCompletionOptions builds CompletionOptions from a connector and raw opts map. // Exported for reuse by gRPC handlers. func BuildCompletionOptions(conn connector.Connector, opts map[string]interface{}) *agentContext.CompletionOptions { return buildCompletionOptions(conn, opts) } func buildCompletionOptions(conn connector.Connector, opts map[string]interface{}) *agentContext.CompletionOptions { // Get capabilities from connector capabilities := GetCapabilitiesFromConn(conn) completionOptions := &agentContext.CompletionOptions{ Capabilities: capabilities, } if opts == nil { return completionOptions } // Temperature if temp, ok := opts["temperature"].(float64); ok { completionOptions.Temperature = &temp } // Max tokens if maxTokens, ok := opts["max_tokens"].(float64); ok { mt := int(maxTokens) completionOptions.MaxTokens = &mt } if maxCompletionTokens, ok := opts["max_completion_tokens"].(float64); ok { mct := int(maxCompletionTokens) completionOptions.MaxCompletionTokens = &mct } // Top P if topP, ok := opts["top_p"].(float64); ok { completionOptions.TopP = &topP } // Presence penalty if presencePenalty, ok := opts["presence_penalty"].(float64); ok { completionOptions.PresencePenalty = &presencePenalty } // Frequency penalty if frequencyPenalty, ok := opts["frequency_penalty"].(float64); ok { completionOptions.FrequencyPenalty = &frequencyPenalty } // Stop sequences if stop, ok := opts["stop"]; ok { completionOptions.Stop = stop } // User if user, ok := opts["user"].(string); ok { completionOptions.User = user } // Seed if seed, ok := opts["seed"].(float64); ok { s := int(seed) completionOptions.Seed = &s } // Tools if tools, ok := opts["tools"].([]interface{}); ok { completionOptions.Tools = make([]map[string]interface{}, 0, len(tools)) for _, tool := range tools { if toolMap, ok := tool.(map[string]interface{}); ok { completionOptions.Tools = append(completionOptions.Tools, toolMap) } } } // Tool choice if toolChoice, ok := opts["tool_choice"]; ok { completionOptions.ToolChoice = toolChoice } // Response format if responseFormat, ok := opts["response_format"].(map[string]interface{}); ok { rf := &agentContext.ResponseFormat{} if rfType, ok := responseFormat["type"].(string); ok { rf.Type = agentContext.ResponseFormatType(rfType) } if jsonSchema, ok := responseFormat["json_schema"].(map[string]interface{}); ok { rf.JSONSchema = &agentContext.JSONSchema{} if name, ok := jsonSchema["name"].(string); ok { rf.JSONSchema.Name = name } if desc, ok := jsonSchema["description"].(string); ok { rf.JSONSchema.Description = desc } if schema, ok := jsonSchema["schema"]; ok { rf.JSONSchema.Schema = schema } if strict, ok := jsonSchema["strict"].(bool); ok { rf.JSONSchema.Strict = &strict } } completionOptions.ResponseFormat = rf } // Reasoning effort (for reasoning models) if reasoningEffort, ok := opts["reasoning_effort"].(string); ok { completionOptions.ReasoningEffort = &reasoningEffort } return completionOptions } // streamState manages stream handler state type streamState struct { ctx *agentContext.Context inMessage bool currentMsgID string currentMsgType string buffer []byte msgCounter int // Counter for generating message IDs when IDGenerator is nil chunkCounter int // Counter for generating chunk IDs when IDGenerator is nil handler agentContext.OnMessageFunc // Direct handler reference (avoids race condition via ctx.Stack.Options) } // generateMessageID generates a unique message ID func (s *streamState) generateMessageID() string { if s.ctx != nil && s.ctx.IDGenerator != nil { return s.ctx.IDGenerator.GenerateMessageID() } s.msgCounter++ return fmt.Sprintf("M%d", s.msgCounter) } // generateChunkID generates a unique chunk ID func (s *streamState) generateChunkID() string { if s.ctx != nil && s.ctx.IDGenerator != nil { return s.ctx.IDGenerator.GenerateChunkID() } s.chunkCounter++ return fmt.Sprintf("C%d", s.chunkCounter) } // handleChunk processes a single stream chunk func (s *streamState) handleChunk(chunkType message.StreamChunkType, data []byte) int { switch chunkType { case message.ChunkMessageStart: s.inMessage = true s.currentMsgID = s.generateMessageID() s.buffer = []byte{} return 0 case message.ChunkText: if !s.inMessage { s.inMessage = true s.currentMsgID = s.generateMessageID() } s.currentMsgType = message.TypeText s.buffer = append(s.buffer, data...) // Create message msg := &message.Message{ ChunkID: s.generateChunkID(), MessageID: s.currentMsgID, Type: message.TypeText, Delta: true, Props: map[string]interface{}{ "content": string(data), }, } // Call handler directly if provided (for batch calls and single calls with callback) // We use direct handler instead of ctx.Stack.Options.OnMessage to avoid race conditions // in concurrent batch calls where multiple goroutines would overwrite the shared OnMessage if s.handler != nil { if ret := s.handler(msg); ret != 0 { return ret } } // Send to output for actual message delivery to client // Note: ctx.Send may also call ctx.Stack.Options.OnMessage if set (for agent calls), // but for LLM calls we don't set OnMessage, so no double callback occurs if err := s.ctx.Send(msg); err != nil { // Log error but continue streaming return 0 } return 0 case message.ChunkThinking: if !s.inMessage { s.inMessage = true s.currentMsgID = s.generateMessageID() } s.currentMsgType = message.TypeThinking s.buffer = append(s.buffer, data...) msg := &message.Message{ ChunkID: s.generateChunkID(), MessageID: s.currentMsgID, Type: message.TypeThinking, Delta: true, Props: map[string]interface{}{ "content": string(data), }, } // Call handler directly if provided if s.handler != nil { if ret := s.handler(msg); ret != 0 { return ret } } if err := s.ctx.Send(msg); err != nil { return 0 } return 0 case message.ChunkToolCall: if !s.inMessage { s.inMessage = true s.currentMsgID = s.generateMessageID() } s.currentMsgType = message.TypeToolCall s.buffer = append(s.buffer, data...) // Tool call chunks are more complex - parse and forward msg := &message.Message{ ChunkID: s.generateChunkID(), MessageID: s.currentMsgID, Type: message.TypeToolCall, Delta: true, Props: map[string]interface{}{ "raw": string(data), }, } // Call handler directly if provided if s.handler != nil { if ret := s.handler(msg); ret != 0 { return ret } } if err := s.ctx.Send(msg); err != nil { return 0 } return 0 case message.ChunkMessageEnd: if s.inMessage { s.inMessage = false s.currentMsgID = "" s.buffer = []byte{} } return 0 case message.ChunkError: // Send error and stop msg := &message.Message{ Type: message.TypeError, Props: map[string]interface{}{ "error": string(data), }, } // Call handler directly if provided if s.handler != nil { s.handler(msg) } _ = s.ctx.Send(msg) // Ignore error on error message return 1 // Stop on error default: // Other chunk types (stream_start, stream_end, metadata) - ignore return 0 } } // extractContent extracts text content from CompletionResponse func extractContent(response *agentContext.CompletionResponse) string { if response == nil || response.Content == nil { return "" } switch content := response.Content.(type) { case string: return content case []interface{}: // Multimodal response - extract text parts var text string for _, part := range content { if partMap, ok := part.(map[string]interface{}); ok { if partMap["type"] == "text" { if t, ok := partMap["text"].(string); ok { text += t } } } } return text default: return "" } } // ============================================================================ // Batch LLM Methods: All, Any, Race // ============================================================================ // All executes all LLM requests concurrently and returns all results func (api *JSAPI) All(requests []interface{}) []interface{} { return api.AllWithHandler(requests, nil) } // Any executes LLM requests concurrently and returns first successful result func (api *JSAPI) Any(requests []interface{}) []interface{} { return api.AnyWithHandler(requests, nil) } // Race executes LLM requests concurrently and returns first completed result func (api *JSAPI) Race(requests []interface{}) []interface{} { return api.RaceWithHandler(requests, nil) } // AllWithHandler executes all LLM requests with global handler func (api *JSAPI) AllWithHandler(requests []interface{}, globalHandler agentContext.LlmBatchOnMessageFunc) []interface{} { parsedRequests := api.parseRequests(requests, globalHandler) return api.executeAll(parsedRequests) } // AnyWithHandler executes LLM requests and returns first success with handler func (api *JSAPI) AnyWithHandler(requests []interface{}, globalHandler agentContext.LlmBatchOnMessageFunc) []interface{} { parsedRequests := api.parseRequests(requests, globalHandler) return api.executeAny(parsedRequests) } // RaceWithHandler executes LLM requests and returns first completion with handler func (api *JSAPI) RaceWithHandler(requests []interface{}, globalHandler agentContext.LlmBatchOnMessageFunc) []interface{} { parsedRequests := api.parseRequests(requests, globalHandler) return api.executeRace(parsedRequests) } // parseRequests converts JS request array to internal Request slice func (api *JSAPI) parseRequests(requests []interface{}, globalHandler agentContext.LlmBatchOnMessageFunc) []*Request { result := make([]*Request, 0, len(requests)) for i, req := range requests { reqMap, ok := req.(map[string]interface{}) if !ok { continue } request := &Request{} // Required: connector if connector, ok := reqMap["connector"].(string); ok { request.Connector = connector } else { continue // Skip invalid request } // Required: messages if messages, ok := reqMap["messages"].([]interface{}); ok { request.Messages = messages } else { continue // Skip invalid request } // Optional: options if options, ok := reqMap["options"].(map[string]interface{}); ok { // Remove onChunk from options if present (handled via globalHandler) delete(options, "onChunk") request.Options = options } // Set handler based on globalHandler if globalHandler != nil { index := i connectorID := request.Connector request.Handler = func(msg *message.Message) int { return globalHandler(connectorID, index, msg) } } result = append(result, request) } return result } // executeAll executes all requests concurrently and waits for all to complete // Each request uses a forked context to avoid race conditions on shared state func (api *JSAPI) executeAll(requests []*Request) []interface{} { if len(requests) == 0 { return []interface{}{} } results := make([]interface{}, len(requests)) done := make(chan struct{}) remaining := len(requests) for i, req := range requests { go func(index int, request *Request) { defer func() { if err := recover(); err != nil { results[index] = &Result{ Connector: request.Connector, Error: fmt.Sprintf("panic: %v", err), } } done <- struct{}{} }() // Use forked context to avoid race conditions results[index] = api.executeSingleRequestWithForkedContext(request) }(i, req) } // Wait for all to complete for remaining > 0 { <-done remaining-- } return results } // executeAny executes requests and returns first successful result // Each request uses a forked context to avoid race conditions on shared state func (api *JSAPI) executeAny(requests []*Request) []interface{} { if len(requests) == 0 { return []interface{}{} } type indexedResult struct { index int result *Result } resultChan := make(chan indexedResult, len(requests)) remaining := len(requests) for i, req := range requests { go func(index int, request *Request) { defer func() { if err := recover(); err != nil { resultChan <- indexedResult{ index: index, result: &Result{ Connector: request.Connector, Error: fmt.Sprintf("panic: %v", err), }, } } }() // Use forked context to avoid race conditions res := api.executeSingleRequestWithForkedContext(request) resultChan <- indexedResult{index: index, result: res.(*Result)} }(i, req) } // Wait for first success or all failures var firstSuccess *indexedResult errors := make([]*indexedResult, 0) for remaining > 0 { ir := <-resultChan remaining-- if ir.result.Error == "" { // Success! firstSuccess = &ir break } errors = append(errors, &ir) } // Drain remaining results in background (don't block) if remaining > 0 { go func(count int) { for i := 0; i < count; i++ { <-resultChan } }(remaining) } if firstSuccess != nil { return []interface{}{firstSuccess.result} } // All failed - return all errors results := make([]interface{}, len(errors)) for i, e := range errors { results[i] = e.result } return results } // executeRace executes requests and returns first completed result (success or failure) // Each request uses a forked context to avoid race conditions on shared state func (api *JSAPI) executeRace(requests []*Request) []interface{} { if len(requests) == 0 { return []interface{}{} } resultChan := make(chan *Result, len(requests)) for _, req := range requests { go func(request *Request) { defer func() { if err := recover(); err != nil { resultChan <- &Result{ Connector: request.Connector, Error: fmt.Sprintf("panic: %v", err), } } }() // Use forked context to avoid race conditions res := api.executeSingleRequestWithForkedContext(request) resultChan <- res.(*Result) }(req) } // Return first result result := <-resultChan // Drain remaining results in background (don't block) remaining := len(requests) - 1 if remaining > 0 { go func(count int) { for i := 0; i < count; i++ { <-resultChan } }(remaining) } return []interface{}{result} } // executeSingleRequestWithForkedContext executes a single LLM request with a forked context // This is used by batch operations (All/Any/Race) to avoid race conditions // when multiple goroutines access shared context state func (api *JSAPI) executeSingleRequestWithForkedContext(request *Request) interface{} { // Fork the context to get independent resources (IDGenerator, Logger, etc.) forkedCtx := api.ctx.Fork() // Create a temporary JSAPI with the forked context forkedAPI := &JSAPI{ctx: forkedCtx} return forkedAPI.StreamWithHandler(request.Connector, request.Messages, request.Options, request.Handler) } ================================================ FILE: agent/llm/jsapi_types.go ================================================ // Package llm provides types and utilities for LLM JSAPI package llm import ( agentContext "github.com/yaoapp/yao/agent/context" ) // Request represents a request to call an LLM connector type Request struct { Connector string `json:"connector"` // Target connector ID Messages []interface{} `json:"messages"` // Messages to send Options map[string]interface{} `json:"options,omitempty"` // LLM call options (temperature, max_tokens, etc.) Handler agentContext.OnMessageFunc `json:"-"` // OnMessage handler for this request (not serialized) } // Result represents the result of a LLM call via JSAPI type Result struct { Connector string `json:"connector"` // Connector ID that was used Response *agentContext.CompletionResponse `json:"response,omitempty"` // Full LLM response Content string `json:"content,omitempty"` // Extracted text content Error string `json:"error,omitempty"` // Error message if call failed } // Note: LlmBatchOnMessageFunc is defined in agent/context/jsapi_llm.go // to avoid circular dependencies ================================================ FILE: agent/llm/llm.go ================================================ package llm import ( "github.com/yaoapp/gou/connector" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/llm/providers" ) // New create a new LLM instance // conn: connector object from connector.Select() // options: completion options containing capabilities and other settings func New(conn connector.Connector, options *context.CompletionOptions) (LLM, error) { // Select appropriate provider based on capabilities return providers.SelectProvider(conn, options) } ================================================ FILE: agent/llm/process.go ================================================ package llm import ( "context" "encoding/json" "fmt" "github.com/yaoapp/gou/connector" gouHTTP "github.com/yaoapp/gou/http" "github.com/yaoapp/gou/process" "github.com/yaoapp/gou/runtime/v8/bridge" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/openapi/oauth/authorized" ) func init() { process.Register("llm.ChatCompletions", ProcessChatCompletions) } // ProcessChatCompletions implements the llm.ChatCompletions Process. // A universal replacement for openai.chat.Completions that auto-detects // connector type (openai, anthropic, etc.) and routes accordingly. // // Usage: // // Process("llm.ChatCompletions", connector, messages) // Process("llm.ChatCompletions", connector, messages, opts) // Process("llm.ChatCompletions", connector, messages, opts, callback) // // Args: // - connector (string): Connector ID, any type (openai / anthropic / ...) // - messages ([]map): Message array, supports multimodal content (image_url, etc.) // - opts (map): Optional. temperature, max_tokens, etc. // - callback (func): Optional. Streaming callback func(data []byte) int // // Returns: OpenAI-compatible format { choices: [{ message: { role, content } }], ... } func ProcessChatCompletions(p *process.Process) interface{} { p.ValidateArgNums(2) // 1. Parse connector ID connectorID := p.ArgsString(0) if connectorID == "" { return newErrorResponse("llm.ChatCompletions: connector is required") } // 2. Parse messages rawMessages := p.ArgsArray(1) messages := make([]map[string]interface{}, 0, len(rawMessages)) for i, v := range rawMessages { msg, ok := v.(map[string]interface{}) if !ok { return newErrorResponse(fmt.Sprintf("llm.ChatCompletions: message %d is not an object", i)) } messages = append(messages, msg) } // 3. Parse optional opts var opts map[string]interface{} if p.NumOfArgs() > 2 && p.Args[2] != nil { if o, ok := p.Args[2].(map[string]interface{}); ok { opts = o } } // 4. Parse optional callback (for streaming) var callback func(data []byte) int if p.NumOfArgs() > 3 && p.Args[3] != nil { switch cb := p.Args[3].(type) { case func(data []byte) int: callback = cb case bridge.FunctionT: callback = func(data []byte) int { v, err := cb.Call(string(data)) if err != nil { return gouHTTP.HandlerReturnError } ret, ok := v.(int) if !ok { return gouHTTP.HandlerReturnError } return ret } } } // 5. Select connector conn, err := connector.Select(connectorID) if err != nil { return newErrorResponse(fmt.Sprintf("llm.ChatCompletions: connector %s not found: %v", connectorID, err)) } // 6. Build completion options (reuse jsapi.go logic) completionOptions := buildCompletionOptions(conn, opts) // 7. Create LLM instance (auto-selects openai/anthropic provider) llmInstance, err := New(conn, completionOptions) if err != nil { return newErrorResponse(fmt.Sprintf("llm.ChatCompletions: failed to create LLM: %v", err)) } // 8. Parse messages to context.Message format (reuse jsapi.go logic) interfaceMessages := make([]interface{}, len(messages)) for i, m := range messages { interfaceMessages[i] = m } ctxMessages, err := parseMessages(interfaceMessages) if err != nil { return newErrorResponse(fmt.Sprintf("llm.ChatCompletions: invalid messages: %v", err)) } // 8.1 Normalize multimodal content: convert []interface{} maps to []ContentPart // so that providers (especially Anthropic) can type-assert correctly. for i := range ctxMessages { if parts, ok := ctxMessages[i].Content.([]interface{}); ok { ctxMessages[i].Content = normalizeContentParts(parts) } } // 9. Build a minimal headless context for LLM call parent := p.Context if parent == nil { parent = context.Background() } authInfo := authorized.ProcessAuthInfo(p) chatID := agentContext.GenChatID() ctx := agentContext.New(parent, authInfo, chatID) defer ctx.Release() // 10. Create stream handler var streamHandler message.StreamFunc if callback != nil { // With callback: forward raw chunks to caller streamHandler = func(chunkType message.StreamChunkType, data []byte) int { if chunkType == message.ChunkText || chunkType == message.ChunkThinking { return callback(data) } return 0 } } else { // No callback: no-op handler, just collect final response streamHandler = func(chunkType message.StreamChunkType, data []byte) int { return 0 } } // 11. Execute LLM stream call response, err := llmInstance.Stream(ctx, ctxMessages, completionOptions, streamHandler) if err != nil { return newErrorResponse(fmt.Sprintf("llm.ChatCompletions: LLM call failed: %v", err)) } // 12. Convert CompletionResponse to OpenAI-compatible format // { choices: [{ message: { role, content } }], id, model, ... } return toOpenAIFormat(response) } // toOpenAIFormat converts CompletionResponse to OpenAI chat.completions format // for backward compatibility with code that consumed openai.chat.Completions. func toOpenAIFormat(resp *agentContext.CompletionResponse) map[string]interface{} { if resp == nil { return map[string]interface{}{ "choices": []interface{}{}, } } msgMap := map[string]interface{}{ "role": resp.Role, "content": resp.Content, } if len(resp.ToolCalls) > 0 { msgMap["tool_calls"] = resp.ToolCalls } choice := map[string]interface{}{ "index": 0, "message": msgMap, "finish_reason": "stop", } result := map[string]interface{}{ "id": resp.ID, "object": "chat.completion", "created": resp.Created, "model": resp.Model, "choices": []interface{}{choice}, } if resp.Usage != nil { result["usage"] = resp.Usage } return result } // newErrorResponse creates an error response in OpenAI-compatible format func newErrorResponse(errMsg string) map[string]interface{} { return map[string]interface{}{ "error": map[string]interface{}{ "message": errMsg, "type": "invalid_request_error", }, } } // normalizeContentParts converts []interface{} (raw maps from Process args) // to []agentContext.ContentPart (strongly typed) via JSON round-trip. // This is essential for providers (e.g. Anthropic) that type-assert on // []ContentPart to apply format-specific conversions (image_url → image). func normalizeContentParts(parts []interface{}) []agentContext.ContentPart { raw, err := json.Marshal(parts) if err != nil { return nil } var typed []agentContext.ContentPart if err := json.Unmarshal(raw, &typed); err != nil { return nil } return typed } ================================================ FILE: agent/llm/providers/ANTHROPIC_PROVIDER_PROPOSAL.md ================================================ # Anthropic Provider Implementation Proposal ## Overview This proposal outlines the implementation plan for native Anthropic Claude API support in Yao Agent. Currently, all LLM connectors use `type: "openai"`, and Anthropic detection relies on URL pattern matching, which is unreliable and architecturally incorrect. ## Current Architecture ``` gou/connector/ ├── openai/ # type: "openai" - handles all OpenAI-compatible APIs ├── moapi/ # type: "moapi" ├── redis/ # type: "redis" └── ... yao/agent/llm/ ├── providers/ │ ├── factory.go # SelectProvider() - selects provider based on connector type │ ├── base/ # Base provider implementation │ └── openai/ # OpenAI-compatible provider ``` **Current Flow:** 1. All LLM connectors declare `"type": "openai"` 2. `factory.go` uses `conn.Is(connector.OPENAI)` → always true for LLMs 3. `DetectAPIFormat()` guesses API format by URL patterns (unreliable) ## Problem Statement 1. **No type distinction**: Cannot differentiate Anthropic from OpenAI at connector level 2. **URL-based detection is fragile**: Relies on hardcoded patterns like `"anthropic.com"` 3. **API incompatibility**: Anthropic API uses different: - Endpoint: `/messages` vs `/chat/completions` - Auth header: `x-api-key` vs `Bearer` token - Request format: `system` as separate field, `max_tokens` required - Response format: Different structure ## Proposed Solution ### Phase 1: gou/connector - Add Anthropic Connector Type **New files:** ``` gou/connector/ ├── anthropic/ │ ├── anthropic.go # Connector implementation │ ├── types.go # Options, Capabilities structs │ └── defaults.go # Default model capabilities ``` **connector/types.go changes:** ```go const ( // ... existing types ANTHROPIC = 7 // New connector type ) ``` **connector/anthropic/anthropic.go:** ```go package anthropic type Connector struct { id string file string Name string `json:"name"` Options Options `json:"options"` } type Options struct { Host string `json:"host,omitempty"` // Default: https://api.anthropic.com Model string `json:"model,omitempty"` // e.g., claude-sonnet-4-5 Key string `json:"key"` // API key Version string `json:"version,omitempty"` // API version, default: 2024-01-01 Capabilities *Capabilities `json:"capabilities,omitempty"` } type Capabilities struct { Vision interface{} `json:"vision,omitempty"` ToolCalls bool `json:"tool_calls,omitempty"` Streaming bool `json:"streaming,omitempty"` // ... same as openai.Capabilities } ``` **DSL Example:** ```json { "label": "Claude Sonnet 4.5", "type": "anthropic", "options": { "model": "claude-sonnet-4-5", "key": "$ENV.ANTHROPIC_API_KEY", "capabilities": { "vision": "claude", "tool_calls": true, "streaming": true } } } ``` ### Phase 2: yao/agent/llm - Add Anthropic Provider **New files:** ``` yao/agent/llm/providers/ ├── anthropic/ │ ├── anthropic.go # Provider implementation │ └── types.go # Request/Response types ``` **anthropic/anthropic.go:** ```go package anthropic type Provider struct { *base.Provider adapters []adapters.CapabilityAdapter } func New(conn connector.Connector, capabilities *Capabilities) *Provider func (p *Provider) Stream(ctx, messages, options, handler) (*CompletionResponse, error) func (p *Provider) Post(ctx, messages, options) (*CompletionResponse, error) // Internal methods func (p *Provider) buildRequestBody(messages, options, streaming) (map[string]interface{}, error) func (p *Provider) convertMessages(messages []context.Message) []map[string]interface{} ``` **Key Implementation Details:** 1. **Message Conversion** (OpenAI format → Anthropic format): ```go // OpenAI format: // {"role": "system", "content": "..."} // {"role": "user", "content": "..."} // Anthropic format: // system: "..." (separate field) // messages: [{"role": "user", "content": "..."}] ``` 2. **Request Building:** ```go body := map[string]interface{}{ "model": model, "max_tokens": maxTokens, // Required in Anthropic "messages": convertedMessages, } if systemPrompt != "" { body["system"] = systemPrompt } ``` 3. **HTTP Headers:** ```go req.SetHeader("Content-Type", "application/json") req.SetHeader("x-api-key", apiKey) // Not Bearer token req.SetHeader("anthropic-version", "2024-01-01") ``` 4. **SSE Parsing** (different from OpenAI): ```go // Anthropic SSE events: // event: message_start // event: content_block_start // event: content_block_delta // event: content_block_stop // event: message_delta // event: message_stop ``` **factory.go changes:** ```go func SelectProvider(conn connector.Connector, options *CompletionOptions) (LLM, error) { // ... // Check connector type directly if conn.Is(connector.ANTHROPIC) { return anthropic.New(conn, options.Capabilities), nil } if conn.Is(connector.OPENAI) { return openai.New(conn, options.Capabilities), nil } // Default fallback return openai.New(conn, options.Capabilities), nil } ``` ## Implementation Effort | Component | Files | Lines (est.) | Effort | |-----------|-------|--------------|--------| | gou/connector/anthropic | 3 | ~250 | 2-3 hours | | yao/agent/llm/providers/anthropic | 2 | ~600 | 4-6 hours | | Tests | 4 | ~400 | 2-3 hours | | **Total** | **9** | **~1250** | **8-12 hours** | ## Migration Path 1. **Backward Compatible**: Existing `type: "openai"` connectors continue to work 2. **New Connectors**: Use `type: "anthropic"` for direct Anthropic API access 3. **Proxy Services**: OpenRouter, AWS Bedrock still use `type: "openai"` (they provide OpenAI-compatible endpoints) ## Testing Strategy 1. **Unit Tests**: Message conversion, request building 2. **Integration Tests**: Real API calls (with test API key) 3. **Connector Tests**: gou connector parsing and validation ## Alternative Considered **URL-based detection in yao layer only** (current approach): - Pros: No gou changes needed - Cons: Fragile, architecturally incorrect, no connector-level validation **Conclusion**: Rejected. Proper connector type is the cleaner solution. ## References - [Anthropic API Documentation](https://docs.anthropic.com/en/api) - [Anthropic Go SDK](https://github.com/anthropics/anthropic-sdk-go) - [OpenAI Compatibility Guide](https://platform.claude.com/docs/en/api/openai-sdk) ## Next Steps 1. Review and approve this proposal 2. Implement gou/connector/anthropic (Phase 1) 3. Implement yao/agent/llm/providers/anthropic (Phase 2) 4. Update yao-init connectors to use `type: "anthropic"` 5. Write tests and documentation ================================================ FILE: agent/llm/providers/README.md ================================================ # LLM Providers Architecture (New) ## Overview This directory contains LLM provider implementations using the **Capability Adapters** pattern. The new architecture separates API format handling from capability handling. ## Architecture Design ``` ┌─────────────────────────────────────────────────┐ │ LLM Provider (API Format) │ │ - OpenAI-compatible │ │ - Claude (TODO) │ │ - Custom (TODO) │ └──────────────┬──────────────────────────────────┘ │ ↓ ┌─────────────────────────────────────────────────┐ │ Capability Adapters (Modular) │ │ - ToolCallAdapter (native or prompt eng.) │ │ - VisionAdapter (native or removal) │ │ - AudioAdapter (native or removal) │ │ - ReasoningAdapter (o1/R1/GPT-Think) │ └─────────────────────────────────────────────────┘ ``` ## Key Concepts ### 1. Provider = API Format Providers handle the **API communication format**: - OpenAI-compatible API (`/v1/chat/completions`) - Claude API (TODO) - Custom API formats (TODO) ### 2. Adapters = Capabilities Adapters handle **model capabilities** independently: - **ToolCallAdapter**: Tool calling (native or prompt engineering) - **VisionAdapter**: Image input (native or removal/conversion) - **AudioAdapter**: Audio input (native or removal/conversion) - **ReasoningAdapter**: Reasoning content (o1/DeepSeek R1/GPT-4o thinking) ## Provider Selection ```go // factory.go func SelectProvider(conn connector.Connector, options *context.CompletionOptions) (LLM, error) { apiFormat := DetectAPIFormat(conn) switch apiFormat { case "openai": // Adapters automatically configured based on capabilities return openai.New(conn, options.Capabilities), nil case "claude": return claude.New(conn, options.Capabilities), nil default: return openai.New(conn, options.Capabilities), nil } } ``` ## Directory Structure ``` providers/ ├── factory.go # Provider selection based on API format ├── base/ # Common functionality │ └── base.go ├── openai/ # OpenAI-compatible API provider │ └── openai.go # Includes adapter integration └── README.md # This file ../adapters/ # Capability adapters (separate package) ├── adapter.go # Base interface ├── toolcall.go # Tool calling adapter ├── vision.go # Vision adapter ├── audio.go # Audio adapter └── reasoning.go # Reasoning adapter ``` ## OpenAI Provider The OpenAI provider supports **all capabilities** through adapters: ```go type Provider struct { *base.Provider adapters []adapters.CapabilityAdapter } func New(conn connector.Connector, capabilities *context.ModelCapabilities) *Provider { return &Provider{ Provider: base.NewProvider(conn, capabilities), adapters: buildAdapters(capabilities), // Auto-configured } } ``` ### Adapter Pipeline **Preprocessing** (before API call): ``` Messages → ToolCallAdapter → VisionAdapter → AudioAdapter → API Request ``` **Streaming** (during API call): ``` API Chunk → ReasoningAdapter → ToolCallAdapter → Output ``` **Postprocessing** (after API call): ``` API Response → All Adapters → Final Response ``` ## Model Examples ### Full-Featured Model (GPT-4o) ```yaml # connectors.yml gpt-4o: vision: true tool_calls: true audio: true reasoning: false ``` **Adapters created**: - ToolCallAdapter(native=true) - VisionAdapter(native=true) - AudioAdapter(native=true) ### Reasoning Model with Tools (OpenAI o1) ```yaml o1-preview: reasoning: true tool_calls: true ``` **Adapters created**: - ToolCallAdapter(native=true) - ReasoningAdapter(format=openai-o1) ### Reasoning Model without Tools (DeepSeek R1) ```yaml deepseek-reasoner: reasoning: true tool_calls: false ``` **Adapters created**: - ToolCallAdapter(native=false) → Uses prompt engineering - ReasoningAdapter(format=deepseek-r1) ### Legacy Model (GPT-3.5-instruct) ```yaml gpt-3.5-turbo-instruct: tool_calls: false vision: false audio: false ``` **Adapters created**: - ToolCallAdapter(native=false) → Prompt engineering - VisionAdapter(native=false) → Removes images - AudioAdapter(native=false) → Removes audio ## Capability Adapters ### ToolCallAdapter **When native=true**: - Passes tool definitions to API - Parses structured tool_calls from response **When native=false**: - Injects tool schemas into system prompt - Extracts tool calls from text response (JSON parsing) ### VisionAdapter **When native=true**: - Passes image URLs/data directly to API **When native=false**: - Removes image content from messages - Optionally converts to text descriptions ### AudioAdapter **When native=true**: - Passes audio data directly to API **When native=false**: - Removes audio content from messages - Optionally converts to text transcriptions ### ReasoningAdapter Handles different reasoning formats: **OpenAI o1** (`reasoning_content` field): ```json { "delta": { "reasoning_content": "Let me think...", "content": "The answer is 42" } } ``` **DeepSeek R1** (may have different format): ```json { "delta": { "content": "Let me think...The answer is 42" } } ``` **GPT-4o thinking** (future): ```json { "delta": { "thinking": "Let me think...", "content": "The answer is 42" } } ``` ## Adding New Capabilities 1. Create new adapter in `../adapters/`: ```go type NewCapabilityAdapter struct { *BaseAdapter nativeSupport bool } ``` 2. Implement CapabilityAdapter interface 3. Add to `buildAdapters()` in `openai/openai.go`: ```go if cap.NewCapability != nil { result = append(result, adapters.NewNewCapabilityAdapter(*cap.NewCapability)) } ``` ## Adding New API Format Provider 1. Create new directory: `providers/newapi/` 2. Implement LLM interface: ```go type Provider struct { *base.Provider adapters []adapters.CapabilityAdapter } func (p *Provider) Stream(...) (*CompletionResponse, error) { // Apply adapter preprocessing // Make API call // Apply adapter postprocessing } ``` 3. Update `factory.go`: ```go case "newapi": return newapi.New(conn, options.Capabilities), nil ``` ## Benefits of New Architecture 1. **Separation of Concerns**: - Providers handle API format - Adapters handle capabilities 2. **Code Reuse**: - Same adapters work across different providers - No duplication of capability logic 3. **Easy Extension**: - Add new capability = add one adapter - Add new API = add one provider 4. **Flexible Combinations**: - Any provider can use any adapter combination - Capabilities are composable 5. **Clear Responsibility**: - Each adapter handles exactly one capability dimension - Easy to test and maintain ## Testing Strategy ### Unit Tests (per adapter) - Test preprocessing logic - Test postprocessing logic - Test stream chunk processing ### Integration Tests (per provider) - Test with different adapter combinations - Test full request/response flow - Test error handling ### End-to-End Tests - Test real API calls with different models - Verify capability detection - Verify adapter selection ## Migration Notes ### Old Architecture → New Architecture **Before**: ``` reasoning.Provider → Reasoning models (o1, R1) openai.Provider → Full-featured models (GPT-4o) legacy.Provider → Old models (GPT-3) ``` **After**: ``` openai.Provider + adapters → ALL models ``` The same OpenAI provider now handles all cases through different adapter combinations. ================================================ FILE: agent/llm/providers/anthropic/anthropic.go ================================================ package anthropic import ( gocontext "context" "fmt" "sort" "strings" "time" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/http" goullm "github.com/yaoapp/gou/llm" "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/agent/llm/adapters" "github.com/yaoapp/yao/agent/llm/providers/base" "github.com/yaoapp/yao/agent/output/message" ) // Provider Anthropic Messages API provider type Provider struct { *base.Provider adapters []adapters.CapabilityAdapter } // New create a new Anthropic provider func New(conn connector.Connector, capabilities *goullm.Capabilities) *Provider { return &Provider{ Provider: base.NewProvider(conn, capabilities), adapters: buildAdapters(capabilities), } } // buildAdapters builds capability adapters based on model capabilities func buildAdapters(cap *goullm.Capabilities) []adapters.CapabilityAdapter { if cap == nil { return []adapters.CapabilityAdapter{} } result := make([]adapters.CapabilityAdapter, 0) // Tool call adapter result = append(result, adapters.NewToolCallAdapter(cap.ToolCalls)) // Vision adapter visionSupport, visionFormat := context.GetVisionSupport(cap) if visionSupport { result = append(result, adapters.NewVisionAdapter(true, visionFormat)) } else if cap.Vision != nil { result = append(result, adapters.NewVisionAdapter(false, context.VisionFormatNone)) } // Audio adapter result = append(result, adapters.NewAudioAdapter(cap.Audio)) // Reasoning adapter if cap.Reasoning { result = append(result, adapters.NewReasoningAdapter(adapters.ReasoningFormatOpenAI, cap)) } else { result = append(result, adapters.NewReasoningAdapter(adapters.ReasoningFormatNone, cap)) } return result } // Stream stream completion from Anthropic API func (p *Provider) Stream(ctx *context.Context, messages []context.Message, options *context.CompletionOptions, handler message.StreamFunc) (*context.CompletionResponse, error) { trace, _ := ctx.Trace() if trace != nil { trace.Debug("Anthropic Stream: Starting stream request", map[string]any{ "message_count": len(messages), }) } maxRetries := 3 var lastErr error goCtx := ctx.Context if ctx.Stack != nil && ctx.Stack.Options != nil && ctx.Stack.Options.Context != nil { goCtx = ctx.Stack.Options.Context } if goCtx == nil { goCtx = gocontext.Background() } currentMessages := make([]context.Message, len(messages)) copy(currentMessages, messages) for attempt := 0; attempt < maxRetries; attempt++ { select { case <-goCtx.Done(): return nil, fmt.Errorf("context cancelled: %w", goCtx.Err()) default: } if ctx.Interrupt != nil { if signal := ctx.Interrupt.Peek(); signal != nil && signal.Type == context.InterruptForce { return nil, fmt.Errorf("force interrupted by user") } } if attempt > 0 { backoff := time.Duration(1< // data: var currentEventType string streamHandler := func(data []byte) int { select { case <-goCtx.Done(): return http.HandlerReturnBreak default: } if ctx.Interrupt != nil { if signal := ctx.Interrupt.Peek(); signal != nil && signal.Type == context.InterruptForce { return http.HandlerReturnBreak } } if len(data) == 0 { return http.HandlerReturnOk } dataStr := string(data) trimmed := strings.TrimSpace(dataStr) if trimmed == "" { return http.HandlerReturnOk } // Parse event type line // Support both "event: type" (with space) and "event:type" (without space) formats if strings.HasPrefix(trimmed, "event:") { currentEventType = strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")) return http.HandlerReturnOk } // Parse data line // Support both "data: {...}" (with space) and "data:{...}" (without space) formats if !strings.HasPrefix(trimmed, "data:") { // Check for error response if strings.HasPrefix(trimmed, "{") && strings.Contains(trimmed, `"error"`) { var apiErr APIError if err := jsoniter.UnmarshalFromString(trimmed, &apiErr); err == nil && apiErr.Error.Message != "" { if handler != nil { handler(message.ChunkError, []byte(apiErr.Error.Message)) } } } return http.HandlerReturnOk } jsonStr := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) if jsonStr == "" { return http.HandlerReturnOk } // Process based on event type switch currentEventType { case "message_start": var event MessageStartEvent if err := jsoniter.UnmarshalFromString(jsonStr, &event); err == nil { accumulator.id = event.Message.ID accumulator.model = event.Message.Model accumulator.role = event.Message.Role if event.Message.Usage != nil { accumulator.usage = &message.UsageInfo{ PromptTokens: event.Message.Usage.InputTokens, TotalTokens: event.Message.Usage.InputTokens, } } } case "content_block_start": var event ContentBlockStartEvent if err := jsoniter.UnmarshalFromString(jsonStr, &event); err == nil { accumulator.currentBlockIndex = event.Index accumulator.currentBlockType = event.ContentBlock.Type switch event.ContentBlock.Type { case "thinking": startMessage(msgTracker, message.ChunkThinking, handler) case "text": startMessage(msgTracker, message.ChunkText, handler) case "tool_use": accumulator.toolCalls[event.Index] = &accumulatedToolCall{ id: event.ContentBlock.ID, name: event.ContentBlock.Name, } toolCallInfo := &message.EventToolCallInfo{ ID: event.ContentBlock.ID, Name: event.ContentBlock.Name, Index: event.Index, } startToolCallMessage(msgTracker, toolCallInfo, handler) // Send initial ChunkToolCall with id and function name // to match OpenAI format so CUI can resolve tool name from stored chunks if handler != nil { toolCallData, _ := jsoniter.Marshal([]map[string]interface{}{ { "index": event.Index, "id": event.ContentBlock.ID, "type": "function", "function": map[string]interface{}{ "name": event.ContentBlock.Name, }, }, }) handler(message.ChunkToolCall, toolCallData) incrementChunk(msgTracker) } } } case "content_block_delta": var event ContentBlockDeltaEvent if err := jsoniter.UnmarshalFromString(jsonStr, &event); err == nil { switch event.Delta.Type { case "thinking_delta": if event.Delta.Thinking != "" { accumulator.thinkingContent += event.Delta.Thinking if handler != nil { handler(message.ChunkThinking, []byte(event.Delta.Thinking)) incrementChunk(msgTracker) } } case "text_delta": if event.Delta.Text != "" { accumulator.content += event.Delta.Text if handler != nil { handler(message.ChunkText, []byte(event.Delta.Text)) incrementChunk(msgTracker) } } case "input_json_delta": if event.Delta.PartialJSON != "" { if tc, exists := accumulator.toolCalls[event.Index]; exists { tc.inputJSON += event.Delta.PartialJSON // Update tracker if msgTracker.active && msgTracker.toolCallInfo != nil { msgTracker.toolCallInfo.Arguments = tc.inputJSON } } if handler != nil { // Send tool call delta toolCallData, _ := jsoniter.Marshal([]map[string]interface{}{ { "index": event.Index, "function": map[string]interface{}{ "arguments": event.Delta.PartialJSON, }, }, }) handler(message.ChunkToolCall, toolCallData) incrementChunk(msgTracker) } } case "signature_delta": // Handle thinking signature delta (for extended thinking) // The signature is accumulated but not sent to handler var sigDelta struct { Type string `json:"type"` Signature string `json:"signature"` } if err := jsoniter.UnmarshalFromString(jsonStr, &struct { Delta *struct { Signature string `json:"signature"` } `json:"delta"` }{Delta: &struct { Signature string `json:"signature"` }{}}); err == nil { _ = sigDelta // signature tracking if needed } } } case "content_block_stop": endMessage(msgTracker, handler) case "message_delta": var event MessageDeltaEvent if err := jsoniter.UnmarshalFromString(jsonStr, &event); err == nil { accumulator.stopReason = event.Delta.StopReason if event.Usage != nil { if accumulator.usage == nil { accumulator.usage = &message.UsageInfo{} } accumulator.usage.CompletionTokens = event.Usage.OutputTokens accumulator.usage.TotalTokens = accumulator.usage.PromptTokens + event.Usage.OutputTokens } } case "message_stop": // Message complete endMessage(msgTracker, handler) case "ping": // Keep-alive, ignore case "error": var apiErr struct { Type string `json:"type"` Error struct { Type string `json:"type"` Message string `json:"message"` } `json:"error"` } if err := jsoniter.UnmarshalFromString(jsonStr, &apiErr); err == nil && apiErr.Error.Message != "" { if handler != nil { handler(message.ChunkError, []byte(apiErr.Error.Message)) } } } return http.HandlerReturnOk } // Log request if trace != nil { if requestBodyJSON, marshalErr := jsoniter.Marshal(requestBody); marshalErr == nil { trace.Debug("Anthropic Stream Request", map[string]any{ "url": url, "body": string(requestBodyJSON), }) } } // Error buffer for non-SSE error responses var errorBuffer strings.Builder errorDetected := false wrappedHandler := func(data []byte) int { dataStr := string(data) trimmed := strings.TrimSpace(dataStr) if trimmed == "" { return http.HandlerReturnOk } // SSE event/data lines - pass to stream handler // Support both "event: type" (with space) and "event:type" (without space) formats if strings.HasPrefix(trimmed, "event:") || strings.HasPrefix(trimmed, "data:") { return streamHandler(data) } // Detect JSON error response if strings.HasPrefix(trimmed, "{") && strings.Contains(dataStr, `"error"`) { errorDetected = true } if errorDetected { errorBuffer.Write(data) errorBuffer.WriteString("\n") return http.HandlerReturnOk } return streamHandler(data) } // Make streaming request log.Trace("[LLM] Starting Anthropic Stream request: url=%s", url) err = req.Stream(goCtx, "POST", requestBody, wrappedHandler) _ = streamStartTime // Check for captured error response if errorDetected && errorBuffer.Len() > 0 { errorJSON := errorBuffer.String() if trace != nil { trace.Error(i18n.T(ctx.Locale, "llm.anthropic.stream.api_error"), map[string]any{"response": errorJSON}) } var apiErr APIError if parseErr := jsoniter.UnmarshalFromString(errorJSON, &apiErr); parseErr == nil && apiErr.Error.Message != "" { err = fmt.Errorf("Anthropic API error: %s (type: %s)", apiErr.Error.Message, apiErr.Error.Type) } else { err = fmt.Errorf("Anthropic API error: %s", strings.TrimSpace(errorJSON)) } } // Handle context cancellation if err != nil && goCtx.Err() != nil { return nil, fmt.Errorf("stream cancelled: %w", goCtx.Err()) } if err != nil { endMessage(msgTracker, handler) if handler != nil { handler(message.ChunkError, []byte(err.Error())) } return nil, fmt.Errorf("streaming request failed: %w", err) } // Check for empty response if accumulator.id == "" { endMessage(msgTracker, handler) errMsg := fmt.Errorf("no data received from Anthropic API") if handler != nil { handler(message.ChunkError, []byte(errMsg.Error())) } return nil, errMsg } // Build final response (convert to unified CompletionResponse) response := &context.CompletionResponse{ ID: accumulator.id, Object: "message", Model: accumulator.model, Role: accumulator.role, Content: accumulator.content, ReasoningContent: accumulator.thinkingContent, FinishReason: mapStopReason(accumulator.stopReason), Usage: accumulator.usage, } // Convert accumulated tool calls // Note: tool call indices may not start at 0 (e.g. if text blocks precede tool_use blocks) if len(accumulator.toolCalls) > 0 { // Collect all indices and sort them to ensure deterministic order indices := make([]int, 0, len(accumulator.toolCalls)) for idx := range accumulator.toolCalls { indices = append(indices, idx) } sort.Ints(indices) toolCalls := make([]context.ToolCall, 0, len(accumulator.toolCalls)) for _, idx := range indices { tc := accumulator.toolCalls[idx] toolCalls = append(toolCalls, context.ToolCall{ ID: tc.id, Type: "function", Function: context.Function{ Name: tc.name, Arguments: tc.inputJSON, }, }) } response.ToolCalls = toolCalls } endMessage(msgTracker, handler) return response, nil } // Post non-streaming completion request to Anthropic API func (p *Provider) Post(ctx *context.Context, messages []context.Message, options *context.CompletionOptions) (*context.CompletionResponse, error) { trace, _ := ctx.Trace() maxRetries := 3 var lastErr error goCtx := ctx.Context if ctx.Stack != nil && ctx.Stack.Options != nil && ctx.Stack.Options.Context != nil { goCtx = ctx.Stack.Options.Context } if goCtx == nil { goCtx = gocontext.Background() } currentMessages := make([]context.Message, len(messages)) copy(currentMessages, messages) for attempt := 0; attempt < maxRetries; attempt++ { select { case <-goCtx.Done(): return nil, fmt.Errorf("context cancelled: %w", goCtx.Err()) default: } if attempt > 0 { backoff := time.Duration(1< 0 { contentBlocks := make([]map[string]interface{}, 0) // Add text content if present if contentStr, ok := msg.Content.(string); ok && contentStr != "" { contentBlocks = append(contentBlocks, map[string]interface{}{ "type": "text", "text": contentStr, }) } // Add tool_use blocks for _, tc := range msg.ToolCalls { var input interface{} if tc.Function.Arguments != "" { jsoniter.UnmarshalFromString(tc.Function.Arguments, &input) } if input == nil { input = map[string]interface{}{} } contentBlocks = append(contentBlocks, map[string]interface{}{ "type": "tool_use", "id": tc.ID, "name": tc.Function.Name, "input": input, }) } apiMsg["content"] = contentBlocks } apiMessages = append(apiMessages, apiMsg) } // Build request body body := map[string]interface{}{ "model": model, } if len(apiMessages) > 0 { body["messages"] = apiMessages } if systemContent != "" { body["system"] = systemContent } if streaming { body["stream"] = true } // max_tokens is required for Anthropic maxTokens := 4096 // default if options.MaxTokens != nil { maxTokens = *options.MaxTokens } else if options.MaxCompletionTokens != nil { maxTokens = *options.MaxCompletionTokens } else if mt, ok := setting["max_tokens"].(int); ok && mt > 0 { maxTokens = mt } body["max_tokens"] = maxTokens // Temperature if options.Temperature != nil { body["temperature"] = *options.Temperature } if options.TopP != nil { body["top_p"] = *options.TopP } if options.Stop != nil { body["stop_sequences"] = options.Stop } // Tools (convert from OpenAI format to Anthropic format) if len(options.Tools) > 0 { anthropicTools := convertTools(options.Tools) if len(anthropicTools) > 0 { body["tools"] = anthropicTools } } if options.ToolChoice != nil { body["tool_choice"] = convertToolChoice(options.ToolChoice) } // Thinking configuration from connector settings if thinking, exists := setting["thinking"]; exists && thinking != nil { body["thinking"] = thinking } return body, nil } // convertTools converts OpenAI-format tools to Anthropic format func convertTools(tools []map[string]interface{}) []map[string]interface{} { result := make([]map[string]interface{}, 0, len(tools)) for _, tool := range tools { function, ok := tool["function"].(map[string]interface{}) if !ok { continue } anthropicTool := map[string]interface{}{ "name": function["name"], } if desc, ok := function["description"]; ok { anthropicTool["description"] = desc } if params, ok := function["parameters"]; ok { anthropicTool["input_schema"] = params } result = append(result, anthropicTool) } return result } // convertToolChoice converts OpenAI tool_choice to Anthropic format func convertToolChoice(choice interface{}) interface{} { switch v := choice.(type) { case string: switch v { case "auto": return map[string]interface{}{"type": "auto"} case "none": return map[string]interface{}{"type": "none"} case "required": return map[string]interface{}{"type": "any"} } case map[string]interface{}: if fn, ok := v["function"].(map[string]interface{}); ok { if name, ok := fn["name"].(string); ok { return map[string]interface{}{ "type": "tool", "name": name, } } } } return map[string]interface{}{"type": "auto"} } // convertImagePart converts an OpenAI image_url content part to Anthropic image format func convertImagePart(part context.ContentPart) map[string]interface{} { if part.ImageURL == nil { return map[string]interface{}{"type": "text", "text": "[image not available]"} } url := part.ImageURL.URL // Check if it's a base64 data URL if strings.HasPrefix(url, "data:") { // Parse data URL: data:image/jpeg;base64, parts := strings.SplitN(url, ",", 2) if len(parts) == 2 { mediaInfo := strings.TrimPrefix(parts[0], "data:") mediaInfo = strings.TrimSuffix(mediaInfo, ";base64") return map[string]interface{}{ "type": "image", "source": map[string]interface{}{ "type": "base64", "media_type": mediaInfo, "data": parts[1], }, } } } // URL-based image (Anthropic supports URL images) return map[string]interface{}{ "type": "image", "source": map[string]interface{}{ "type": "url", "url": url, }, } } // buildAPIURL builds the API URL for Anthropic func buildAPIURL(host, endpoint string) string { return connector.BuildAPIURL(host, endpoint) } // mapStopReason maps Anthropic stop_reason to OpenAI finish_reason func mapStopReason(stopReason string) string { switch stopReason { case "end_turn": return "stop" case "max_tokens": return "length" case "tool_use": return "tool_calls" case "stop_sequence": return "stop" default: return stopReason } } // Message tracker helper functions func startMessage(mt *messageTracker, messageType message.StreamChunkType, handler message.StreamFunc) { if mt.active { endMessage(mt, handler) } mt.active = true if mt.idGenerator != nil { mt.messageID = mt.idGenerator.GenerateMessageID() } else { mt.messageID = message.GenerateNanoID() } mt.messageType = messageType mt.startTime = time.Now().UnixMilli() mt.chunkCount = 0 mt.toolCallInfo = nil if handler != nil { startData := &message.EventMessageStartData{ MessageID: mt.messageID, Type: string(messageType), Timestamp: mt.startTime, } if startJSON, err := jsoniter.Marshal(startData); err == nil { handler(message.ChunkMessageStart, startJSON) } } } func startToolCallMessage(mt *messageTracker, toolCallInfo *message.EventToolCallInfo, handler message.StreamFunc) { if mt.active { endMessage(mt, handler) } mt.active = true if mt.idGenerator != nil { mt.messageID = mt.idGenerator.GenerateMessageID() } else { mt.messageID = message.GenerateNanoID() } mt.messageType = message.ChunkToolCall mt.startTime = time.Now().UnixMilli() mt.chunkCount = 0 mt.toolCallInfo = toolCallInfo if handler != nil { startData := &message.EventMessageStartData{ MessageID: mt.messageID, Type: string(message.ChunkToolCall), Timestamp: mt.startTime, ToolCall: toolCallInfo, } if startJSON, err := jsoniter.Marshal(startData); err == nil { handler(message.ChunkMessageStart, startJSON) } } } func incrementChunk(mt *messageTracker) { if mt.active { mt.chunkCount++ } } func endMessage(mt *messageTracker, handler message.StreamFunc) { if !mt.active { return } if handler != nil { endData := &message.EventMessageEndData{ MessageID: mt.messageID, Type: string(mt.messageType), Timestamp: time.Now().UnixMilli(), DurationMs: time.Now().UnixMilli() - mt.startTime, ChunkCount: mt.chunkCount, Status: "completed", } if mt.toolCallInfo != nil { endData.ToolCall = mt.toolCallInfo } if endJSON, err := jsoniter.Marshal(endData); err == nil { handler(message.ChunkMessageEnd, endJSON) } } mt.active = false mt.messageID = "" mt.toolCallInfo = nil } // isRetryableError checks if an error is retryable func isRetryableError(err error) bool { if err == nil { return false } errStr := err.Error() retryablePatterns := []string{ "timeout", "connection refused", "connection reset", "EOF", "HTTP 429", "HTTP 500", "HTTP 502", "HTTP 503", "HTTP 504", "overloaded", } for _, pattern := range retryablePatterns { if strings.Contains(strings.ToLower(errStr), strings.ToLower(pattern)) { return true } } return false } ================================================ FILE: agent/llm/providers/anthropic/anthropic_test.go ================================================ package anthropic_test import ( gocontext "context" "encoding/json" "strings" "testing" "github.com/yaoapp/gou/connector" goullm "github.com/yaoapp/gou/llm" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/llm" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/openapi/oauth/types" "github.com/yaoapp/yao/test" ) // testConnectorID uses the cheapest model (Claude Haiku 3) to save tokens const testConnectorID = "claude.haiku-3_0" // TestAnthropicStreamBasic tests basic streaming completion with Anthropic API func TestAnthropicStreamBasic(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select(testConnectorID) if err != nil { t.Fatalf("Failed to select connector: %v", err) } // Verify it's an Anthropic connector if !conn.Is(connector.ANTHROPIC) { t.Fatal("Connector is not ANTHROPIC type") } options := &context.CompletionOptions{ Capabilities: &goullm.Capabilities{ Streaming: true, ToolCalls: true, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Say 'Hi' in one word.", }, } maxTokens := 10 options.MaxTokens = &maxTokens ctx := newTestContext("test-anthropic-stream", testConnectorID) var chunks []string handler := func(chunkType message.StreamChunkType, data []byte) int { chunks = append(chunks, string(data)) t.Logf("Stream chunk [%s]: %s", chunkType, string(data)) return 0 } response, err := llmInstance.Stream(ctx, messages, options, handler) if err != nil { t.Fatalf("Stream failed: %v", err) } if response == nil { t.Fatal("Response is nil") } if response.ID == "" { t.Error("Response ID is empty") } if response.Model == "" { t.Error("Response Model is empty") } if response.Content == "" { t.Error("Response content is empty") } if response.FinishReason == "" { t.Error("FinishReason is empty") } if response.Usage == nil { t.Error("Response Usage is nil") } else { t.Logf("Usage: prompt=%d, completion=%d, total=%d", response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) } if len(chunks) == 0 { t.Error("No streaming chunks received") } t.Logf("Final response content: %s", response.Content) t.Logf("Total chunks received: %d", len(chunks)) } // TestAnthropicStreamWithToolCalls tests streaming with tool calls via Anthropic API func TestAnthropicStreamWithToolCalls(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select(testConnectorID) if err != nil { t.Fatalf("Failed to select connector: %v", err) } options := &context.CompletionOptions{ Capabilities: &goullm.Capabilities{ Streaming: true, ToolCalls: true, }, } weatherTool := map[string]interface{}{ "type": "function", "function": map[string]interface{}{ "name": "get_weather", "description": "Get the current weather for a location", "parameters": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "location": map[string]interface{}{ "type": "string", "description": "The city name, e.g. Tokyo", }, }, "required": []string{"location"}, }, }, } options.Tools = []map[string]interface{}{weatherTool} options.ToolChoice = "auto" llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "What's the weather in Tokyo?", }, } ctx := newTestContext("test-anthropic-tool", testConnectorID) var toolCallChunks int handler := func(chunkType message.StreamChunkType, data []byte) int { if chunkType == message.ChunkToolCall { toolCallChunks++ } t.Logf("Stream chunk [%s]: %s", chunkType, string(data)) return 0 } response, err := llmInstance.Stream(ctx, messages, options, handler) if err != nil { t.Fatalf("Stream with tool calls failed: %v", err) } if response == nil { t.Fatal("Response is nil") } if len(response.ToolCalls) == 0 { t.Error("Expected tool calls but got none") } else { t.Logf("Received %d tool call(s)", len(response.ToolCalls)) for i, tc := range response.ToolCalls { t.Logf("Tool call %d: %s(%s)", i, tc.Function.Name, tc.Function.Arguments) if tc.ID == "" { t.Errorf("Tool call %d missing ID", i) } if tc.Function.Name == "" { t.Errorf("Tool call %d missing function name", i) } if tc.Function.Name != "get_weather" { t.Errorf("Tool call %d expected 'get_weather', got '%s'", i, tc.Function.Name) } if tc.Function.Arguments == "" { t.Errorf("Tool call %d missing arguments", i) } // Verify arguments contain location var args map[string]interface{} if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err == nil { if _, hasLocation := args["location"]; !hasLocation { t.Errorf("Tool call %d arguments missing 'location'", i) } } } } if response.FinishReason != context.FinishReasonToolCalls { t.Logf("Warning: Expected finish_reason='tool_calls', got '%s'", response.FinishReason) } t.Logf("Final response: %+v", response) } // TestAnthropicStreamRetry tests error handling with invalid API key func TestAnthropicStreamRetry(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() connDSL := `{ "type": "anthropic", "options": { "model": "claude-3-haiku-20240307", "key": "sk-ant-invalid-key-should-fail" } }` conn, err := connector.New("anthropic", "test-anthropic-retry", []byte(connDSL)) if err != nil { t.Fatalf("Failed to create test connector: %v", err) } options := &context.CompletionOptions{ Capabilities: &goullm.Capabilities{ Streaming: true, ToolCalls: true, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Test", }, } ctx := newTestContext("test-anthropic-retry", "test-anthropic-retry") _, err = llmInstance.Stream(ctx, messages, options, nil) if err == nil { t.Fatal("Expected error due to invalid API key, but got success") } errMsg := strings.ToLower(err.Error()) hasExpectedError := strings.Contains(errMsg, "401") || strings.Contains(errMsg, "authentication") || strings.Contains(errMsg, "invalid") || strings.Contains(errMsg, "no data received") if !hasExpectedError { t.Errorf("Expected authentication error, got: %v", err) } t.Logf("Failed as expected with error: %v", err) } // ============================================================================ // Helper Functions // ============================================================================ func newTestContext(chatID, connectorID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client", UserID: "test-user-123", TeamID: "test-team-456", TenantID: "test-tenant-789", SessionID: "test-session-id", Constraints: types.DataConstraints{ TeamOnly: true, Extra: map[string]interface{}{ "test": "anthropic-provider", }, }, } ctx := context.New(gocontext.Background(), authorized, chatID) ctx.AssistantID = "test-assistant" ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "AnthropicProviderTest/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptStandard ctx.Route = "/api/test" ctx.Metadata = make(map[string]interface{}) return ctx } ================================================ FILE: agent/llm/providers/anthropic/types.go ================================================ package anthropic import ( "github.com/yaoapp/yao/agent/output/message" ) // ============================================================ // Anthropic Messages API types // Reference: https://docs.anthropic.com/en/api/messages // ============================================================ // StreamEvent represents an SSE event from Anthropic streaming API type StreamEvent struct { Type string `json:"type"` } // MessageStartEvent represents the message_start SSE event type MessageStartEvent struct { Type string `json:"type"` Message MessageStart `json:"message"` } // MessageStart represents the message object in message_start event type MessageStart struct { ID string `json:"id"` Type string `json:"type"` Role string `json:"role"` Content []ContentBlock `json:"content"` Model string `json:"model"` StopReason *string `json:"stop_reason"` StopSequence *string `json:"stop_sequence"` Usage *UsageInfo `json:"usage,omitempty"` } // ContentBlockStartEvent represents the content_block_start SSE event type ContentBlockStartEvent struct { Type string `json:"type"` Index int `json:"index"` ContentBlock ContentBlock `json:"content_block"` } // ContentBlockDeltaEvent represents the content_block_delta SSE event type ContentBlockDeltaEvent struct { Type string `json:"type"` Index int `json:"index"` Delta DeltaBlock `json:"delta"` } // ContentBlockStopEvent represents the content_block_stop SSE event type ContentBlockStopEvent struct { Type string `json:"type"` Index int `json:"index"` } // MessageDeltaEvent represents the message_delta SSE event type MessageDeltaEvent struct { Type string `json:"type"` Delta MessageDelta `json:"delta"` Usage *DeltaUsage `json:"usage,omitempty"` } // MessageDelta represents the delta in message_delta event type MessageDelta struct { StopReason string `json:"stop_reason,omitempty"` StopSequence *string `json:"stop_sequence,omitempty"` } // DeltaUsage represents usage in message_delta event type DeltaUsage struct { OutputTokens int `json:"output_tokens"` } // ContentBlock represents a content block in the response type ContentBlock struct { Type string `json:"type"` // "text", "thinking", "tool_use" Text string `json:"text,omitempty"` // for type "text" Thinking string `json:"thinking,omitempty"` // for type "thinking" Signature string `json:"signature,omitempty"` // for type "thinking" ID string `json:"id,omitempty"` // for type "tool_use" Name string `json:"name,omitempty"` // for type "tool_use" Input interface{} `json:"input,omitempty"` // for type "tool_use" } // DeltaBlock represents a delta block in streaming type DeltaBlock struct { Type string `json:"type"` // "text_delta", "thinking_delta", "input_json_delta" Text string `json:"text,omitempty"` // for type "text_delta" Thinking string `json:"thinking,omitempty"` // for type "thinking_delta" PartialJSON string `json:"partial_json,omitempty"` // for type "input_json_delta" } // UsageInfo represents token usage information type UsageInfo struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` } // NonStreamResponse represents the full non-streaming response from Anthropic API type NonStreamResponse struct { ID string `json:"id"` Type string `json:"type"` Role string `json:"role"` Content []ContentBlock `json:"content"` Model string `json:"model"` StopReason string `json:"stop_reason"` StopSequence *string `json:"stop_sequence"` Usage *UsageInfo `json:"usage,omitempty"` } // APIError represents an error response from Anthropic API type APIError struct { Type string `json:"type"` Error struct { Type string `json:"type"` Message string `json:"message"` } `json:"error"` } // streamAccumulator accumulates streaming response data type streamAccumulator struct { id string model string role string content string thinkingContent string thinkingSignature string toolCalls map[int]*accumulatedToolCall stopReason string usage *message.UsageInfo // Current content block tracking currentBlockIndex int currentBlockType string } // accumulatedToolCall accumulates a single tool call from streaming type accumulatedToolCall struct { id string name string inputJSON string } // messageTracker tracks message lifecycle for stream events type messageTracker struct { active bool messageID string messageType message.StreamChunkType startTime int64 chunkCount int toolCallInfo *message.EventToolCallInfo idGenerator *message.IDGenerator } ================================================ FILE: agent/llm/providers/base/base.go ================================================ package base import ( "fmt" "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/llm" "github.com/yaoapp/yao/agent/context" ) // Provider base provider implementation // Provides common functionality for all LLM providers type Provider struct { Connector connector.Connector Capabilities *llm.Capabilities } // NewProvider create a new base provider func NewProvider(conn connector.Connector, capabilities *llm.Capabilities) *Provider { return &Provider{ Connector: conn, Capabilities: capabilities, } } // PreprocessMessages preprocess messages before sending to LLM // Handles vision messages, audio messages, tool messages, etc. // Filters out unsupported content types based on model capabilities func (p *Provider) PreprocessMessages(messages []context.Message) ([]context.Message, error) { processed := make([]context.Message, 0, len(messages)) for _, msg := range messages { processedMsg := msg // Handle multimodal content (array of ContentPart) if contentParts, ok := msg.Content.([]context.ContentPart); ok { filteredParts := make([]context.ContentPart, 0, len(contentParts)) for _, part := range contentParts { // Filter vision content if not supported if part.Type == context.ContentImageURL { if !p.SupportsVision() { // Skip image content if vision not supported continue } } // Filter audio content if not supported if part.Type == context.ContentInputAudio { if !p.SupportsAudio() { // Skip audio content if audio not supported continue } } filteredParts = append(filteredParts, part) } // If all parts were filtered out, convert to text message if len(filteredParts) == 0 { processedMsg.Content = "[Content not supported by this model]" } else { processedMsg.Content = filteredParts } } processed = append(processed, processedMsg) } return processed, nil } // SupportsVision check if this provider supports vision func (p *Provider) SupportsVision() bool { if p.Capabilities == nil { return false } supported, _ := context.GetVisionSupport(p.Capabilities) return supported } // SupportsAudio check if this provider supports audio func (p *Provider) SupportsAudio() bool { return p.Capabilities != nil && p.Capabilities.Audio } // SupportsTools check if this provider supports tool calls func (p *Provider) SupportsTools() bool { return p.Capabilities != nil && p.Capabilities.ToolCalls } // SupportsStreaming check if this provider supports streaming func (p *Provider) SupportsStreaming() bool { return p.Capabilities != nil && p.Capabilities.Streaming } // SupportsJSON check if this provider supports JSON mode func (p *Provider) SupportsJSON() bool { return p.Capabilities != nil && p.Capabilities.JSON } // SupportsReasoning check if this provider supports reasoning mode func (p *Provider) SupportsReasoning() bool { return p.Capabilities != nil && p.Capabilities.Reasoning } // GetConnectorSetting gets a setting value from the connector func (p *Provider) GetConnectorSetting(key string) (interface{}, error) { if p.Connector == nil { return nil, fmt.Errorf("connector is nil") } settings := p.Connector.Setting() if settings == nil { return nil, fmt.Errorf("connector settings are nil") } value, exists := settings[key] if !exists { return nil, fmt.Errorf("setting '%s' not found", key) } return value, nil } // GetConnectorStringSetting gets a string setting value from the connector func (p *Provider) GetConnectorStringSetting(key string) (string, error) { value, err := p.GetConnectorSetting(key) if err != nil { return "", err } strValue, ok := value.(string) if !ok { return "", fmt.Errorf("setting '%s' is not a string", key) } return strValue, nil } // GetModel gets the model name from connector settings func (p *Provider) GetModel() (string, error) { return p.GetConnectorStringSetting("model") } // GetAPIKey gets the API key from connector settings func (p *Provider) GetAPIKey() (string, error) { return p.GetConnectorStringSetting("key") } // GetHost gets the host URL from connector settings func (p *Provider) GetHost() (string, error) { return p.GetConnectorStringSetting("host") } ================================================ FILE: agent/llm/providers/factory.go ================================================ package providers import ( "fmt" "github.com/yaoapp/gou/connector" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/llm/providers/anthropic" "github.com/yaoapp/yao/agent/llm/providers/openai" "github.com/yaoapp/yao/agent/output/message" ) // LLM interface (copied to avoid import cycle) type LLM interface { Stream(ctx *context.Context, messages []context.Message, options *context.CompletionOptions, handler message.StreamFunc) (*context.CompletionResponse, error) Post(ctx *context.Context, messages []context.Message, options *context.CompletionOptions) (*context.CompletionResponse, error) } // SelectProvider selects the appropriate provider based on API format and capabilities // The new architecture uses capability adapters to handle different model features func SelectProvider(conn connector.Connector, options *context.CompletionOptions) (LLM, error) { if options == nil { return nil, fmt.Errorf("options are required") } if options.Capabilities == nil { return nil, fmt.Errorf("capabilities are required") } // Detect API format apiFormat := DetectAPIFormat(conn) // Select provider based on API format switch apiFormat { case "openai": // OpenAI-compatible API // Capability adapters will handle: // - Tool calling (native or prompt engineering) // - Vision (native or removal) // - Audio (native or removal) // - Reasoning (o1, GPT-4o thinking, etc.) return openai.New(conn, options.Capabilities), nil case "anthropic": return anthropic.New(conn, options.Capabilities), nil default: // Default to OpenAI-compatible provider return openai.New(conn, options.Capabilities), nil } } // DetectAPIFormat detects the API format from connector func DetectAPIFormat(conn connector.Connector) string { // Check connector type directly if conn.Is(connector.ANTHROPIC) { return "anthropic" } if conn.Is(connector.OPENAI) { return "openai" } // Check connector settings for host URL patterns as fallback settings := conn.Setting() if settings != nil { if host, ok := settings["host"].(string); ok { if contains(host, "anthropic.com") || contains(host, "api.kimi.com/coding") { return "anthropic" } if contains(host, "deepseek.com") { return "openai" } } } // Default to OpenAI-compatible return "openai" } // contains checks if a string contains a substring (case-insensitive helper) func contains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || findSubstring(s, substr))) } func findSubstring(s, substr string) bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false } ================================================ FILE: agent/llm/providers/openai/claude_test.go ================================================ package openai_test // import ( // gocontext "context" // "testing" // "github.com/yaoapp/gou/connector" // "github.com/yaoapp/gou/connector/openai" // "github.com/yaoapp/yao/agent/context" // "github.com/yaoapp/yao/agent/llm" // "github.com/yaoapp/yao/agent/output/message" // "github.com/yaoapp/yao/config" // "github.com/yaoapp/yao/openapi/oauth/types" // "github.com/yaoapp/yao/test" // ) // // newClaudeTestContext creates a real Context for testing Claude provider // func newClaudeTestContext(chatID, connectorID string) *context.Context { // authorized := &types.AuthorizedInfo{ // Subject: "test-user", // ClientID: "test-client", // UserID: "test-user-123", // TeamID: "test-team-456", // TenantID: "test-tenant-789", // SessionID: "test-session-id", // Constraints: types.DataConstraints{ // TeamOnly: true, // Extra: map[string]interface{}{ // "test": "claude-provider", // }, // }, // } // ctx := context.New(gocontext.Background(), authorized, chatID) // ctx.AssistantID = "test-assistant" // ctx.Locale = "en-us" // ctx.Theme = "light" // ctx.Client = context.Client{ // Type: "web", // UserAgent: "ClaudeProviderTest/1.0", // IP: "127.0.0.1", // } // ctx.Referer = context.RefererAPI // ctx.Accept = context.AcceptStandard // ctx.Route = "/api/test" // ctx.Metadata = make(map[string]interface{}) // return ctx // } // // TestClaudeSonnet4StreamBasic tests basic streaming completion with Claude Sonnet 4 // func TestClaudeSonnet4StreamBasic(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // conn, err := connector.Select("claude.sonnet-4_0") // if err != nil { // t.Fatalf("Failed to select connector: %v", err) // } // options := &context.CompletionOptions{ // Capabilities: &openai.Capabilities{ // Streaming: true, // Reasoning: false, // Claude Sonnet 4 (non-thinking) doesn't expose reasoning // ToolCalls: true, // Vision: "claude", // Claude requires base64 format // Multimodal: true, // }, // } // llmInstance, err := llm.New(conn, options) // if err != nil { // t.Fatalf("Failed to create LLM instance: %v", err) // } // messages := []context.Message{ // { // Role: context.RoleUser, // Content: "What is 3+3? Reply with just the number.", // }, // } // maxTokens := 100 // options.MaxTokens = &maxTokens // ctx := newClaudeTestContext("test-claude-sonnet4-basic", "claude.sonnet-4_0") // var chunks []string // handler := func(chunkType message.StreamChunkType, data []byte) int { // chunks = append(chunks, string(data)) // t.Logf("Stream chunk [%s]: %s", chunkType, string(data)) // return 0 // } // response, err := llmInstance.Stream(ctx, messages, options, handler) // if err != nil { // t.Fatalf("Stream failed: %v", err) // } // if response == nil { // t.Fatal("Response is nil") // } // // Basic validation // if response.ID == "" { // t.Error("Response ID is empty") // } // if response.Model == "" { // t.Error("Response Model is empty") // } // // Validate content // contentStr, ok := response.Content.(string) // if !ok { // t.Errorf("Content is not a string: %T", response.Content) // } // if len(contentStr) == 0 { // t.Error("Content is empty") // } // t.Logf("Response content: %v", response.Content) // t.Logf("Usage: prompt=%d, completion=%d, total=%d", // response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) // t.Logf("Final response: %+v", response) // t.Logf("Total chunks received: %d", len(chunks)) // } // // TestClaudeSonnet4PostBasic tests non-streaming completion // func TestClaudeSonnet4PostBasic(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // conn, err := connector.Select("claude.sonnet-4_0") // if err != nil { // t.Fatalf("Failed to select connector: %v", err) // } // options := &context.CompletionOptions{ // Capabilities: &openai.Capabilities{ // Streaming: false, // Reasoning: false, // ToolCalls: true, // Vision: "claude", // Claude requires base64 format // Multimodal: true, // }, // } // llmInstance, err := llm.New(conn, options) // if err != nil { // t.Fatalf("Failed to create LLM instance: %v", err) // } // messages := []context.Message{ // { // Role: context.RoleUser, // Content: "What is 4+4? Reply with just the number.", // }, // } // maxTokens := 100 // options.MaxTokens = &maxTokens // ctx := newClaudeTestContext("test-claude-sonnet4-post", "claude.sonnet-4_0") // response, err := llmInstance.Post(ctx, messages, options) // if err != nil { // t.Fatalf("Post failed: %v", err) // } // if response == nil { // t.Fatal("Response is nil") // } // // Validate content // contentStr, ok := response.Content.(string) // if !ok { // t.Fatalf("Content is not a string: %T", response.Content) // } // t.Logf("Response content: %s", contentStr) // t.Logf("Usage: prompt=%d, completion=%d, total=%d", // response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) // // Basic content validation // if len(contentStr) == 0 { // t.Error("Content is empty") // } // t.Logf("Response: %+v", response) // } // // TestClaudeSonnet4WithToolCalls tests tool calling capability // func TestClaudeSonnet4WithToolCalls(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // conn, err := connector.Select("claude.sonnet-4_0") // if err != nil { // t.Fatalf("Failed to select connector: %v", err) // } // options := &context.CompletionOptions{ // Capabilities: &openai.Capabilities{ // Streaming: false, // Reasoning: false, // ToolCalls: true, // Vision: "claude", // Claude requires base64 format // Multimodal: true, // }, // } // // Define a simple tool with minimal parameters // simpleTool := map[string]interface{}{ // "type": "function", // "function": map[string]interface{}{ // "name": "get_info", // "description": "Get information", // "parameters": map[string]interface{}{ // "type": "object", // "properties": map[string]interface{}{ // "query": map[string]interface{}{ // "type": "string", // "description": "Query string (single letter)", // }, // "count": map[string]interface{}{ // "type": "number", // "description": "Count (single digit)", // }, // }, // "required": []string{"query", "count"}, // }, // }, // } // options.Tools = []map[string]interface{}{simpleTool} // options.ToolChoice = "auto" // // Set enough tokens for tool call response // maxTokens := 150 // options.MaxTokens = &maxTokens // llmInstance, err := llm.New(conn, options) // if err != nil { // t.Fatalf("Failed to create LLM instance: %v", err) // } // messages := []context.Message{ // { // Role: context.RoleUser, // Content: "Please use the get_info function to retrieve information. Pass 'A' as the query parameter and 1 as the count parameter.", // }, // } // ctx := newClaudeTestContext("test-claude-sonnet4-tools", "claude.sonnet-4_0") // response, err := llmInstance.Post(ctx, messages, options) // if err != nil { // t.Fatalf("Post failed: %v", err) // } // if response == nil { // t.Fatal("Response is nil") // } // // Validate tool calls // if len(response.ToolCalls) == 0 { // t.Error("No tool calls in response") // } else { // tc := response.ToolCalls[0] // t.Logf("✓ Tool call: %s(%s)", tc.Function.Name, tc.Function.Arguments) // if tc.Function.Name != "get_info" { // t.Errorf("Expected tool name 'get_info', got '%s'", tc.Function.Name) // } // } // t.Logf("Usage: prompt=%d, completion=%d, total=%d", // response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) // t.Logf("Response: %+v", response) // } // // TestClaudeSonnet4Vision tests vision capability with image input // func TestClaudeSonnet4Vision(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // conn, err := connector.Select("claude.sonnet-4_0") // if err != nil { // t.Fatalf("Failed to select connector: %v", err) // } // options := &context.CompletionOptions{ // Capabilities: &openai.Capabilities{ // Streaming: false, // Reasoning: false, // ToolCalls: true, // Vision: "claude", // Claude requires base64 format // Multimodal: true, // }, // } // llmInstance, err := llm.New(conn, options) // if err != nil { // t.Fatalf("Failed to create LLM instance: %v", err) // } // // Use a test image URL // imageURL := "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" // messages := []context.Message{ // { // Role: context.RoleUser, // Content: []map[string]interface{}{ // { // "type": "text", // "text": "Describe this image in one sentence.", // }, // { // "type": "image_url", // "image_url": map[string]string{ // "url": imageURL, // }, // }, // }, // }, // } // maxTokens := 150 // options.MaxTokens = &maxTokens // ctx := newClaudeTestContext("test-claude-sonnet4-vision", "claude.sonnet-4_0") // response, err := llmInstance.Post(ctx, messages, options) // if err != nil { // t.Fatalf("Post failed: %v", err) // } // if response == nil { // t.Fatal("Response is nil") // } // // Validate content // contentStr, ok := response.Content.(string) // if !ok { // t.Fatalf("Content is not a string: %T", response.Content) // } // if len(contentStr) == 0 { // t.Error("Image description is empty") // } // t.Logf("Image description: %s", contentStr) // t.Logf("Usage: prompt=%d, completion=%d, total=%d", // response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) // } // // TestClaudeSonnet4ThinkingStream tests Claude Sonnet 4 Thinking with streaming // func TestClaudeSonnet4ThinkingStream(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // conn, err := connector.Select("claude.sonnet-4_0-thinking") // if err != nil { // t.Fatalf("Failed to select connector: %v", err) // } // options := &context.CompletionOptions{ // Capabilities: &openai.Capabilities{ // Streaming: true, // Reasoning: true, // Claude Thinking mode exposes reasoning // ToolCalls: false, // Vision: "claude", // Claude requires base64 format // Multimodal: true, // }, // } // llmInstance, err := llm.New(conn, options) // if err != nil { // t.Fatalf("Failed to create LLM instance: %v", err) // } // messages := []context.Message{ // { // Role: context.RoleUser, // Content: "If Sally has 3 apples and gives 2 to John, how many does she have left? Think through this step by step.", // }, // } // maxTokens := 500 // options.MaxTokens = &maxTokens // ctx := newClaudeTestContext("test-claude-thinking-stream", "claude.sonnet-4_0-thinking") // var thinkingChunks []string // var textChunks []string // handler := func(chunkType message.StreamChunkType, data []byte) int { // t.Logf("Stream chunk [%s]: %s", chunkType, string(data)) // if chunkType == message.ChunkThinking { // thinkingChunks = append(thinkingChunks, string(data)) // } else if chunkType == message.ChunkText { // textChunks = append(textChunks, string(data)) // } // return 0 // } // response, err := llmInstance.Stream(ctx, messages, options, handler) // if err != nil { // t.Fatalf("Stream failed: %v", err) // } // if response == nil { // t.Fatal("Response is nil") // } // // Validate response // contentStr, ok := response.Content.(string) // if !ok { // t.Errorf("Content is not a string: %T", response.Content) // } // t.Logf("Reasoning/Thinking content length: %d characters", len(response.ReasoningContent)) // t.Logf("Response content: %v", contentStr) // t.Logf("Usage: prompt=%d, completion=%d, total=%d", // response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) // if response.Usage != nil && response.Usage.CompletionTokensDetails != nil { // t.Logf("Reasoning tokens: %d", response.Usage.CompletionTokensDetails.ReasoningTokens) // } // t.Logf("Received %d thinking chunks", len(thinkingChunks)) // t.Logf("Received %d text chunks", len(textChunks)) // t.Logf("Final response: %+v", response) // } // // TestClaudeSonnet4ThinkingPost tests Claude Sonnet 4 Thinking in non-streaming mode // func TestClaudeSonnet4ThinkingPost(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // conn, err := connector.Select("claude.sonnet-4_0-thinking") // if err != nil { // t.Fatalf("Failed to select connector: %v", err) // } // options := &context.CompletionOptions{ // Capabilities: &openai.Capabilities{ // Streaming: false, // Reasoning: true, // ToolCalls: false, // Vision: "claude", // Claude requires base64 format // Multimodal: true, // }, // } // llmInstance, err := llm.New(conn, options) // if err != nil { // t.Fatalf("Failed to create LLM instance: %v", err) // } // messages := []context.Message{ // { // Role: context.RoleUser, // Content: "Is 7 greater than 5? Explain your reasoning.", // }, // } // maxTokens := 500 // options.MaxTokens = &maxTokens // ctx := newClaudeTestContext("test-claude-thinking-post", "claude.sonnet-4_0-thinking") // response, err := llmInstance.Post(ctx, messages, options) // if err != nil { // t.Fatalf("Post failed: %v", err) // } // if response == nil { // t.Fatal("Response is nil") // } // // Validate content // contentStr, ok := response.Content.(string) // if !ok { // t.Fatalf("Content is not a string: %T", response.Content) // } // t.Logf("Reasoning content: %s", response.ReasoningContent) // t.Logf("Final answer: %s", contentStr) // // Check for reasoning content // if len(response.ReasoningContent) > 0 { // t.Logf("✓ Reasoning content present: %d characters", len(response.ReasoningContent)) // } // if response.Usage != nil && response.Usage.CompletionTokensDetails != nil && response.Usage.CompletionTokensDetails.ReasoningTokens > 0 { // t.Logf("✓ Reasoning tokens: %d", response.Usage.CompletionTokensDetails.ReasoningTokens) // } // t.Logf("Usage: prompt=%d, completion=%d, total=%d", // response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) // t.Logf("Response: %+v", response) // } // // TestClaudeTemperatureHandling tests that Claude models handle temperature parameter correctly // func TestClaudeTemperatureHandling(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // tests := []struct { // name string // connector string // temperature float64 // reasoning bool // }{ // { // name: "Sonnet 4 with temperature 0.7", // connector: "claude.sonnet-4_0", // temperature: 0.7, // reasoning: false, // }, // { // name: "Sonnet 4 Thinking with temperature 0.5", // connector: "claude.sonnet-4_0-thinking", // temperature: 0.5, // reasoning: true, // }, // } // for _, tt := range tests { // t.Run(tt.name, func(t *testing.T) { // conn, err := connector.Select(tt.connector) // if err != nil { // t.Fatalf("Failed to select connector: %v", err) // } // options := &context.CompletionOptions{ // Capabilities: &openai.Capabilities{ // Streaming: false, // Reasoning: tt.reasoning, // ToolCalls: true, // Vision: true, // }, // } // llmInstance, err := llm.New(conn, options) // if err != nil { // t.Fatalf("Failed to create LLM instance: %v", err) // } // messages := []context.Message{ // { // Role: context.RoleUser, // Content: "Say 'hello'.", // }, // } // maxTokens := 50 // options.MaxTokens = &maxTokens // options.Temperature = &tt.temperature // ctx := newClaudeTestContext("test-claude-temp-"+tt.connector, tt.connector) // response, err := llmInstance.Post(ctx, messages, options) // if err != nil { // t.Fatalf("Post failed: %v", err) // } // if response == nil { // t.Fatal("Response is nil") // } // t.Logf("✓ %s completed successfully with temperature=%.1f", tt.name, tt.temperature) // }) // } // } ================================================ FILE: agent/llm/providers/openai/deepseek_r1_test.go ================================================ package openai_test import ( gocontext "context" "strings" "testing" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/connector/openai" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/llm" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/openapi/oauth/types" "github.com/yaoapp/yao/test" ) // TestDeepSeekR1StreamBasic tests basic streaming completion with DeepSeek R1 func TestDeepSeekR1StreamBasic(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create connector from real configuration conn, err := connector.Select("deepseek.r1") if err != nil { t.Fatalf("Failed to select connector: %v", err) } // Create LLM instance with capabilities options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Streaming: true, Reasoning: true, // DeepSeek R1 supports reasoning ToolCalls: false, // R1 doesn't support native tool calls Vision: false, Audio: false, Multimodal: false, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } // Prepare messages with reasoning prompt (simple question for faster reasoning) messages := []context.Message{ { Role: context.RoleUser, Content: "What is 2 + 2?", }, } // Set max tokens (higher for reasoning models to allow full reasoning + answer) maxTokens := 500 options.MaxTokens = &maxTokens // Create context ctx := newDeepSeekTestContext("test-deepseek-r1-basic", "deepseek.r1") // Track streaming chunks and group events var reasoningChunks []string var contentChunks []string var thinkingGroupEnded bool var textGroupEnded bool handler := func(chunkType message.StreamChunkType, data []byte) int { dataStr := string(data) t.Logf("Stream chunk [%s]: %s", chunkType, dataStr) // Track different chunk types switch chunkType { case message.ChunkThinking: reasoningChunks = append(reasoningChunks, dataStr) case message.ChunkText: contentChunks = append(contentChunks, dataStr) } // Track group_end events to verify type field if chunkType == message.ChunkMessageEnd { // Parse the group_end data to check the type field var groupEndData struct { GroupID string `json:"group_id"` Type string `json:"type"` Timestamp int64 `json:"timestamp"` DurationMs int64 `json:"duration_ms"` ChunkCount int `json:"chunk_count"` Status string `json:"status"` } if err := jsoniter.Unmarshal(data, &groupEndData); err == nil { t.Logf("✓ group_end received: type=%s, chunks=%d, duration=%dms", groupEndData.Type, groupEndData.ChunkCount, groupEndData.DurationMs) // Verify the type field matches expected group types switch groupEndData.Type { case "thinking": thinkingGroupEnded = true if groupEndData.ChunkCount == 0 { t.Error("thinking group_end should have chunk_count > 0") } case "text": textGroupEnded = true if groupEndData.ChunkCount == 0 { t.Error("text group_end should have chunk_count > 0") } } } else { t.Errorf("Failed to parse group_end data: %v", err) } } return 0 // Continue } // Call Stream response, err := llmInstance.Stream(ctx, messages, options, handler) if err != nil { t.Fatalf("Stream failed: %v", err) } // Validate response if response == nil { t.Fatal("Response is nil") } if response.ID == "" { t.Error("Response ID is empty") } if response.Model == "" { t.Error("Response Model is empty") } if response.Content == "" { t.Error("Response content is empty") } if response.FinishReason == "" { t.Error("FinishReason is empty") } // DeepSeek R1 should have reasoning content if response.ReasoningContent == "" { t.Error("Expected reasoning_content but got empty") } else { t.Logf("Reasoning content length: %d characters", len(response.ReasoningContent)) } // Check reasoning tokens in usage if response.Usage == nil { t.Error("Response Usage is nil") } else { if response.Usage.TotalTokens == 0 { t.Error("Response Usage.TotalTokens is 0") } if response.Usage.CompletionTokensDetails != nil { if response.Usage.CompletionTokensDetails.ReasoningTokens == 0 { t.Error("Expected reasoning_tokens > 0 for DeepSeek R1") } else { t.Logf("Reasoning tokens: %d", response.Usage.CompletionTokensDetails.ReasoningTokens) } } t.Logf("Usage: prompt=%d, completion=%d, total=%d", response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) } // Should have received reasoning chunks if len(reasoningChunks) == 0 { t.Error("Expected reasoning chunks (ChunkThinking) but got none") } else { t.Logf("Received %d reasoning chunks", len(reasoningChunks)) } // Should have received content chunks if len(contentChunks) == 0 { t.Error("Expected content chunks (ChunkText) but got none") } else { t.Logf("Received %d content chunks", len(contentChunks)) } // Verify group_end events were received with correct types if !thinkingGroupEnded { t.Error("❌ Expected thinking group_end event but didn't receive it") } else { t.Log("✅ Thinking group_end event received with type='thinking'") } if !textGroupEnded { t.Error("❌ Expected text group_end event but didn't receive it") } else { t.Log("✅ Text group_end event received with type='text'") } t.Logf("Final response: %+v", response) } // TestDeepSeekR1PostBasic tests basic non-streaming completion func TestDeepSeekR1PostBasic(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create connector conn, err := connector.Select("deepseek.r1") if err != nil { t.Fatalf("Failed to select connector: %v", err) } // Create LLM instance options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Reasoning: true, ToolCalls: false, Vision: false, Audio: false, Multimodal: false, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } // Prepare messages (very simple question) messages := []context.Message{ { Role: context.RoleUser, Content: "What is 1+1?", }, } // Set max tokens (enough for reasoning + answer) maxTokens := 500 options.MaxTokens = &maxTokens // Create context ctx := newDeepSeekTestContext("test-deepseek-r1-post", "deepseek.r1") // Call Post response, err := llmInstance.Post(ctx, messages, options) if err != nil { t.Fatalf("Post failed: %v", err) } // Validate response if response == nil { t.Fatal("Response is nil") } if response.ID == "" { t.Error("Response ID is empty") } if response.Model == "" { t.Error("Response Model is empty") } if response.Content == "" { t.Error("Response content is empty") } // DeepSeek R1 should have reasoning content if response.ReasoningContent == "" { t.Error("Expected reasoning_content but got empty") } else { t.Logf("Reasoning content: %s", response.ReasoningContent) t.Logf("Final answer: %s", response.Content) } // Check usage if response.Usage == nil { t.Error("Response Usage is nil") } else { if response.Usage.TotalTokens == 0 { t.Error("Response Usage.TotalTokens is 0") } if response.Usage.CompletionTokensDetails != nil && response.Usage.CompletionTokensDetails.ReasoningTokens > 0 { t.Logf("Reasoning tokens: %d", response.Usage.CompletionTokensDetails.ReasoningTokens) } t.Logf("Usage: prompt=%d, completion=%d, total=%d", response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) } t.Logf("Response: %+v", response) } // TestDeepSeekR1LogicPuzzle tests DeepSeek R1's reasoning with a logic puzzle func TestDeepSeekR1LogicPuzzle(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("deepseek.r1") if err != nil { t.Fatalf("Failed to select connector: %v", err) } options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Streaming: true, Reasoning: true, ToolCalls: false, Vision: false, Audio: false, Multimodal: false, }, } maxTokens := 800 options.MaxTokens = &maxTokens llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } // Use a simpler logic question messages := []context.Message{ { Role: context.RoleUser, Content: "Is 5 greater than 3? Explain your reasoning.", }, } ctx := newDeepSeekTestContext("test-deepseek-r1-logic", "deepseek.r1") // Track reasoning and content separately var hasReasoning, hasContent bool handler := func(chunkType message.StreamChunkType, data []byte) int { if chunkType == message.ChunkThinking && len(data) > 0 { hasReasoning = true } else if chunkType == message.ChunkText && len(data) > 0 { hasContent = true } return 0 } response, err := llmInstance.Stream(ctx, messages, options, handler) if err != nil { t.Fatalf("Stream failed: %v", err) } if response == nil { t.Fatal("Response is nil") } // Should have both reasoning and content if !hasReasoning { t.Error("Expected to receive reasoning chunks but didn't") } if !hasContent { t.Error("Expected to receive content chunks but didn't") } // Validate reasoning content exists and is substantial if response.ReasoningContent == "" { t.Error("Expected reasoning_content but got empty") } else if len(response.ReasoningContent) < 50 { t.Errorf("Reasoning content too short (%d chars), expected detailed thinking", len(response.ReasoningContent)) } else { t.Logf("✓ Reasoning content length: %d characters", len(response.ReasoningContent)) } // Validate final answer contentStr := "" if response.Content != nil { if str, ok := response.Content.(string); ok { contentStr = str } } if len(contentStr) == 0 { t.Error("Content is empty") } else { // Should mention "Yes" or affirm that 5 > 3 if !strings.Contains(strings.ToLower(contentStr), "yes") && !strings.Contains(strings.ToLower(contentStr), "greater") { t.Logf("Warning: Content might not contain expected answer. Content: %s", contentStr) } else { t.Logf("✓ Final answer: %s", contentStr) } } // Check reasoning tokens if response.Usage != nil && response.Usage.CompletionTokensDetails != nil { reasoningTokens := response.Usage.CompletionTokensDetails.ReasoningTokens if reasoningTokens == 0 { t.Error("Expected reasoning_tokens > 0") } else { t.Logf("✓ Reasoning tokens: %d", reasoningTokens) } } t.Log("Logic puzzle test completed successfully") } // ============================================================================ // Helper Functions // ============================================================================ // newDeepSeekTestContext creates a real Context for testing DeepSeek provider func newDeepSeekTestContext(chatID, connectorID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client", UserID: "test-user-123", TeamID: "test-team-456", TenantID: "test-tenant-789", SessionID: "test-session-id", Constraints: types.DataConstraints{ TeamOnly: true, Extra: map[string]interface{}{ "test": "deepseek-provider", }, }, } ctx := context.New(gocontext.Background(), authorized, chatID) ctx.AssistantID = "test-assistant" ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "DeepSeekProviderTest/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptStandard ctx.Route = "/api/test" ctx.Metadata = make(map[string]interface{}) return ctx } ================================================ FILE: agent/llm/providers/openai/deepseek_v3_test.go ================================================ package openai_test import ( gocontext "context" "testing" "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/connector/openai" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/llm" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/openapi/oauth/types" "github.com/yaoapp/yao/test" ) // TestDeepSeekV3StreamBasic tests basic streaming completion with DeepSeek V3 func TestDeepSeekV3StreamBasic(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("deepseek.v3") if err != nil { t.Fatalf("Failed to select connector: %v", err) } options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Streaming: true, Reasoning: false, // V3 doesn't support reasoning ToolCalls: true, // V3 supports tool calls Vision: false, Audio: false, Multimodal: false, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } // Simple math question messages := []context.Message{ { Role: context.RoleUser, Content: "What is 5 + 3?", }, } // Set max tokens maxTokens := 100 options.MaxTokens = &maxTokens ctx := newDeepSeekV3TestContext("test-deepseek-v3-basic", "deepseek.v3") // Track streaming chunks var contentChunks []string handler := func(chunkType message.StreamChunkType, data []byte) int { dataStr := string(data) t.Logf("Stream chunk [%s]: %s", chunkType, dataStr) if chunkType == message.ChunkText { contentChunks = append(contentChunks, dataStr) } return 0 // Continue } // Call Stream response, err := llmInstance.Stream(ctx, messages, options, handler) if err != nil { t.Fatalf("Stream failed: %v", err) } // Validate response if response == nil { t.Fatal("Response is nil") } if response.ID == "" { t.Error("Response ID is empty") } if response.Model == "" { t.Error("Response Model is empty") } // Should have content (V3 is not a reasoning model) contentStr, ok := response.Content.(string) if !ok || contentStr == "" { t.Error("Expected content but got empty") } else { t.Logf("Response content: %s", contentStr) } // Should NOT have reasoning content (V3 doesn't support reasoning) if response.ReasoningContent != "" { t.Errorf("Expected no reasoning_content for V3, but got: %s", response.ReasoningContent) } // Check usage if response.Usage == nil { t.Error("Response Usage is nil") } else { t.Logf("Usage: prompt=%d, completion=%d, total=%d", response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) // Should have 0 reasoning tokens if response.Usage.CompletionTokensDetails != nil { reasoningTokens := response.Usage.CompletionTokensDetails.ReasoningTokens if reasoningTokens != 0 { t.Errorf("Expected reasoning_tokens=0 for V3, got %d", reasoningTokens) } } } if len(contentChunks) == 0 { t.Error("Expected content chunks but got none") } else { t.Logf("Received %d content chunks", len(contentChunks)) } t.Logf("Final response: %+v", response) } // TestDeepSeekV3PostBasic tests basic non-streaming completion func TestDeepSeekV3PostBasic(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("deepseek.v3") if err != nil { t.Fatalf("Failed to select connector: %v", err) } options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Reasoning: false, ToolCalls: true, Vision: false, Audio: false, Multimodal: false, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } // Simple question messages := []context.Message{ { Role: context.RoleUser, Content: "What is 2 * 4?", }, } // Set max tokens maxTokens := 100 options.MaxTokens = &maxTokens ctx := newDeepSeekV3TestContext("test-deepseek-v3-post", "deepseek.v3") // Call Post response, err := llmInstance.Post(ctx, messages, options) if err != nil { t.Fatalf("Post failed: %v", err) } // Validate response if response == nil { t.Fatal("Response is nil") } if response.ID == "" { t.Error("Response ID is empty") } if response.Model == "" { t.Error("Response Model is empty") } // Should have content contentStr, ok := response.Content.(string) if !ok || contentStr == "" { t.Error("Expected content but got empty") } else { t.Logf("Response content: %s", contentStr) } // Should NOT have reasoning content if response.ReasoningContent != "" { t.Errorf("V3 should not have reasoning_content, but got: %s", response.ReasoningContent) } // Check usage if response.Usage == nil { t.Error("Response Usage is nil") } else { t.Logf("Usage: prompt=%d, completion=%d, total=%d", response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) // Should have 0 reasoning tokens if response.Usage.CompletionTokensDetails != nil { reasoningTokens := response.Usage.CompletionTokensDetails.ReasoningTokens if reasoningTokens != 0 { t.Errorf("Expected reasoning_tokens=0 for V3, got %d", reasoningTokens) } } } t.Logf("Response: %+v", response) } // TestDeepSeekV3WithToolCalls tests V3 with tool calls func TestDeepSeekV3WithToolCalls(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("deepseek.v3") if err != nil { t.Fatalf("Failed to select connector: %v", err) } options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Reasoning: false, ToolCalls: true, }, } // Define a simple tool with minimal parameters simpleTool := map[string]interface{}{ "type": "function", "function": map[string]interface{}{ "name": "get_info", "description": "Get information", "parameters": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "query": map[string]interface{}{ "type": "string", "description": "Query string (single letter)", }, "count": map[string]interface{}{ "type": "number", "description": "Count (single digit)", }, }, "required": []string{"query", "count"}, }, }, } options.Tools = []map[string]interface{}{simpleTool} options.ToolChoice = "auto" // Set lower max_tokens for faster response maxTokens := 50 options.MaxTokens = &maxTokens llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Call get_info with query='A' and count=1", }, } ctx := newDeepSeekV3TestContext("test-deepseek-v3-tools", "deepseek.v3") response, err := llmInstance.Post(ctx, messages, options) if err != nil { t.Fatalf("Post with tool calls failed: %v", err) } if response == nil { t.Fatal("Response is nil") } // Should have tool calls if len(response.ToolCalls) == 0 { t.Error("Expected tool calls but got none") } else { tc := response.ToolCalls[0] t.Logf("✓ Tool call: %s(%s)", tc.Function.Name, tc.Function.Arguments) if tc.Function.Name != "get_info" { t.Errorf("Expected tool name 'get_info', got '%s'", tc.Function.Name) } } if response.Usage != nil { t.Logf("Usage: prompt=%d, completion=%d, total=%d", response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) } t.Logf("Response: %+v", response) } // TestDeepSeekV3NoReasoningEffort tests that V3 ignores reasoning_effort parameter func TestDeepSeekV3NoReasoningEffort(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("deepseek.v3") if err != nil { t.Fatalf("Failed to select connector: %v", err) } effort := "high" options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Reasoning: false, // V3 doesn't support reasoning ToolCalls: true, }, ReasoningEffort: &effort, // Should be ignored by adapter } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Reply with just: OK", }, } maxTokens := 20 options.MaxTokens = &maxTokens ctx := newDeepSeekV3TestContext("test-deepseek-v3-no-reasoning", "deepseek.v3") // Should succeed (adapter removes reasoning_effort parameter) response, err := llmInstance.Post(ctx, messages, options) if err != nil { t.Fatalf("Post failed: %v", err) } if response == nil { t.Fatal("Response is nil") } // Should have 0 reasoning tokens if response.Usage != nil && response.Usage.CompletionTokensDetails != nil { reasoningTokens := response.Usage.CompletionTokensDetails.ReasoningTokens if reasoningTokens != 0 { t.Errorf("Expected reasoning_tokens=0 for V3, got %d", reasoningTokens) } else { t.Log("✓ V3 correctly shows reasoning_tokens=0") } } t.Log("✓ ReasoningAdapter correctly removed reasoning_effort parameter for V3") } // ============================================================================ // Helper Functions // ============================================================================ // newDeepSeekV3TestContext creates a real Context for testing DeepSeek V3 provider func newDeepSeekV3TestContext(chatID, connectorID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client", UserID: "test-user-123", TeamID: "test-team-456", TenantID: "test-tenant-789", SessionID: "test-session-id", Constraints: types.DataConstraints{ TeamOnly: true, Extra: map[string]interface{}{ "test": "deepseek-v3-provider", }, }, } ctx := context.New(gocontext.Background(), authorized, chatID) ctx.AssistantID = "test-assistant" ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "DeepSeekV3ProviderTest/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptStandard ctx.Route = "/api/test" ctx.Metadata = make(map[string]interface{}) return ctx } ================================================ FILE: agent/llm/providers/openai/gpt5_test.go ================================================ package openai_test // GPT-5 tests temporarily commented out // // import ( // gocontext "context" // "testing" // // "github.com/yaoapp/gou/connector" // "github.com/yaoapp/gou/connector/openai" // "github.com/yaoapp/yao/agent/context" // "github.com/yaoapp/yao/agent/llm" // "github.com/yaoapp/yao/agent/output/message" // "github.com/yaoapp/yao/config" // "github.com/yaoapp/yao/openapi/oauth/types" // "github.com/yaoapp/yao/test" // ) // // // TestGPT5StreamBasic tests basic streaming completion with GPT-5 // func TestGPT5StreamBasic(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // // conn, err := connector.Select("openai.gpt-5") // if err != nil { // t.Fatalf("Failed to select connector: %v", err) // } // // options := &context.CompletionOptions{ // Capabilities: &openai.Capabilities{ // Streaming: true, // Reasoning: true, // GPT-5 supports reasoning // ToolCalls: true, // Vision: true, // Multimodal: true, // }, // } // // llmInstance, err := llm.New(conn, options) // if err != nil { // t.Fatalf("Failed to create LLM instance: %v", err) // } // // messages := []context.Message{ // { // Role: context.RoleUser, // Content: "What is 1+1? Reply with just the number.", // }, // } // // maxTokens := 100 // options.MaxCompletionTokens = &maxTokens // // ctx := newGPT5TestContext("test-gpt5-basic", "openai.gpt-5") // // var chunks []string // handler := func(chunkType message.StreamChunkType, data []byte) int { // chunks = append(chunks, string(data)) // t.Logf("Stream chunk [%s]: %s", chunkType, string(data)) // return 0 // } // // response, err := llmInstance.Stream(ctx, messages, options, handler) // if err != nil { // t.Fatalf("Stream failed: %v", err) // } // // if response == nil { // t.Fatal("Response is nil") // } // // // Basic validation // if response.ID == "" { // t.Error("Response ID is empty") // } // if response.Model == "" { // t.Error("Response Model is empty") // } // // // GPT-5 may use all tokens for reasoning, so content could be empty // // Just log the content instead of failing // t.Logf("Response content: %v", response.Content) // t.Logf("Usage: prompt=%d, completion=%d, total=%d", // response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) // // if response.Usage != nil && response.Usage.CompletionTokensDetails != nil { // t.Logf("Reasoning tokens: %d", response.Usage.CompletionTokensDetails.ReasoningTokens) // } // // t.Logf("Final response: %+v", response) // t.Logf("Total chunks received: %d", len(chunks)) // } // // // TestGPT5ReasoningEffort tests reasoning_effort parameter with different levels // func TestGPT5ReasoningEffort(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // // conn, err := connector.Select("openai.gpt-5") // if err != nil { // t.Fatalf("Failed to select connector: %v", err) // } // // // Test with different reasoning effort levels // effortLevels := []string{"low", "medium", "high"} // // for _, effort := range effortLevels { // t.Run("effort_"+effort, func(t *testing.T) { // options := &context.CompletionOptions{ // Capabilities: &openai.Capabilities{ // Reasoning: true, // ToolCalls: true, // }, // ReasoningEffort: &effort, // } // // llmInstance, err := llm.New(conn, options) // if err != nil { // t.Fatalf("Failed to create LLM instance: %v", err) // } // // messages := []context.Message{ // { // Role: context.RoleUser, // Content: "Solve: If all Bloops are Razzies and all Razzies are Lazzies, are all Bloops Lazzies?", // }, // } // // maxTokens := 1000 // options.MaxCompletionTokens = &maxTokens // // ctx := newGPT5TestContext("test-gpt5-reasoning-"+effort, "openai.gpt-5") // // response, err := llmInstance.Post(ctx, messages, options) // if err != nil { // t.Fatalf("Post failed with effort=%s: %v", effort, err) // } // // if response == nil { // t.Fatal("Response is nil") // } // // // Check reasoning tokens // var reasoningTokens int // if response.Usage != nil && response.Usage.CompletionTokensDetails != nil { // reasoningTokens = response.Usage.CompletionTokensDetails.ReasoningTokens // } // // t.Logf("Reasoning effort: %s", effort) // t.Logf("Reasoning tokens: %d", reasoningTokens) // t.Logf("Total tokens: %d", response.Usage.TotalTokens) // t.Logf("Content: %s", response.Content) // // // GPT-5 reasoning is hidden (no reasoning_content field) // // But should have reasoning_tokens in usage // if effort != "low" { // if reasoningTokens == 0 { // t.Logf("Warning: Expected reasoning_tokens > 0 for effort='%s', got 0", effort) // } // } // }) // } // } // // // TestGPT5PostWithToolCalls tests GPT-5 with tool calls // func TestGPT5PostWithToolCalls(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // // conn, err := connector.Select("openai.gpt-5") // if err != nil { // t.Fatalf("Failed to select connector: %v", err) // } // // options := &context.CompletionOptions{ // Capabilities: &openai.Capabilities{ // Reasoning: true, // ToolCalls: true, // }, // } // // // Define a calculation tool // calcTool := map[string]interface{}{ // "type": "function", // "function": map[string]interface{}{ // "name": "calculate", // "description": "Perform a mathematical calculation", // "parameters": map[string]interface{}{ // "type": "object", // "properties": map[string]interface{}{ // "expression": map[string]interface{}{ // "type": "string", // "description": "The mathematical expression to evaluate", // }, // }, // "required": []string{"expression"}, // }, // }, // } // // options.Tools = []map[string]interface{}{calcTool} // options.ToolChoice = "auto" // // llmInstance, err := llm.New(conn, options) // if err != nil { // t.Fatalf("Failed to create LLM instance: %v", err) // } // // messages := []context.Message{ // { // Role: context.RoleUser, // Content: "Use the calculate function to compute 2 * 3", // }, // } // // ctx := newGPT5TestContext("test-gpt5-tools", "openai.gpt-5") // // response, err := llmInstance.Post(ctx, messages, options) // if err != nil { // t.Fatalf("Post with tool calls failed: %v", err) // } // // if response == nil { // t.Fatal("Response is nil") // } // // // GPT-5 reasoning models may not always use tool calls // // Log what we got instead of failing // if len(response.ToolCalls) == 0 { // t.Logf("No tool calls returned. Content: %v", response.Content) // } else { // tc := response.ToolCalls[0] // t.Logf("✓ Tool call: %s(%s)", tc.Function.Name, tc.Function.Arguments) // // if tc.Function.Name != "calculate" { // t.Logf("Warning: Expected tool name 'calculate', got '%s'", tc.Function.Name) // } // } // // if response.Usage != nil { // t.Logf("Usage: prompt=%d, completion=%d, total=%d", // response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) // if response.Usage.CompletionTokensDetails != nil { // t.Logf("Reasoning tokens: %d", response.Usage.CompletionTokensDetails.ReasoningTokens) // } // } // // t.Logf("Response: %+v", response) // } // // // TestGPT5Vision tests GPT-5 with image input // func TestGPT5Vision(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // // conn, err := connector.Select("openai.gpt-5") // if err != nil { // t.Fatalf("Failed to select connector: %v", err) // } // // options := &context.CompletionOptions{ // Capabilities: &openai.Capabilities{ // Reasoning: true, // Vision: true, // Multimodal: true, // }, // } // // llmInstance, err := llm.New(conn, options) // if err != nil { // t.Fatalf("Failed to create LLM instance: %v", err) // } // // // Message with image content // messages := []context.Message{ // { // Role: context.RoleUser, // Content: []context.ContentPart{ // { // Type: context.ContentText, // Text: "What is in this image? Describe briefly.", // }, // { // Type: context.ContentImageURL, // ImageURL: &context.ImageURL{ // URL: "https://raw.githubusercontent.com/YaoApp/yao/refs/heads/main/yao/data/icons/icon.png", // }, // }, // }, // }, // } // // maxTokens := 200 // options.MaxCompletionTokens = &maxTokens // // ctx := newGPT5TestContext("test-gpt5-vision", "openai.gpt-5") // // response, err := llmInstance.Post(ctx, messages, options) // if err != nil { // t.Fatalf("Post with vision failed: %v", err) // } // // if response == nil { // t.Fatal("Response is nil") // } // // // Should have content describing the image // // Content can be string or []ContentPart for multimodal responses // var contentStr string // switch v := response.Content.(type) { // case string: // contentStr = v // case []interface{}: // // Handle []ContentPart serialized as []interface{} // for _, part := range v { // if partMap, ok := part.(map[string]interface{}); ok { // if text, ok := partMap["text"].(string); ok { // contentStr += text // } // } // } // case []context.ContentPart: // for _, part := range v { // if part.Type == context.ContentText { // contentStr += part.Text // } // } // case nil: // // GPT-5 reasoning models may use all tokens for reasoning, leaving no content // t.Log("Content is nil (reasoning model may have used all tokens for reasoning)") // default: // t.Logf("Unexpected content type: %T", response.Content) // } // // if contentStr != "" { // t.Logf("Image description: %s", contentStr) // } else if response.Content != nil { // t.Logf("Warning: Expected text content describing the image, got empty or non-text content") // } // // if response.Usage != nil { // t.Logf("Usage: prompt=%d, completion=%d, total=%d", // response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) // } // } // // // TestGPT5ReasoningEffortWithGPT4o tests that GPT-4o ignores reasoning_effort // func TestGPT5ReasoningEffortWithGPT4o(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // // // Use GPT-4o which doesn't support reasoning // conn, err := connector.Select("openai.gpt-4o") // if err != nil { // t.Fatalf("Failed to select connector: %v", err) // } // // effort := "high" // options := &context.CompletionOptions{ // Capabilities: &openai.Capabilities{ // Reasoning: false, // GPT-4o doesn't support reasoning // ToolCalls: true, // }, // ReasoningEffort: &effort, // Should be ignored by adapter // } // // llmInstance, err := llm.New(conn, options) // if err != nil { // t.Fatalf("Failed to create LLM instance: %v", err) // } // // messages := []context.Message{ // { // Role: context.RoleUser, // Content: "Say 'OK'", // }, // } // // maxTokens := 10 // options.MaxCompletionTokens = &maxTokens // // ctx := newGPT5TestContext("test-gpt4o-no-reasoning", "openai.gpt-4o") // // // Should succeed (adapter removes reasoning_effort parameter) // response, err := llmInstance.Post(ctx, messages, options) // if err != nil { // t.Fatalf("Post failed: %v", err) // } // // if response == nil { // t.Fatal("Response is nil") // } // // // Should have 0 reasoning tokens (GPT-4o doesn't do reasoning) // if response.Usage != nil && response.Usage.CompletionTokensDetails != nil { // reasoningTokens := response.Usage.CompletionTokensDetails.ReasoningTokens // if reasoningTokens != 0 { // t.Errorf("Expected reasoning_tokens=0 for GPT-4o, got %d", reasoningTokens) // } else { // t.Log("✓ GPT-4o correctly shows reasoning_tokens=0") // } // } // // t.Log("✓ ReasoningAdapter correctly removed reasoning_effort parameter for GPT-4o") // } // // // ============================================================================ // // Helper Functions // // ============================================================================ // // // newGPT5TestContext creates a real Context for testing GPT-5 provider // func newGPT5TestContext(chatID, connectorID string) *context.Context { // authorized := &types.AuthorizedInfo{ // Subject: "test-user", // ClientID: "test-client", // UserID: "test-user-123", // TeamID: "test-team-456", // TenantID: "test-tenant-789", // SessionID: "test-session-id", // Constraints: types.DataConstraints{ // TeamOnly: true, // Extra: map[string]interface{}{ // "test": "gpt5-provider", // }, // }, // } // // ctx := context.New(gocontext.Background(), authorized, chatID) // ctx.AssistantID = "test-assistant" // ctx.Locale = "en-us" // ctx.Theme = "light" // ctx.Client = context.Client{ // Type: "web", // UserAgent: "GPT5ProviderTest/1.0", // IP: "127.0.0.1", // } // ctx.Referer = context.RefererAPI // ctx.Accept = context.AcceptStandard // ctx.Route = "/api/test" // ctx.Metadata = make(map[string]interface{}) // return ctx // } ================================================ FILE: agent/llm/providers/openai/openai.go ================================================ package openai import ( gocontext "context" "fmt" "strings" "time" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/http" goullm "github.com/yaoapp/gou/llm" "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/agent/llm/adapters" "github.com/yaoapp/yao/agent/llm/providers/base" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/utils/jsonschema" ) // startMessage starts a new message and sends group_start event // Note: group_start/group_end events are used for backward compatibility // but at LLM level they represent message boundaries, not Agent-level blocks func (mt *messageTracker) startMessage(messageType message.StreamChunkType, handler message.StreamFunc) { if mt.active { // End previous message first mt.endMessage(handler) } mt.active = true // Generate message ID using context's ID generator if mt.idGenerator != nil { mt.messageID = mt.idGenerator.GenerateMessageID() // M1, M2, M3... } else { // Fallback to global generator if no context generator mt.messageID = message.GenerateNanoID() } mt.messageType = messageType mt.startTime = time.Now().UnixMilli() mt.chunkCount = 0 mt.toolCallInfo = nil if handler != nil { startData := &message.EventMessageStartData{ MessageID: mt.messageID, Type: string(messageType), Timestamp: mt.startTime, } if startJSON, err := jsoniter.Marshal(startData); err == nil { handler(message.ChunkMessageStart, startJSON) } } } // startToolCallMessage starts a new tool call message with tool call info func (mt *messageTracker) startToolCallMessage(toolCallInfo *message.EventToolCallInfo, handler message.StreamFunc) { if mt.active { mt.endMessage(handler) } mt.active = true // Generate message ID using context's ID generator if mt.idGenerator != nil { mt.messageID = mt.idGenerator.GenerateMessageID() // M1, M2, M3... } else { // Fallback to global generator if no context generator mt.messageID = message.GenerateNanoID() } mt.messageType = message.ChunkToolCall mt.startTime = time.Now().UnixMilli() mt.chunkCount = 0 mt.toolCallInfo = toolCallInfo if handler != nil { startData := &message.EventMessageStartData{ MessageID: mt.messageID, Type: string(message.ChunkToolCall), Timestamp: mt.startTime, ToolCall: toolCallInfo, } if startJSON, err := jsoniter.Marshal(startData); err == nil { handler(message.ChunkMessageStart, startJSON) } } } // incrementChunk increments the chunk count for the current message func (mt *messageTracker) incrementChunk() { if mt.active { mt.chunkCount++ } } // endMessage ends the current message and sends group_end event // Note: group_end event is used for backward compatibility // but at LLM level it represents message completion, not Agent-level block func (mt *messageTracker) endMessage(handler message.StreamFunc) { if !mt.active { return } if handler != nil { endData := &message.EventMessageEndData{ MessageID: mt.messageID, Type: string(mt.messageType), Timestamp: time.Now().UnixMilli(), DurationMs: time.Now().UnixMilli() - mt.startTime, ChunkCount: mt.chunkCount, Status: "completed", } if mt.toolCallInfo != nil { endData.ToolCall = mt.toolCallInfo } if endJSON, err := jsoniter.Marshal(endData); err == nil { handler(message.ChunkMessageEnd, endJSON) } } mt.active = false mt.messageID = "" mt.toolCallInfo = nil } // Provider OpenAI-compatible provider with capability adapters // Supports: vision, tool calls, streaming, JSON mode, reasoning type Provider struct { *base.Provider adapters []adapters.CapabilityAdapter } // buildAPIURL builds the complete API URL from host and endpoint. // Delegates to the shared connector.BuildAPIURL for consistent URL building // across the agent LLM path and the sandbox proxy path. func buildAPIURL(host, endpoint string) string { return connector.BuildAPIURL(host, endpoint) } // New create a new OpenAI provider with capability adapters func New(conn connector.Connector, capabilities *goullm.Capabilities) *Provider { return &Provider{ Provider: base.NewProvider(conn, capabilities), adapters: buildAdapters(capabilities), } } // buildAdapters builds capability adapters based on model capabilities func buildAdapters(cap *goullm.Capabilities) []adapters.CapabilityAdapter { if cap == nil { return []adapters.CapabilityAdapter{} } result := make([]adapters.CapabilityAdapter, 0) // Tool call adapter result = append(result, adapters.NewToolCallAdapter(cap.ToolCalls)) // Vision adapter visionSupport, visionFormat := context.GetVisionSupport(cap) if visionSupport { result = append(result, adapters.NewVisionAdapter(true, visionFormat)) } else if cap.Vision != nil { // Vision explicitly disabled, add adapter to remove image content result = append(result, adapters.NewVisionAdapter(false, context.VisionFormatNone)) } // Audio adapter result = append(result, adapters.NewAudioAdapter(cap.Audio)) // Reasoning adapter (always add to handle reasoning_effort and temperature parameters) // Even if the model doesn't support reasoning, we need the adapter to strip reasoning_effort if cap.Reasoning { // Detect reasoning format based on capabilities format := detectReasoningFormat(cap) result = append(result, adapters.NewReasoningAdapter(format, cap)) } else { // Model doesn't support reasoning, use None format to strip reasoning parameters result = append(result, adapters.NewReasoningAdapter(adapters.ReasoningFormatNone, cap)) } return result } // detectReasoningFormat detects the reasoning format based on capabilities func detectReasoningFormat(cap *goullm.Capabilities) adapters.ReasoningFormat { // TODO: Implement better detection logic // For now, default to OpenAI o1 format if reasoning is supported if cap.Reasoning { return adapters.ReasoningFormatOpenAI } return adapters.ReasoningFormatNone } // Stream stream completion from OpenAI API func (p *Provider) Stream(ctx *context.Context, messages []context.Message, options *context.CompletionOptions, handler message.StreamFunc) (*context.CompletionResponse, error) { // Add debug log trace, _ := ctx.Trace() if trace != nil { trace.Debug("OpenAI Stream: Starting stream request", map[string]any{ "message_count": len(messages), }) } maxRetries := 3 var lastErr error // Get Go context for cancellation support // Read from Stack.Options if available (call-level override) goCtx := ctx.Context if ctx.Stack != nil && ctx.Stack.Options != nil && ctx.Stack.Options.Context != nil { goCtx = ctx.Stack.Options.Context } if goCtx == nil { goCtx = gocontext.Background() } // Make a copy of messages to avoid modifying the original currentMessages := make([]context.Message, len(messages)) copy(currentMessages, messages) // Outer loop: handle network/API errors with exponential backoff for attempt := 0; attempt < maxRetries; attempt++ { // Check if context is cancelled before retry select { case <-goCtx.Done(): return nil, fmt.Errorf("context cancelled: %w", goCtx.Err()) default: } // Check for force interrupt before retry if ctx.Interrupt != nil { if signal := ctx.Interrupt.Peek(); signal != nil && signal.Type == context.InterruptForce { return nil, fmt.Errorf("force interrupted by user") } } if attempt > 0 { // Exponential backoff: 1s, 2s, 4s backoff := time.Duration(1< 0 { choice := chunk.Choices[0] delta := choice.Delta // Update accumulator metadata if accumulator.id == "" { accumulator.id = chunk.ID accumulator.model = chunk.Model accumulator.created = chunk.Created } // Handle role if delta.Role != "" { accumulator.role = delta.Role } // Handle reasoning content (DeepSeek R1) if delta.ReasoningContent != "" { // Start thinking message if not active if !messageTracker.active || messageTracker.messageType != message.ChunkThinking { messageTracker.startMessage(message.ChunkThinking, handler) } accumulator.reasoningContent += delta.ReasoningContent if handler != nil { handler(message.ChunkThinking, []byte(delta.ReasoningContent)) messageTracker.incrementChunk() } } // Handle content if delta.Content != "" { // Start text message if not active if !messageTracker.active || messageTracker.messageType != message.ChunkText { messageTracker.startMessage(message.ChunkText, handler) } accumulator.content += delta.Content if handler != nil { handler(message.ChunkText, []byte(delta.Content)) messageTracker.incrementChunk() } } // Handle refusal if delta.Refusal != "" { // Start refusal message if not active if !messageTracker.active || messageTracker.messageType != message.ChunkRefusal { messageTracker.startMessage(message.ChunkRefusal, handler) } accumulator.refusal += delta.Refusal if handler != nil { handler(message.ChunkRefusal, []byte(delta.Refusal)) messageTracker.incrementChunk() } } // Handle tool calls if len(delta.ToolCalls) > 0 { for _, tc := range delta.ToolCalls { if _, exists := accumulator.toolCalls[tc.Index]; !exists { accumulator.toolCalls[tc.Index] = &accumulatedToolCall{} // Start new tool call message when we first see this tool call if tc.ID != "" { toolCallInfo := &message.EventToolCallInfo{ ID: tc.ID, Name: tc.Function.Name, // May be partial or empty initially Index: tc.Index, } messageTracker.startToolCallMessage(toolCallInfo, handler) } } accTC := accumulator.toolCalls[tc.Index] if tc.ID != "" { accTC.id = tc.ID } if tc.Type != "" { accTC.typ = tc.Type } if tc.Function.Name != "" { accTC.functionName = tc.Function.Name // Update tool call info in tracker if messageTracker.active && messageTracker.toolCallInfo != nil { messageTracker.toolCallInfo.Name = tc.Function.Name } } if tc.Function.Arguments != "" { accTC.functionArgs += tc.Function.Arguments // Update tool call info in tracker if messageTracker.active && messageTracker.toolCallInfo != nil { messageTracker.toolCallInfo.Arguments = accTC.functionArgs } } } // Notify handler of tool call progress // Send the raw delta from OpenAI (as JSON bytes) // Handler will convert to object for frontend merge if handler != nil { toolCallData, _ := jsoniter.Marshal(delta.ToolCalls) handler(message.ChunkToolCall, toolCallData) messageTracker.incrementChunk() } } // Handle finish reason if choice.FinishReason != nil && *choice.FinishReason != "" { accumulator.finishReason = *choice.FinishReason } // Handle usage (in choices, for older API versions) if chunk.Usage != nil { accumulator.usage = &message.UsageInfo{ PromptTokens: chunk.Usage.PromptTokens, CompletionTokens: chunk.Usage.CompletionTokens, TotalTokens: chunk.Usage.TotalTokens, } } } // Check for usage at the top level (newer API versions with stream_options) if chunk.Usage != nil && accumulator.usage == nil { accumulator.usage = &message.UsageInfo{ PromptTokens: chunk.Usage.PromptTokens, CompletionTokens: chunk.Usage.CompletionTokens, TotalTokens: chunk.Usage.TotalTokens, } } return http.HandlerReturnOk } // Log request for debugging if trace != nil { if requestBodyJSON, marshalErr := jsoniter.Marshal(requestBody); marshalErr == nil { trace.Debug("OpenAI Stream Request", map[string]any{ "url": url, "body": string(requestBodyJSON), }) } } // Buffer to capture non-SSE error responses var errorBuffer strings.Builder errorDetected := false // Wrap streamHandler to detect JSON error responses // Note: API error responses are raw JSON without "data: " prefix // Normal SSE data always starts with "data: " prefix wrappedHandler := func(data []byte) int { dataStr := string(data) trimmed := strings.TrimSpace(dataStr) // Skip empty lines if trimmed == "" { return http.HandlerReturnOk } // Normal SSE data starts with "data: " - pass to streamHandler if strings.HasPrefix(dataStr, "data: ") { return streamHandler(data) } // Detect if this looks like a JSON error response (raw JSON without "data: " prefix) // API errors are returned as raw JSON: {"error": {...}} if strings.HasPrefix(trimmed, "{") && strings.Contains(dataStr, `"error"`) { errorDetected = true } // If error detected, accumulate all data for parsing if errorDetected { errorBuffer.Write(data) errorBuffer.WriteString("\n") return http.HandlerReturnOk } // Unknown format, pass to streamHandler (it will skip non-SSE data) return streamHandler(data) } // Make streaming request (goCtx already set at function start) log.Trace("[LLM] Starting HTTP Stream request: url=%s", url) err = req.Stream(goCtx, "POST", requestBody, wrappedHandler) log.Trace("[LLM] HTTP Stream request returned: err=%v", err) // Check if we captured an error response if errorDetected && errorBuffer.Len() > 0 { errorJSON := errorBuffer.String() if trace != nil { trace.Error(i18n.T(ctx.Locale, "llm.openai.stream.api_error"), map[string]any{"response": errorJSON}) // "OpenAI API returned error response" } // Try to parse error var apiError struct { Error struct { Message string `json:"message"` Type string `json:"type"` Param string `json:"param"` Code string `json:"code"` } `json:"error"` } if parseErr := jsoniter.UnmarshalFromString(errorJSON, &apiError); parseErr == nil && apiError.Error.Message != "" { err = fmt.Errorf("OpenAI API error: %s (type: %s, param: %s, code: %s)", apiError.Error.Message, apiError.Error.Type, apiError.Error.Param, apiError.Error.Code) } else { err = fmt.Errorf("OpenAI API error: %s", strings.TrimSpace(errorJSON)) } } // Check if error is due to context cancellation FIRST (before logging) // This prevents blocking on trace operations when context is cancelled if err != nil && goCtx.Err() != nil { log.Trace("[LLM] Context cancelled detected, skipping handler calls and returning") // NOTE: Do NOT call handler or groupTracker.endGroup here // The connection is already closed, calling handler may block indefinitely // Just return the error immediately return nil, fmt.Errorf("stream cancelled: %w", goCtx.Err()) } // Log any error from streaming (only if not cancelled) if err != nil && trace != nil { trace.Error(i18n.T(ctx.Locale, "llm.openai.stream.error"), map[string]any{"error": err.Error()}) // "OpenAI Stream Error" } if err != nil { // End current message if active messageTracker.endMessage(handler) // Notify handler of error if provided if handler != nil { errData := []byte(err.Error()) handler(message.ChunkError, errData) } return nil, fmt.Errorf("streaming request failed: %w", err) } // Check if we received any data if accumulator.id == "" { if trace != nil { trace.Warn("OpenAI stream completed but no data was received") // Log request details for debugging if requestBodyJSON, err := jsoniter.Marshal(requestBody); err == nil { trace.Error(i18n.T(ctx.Locale, "llm.openai.stream.no_data"), map[string]any{"body": string(requestBodyJSON)}) // "Request body that caused empty response" } trace.Error(i18n.T(ctx.Locale, "llm.openai.stream.no_data_info"), map[string]any{ // "Request details" "url": url, "model": accumulator.model, "created": accumulator.created, }) } err := fmt.Errorf("no data received from OpenAI API") // End current message if active messageTracker.endMessage(handler) // Notify handler of error if provided if handler != nil { errData := []byte(err.Error()) handler(message.ChunkError, errData) } return nil, err } // Build final response response := &context.CompletionResponse{ ID: accumulator.id, Object: "chat.completion", Created: accumulator.created, Model: accumulator.model, Role: accumulator.role, Content: accumulator.content, ReasoningContent: accumulator.reasoningContent, Refusal: accumulator.refusal, FinishReason: accumulator.finishReason, Usage: accumulator.usage, } // Convert accumulated tool calls to ToolCall slice if len(accumulator.toolCalls) > 0 { toolCalls := make([]context.ToolCall, 0, len(accumulator.toolCalls)) for i := 0; i < len(accumulator.toolCalls); i++ { if tc, exists := accumulator.toolCalls[i]; exists { toolCalls = append(toolCalls, context.ToolCall{ ID: tc.id, Type: context.ToolCallType(tc.typ), Function: context.Function{ Name: tc.functionName, Arguments: tc.functionArgs, }, }) } } response.ToolCalls = toolCalls // Validate tool call results if schema is provided // Note: If validation fails, we log the error but DO NOT return error // Instead, we let the response through so Agent layer can handle it // Agent layer will re-validate and provide better error feedback to LLM if err := p.validateToolCallResults(options, toolCalls); err != nil { // Log validation error if trace, _ := ctx.Trace(); trace != nil { trace.Warn("Tool call validation failed at LLM layer, passing to Agent layer for handling", map[string]any{ "error": err.Error(), }) } // End current message messageTracker.endMessage(handler) // Continue and return response (don't return error) // Agent layer will handle validation and retry } } // End final message if still active messageTracker.endMessage(handler) return response, nil } // Post post completion request to OpenAI API func (p *Provider) Post(ctx *context.Context, messages []context.Message, options *context.CompletionOptions) (*context.CompletionResponse, error) { // Add debug log trace, _ := ctx.Trace() if trace != nil { trace.Debug("OpenAI Post: Starting non-stream request", map[string]any{ "message_count": len(messages), }) } maxRetries := 3 var lastErr error // Get Go context for cancellation support // Read from Stack.Options if available (call-level override) goCtx := ctx.Context if ctx.Stack != nil && ctx.Stack.Options != nil && ctx.Stack.Options.Context != nil { goCtx = ctx.Stack.Options.Context } if goCtx == nil { goCtx = gocontext.Background() } // Make a copy of messages to avoid modifying the original currentMessages := make([]context.Message, len(messages)) copy(currentMessages, messages) // Outer loop: handle network/API errors with exponential backoff for attempt := 0; attempt < maxRetries; attempt++ { // Check if context is cancelled before retry select { case <-goCtx.Done(): return nil, fmt.Errorf("context cancelled: %w", goCtx.Err()) default: } if attempt > 0 { // Exponential backoff backoff := time.Duration(1< 0 { if err := p.validateToolCallResults(options, response.ToolCalls); err != nil { return nil, fmt.Errorf("tool call validation failed: %w", err) } } return response, nil } // buildRequestBody builds the request body for OpenAI API func (p *Provider) buildRequestBody(messages []context.Message, options *context.CompletionOptions, streaming bool) (map[string]interface{}, error) { if options == nil { return nil, fmt.Errorf("options are required") } // Get model and other settings from connector setting := p.Connector.Setting() model, ok := setting["model"].(string) if !ok || model == "" { return nil, fmt.Errorf("model is not set in connector") } // Get thinking setting from connector (for models that support reasoning/thinking mode) var thinkingSetting interface{} if thinking, exists := setting["thinking"]; exists { thinkingSetting = thinking } // Convert messages to API format apiMessages := make([]map[string]interface{}, 0, len(messages)) for _, msg := range messages { apiMsg := map[string]interface{}{ "role": string(msg.Role), } if msg.Content != nil { // Check if Content is []context.ContentPart and convert to API format if parts, ok := msg.Content.([]context.ContentPart); ok { apiParts := make([]map[string]interface{}, 0, len(parts)) for _, part := range parts { apiPart := map[string]interface{}{ "type": string(part.Type), } switch part.Type { case context.ContentText: apiPart["text"] = part.Text case context.ContentImageURL: if part.ImageURL != nil { apiPart["image_url"] = map[string]interface{}{ "url": part.ImageURL.URL, } if part.ImageURL.Detail != "" { apiPart["image_url"].(map[string]interface{})["detail"] = part.ImageURL.Detail } } case context.ContentInputAudio: if part.InputAudio != nil { apiPart["input_audio"] = part.InputAudio } } apiParts = append(apiParts, apiPart) } apiMsg["content"] = apiParts } else { // Content is string or already in map format, use as is apiMsg["content"] = msg.Content } } if msg.Name != nil { apiMsg["name"] = *msg.Name } if msg.ToolCallID != nil { apiMsg["tool_call_id"] = *msg.ToolCallID } if len(msg.ToolCalls) > 0 { apiMsg["tool_calls"] = msg.ToolCalls } if msg.Refusal != nil { apiMsg["refusal"] = *msg.Refusal } apiMessages = append(apiMessages, apiMsg) } // Build request body body := map[string]interface{}{ "model": model, "messages": apiMessages, "stream": streaming, } // Add optional parameters if options.Temperature != nil { body["temperature"] = *options.Temperature } // Use max_completion_tokens (modern API parameter for GPT-5+) // GPT-5 models only support max_completion_tokens (not max_tokens) if options.MaxCompletionTokens != nil { body["max_completion_tokens"] = *options.MaxCompletionTokens } else if options.MaxTokens != nil { // Fallback: convert MaxTokens to max_completion_tokens for compatibility body["max_completion_tokens"] = *options.MaxTokens } if options.TopP != nil { body["top_p"] = *options.TopP } if options.N != nil { body["n"] = *options.N } if options.Stop != nil { body["stop"] = options.Stop } if options.PresencePenalty != nil { body["presence_penalty"] = *options.PresencePenalty } if options.FrequencyPenalty != nil { body["frequency_penalty"] = *options.FrequencyPenalty } if len(options.LogitBias) > 0 { body["logit_bias"] = options.LogitBias } if options.User != "" { body["user"] = options.User } if options.ResponseFormat != nil { // Build response_format according to OpenAI API requirements responseFormat := map[string]interface{}{ "type": options.ResponseFormat.Type, } // For json_schema type, include the schema details if options.ResponseFormat.Type == context.ResponseFormatJSONSchema && options.ResponseFormat.JSONSchema != nil { responseFormat["json_schema"] = options.ResponseFormat.JSONSchema } body["response_format"] = responseFormat } if options.Seed != nil { body["seed"] = *options.Seed } if len(options.Tools) > 0 { body["tools"] = options.Tools } if options.ToolChoice != nil { body["tool_choice"] = options.ToolChoice } // Reasoning effort (o1 and GPT-5 models) if options.ReasoningEffort != nil { body["reasoning_effort"] = *options.ReasoningEffort } // For streaming, include usage info by default if streaming { if options.StreamOptions != nil { body["stream_options"] = options.StreamOptions } else { // Default: include usage info in streaming response body["stream_options"] = map[string]interface{}{ "include_usage": true, } } } if options.Audio != nil { body["audio"] = options.Audio } // Add thinking parameter for models that support reasoning/thinking mode if thinkingSetting != nil { body["thinking"] = thinkingSetting } return body, nil } // validateToolCallResults validates tool call arguments against JSON schema func (p *Provider) validateToolCallResults(options *context.CompletionOptions, toolCalls []context.ToolCall) error { if options == nil || options.Tools == nil || len(options.Tools) == 0 { return nil } // Build tool schema map for quick lookup toolSchemas := make(map[string]interface{}) for _, tool := range options.Tools { if function, ok := tool["function"].(map[string]interface{}); ok { if name, ok := function["name"].(string); ok { if parameters, ok := function["parameters"]; ok { toolSchemas[name] = parameters } } } } // Validate each tool call for _, tc := range toolCalls { schema, hasSchema := toolSchemas[tc.Function.Name] if !hasSchema { continue // No schema to validate against } // Parse arguments JSON var args interface{} if err := jsoniter.UnmarshalFromString(tc.Function.Arguments, &args); err != nil { return fmt.Errorf("tool call %s has invalid JSON arguments: %w", tc.Function.Name, err) } // Validate against schema if err := jsonschema.ValidateData(schema, args); err != nil { return fmt.Errorf("tool call %s arguments validation failed: %w", tc.Function.Name, err) } } return nil } // isToolCallValidationError checks if an error is a tool call validation error func isToolCallValidationError(err error) bool { if err == nil { return false } errStr := err.Error() return strings.Contains(errStr, "tool call validation failed") || strings.Contains(errStr, "arguments validation failed") } // isRetryableError checks if an error is retryable func isRetryableError(err error) bool { if err == nil { return false } errStr := err.Error() // Retryable: network errors, timeouts, rate limits, server errors retryablePatterns := []string{ "timeout", "connection refused", "connection reset", "EOF", "HTTP 429", // Rate limit "HTTP 500", // Internal server error "HTTP 502", // Bad gateway "HTTP 503", // Service unavailable "HTTP 504", // Gateway timeout } for _, pattern := range retryablePatterns { if strings.Contains(strings.ToLower(errStr), strings.ToLower(pattern)) { return true } } return false } ================================================ FILE: agent/llm/providers/openai/openai_test.go ================================================ package openai_test import ( gocontext "context" "encoding/json" "strings" "testing" "time" "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/connector/openai" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/llm" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/openapi/oauth/types" "github.com/yaoapp/yao/test" ) // TestOpenAIStreamBasic tests basic streaming completion with short output func TestOpenAIStreamBasic(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create connector from real configuration conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } // Create LLM instance with capabilities options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Streaming: true, ToolCalls: true, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } // Prepare messages with concise prompt messages := []context.Message{ { Role: context.RoleUser, Content: "Say 'Hello' in one word.", }, } // Set short max tokens to ensure quick response maxTokens := 5 options.MaxTokens = &maxTokens // Create context ctx := newTestContext("test-stream-basic", "openai.gpt-4o") // Track streaming chunks var chunks []string handler := func(chunkType message.StreamChunkType, data []byte) int { chunks = append(chunks, string(data)) t.Logf("Stream chunk [%s]: %s", chunkType, string(data)) return 0 // Continue } // Call Stream response, err := llmInstance.Stream(ctx, messages, options, handler) if err != nil { t.Fatalf("Stream failed: %v", err) } // Validate response if response == nil { t.Fatal("Response is nil") } if response.ID == "" { t.Error("Response ID is empty") } if response.Model == "" { t.Error("Response Model is empty") } if response.Content == "" { t.Error("Response content is empty") } if response.FinishReason == "" { t.Error("FinishReason is empty") } if response.Usage == nil { t.Error("Response Usage is nil") } else { if response.Usage.TotalTokens == 0 { t.Error("Response Usage.TotalTokens is 0") } t.Logf("Usage: prompt=%d, completion=%d, total=%d", response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) } if len(chunks) == 0 { t.Error("No streaming chunks received") } t.Logf("Final response: %+v", response) t.Logf("Total chunks received: %d", len(chunks)) } // TestOpenAIPostBasic tests basic non-streaming completion func TestOpenAIPostBasic(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create connector conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } // Create LLM instance options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ ToolCalls: true, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } // Prepare messages with concise prompt messages := []context.Message{ { Role: context.RoleUser, Content: "Reply with only the word 'OK'.", }, } // Set short max tokens maxTokens := 5 options.MaxTokens = &maxTokens // Create context ctx := newTestContext("test-stream-basic", "openai.gpt-4o") // Call Post response, err := llmInstance.Post(ctx, messages, options) if err != nil { t.Fatalf("Post failed: %v", err) } // Validate response if response == nil { t.Fatal("Response is nil") } if response.ID == "" { t.Error("Response ID is empty") } if response.Model == "" { t.Error("Response Model is empty") } if response.Content == "" { t.Error("Response content is empty") } if response.FinishReason == "" { t.Error("FinishReason is empty") } if response.Usage == nil { t.Error("Response Usage is nil") } else { if response.Usage.TotalTokens == 0 { t.Error("Response Usage.TotalTokens is 0") } t.Logf("Usage: prompt=%d, completion=%d, total=%d", response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) } t.Logf("Response: %+v", response) } // TestOpenAIStreamWithToolCalls tests streaming with tool calls and JSON schema validation func TestOpenAIStreamWithToolCalls(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create connector conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } // Create LLM instance with tool call capabilities options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Streaming: true, ToolCalls: true, }, } // Define a simple weather tool with JSON schema weatherTool := map[string]interface{}{ "type": "function", "function": map[string]interface{}{ "name": "get_weather", "description": "Get the current weather for a location", "parameters": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "location": map[string]interface{}{ "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, "unit": map[string]interface{}{ "type": "string", "enum": []string{"celsius", "fahrenheit"}, }, }, "required": []string{"location"}, }, }, } options.Tools = []map[string]interface{}{weatherTool} options.ToolChoice = "auto" llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } // Prepare messages that should trigger tool call messages := []context.Message{ { Role: context.RoleUser, Content: "What's the weather in Tokyo? Use celsius.", }, } // Create context ctx := newTestContext("test-stream-basic", "openai.gpt-4o") // Track streaming chunks var toolCallChunks int handler := func(chunkType message.StreamChunkType, data []byte) int { if chunkType == message.ChunkToolCall { toolCallChunks++ } t.Logf("Stream chunk [%s]: %s", chunkType, string(data)) return 0 // Continue } // Call Stream response, err := llmInstance.Stream(ctx, messages, options, handler) if err != nil { t.Fatalf("Stream with tool calls failed: %v", err) } // Validate response if response == nil { t.Fatal("Response is nil") } // Should have tool calls if len(response.ToolCalls) == 0 { t.Error("Expected tool calls but got none") } else { t.Logf("Received %d tool call(s)", len(response.ToolCalls)) for i, tc := range response.ToolCalls { t.Logf("Tool call %d: %s(%s)", i, tc.Function.Name, tc.Function.Arguments) // Validate tool call has required fields if tc.ID == "" { t.Errorf("Tool call %d missing ID", i) } if tc.Function.Name == "" { t.Errorf("Tool call %d missing function name", i) } if tc.Function.Arguments == "" { t.Errorf("Tool call %d missing arguments", i) } } } if response.FinishReason != context.FinishReasonToolCalls { t.Logf("Warning: Expected finish_reason='tool_calls', got '%s'", response.FinishReason) } if toolCallChunks == 0 { t.Error("No tool call chunks received during streaming") } t.Logf("Final response: %+v", response) } // TestOpenAIPostWithToolCalls tests non-streaming with tool calls func TestOpenAIPostWithToolCalls(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create connector conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } // Create LLM instance with tool call capabilities options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ ToolCalls: true, }, } // Define a calculation tool calcTool := map[string]interface{}{ "type": "function", "function": map[string]interface{}{ "name": "calculate", "description": "Perform a mathematical calculation", "parameters": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "expression": map[string]interface{}{ "type": "string", "description": "The mathematical expression to evaluate", }, }, "required": []string{"expression"}, }, }, } options.Tools = []map[string]interface{}{calcTool} options.ToolChoice = "auto" llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } // Prepare messages messages := []context.Message{ { Role: context.RoleUser, Content: "Calculate 15 * 8", }, } // Create context ctx := newTestContext("test-stream-basic", "openai.gpt-4o") // Call Post response, err := llmInstance.Post(ctx, messages, options) if err != nil { t.Fatalf("Post with tool calls failed: %v", err) } // Validate response if response == nil { t.Fatal("Response is nil") } // Validate response metadata if response.ID == "" { t.Error("Response ID is empty") } if response.Model == "" { t.Error("Response Model is empty") } if response.FinishReason != "tool_calls" { t.Errorf("FinishReason is %s, expected tool_calls", response.FinishReason) } if response.Usage == nil { t.Error("Response Usage is nil") } else { if response.Usage.TotalTokens == 0 { t.Error("Response Usage.TotalTokens is 0") } t.Logf("Usage: prompt=%d, completion=%d, total=%d", response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) } // Should have tool calls if len(response.ToolCalls) == 0 { t.Error("Expected tool calls but got none") } else { tc := response.ToolCalls[0] // Validate tool call structure if tc.ID == "" { t.Error("Tool call ID is empty") } if tc.Type != context.ToolTypeFunction { t.Errorf("Tool call Type is %s, expected %s", tc.Type, context.ToolTypeFunction) } if tc.Function.Name != "calculate" { t.Errorf("Tool call function name is %s, expected calculate", tc.Function.Name) } if tc.Function.Arguments == "" { t.Error("Tool call arguments are empty") } t.Logf("Received %d tool call(s)", len(response.ToolCalls)) for i, tc := range response.ToolCalls { t.Logf("Tool call %d: %s(%s)", i, tc.Function.Name, tc.Function.Arguments) } } t.Logf("Response: %+v", response) } // TestOpenAIStreamWithInvalidToolCall tests that invalid tool calls trigger validation error func TestOpenAIStreamWithInvalidToolCall(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create connector conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } // Create LLM instance options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Streaming: true, ToolCalls: true, }, } // Define a strict tool that requires specific format strictTool := map[string]interface{}{ "type": "function", "function": map[string]interface{}{ "name": "send_email", "description": "Send an email", "parameters": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "to": map[string]interface{}{ "type": "string", "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$", }, "subject": map[string]interface{}{ "type": "string", "minLength": 1, }, "body": map[string]interface{}{ "type": "string", "minLength": 1, }, }, "required": []string{"to", "subject", "body"}, }, }, } options.Tools = []map[string]interface{}{strictTool} options.ToolChoice = map[string]interface{}{ "type": "function", "function": map[string]interface{}{ "name": "send_email", }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } // Prepare messages with incomplete information (should cause validation error) messages := []context.Message{ { Role: context.RoleUser, Content: "Send email to invalid-email without subject", }, } // Create context ctx := newTestContext("test-stream-basic", "openai.gpt-4o") handler := func(chunkType message.StreamChunkType, data []byte) int { return 0 // Continue } // Call Stream - should succeed but may trigger validation if tool call is malformed response, err := llmInstance.Stream(ctx, messages, options, handler) // The API might return a valid tool call despite the bad prompt, // so we just log the result if err != nil { t.Logf("Stream failed as expected with validation error: %v", err) } else { t.Logf("Stream succeeded, response: %+v", response) if len(response.ToolCalls) > 0 { t.Logf("Tool calls: %v", response.ToolCalls) } } } // TestOpenAIStreamRetry tests the retry mechanism with invalid API key func TestOpenAIStreamRetry(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create connector with invalid API key to trigger 401 error (non-retryable) connDSL := `{ "type": "openai", "options": { "model": "gpt-4o", "key": "sk-invalid-key-should-fail-auth", "host": "https://api.openai.com" } }` conn, err := connector.New("openai", "test-retry", []byte(connDSL)) if err != nil { t.Fatalf("Failed to create test connector: %v", err) } // Create LLM instance options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Streaming: true, ToolCalls: true, // Need this to select OpenAI provider }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Test", }, } ctx := newTestContext("test-retry", "test-retry") // This should fail quickly without retry (401 is non-retryable) _, err = llmInstance.Stream(ctx, messages, options, nil) if err == nil { t.Fatal("Expected error due to invalid API key, but got success") } // Verify it's an error related to invalid API key // Could be: 401, unauthorized, authentication error, or no data (empty response) errMsg := err.Error() hasExpectedError := strings.Contains(strings.ToLower(errMsg), "401") || strings.Contains(strings.ToLower(errMsg), "unauthorized") || strings.Contains(strings.ToLower(errMsg), "authentication") || strings.Contains(strings.ToLower(errMsg), "incorrect api key") || strings.Contains(strings.ToLower(errMsg), "no data received") if !hasExpectedError { t.Errorf("Expected authentication or empty response error, got: %v", err) } // Should mention non-retryable (these errors should not trigger retry) if !strings.Contains(strings.ToLower(errMsg), "non-retryable") { t.Errorf("Error should indicate non-retryable: %v", err) } t.Logf("Failed as expected with error: %v", err) } // TestOpenAIStreamChunkTypes tests that stream handler receives correct chunk types func TestOpenAIStreamChunkTypes(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Streaming: true, ToolCalls: true, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Say 'test' in one word.", }, } ctx := newTestContext("test-chunk-types", "openai.gpt-4o") // Track chunk types chunkTypes := make(map[message.StreamChunkType]int) handler := func(chunkType message.StreamChunkType, data []byte) int { chunkTypes[chunkType]++ t.Logf("Received chunk type: %s, data length: %d", chunkType, len(data)) return 1 // Continue } response, err := llmInstance.Stream(ctx, messages, options, handler) if err != nil { t.Fatalf("Stream failed: %v", err) } if response == nil { t.Fatal("Response is nil") } // Validate chunk types received if chunkTypes[message.ChunkText] == 0 { t.Error("Expected to receive ChunkText, but got 0") } t.Logf("Chunk types received: %+v", chunkTypes) } // TestOpenAIStreamErrorCallback tests that errors are sent to stream handler func TestOpenAIStreamErrorCallback(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create connector with invalid API key to trigger error connDSL := `{ "type": "openai", "options": { "model": "gpt-4o", "key": "sk-invalid-for-error-test", "host": "https://api.openai.com" } }` conn, err := connector.New("openai", "test-error-callback", []byte(connDSL)) if err != nil { t.Fatalf("Failed to create test connector: %v", err) } options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Streaming: true, ToolCalls: true, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Test", }, } ctx := newTestContext("test-error-callback", "test-error-callback") // Track if error chunk was received receivedError := false var errorMessage string handler := func(chunkType message.StreamChunkType, data []byte) int { if chunkType == message.ChunkError { receivedError = true errorMessage = string(data) t.Logf("Received error chunk: %s", errorMessage) } return 1 // Continue } // This should fail and send error to handler _, err = llmInstance.Stream(ctx, messages, options, handler) if err == nil { t.Fatal("Expected error due to invalid API key") } // Verify error was sent to handler if !receivedError { t.Error("Expected to receive ChunkError in handler, but didn't") } if errorMessage == "" { t.Error("Error message in chunk is empty") } t.Logf("Error callback test passed. Error: %v", err) } // TestOpenAIToolCallValidationRetry tests automatic tool call validation retry with LLM feedback func TestOpenAIToolCallValidationRetry(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Streaming: true, ToolCalls: true, }, Tools: []map[string]interface{}{ { "type": "function", "function": map[string]interface{}{ "name": "test_strict_validation", "description": "A function with very strict validation rules", "parameters": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "status": map[string]interface{}{ "type": "string", "description": "Must be exactly 'active' or 'inactive'", "enum": []string{"active", "inactive"}, }, "priority": map[string]interface{}{ "type": "integer", "description": "Must be between 1 and 5", "minimum": 1, "maximum": 5, }, }, "required": []string{"status", "priority"}, }, }, }, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } ctx := newTestContext("test-tool-validation-retry", "openai.gpt-4o") // Try to make LLM call with intentionally unclear requirements // This may or may not trigger validation, depending on LLM behavior messages := []context.Message{ { Role: context.RoleUser, Content: "Call test_strict_validation function with status='pending' and priority=10", }, } // The Provider will automatically: // 1. Call LLM // 2. If validation fails, add error feedback to conversation // 3. Retry up to 3 times with feedback // 4. Return success or validation error after max retries response, err := llmInstance.Stream(ctx, messages, options, nil) if err != nil { // Check if it's a validation error after retries if strings.Contains(err.Error(), "tool call validation failed after") && strings.Contains(err.Error(), "retries") { t.Logf("✓ Automatic validation retry exhausted: %v", err) } else if strings.Contains(err.Error(), "validation") { t.Logf("✓ Validation failed: %v", err) } else { t.Logf("Request failed (non-validation): %v", err) } } else if response != nil { if len(response.ToolCalls) > 0 { t.Logf("✓ Tool call succeeded (possibly after auto-retry): %+v", response.ToolCalls[0]) // Verify the tool call arguments are valid tc := response.ToolCalls[0] var args map[string]interface{} if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err == nil { if status, ok := args["status"].(string); ok { if status != "active" && status != "inactive" { t.Errorf("Status should be 'active' or 'inactive', got: %s", status) } } if priority, ok := args["priority"].(float64); ok { if priority < 1 || priority > 5 { t.Errorf("Priority should be between 1-5, got: %v", priority) } } } } else { t.Log("✓ Response returned but no tool calls") } } t.Log("Automatic tool call validation retry test completed") } // TestOpenAIJSONMode tests JSON mode response formatting func TestOpenAIJSONMode(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Streaming: true, ToolCalls: true, }, ResponseFormat: &context.ResponseFormat{ Type: context.ResponseFormatJSON, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Generate a JSON object with fields: name (string), age (number), city (string). Use values: John, 30, New York", }, } ctx := newTestContext("test-json-mode", "openai.gpt-4o") // Test streaming with JSON mode response, err := llmInstance.Stream(ctx, messages, options, nil) if err != nil { t.Fatalf("Stream with JSON mode failed: %v", err) } if response == nil { t.Fatal("Response is nil") } // Validate response contentStr, ok := response.Content.(string) if !ok || contentStr == "" { t.Error("Response content is empty or not a string") } // Try to parse as JSON var jsonData map[string]interface{} if err := json.Unmarshal([]byte(contentStr), &jsonData); err != nil { t.Errorf("Response is not valid JSON: %v\nContent: %s", err, contentStr) } else { t.Logf("✓ Response is valid JSON: %+v", jsonData) // Verify expected fields exist if _, hasName := jsonData["name"]; !hasName { t.Error("JSON response missing 'name' field") } if _, hasAge := jsonData["age"]; !hasAge { t.Error("JSON response missing 'age' field") } if _, hasCity := jsonData["city"]; !hasCity { t.Error("JSON response missing 'city' field") } } // Validate metadata if response.ID == "" { t.Error("Response ID is empty") } if response.Model == "" { t.Error("Response Model is empty") } if response.Usage == nil { t.Error("Response Usage is nil") } else { t.Logf("Usage: prompt=%d, completion=%d, total=%d", response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) } t.Log("JSON mode test completed successfully") } // TestOpenAIJSONModePost tests JSON mode with non-streaming func TestOpenAIJSONModePost(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ ToolCalls: true, }, ResponseFormat: &context.ResponseFormat{ Type: context.ResponseFormatJSON, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Return a JSON with: status='success', count=42", }, } ctx := newTestContext("test-json-mode-post", "openai.gpt-4o") // Test non-streaming with JSON mode response, err := llmInstance.Post(ctx, messages, options) if err != nil { t.Fatalf("Post with JSON mode failed: %v", err) } if response == nil { t.Fatal("Response is nil") } // Validate response content is JSON contentStr, ok := response.Content.(string) if !ok || contentStr == "" { t.Error("Response content is empty or not a string") } var jsonData map[string]interface{} if err := json.Unmarshal([]byte(contentStr), &jsonData); err != nil { t.Errorf("Response is not valid JSON: %v\nContent: %s", err, contentStr) } else { t.Logf("✓ Response is valid JSON: %+v", jsonData) } // Validate metadata if response.Usage == nil { t.Error("Response Usage is nil") } else { if response.Usage.TotalTokens == 0 { t.Error("Response Usage.TotalTokens is 0") } t.Logf("Usage: prompt=%d, completion=%d, total=%d", response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) } t.Log("JSON mode Post test completed successfully") } // TestOpenAIJSONSchema tests JSON mode with strict schema validation func TestOpenAIJSONSchema(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } // Define a strict JSON schema // Note: For OpenAI strict mode, 'required' must include ALL properties schema := map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "user": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "name": map[string]interface{}{ "type": "string", "description": "User's full name", }, "email": map[string]interface{}{ "type": "string", "description": "User's email address", }, "age": map[string]interface{}{ "type": "integer", "description": "User's age", }, "isActive": map[string]interface{}{ "type": "boolean", "description": "Whether user is active", }, }, "required": []string{"name", "email", "age", "isActive"}, "additionalProperties": false, }, }, "required": []string{"user"}, "additionalProperties": false, } options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Streaming: true, ToolCalls: true, }, ResponseFormat: &context.ResponseFormat{ Type: context.ResponseFormatJSONSchema, JSONSchema: &context.JSONSchema{ Name: "user_info", Description: "User information schema", Schema: schema, Strict: func() *bool { v := true; return &v }(), }, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Create user info for: Alice Smith, alice@example.com, age 28, active user", }, } ctx := newTestContext("test-json-schema", "openai.gpt-4o") // Test streaming with JSON schema response, err := llmInstance.Stream(ctx, messages, options, nil) if err != nil { t.Fatalf("Stream with JSON schema failed: %v", err) } if response == nil { t.Fatal("Response is nil") } // Validate response content contentStr, ok := response.Content.(string) if !ok || contentStr == "" { t.Fatal("Response content is empty or not a string") } // Parse as JSON var jsonData map[string]interface{} if err := json.Unmarshal([]byte(contentStr), &jsonData); err != nil { t.Fatalf("Response is not valid JSON: %v\nContent: %s", err, contentStr) } t.Logf("✓ Response is valid JSON: %+v", jsonData) // Verify structure matches schema user, hasUser := jsonData["user"].(map[string]interface{}) if !hasUser { t.Fatal("JSON response missing 'user' object") } // Verify required fields if _, hasName := user["name"]; !hasName { t.Error("User object missing required 'name' field") } if _, hasEmail := user["email"]; !hasEmail { t.Error("User object missing required 'email' field") } // Verify field types if name, ok := user["name"].(string); ok { t.Logf("✓ name: %s (string)", name) } else { t.Error("name is not a string") } if email, ok := user["email"].(string); ok { t.Logf("✓ email: %s (string)", email) } else { t.Error("email is not a string") } if age, ok := user["age"].(float64); ok { if age < 0 || age > 150 { t.Errorf("age %v is out of range [0, 150]", age) } t.Logf("✓ age: %v (integer, in range)", age) } if isActive, ok := user["isActive"].(bool); ok { t.Logf("✓ isActive: %v (boolean)", isActive) } // Validate metadata if response.Usage == nil { t.Error("Response Usage is nil") } else { t.Logf("Usage: prompt=%d, completion=%d, total=%d", response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) } t.Log("JSON schema test completed successfully") } // TestOpenAIJSONSchemaPost tests JSON schema with non-streaming func TestOpenAIJSONSchemaPost(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } // Simple schema for testing // Note: For OpenAI strict mode, 'required' must include ALL properties schema := map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "status": map[string]interface{}{ "type": "string", "enum": []string{"success", "error", "pending"}, }, "message": map[string]interface{}{ "type": "string", }, "code": map[string]interface{}{ "type": "integer", }, }, "required": []string{"status", "message", "code"}, "additionalProperties": false, } options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ ToolCalls: true, }, ResponseFormat: &context.ResponseFormat{ Type: context.ResponseFormatJSONSchema, JSONSchema: &context.JSONSchema{ Name: "api_response", Description: "API response format", Schema: schema, Strict: func() *bool { v := true; return &v }(), }, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Generate an API response with status 'success', message 'Operation completed', and code 200", }, } ctx := newTestContext("test-json-schema-post", "openai.gpt-4o") // Test non-streaming with JSON schema response, err := llmInstance.Post(ctx, messages, options) if err != nil { t.Fatalf("Post with JSON schema failed: %v", err) } if response == nil { t.Fatal("Response is nil") } // Validate response content contentStr, ok := response.Content.(string) if !ok || contentStr == "" { t.Fatal("Response content is empty or not a string") } // Parse and validate JSON var jsonData map[string]interface{} if err := json.Unmarshal([]byte(contentStr), &jsonData); err != nil { t.Fatalf("Response is not valid JSON: %v\nContent: %s", err, contentStr) } t.Logf("✓ Response is valid JSON: %+v", jsonData) // Verify required fields status, hasStatus := jsonData["status"].(string) if !hasStatus { t.Fatal("Missing required 'status' field") } // Verify enum constraint validStatuses := map[string]bool{"success": true, "error": true, "pending": true} if !validStatuses[status] { t.Errorf("status '%s' is not in enum [success, error, pending]", status) } if _, hasMessage := jsonData["message"].(string); !hasMessage { t.Error("Missing required 'message' field") } // Validate metadata if response.Usage == nil { t.Error("Response Usage is nil") } else { t.Logf("Usage: prompt=%d, completion=%d, total=%d", response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) } t.Log("JSON schema Post test completed successfully") } // TestOpenAIProxySupport tests that HTTP proxy configuration is respected func TestOpenAIProxySupport(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // This test verifies proxy support exists in the connector configuration // Actual proxy testing requires a real proxy server setup conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } settings := conn.Setting() t.Logf("Connector settings: %+v", settings) // Verify host field exists in settings (host is the API endpoint) if host, hasHost := settings["host"]; hasHost { t.Logf("API host configured: %v", host) } else { t.Log("Host field not in settings (will use default)") } // The actual HTTP proxy functionality is implemented via environment variables // (HTTP_PROXY, HTTPS_PROXY, NO_PROXY) and handled by http.GetTransport t.Log("HTTP proxy support is implemented via http.GetTransport using environment variables") } // TestOpenAIStreamLifecycleEvents tests that LLM-level lifecycle events are sent correctly // LLM layer sends group_start/end for individual messages (thinking, text, tool_call) // Note: stream_start/end and Agent-level blocks are handled at Agent level func TestOpenAIStreamLifecycleEvents(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Streaming: true, ToolCalls: true, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Say 'hello' in one word", }, } ctx := newTestContext("test-lifecycle", "openai.gpt-4o") // Track lifecycle events (group_start/end at LLM layer represent message boundaries) var events []string var groupStartReceived, groupEndReceived bool handler := func(chunkType message.StreamChunkType, data []byte) int { events = append(events, string(chunkType)) switch chunkType { case message.ChunkStreamStart: t.Error("❌ LLM layer should NOT send stream_start (now sent at Agent level)") case message.ChunkStreamEnd: t.Error("❌ LLM layer should NOT send stream_end (now sent at Agent level)") case message.ChunkMessageStart: groupStartReceived = true var startData message.EventMessageStartData if err := json.Unmarshal(data, &startData); err == nil { t.Logf("✓ group_start (message start): type=%s, id=%s", startData.Type, startData.MessageID) if startData.MessageID == "" { t.Error("group_start missing message_id") } } else { t.Errorf("Failed to parse group_start data: %v", err) } case message.ChunkMessageEnd: groupEndReceived = true var endData message.EventMessageEndData if err := json.Unmarshal(data, &endData); err == nil { t.Logf("✓ group_end (message end): type=%s, chunks=%d, duration=%dms", endData.Type, endData.ChunkCount, endData.DurationMs) if endData.ChunkCount <= 0 { t.Error("group_end should have chunk_count > 0") } } else { t.Errorf("Failed to parse group_end data: %v", err) } case message.ChunkText: t.Logf(" text chunk: %s", string(data)) } return 0 } response, err := llmInstance.Stream(ctx, messages, options, handler) if err != nil { t.Fatalf("Stream failed: %v", err) } if response == nil { t.Fatal("Response is nil") } // Validate that LLM-level message lifecycle events were received if !groupStartReceived { t.Error("group_start (message start) event was not received") } if !groupEndReceived { t.Error("group_end (message end) event was not received") } // Validate event order: group_start should come before group_end if len(events) < 2 { t.Errorf("Expected at least 2 events (message start/end), got %d", len(events)) } t.Logf("Total events received: %d", len(events)) t.Log("LLM message lifecycle events test completed successfully") t.Log("Note: LLM layer group_start/end represent message boundaries (thinking, text, tool_call)") t.Log(" Agent-level block boundaries and stream_start/end are handled at Agent level") } // TestOpenAIStreamContextCancellation tests that stream respects context cancellation func TestOpenAIStreamContextCancellation(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Streaming: true, ToolCalls: true, }, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Write a very long essay about the history of computing", // Long task }, } // Create a context with a very short timeout ctx := newTestContext("test-cancel", "openai.gpt-4o") goCtx, cancel := gocontext.WithTimeout(gocontext.Background(), 100*time.Millisecond) defer cancel() ctx.Context = goCtx var receivedChunks int handler := func(chunkType message.StreamChunkType, data []byte) int { if chunkType == message.ChunkText || chunkType == message.ChunkToolCall { receivedChunks++ } // Note: stream_end is now sent at Agent level, not LLM level if chunkType == message.ChunkStreamEnd { t.Error("❌ LLM layer should NOT send stream_end (now sent at Agent level)") } return 0 } response, err := llmInstance.Stream(ctx, messages, options, handler) // Should get an error due to context cancellation if err == nil { t.Error("Expected error due to context cancellation, but got nil") } else { t.Logf("✓ Got expected cancellation error: %v", err) // Check if error message indicates cancellation errStr := err.Error() if !strings.Contains(errStr, "context") && !strings.Contains(errStr, "cancel") { t.Errorf("Error should mention context/cancellation: %v", err) } } // Response should be nil due to cancellation if response != nil { t.Logf("Warning: Response is not nil despite cancellation (partial response)") } t.Logf("Received %d chunks before cancellation", receivedChunks) t.Log("Context cancellation test completed successfully") t.Log("Note: stream_end for cancellation is now sent at Agent level") } // TestOpenAIStreamWithTemperature tests different temperature settings func TestOpenAIStreamWithTemperature(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } temperature := 0.7 // Moderate temperature options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Streaming: true, ToolCalls: true, // Need this to select OpenAI provider }, Temperature: &temperature, } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Say 'yes' in one word.", }, } ctx := newTestContext("test-temperature", "openai.gpt-4o") // Use callback to collect chunks chunkCount := 0 var callback message.StreamFunc = func(chunkType message.StreamChunkType, data []byte) int { chunkCount++ return 1 // Continue } response, err := llmInstance.Stream(ctx, messages, options, callback) if err != nil { t.Fatalf("Stream failed: %v", err) } if response == nil { t.Fatal("Response is nil") } // Validate response data if response.ID == "" { t.Error("Response ID is empty") } if response.Model == "" { t.Error("Response Model is empty") } if response.Content == "" { t.Error("Response Content is empty") } if response.FinishReason == "" { t.Error("Response FinishReason is empty") } if response.Usage == nil { t.Error("Response Usage is nil") } else { if response.Usage.TotalTokens == 0 { t.Error("Response Usage.TotalTokens is 0") } t.Logf("Usage: prompt=%d, completion=%d, total=%d", response.Usage.PromptTokens, response.Usage.CompletionTokens, response.Usage.TotalTokens) } if chunkCount == 0 { t.Error("No chunks received") } t.Logf("Response with temperature=0.7: %+v", response) t.Logf("Total chunks received: %d", chunkCount) } // ============================================================================ // Helper Functions // ============================================================================ // newTestContext creates a real Context for testing OpenAI provider func newTestContext(chatID, connectorID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client", UserID: "test-user-123", TeamID: "test-team-456", TenantID: "test-tenant-789", SessionID: "test-session-id", Constraints: types.DataConstraints{ TeamOnly: true, Extra: map[string]interface{}{ "test": "openai-provider", }, }, } ctx := context.New(gocontext.Background(), authorized, chatID) ctx.AssistantID = "test-assistant" ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "OpenAIProviderTest/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptStandard ctx.Route = "/api/test" ctx.Metadata = make(map[string]interface{}) return ctx } ================================================ FILE: agent/llm/providers/openai/temperature_test.go ================================================ package openai_test import ( gocontext "context" "testing" "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/connector/openai" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/llm" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/openapi/oauth/types" "github.com/yaoapp/yao/test" ) // TestTemperatureGPT5AutoReset tests that GPT-5 automatically resets temperature to 1.0 // Temporarily commented out // func TestTemperatureGPT5AutoReset(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // // conn, err := connector.Select("openai.gpt-5") // if err != nil { // t.Fatalf("Failed to select connector: %v", err) // } // // invalidTemp := 0.7 // GPT-5 doesn't support this // options := &context.CompletionOptions{ // Capabilities: &openai.Capabilities{ // Reasoning: true, // }, // Temperature: &invalidTemp, // Should be reset to 1.0 // } // // llmInstance, err := llm.New(conn, options) // if err != nil { // t.Fatalf("Failed to create LLM instance: %v", err) // } // // messages := []context.Message{ // { // Role: context.RoleUser, // Content: "Say 'OK'", // }, // } // // maxTokens := 10 // options.MaxCompletionTokens = &maxTokens // // ctx := newTemperatureTestContext("test-gpt5-temp", "openai.gpt-5") // // // Should succeed (temperature automatically reset to 1.0) // response, err := llmInstance.Post(ctx, messages, options) // if err != nil { // t.Fatalf("Post failed: %v", err) // } // // if response == nil { // t.Fatal("Response is nil") // } // // t.Log("✓ GPT-5 successfully handled invalid temperature by resetting to 1.0") // t.Logf("Response: %v", response.Content) // } // TestTemperatureDeepSeekR1AutoReset tests that DeepSeek R1 automatically resets temperature to 1.0 func TestTemperatureDeepSeekR1AutoReset(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("deepseek.r1") if err != nil { t.Fatalf("Failed to select connector: %v", err) } invalidTemp := 0.5 // DeepSeek R1 doesn't support this options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Reasoning: true, }, Temperature: &invalidTemp, // Should be reset to 1.0 } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Say 'Hello'", }, } maxTokens := 100 options.MaxCompletionTokens = &maxTokens ctx := newTemperatureTestContext("test-deepseek-r1-temp", "deepseek.r1") // Should succeed (temperature automatically reset to 1.0) response, err := llmInstance.Post(ctx, messages, options) if err != nil { t.Fatalf("Post failed: %v", err) } if response == nil { t.Fatal("Response is nil") } t.Log("✓ DeepSeek R1 successfully handled invalid temperature by resetting to 1.0") t.Logf("Response content: %v", response.Content) if response.ReasoningContent != "" { t.Logf("Reasoning content length: %d", len(response.ReasoningContent)) } } // TestTemperatureGPT4oPreserved tests that GPT-4o preserves custom temperature func TestTemperatureGPT4oPreserved(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("openai.gpt-4o") if err != nil { t.Fatalf("Failed to select connector: %v", err) } customTemp := 0.3 // GPT-4o should preserve this options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Reasoning: false, // Not a reasoning model ToolCalls: true, }, Temperature: &customTemp, // Should be preserved } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Say 'OK'", }, } maxTokens := 10 options.MaxCompletionTokens = &maxTokens ctx := newTemperatureTestContext("test-gpt4o-temp", "openai.gpt-4o") // Should succeed with custom temperature preserved response, err := llmInstance.Post(ctx, messages, options) if err != nil { t.Fatalf("Post failed: %v", err) } if response == nil { t.Fatal("Response is nil") } t.Log("✓ GPT-4o successfully preserved custom temperature (0.3)") t.Logf("Response: %v", response.Content) } // TestTemperatureDeepSeekV3Preserved tests that DeepSeek V3 preserves custom temperature func TestTemperatureDeepSeekV3Preserved(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() conn, err := connector.Select("deepseek.v3") if err != nil { t.Fatalf("Failed to select connector: %v", err) } customTemp := 0.8 // DeepSeek V3 should preserve this options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Reasoning: false, // Not a reasoning model ToolCalls: true, }, Temperature: &customTemp, // Should be preserved } llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Say 'Hello World'", }, } maxTokens := 20 options.MaxCompletionTokens = &maxTokens ctx := newTemperatureTestContext("test-deepseek-v3-temp", "deepseek.v3") // Should succeed with custom temperature preserved response, err := llmInstance.Post(ctx, messages, options) if err != nil { t.Fatalf("Post failed: %v", err) } if response == nil { t.Fatal("Response is nil") } t.Log("✓ DeepSeek V3 successfully preserved custom temperature (0.8)") t.Logf("Response: %v", response.Content) } // TestTemperatureGPT5Default tests that GPT-5 with temperature=1.0 works fine // Temporarily commented out // func TestTemperatureGPT5Default(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // // conn, err := connector.Select("openai.gpt-5") // if err != nil { // t.Fatalf("Failed to select connector: %v", err) // } // // defaultTemp := 1.0 // GPT-5's valid temperature // options := &context.CompletionOptions{ // Capabilities: &openai.Capabilities{ // Reasoning: true, // }, // Temperature: &defaultTemp, // Should work fine // } // // llmInstance, err := llm.New(conn, options) // if err != nil { // t.Fatalf("Failed to create LLM instance: %v", err) // } // // messages := []context.Message{ // { // Role: context.RoleUser, // Content: "What is 2+2? Reply with just the number.", // }, // } // // maxTokens := 10 // options.MaxCompletionTokens = &maxTokens // // ctx := newTemperatureTestContext("test-gpt5-temp-default", "openai.gpt-5") // // // Should succeed with default temperature // response, err := llmInstance.Post(ctx, messages, options) // if err != nil { // t.Fatalf("Post failed: %v", err) // } // // if response == nil { // t.Fatal("Response is nil") // } // // t.Log("✓ GPT-5 successfully handled default temperature (1.0)") // t.Logf("Response: %v", response.Content) // } // TestTemperatureNoTemperatureProvided tests that models work when no temperature is provided func TestTemperatureNoTemperatureProvided(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() testCases := []struct { name string connector string reasoning bool }{ // {"GPT-5 No Temp", "openai.gpt-5", true}, // Temporarily commented out {"GPT-4o No Temp", "openai.gpt-4o", false}, {"DeepSeek R1 No Temp", "deepseek.r1", true}, {"DeepSeek V3 No Temp", "deepseek.v3", false}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { conn, err := connector.Select(tc.connector) if err != nil { t.Fatalf("Failed to select connector: %v", err) } options := &context.CompletionOptions{ Capabilities: &openai.Capabilities{ Reasoning: false, ToolCalls: true, }, } if tc.reasoning { options.Capabilities.Reasoning = true } // Temperature not set - should use API default llmInstance, err := llm.New(conn, options) if err != nil { t.Fatalf("Failed to create LLM instance: %v", err) } messages := []context.Message{ { Role: context.RoleUser, Content: "Say 'OK'", }, } maxTokens := 10 options.MaxCompletionTokens = &maxTokens ctx := newTemperatureTestContext("test-no-temp-"+tc.connector, tc.connector) response, err := llmInstance.Post(ctx, messages, options) if err != nil { t.Fatalf("Post failed: %v", err) } if response == nil { t.Fatal("Response is nil") } t.Logf("✓ %s works fine without temperature parameter", tc.connector) }) } } // ============================================================================ // Helper Functions // ============================================================================ // newTemperatureTestContext creates a real Context for testing temperature handling func newTemperatureTestContext(chatID, connectorID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: "test-user", ClientID: "test-client", UserID: "test-user-123", TeamID: "test-team-456", TenantID: "test-tenant-789", SessionID: "test-session-id", Constraints: types.DataConstraints{ TeamOnly: true, Extra: map[string]interface{}{ "test": "temperature", }, }, } ctx := context.New(gocontext.Background(), authorized, chatID) ctx.AssistantID = "test-assistant" ctx.Locale = "en-us" ctx.Theme = "light" ctx.Client = context.Client{ Type: "web", UserAgent: "TemperatureTest/1.0", IP: "127.0.0.1", } ctx.Referer = context.RefererAPI ctx.Accept = context.AcceptStandard ctx.Route = "/api/test" ctx.Metadata = make(map[string]interface{}) return ctx } ================================================ FILE: agent/llm/providers/openai/types.go ================================================ package openai import ( "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" ) // StreamChunk represents a chunk from OpenAI's streaming response type StreamChunk struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []Delta `json:"choices"` Usage *struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } `json:"usage,omitempty"` } // Delta represents the delta in a streaming chunk type Delta struct { Index int `json:"index"` Delta DeltaContent `json:"delta"` FinishReason *string `json:"finish_reason"` } // DeltaContent represents the content in a delta type DeltaContent struct { Role string `json:"role,omitempty"` Content string `json:"content,omitempty"` ReasoningContent string `json:"reasoning_content,omitempty"` // DeepSeek R1 reasoning ToolCalls []ToolCallDelta `json:"tool_calls,omitempty"` Refusal string `json:"refusal,omitempty"` } // ToolCallDelta represents a tool call delta in streaming type ToolCallDelta struct { Index int `json:"index"` ID string `json:"id,omitempty"` Type string `json:"type,omitempty"` Function FunctionCallDelta `json:"function,omitempty"` } // FunctionCallDelta represents a function call delta type FunctionCallDelta struct { Name string `json:"name,omitempty"` Arguments string `json:"arguments,omitempty"` } // CompletionResponseFull represents the full non-streaming response type CompletionResponseFull struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []struct { Index int `json:"index"` Message struct { Role context.MessageRole `json:"role"` Content interface{} `json:"content,omitempty"` // string or array ReasoningContent string `json:"reasoning_content,omitempty"` // DeepSeek R1 reasoning ToolCalls []context.ToolCall `json:"tool_calls,omitempty"` Refusal *string `json:"refusal,omitempty"` } `json:"message"` FinishReason string `json:"finish_reason"` } `json:"choices"` Usage *message.UsageInfo `json:"usage,omitempty"` SystemFingerprint string `json:"system_fingerprint,omitempty"` } // streamAccumulator accumulates streaming response data type streamAccumulator struct { id string model string created int64 role string content string reasoningContent string // DeepSeek R1 reasoning content refusal string toolCalls map[int]*accumulatedToolCall finishReason string usage *message.UsageInfo } // accumulatedToolCall accumulates a single tool call type accumulatedToolCall struct { id string typ string functionName string functionArgs string } // messageTracker tracks the current message state for lifecycle events type messageTracker struct { active bool // Whether a message is currently active messageID string // Current message ID messageType message.StreamChunkType // Current message type (thinking, text, tool_call) startTime int64 // Message start timestamp chunkCount int // Number of chunks in this message toolCallInfo *message.EventToolCallInfo // Tool call info if message is tool_call type idGenerator *message.IDGenerator // ID generator from context } ================================================ FILE: agent/load.go ================================================ package agent import ( "fmt" "path/filepath" "github.com/yaoapp/gou/application" "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/helper" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" searchDefaults "github.com/yaoapp/yao/agent/search/defaults" searchTypes "github.com/yaoapp/yao/agent/search/types" storeMongo "github.com/yaoapp/yao/agent/store/mongo" storeRedis "github.com/yaoapp/yao/agent/store/redis" store "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/agent/store/xun" "github.com/yaoapp/yao/agent/types" "github.com/yaoapp/yao/config" ) var agentDSL *types.DSL // Load load AIGC func Load(cfg config.Config) error { setting := types.DSL{ Cache: "__yao.agent.cache", // default is "__yao.agent.cache" StoreSetting: store.Setting{ MaxSize: 20, TTL: 90 * 24 * 60 * 60, // 90 days in seconds }, } bytes, err := application.App.Read(filepath.Join("agent", "agent.yml")) if err != nil { return err } err = application.Parse("agent.yml", bytes, &setting) if err != nil { return err } if setting.StoreSetting.MaxSize == 0 { setting.StoreSetting.MaxSize = 20 // default is 20 } // Resolve $ENV.XXX references in system and uses fields resolveEnvStrings(&setting) // Default Assistant, Agent is the developer name, Mohe is the brand name of the assistant if setting.Uses == nil { setting.Uses = &types.Uses{Default: "mohe"} // Agent is the developer name, Mohe is the brand name of the assistant } // Title Assistant (default to system agent) if setting.Uses.Title == "" { setting.Uses.Title = "__yao.title" } // Prompt Assistant (default to system agent) if setting.Uses.Prompt == "" { setting.Uses.Prompt = "__yao.prompt" } // RobotPrompt Assistant (default to system agent) if setting.Uses.RobotPrompt == "" { setting.Uses.RobotPrompt = "__yao.robot_prompt" } agentDSL = &setting // Store Setting err = initStore() if err != nil { return err } // Initialize Global I18n err = initGlobalI18n() if err != nil { return err } // Initialize Global Prompts err = initGlobalPrompts() if err != nil { return err } // Initialize KB Configuration err = initKBConfig() if err != nil { return err } // Initialize Search Configuration err = initSearchConfig() if err != nil { return err } // Initialize Assistant err = initAssistant() if err != nil { return err } return nil } // GetAgent returns the Agent settings func GetAgent() *types.DSL { return agentDSL } // initGlobalI18n initialize the global i18n func initGlobalI18n() error { locales, err := i18n.GetLocales("agent") if err != nil { return err } i18n.Locales["__global__"] = locales.Flatten() return nil } // initGlobalPrompts initialize the global prompts from agent/prompts.yml func initGlobalPrompts() error { prompts, _, err := store.LoadGlobalPrompts() if err != nil { return err } agentDSL.GlobalPrompts = prompts return nil } // GetGlobalPrompts returns the global prompts // ctx: context variables for parsing $CTX.* variables func GetGlobalPrompts(ctx map[string]string) []store.Prompt { if agentDSL == nil || len(agentDSL.GlobalPrompts) == 0 { return nil } return store.Prompts(agentDSL.GlobalPrompts).Parse(ctx) } // initStore initialize the store func initStore() error { var err error if agentDSL.StoreSetting.Connector == "default" || agentDSL.StoreSetting.Connector == "" { agentDSL.Store, err = xun.NewXun(agentDSL.StoreSetting) return err } // other connector conn, err := connector.Select(agentDSL.StoreSetting.Connector) if err != nil { return fmt.Errorf("load connectors error: %s", err.Error()) } if conn.Is(connector.DATABASE) { agentDSL.Store, err = xun.NewXun(agentDSL.StoreSetting) return err } else if conn.Is(connector.REDIS) { agentDSL.Store = storeRedis.NewRedis() return nil } else if conn.Is(connector.MONGO) { agentDSL.Store = storeMongo.NewMongo() return nil } return fmt.Errorf("Agent store connector %s not support", agentDSL.StoreSetting.Connector) } // initAssistant initialize the assistant func initAssistant() error { // Set Storage assistant.SetStorage(agentDSL.Store) // Set Store Setting (MaxSize, TTL, etc.) assistant.SetStoreSetting(&agentDSL.StoreSetting) // Set global Uses configuration if agentDSL.Uses != nil { globalUses := &context.Uses{ Vision: agentDSL.Uses.Vision, Audio: agentDSL.Uses.Audio, Search: agentDSL.Uses.Search, Fetch: agentDSL.Uses.Fetch, Web: agentDSL.Uses.Web, Keyword: agentDSL.Uses.Keyword, QueryDSL: agentDSL.Uses.QueryDSL, Rerank: agentDSL.Uses.Rerank, } assistant.SetGlobalUses(globalUses) } // Set global prompts if len(agentDSL.GlobalPrompts) > 0 { assistant.SetGlobalPrompts(agentDSL.GlobalPrompts) } if agentDSL.KB != nil { assistant.SetGlobalKBSetting(agentDSL.KB) } if agentDSL.Search != nil { assistant.SetGlobalSearchConfig(agentDSL.Search) } // Set system agents configuration if agentDSL.System != nil { assistant.SetSystemConfig(&assistant.SystemConfig{ Default: agentDSL.System.Default, Keyword: agentDSL.System.Keyword, QueryDSL: agentDSL.System.QueryDSL, Title: agentDSL.System.Title, Prompt: agentDSL.System.Prompt, NeedSearch: agentDSL.System.NeedSearch, Entity: agentDSL.System.Entity, }) } // Load System Agents (from bindata: __yao.keyword, __yao.querydsl, etc.) if err := assistant.LoadSystemAgents(); err != nil { return err } // Load Built-in Assistants (from application /assistants directory) err := assistant.LoadBuiltIn() if err != nil { return err } // Default Assistant defaultAssistant, err := defaultAssistant() if err != nil { return err } agentDSL.Assistant = defaultAssistant return nil } // initKBConfig initialize the knowledge base configuration from agent/kb.yml func initKBConfig() error { path := filepath.Join("agent", "kb.yml") if exists, _ := application.App.Exists(path); !exists { return nil // KB config is optional } // Read the KB configuration bytes, err := application.App.Read(path) if err != nil { return err } var kbSetting store.KBSetting err = application.Parse("kb.yml", bytes, &kbSetting) if err != nil { return err } agentDSL.KB = &kbSetting return nil } // initSearchConfig initialize the search configuration from agent/search.yml func initSearchConfig() error { // Start with system defaults agentDSL.Search = searchDefaults.SystemDefaults path := filepath.Join("agent", "search.yml") if exists, _ := application.App.Exists(path); !exists { return nil // Search config is optional, use defaults } // Read the search configuration bytes, err := application.App.Read(path) if err != nil { return err } var searchConfig searchTypes.Config err = application.Parse("search.yml", bytes, &searchConfig) if err != nil { return err } // Merge with defaults agentDSL.Search = mergeSearchConfig(searchDefaults.SystemDefaults, &searchConfig) return nil } // mergeSearchConfig merges two search configs (base < override) func mergeSearchConfig(base, override *searchTypes.Config) *searchTypes.Config { if base == nil { return override } if override == nil { return base } result := *base // Copy base // Merge Web config if override.Web != nil { if result.Web == nil { result.Web = override.Web } else { if override.Web.Provider != "" { result.Web.Provider = override.Web.Provider } if override.Web.APIKeyEnv != "" { result.Web.APIKeyEnv = override.Web.APIKeyEnv } if override.Web.MaxResults > 0 { result.Web.MaxResults = override.Web.MaxResults } } } // Merge KB config if override.KB != nil { if result.KB == nil { result.KB = override.KB } else { if len(override.KB.Collections) > 0 { result.KB.Collections = override.KB.Collections } if override.KB.Threshold > 0 { result.KB.Threshold = override.KB.Threshold } if override.KB.Graph { result.KB.Graph = override.KB.Graph } } } // Merge DB config if override.DB != nil { if result.DB == nil { result.DB = override.DB } else { if len(override.DB.Models) > 0 { result.DB.Models = override.DB.Models } if override.DB.MaxResults > 0 { result.DB.MaxResults = override.DB.MaxResults } } } // Merge Keyword config if override.Keyword != nil { if result.Keyword == nil { result.Keyword = override.Keyword } else { if override.Keyword.MaxKeywords > 0 { result.Keyword.MaxKeywords = override.Keyword.MaxKeywords } if override.Keyword.Language != "" { result.Keyword.Language = override.Keyword.Language } } } // Merge QueryDSL config if override.QueryDSL != nil { result.QueryDSL = override.QueryDSL } // Merge Rerank config if override.Rerank != nil { if result.Rerank == nil { result.Rerank = override.Rerank } else { if override.Rerank.TopN > 0 { result.Rerank.TopN = override.Rerank.TopN } } } // Merge Citation config if override.Citation != nil { if result.Citation == nil { result.Citation = override.Citation } else { if override.Citation.Format != "" { result.Citation.Format = override.Citation.Format } // AutoInjectPrompt is a bool, need to check if explicitly set result.Citation.AutoInjectPrompt = override.Citation.AutoInjectPrompt if override.Citation.CustomPrompt != "" { result.Citation.CustomPrompt = override.Citation.CustomPrompt } } } // Merge Weights config if override.Weights != nil { if result.Weights == nil { result.Weights = override.Weights } else { if override.Weights.User > 0 { result.Weights.User = override.Weights.User } if override.Weights.Hook > 0 { result.Weights.Hook = override.Weights.Hook } if override.Weights.Auto > 0 { result.Weights.Auto = override.Weights.Auto } } } // Merge Options config if override.Options != nil { if result.Options == nil { result.Options = override.Options } else { if override.Options.SkipThreshold > 0 { result.Options.SkipThreshold = override.Options.SkipThreshold } } } return &result } // GetSearchConfig returns the global search configuration func GetSearchConfig() *searchTypes.Config { if agentDSL == nil { return searchDefaults.SystemDefaults } return agentDSL.Search } // defaultAssistant get the default assistant func defaultAssistant() (*assistant.Assistant, error) { if agentDSL.Uses == nil || agentDSL.Uses.Default == "" { return nil, fmt.Errorf("default assistant not found") } return assistant.Get(agentDSL.Uses.Default) } // resolveEnvStrings resolves $ENV.XXX references in agent.yml string fields. // agent.yml is parsed via yaml.Unmarshal which does not handle $ENV substitution, // unlike connector files which call helper.EnvString explicitly during Register. func resolveEnvStrings(setting *types.DSL) { if setting.System != nil { setting.System.Default = helper.EnvString(setting.System.Default) setting.System.Keyword = helper.EnvString(setting.System.Keyword) setting.System.QueryDSL = helper.EnvString(setting.System.QueryDSL) setting.System.Title = helper.EnvString(setting.System.Title) setting.System.Prompt = helper.EnvString(setting.System.Prompt) setting.System.RobotPrompt = helper.EnvString(setting.System.RobotPrompt) setting.System.NeedSearch = helper.EnvString(setting.System.NeedSearch) setting.System.Entity = helper.EnvString(setting.System.Entity) } if setting.Uses != nil { setting.Uses.Default = helper.EnvString(setting.Uses.Default) setting.Uses.Title = helper.EnvString(setting.Uses.Title) setting.Uses.Prompt = helper.EnvString(setting.Uses.Prompt) setting.Uses.RobotPrompt = helper.EnvString(setting.Uses.RobotPrompt) setting.Uses.Vision = helper.EnvString(setting.Uses.Vision) setting.Uses.Audio = helper.EnvString(setting.Uses.Audio) setting.Uses.Search = helper.EnvString(setting.Uses.Search) setting.Uses.Fetch = helper.EnvString(setting.Uses.Fetch) setting.Uses.Web = helper.EnvString(setting.Uses.Web) setting.Uses.Keyword = helper.EnvString(setting.Uses.Keyword) setting.Uses.QueryDSL = helper.EnvString(setting.Uses.QueryDSL) setting.Uses.Rerank = helper.EnvString(setting.Uses.Rerank) } setting.Cache = helper.EnvString(setting.Cache) } ================================================ FILE: agent/load_test.go ================================================ package agent import ( "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/types" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func prepare(t *testing.T) { test.Prepare(t, config.Conf) err := Load(config.Conf) require.NoError(t, err) } func TestLoad(t *testing.T) { prepare(t) defer test.Clean() agent := GetAgent() require.NotNil(t, agent) t.Run("LoadAgentSettings", func(t *testing.T) { // Cache setting assert.NotEmpty(t, agent.Cache) // Store setting assert.NotNil(t, agent.Store) assert.Greater(t, agent.StoreSetting.MaxSize, 0) // Uses setting assert.NotNil(t, agent.Uses) assert.NotEmpty(t, agent.Uses.Default) }) t.Run("LoadDefaultAssistant", func(t *testing.T) { assert.NotNil(t, agent.Assistant) }) t.Run("LoadGlobalPrompts", func(t *testing.T) { // Global prompts should be loaded from agent/prompts.yml assert.NotNil(t, agent.GlobalPrompts) assert.Greater(t, len(agent.GlobalPrompts), 0) // First prompt should be system role assert.Equal(t, "system", agent.GlobalPrompts[0].Role) // Content should contain system context info (with variables not yet parsed) assert.Contains(t, agent.GlobalPrompts[0].Content, "$SYS.") }) t.Run("LoadKBConfig", func(t *testing.T) { // KB configuration should be loaded from agent/kb.yml assert.NotNil(t, agent.KB) assert.NotNil(t, agent.KB.Chat) // Verify chat KB settings assert.Equal(t, "__yao.openai", agent.KB.Chat.EmbeddingProviderID) assert.Equal(t, "text-embedding-3-small", agent.KB.Chat.EmbeddingOptionID) assert.Equal(t, "zh-CN", agent.KB.Chat.Locale) // Verify config assert.NotNil(t, agent.KB.Chat.Config) assert.Equal(t, "hnsw", agent.KB.Chat.Config.IndexType.String()) assert.Equal(t, "cosine", agent.KB.Chat.Config.Distance.String()) // Verify metadata assert.NotNil(t, agent.KB.Chat.Metadata) assert.Equal(t, "chat_session", agent.KB.Chat.Metadata["category"]) assert.Equal(t, true, agent.KB.Chat.Metadata["auto_created"]) // Verify document defaults assert.NotNil(t, agent.KB.Chat.DocumentDefaults) assert.NotNil(t, agent.KB.Chat.DocumentDefaults.Chunking) assert.Equal(t, "__yao.structured", agent.KB.Chat.DocumentDefaults.Chunking.ProviderID) assert.Equal(t, "standard", agent.KB.Chat.DocumentDefaults.Chunking.OptionID) assert.NotNil(t, agent.KB.Chat.DocumentDefaults.Extraction) assert.Equal(t, "__yao.openai", agent.KB.Chat.DocumentDefaults.Extraction.ProviderID) assert.Equal(t, "gpt-4o-mini", agent.KB.Chat.DocumentDefaults.Extraction.OptionID) assert.NotNil(t, agent.KB.Chat.DocumentDefaults.Converter) assert.Equal(t, "__yao.utf8", agent.KB.Chat.DocumentDefaults.Converter.ProviderID) assert.Equal(t, "standard-text", agent.KB.Chat.DocumentDefaults.Converter.OptionID) }) t.Run("LoadSearchConfig", func(t *testing.T) { // Search configuration should be loaded from agent/search.yml assert.NotNil(t, agent.Search) // Verify web config assert.NotNil(t, agent.Search.Web) assert.Equal(t, "tavily", agent.Search.Web.Provider) assert.Equal(t, 10, agent.Search.Web.MaxResults) // Verify KB config assert.NotNil(t, agent.Search.KB) assert.Equal(t, 0.7, agent.Search.KB.Threshold) assert.False(t, agent.Search.KB.Graph) // Verify DB config assert.NotNil(t, agent.Search.DB) assert.Equal(t, 20, agent.Search.DB.MaxResults) // Verify keyword config assert.NotNil(t, agent.Search.Keyword) assert.Equal(t, 10, agent.Search.Keyword.MaxKeywords) assert.Equal(t, "auto", agent.Search.Keyword.Language) // Verify rerank config assert.NotNil(t, agent.Search.Rerank) assert.Equal(t, 10, agent.Search.Rerank.TopN) // Verify citation config assert.NotNil(t, agent.Search.Citation) assert.Equal(t, "#ref:{id}", agent.Search.Citation.Format) assert.True(t, agent.Search.Citation.AutoInjectPrompt) // Verify weights config assert.NotNil(t, agent.Search.Weights) assert.Equal(t, 1.0, agent.Search.Weights.User) assert.Equal(t, 0.8, agent.Search.Weights.Hook) assert.Equal(t, 0.6, agent.Search.Weights.Auto) // Verify options config assert.NotNil(t, agent.Search.Options) assert.Equal(t, 5, agent.Search.Options.SkipThreshold) }) } func TestGetGlobalPrompts(t *testing.T) { prepare(t) defer test.Clean() t.Run("ParseWithoutContext", func(t *testing.T) { prompts := GetGlobalPrompts(nil) require.NotNil(t, prompts) require.Greater(t, len(prompts), 0) // $SYS.* variables should be replaced content := prompts[0].Content assert.NotContains(t, content, "$SYS.DATETIME") assert.NotContains(t, content, "$SYS.TIMEZONE") assert.NotContains(t, content, "$SYS.WEEKDAY") // Should contain actual time values now := time.Now() assert.Contains(t, content, now.Format("2006-01-02")) }) t.Run("ParseWithContext", func(t *testing.T) { ctx := map[string]string{ "USER_ID": "test-user-123", "LOCALE": "zh-CN", } prompts := GetGlobalPrompts(ctx) require.NotNil(t, prompts) require.Greater(t, len(prompts), 0) // $SYS.* variables should be replaced content := prompts[0].Content assert.NotContains(t, content, "$SYS.DATETIME") }) t.Run("ParseSystemTimeVariables", func(t *testing.T) { prompts := GetGlobalPrompts(nil) require.NotNil(t, prompts) content := prompts[0].Content now := time.Now() // Should contain current date assert.Contains(t, content, now.Format("2006-01-02")) // Should contain timezone assert.Contains(t, content, now.Location().String()) // Should contain weekday assert.Contains(t, content, now.Weekday().String()) }) } func TestGetGlobalPromptsWithDisableFlag(t *testing.T) { prepare(t) defer test.Clean() agent := GetAgent() require.NotNil(t, agent) t.Run("GlobalPromptsExist", func(t *testing.T) { // Verify global prompts are loaded assert.NotNil(t, agent.GlobalPrompts) assert.Greater(t, len(agent.GlobalPrompts), 0) }) t.Run("AssistantCanDisableGlobalPrompts", func(t *testing.T) { // The fullfields test assistant has disable_global_prompts: true // This test verifies the flag is properly loaded // The actual merging logic is in the assistant module prompts := GetGlobalPrompts(nil) assert.NotNil(t, prompts) // Global prompts should still be available // The assistant decides whether to use them based on DisableGlobalPrompts flag }) } func TestResolveEnvStrings(t *testing.T) { t.Setenv("TEST_CONNECTOR", "openai.gpt-5") t.Setenv("TEST_ASSISTANT", "my-assistant") t.Setenv("TEST_CACHE", "my-cache") t.Run("SystemFields", func(t *testing.T) { setting := &types.DSL{ System: &types.System{ Default: "$ENV.TEST_CONNECTOR", Keyword: "$ENV.TEST_CONNECTOR", QueryDSL: "$ENV.TEST_CONNECTOR", Title: "$ENV.TEST_CONNECTOR", Prompt: "$ENV.TEST_CONNECTOR", RobotPrompt: "$ENV.TEST_CONNECTOR", NeedSearch: "$ENV.TEST_CONNECTOR", Entity: "$ENV.TEST_CONNECTOR", }, } resolveEnvStrings(setting) assert.Equal(t, "openai.gpt-5", setting.System.Default) assert.Equal(t, "openai.gpt-5", setting.System.Keyword) assert.Equal(t, "openai.gpt-5", setting.System.QueryDSL) assert.Equal(t, "openai.gpt-5", setting.System.Title) assert.Equal(t, "openai.gpt-5", setting.System.Prompt) assert.Equal(t, "openai.gpt-5", setting.System.RobotPrompt) assert.Equal(t, "openai.gpt-5", setting.System.NeedSearch) assert.Equal(t, "openai.gpt-5", setting.System.Entity) }) t.Run("UsesFields", func(t *testing.T) { setting := &types.DSL{ Uses: &types.Uses{ Default: "$ENV.TEST_ASSISTANT", Title: "$ENV.TEST_ASSISTANT", Prompt: "$ENV.TEST_ASSISTANT", RobotPrompt: "$ENV.TEST_ASSISTANT", Vision: "$ENV.TEST_ASSISTANT", Audio: "$ENV.TEST_ASSISTANT", Search: "$ENV.TEST_ASSISTANT", Fetch: "$ENV.TEST_ASSISTANT", Web: "$ENV.TEST_ASSISTANT", Keyword: "$ENV.TEST_ASSISTANT", QueryDSL: "$ENV.TEST_ASSISTANT", Rerank: "$ENV.TEST_ASSISTANT", }, } resolveEnvStrings(setting) assert.Equal(t, "my-assistant", setting.Uses.Default) assert.Equal(t, "my-assistant", setting.Uses.Title) assert.Equal(t, "my-assistant", setting.Uses.Prompt) assert.Equal(t, "my-assistant", setting.Uses.RobotPrompt) assert.Equal(t, "my-assistant", setting.Uses.Vision) assert.Equal(t, "my-assistant", setting.Uses.Audio) assert.Equal(t, "my-assistant", setting.Uses.Search) assert.Equal(t, "my-assistant", setting.Uses.Fetch) assert.Equal(t, "my-assistant", setting.Uses.Web) assert.Equal(t, "my-assistant", setting.Uses.Keyword) assert.Equal(t, "my-assistant", setting.Uses.QueryDSL) assert.Equal(t, "my-assistant", setting.Uses.Rerank) }) t.Run("CacheField", func(t *testing.T) { setting := &types.DSL{Cache: "$ENV.TEST_CACHE"} resolveEnvStrings(setting) assert.Equal(t, "my-cache", setting.Cache) }) t.Run("PlainStringsUnchanged", func(t *testing.T) { setting := &types.DSL{ Cache: "plain-cache", System: &types.System{ Default: "openai.gpt-5", }, Uses: &types.Uses{ Default: "mohe", Title: "__yao.title", }, } resolveEnvStrings(setting) assert.Equal(t, "plain-cache", setting.Cache) assert.Equal(t, "openai.gpt-5", setting.System.Default) assert.Equal(t, "mohe", setting.Uses.Default) assert.Equal(t, "__yao.title", setting.Uses.Title) }) t.Run("NilSystemAndUses", func(t *testing.T) { setting := &types.DSL{Cache: "test"} assert.NotPanics(t, func() { resolveEnvStrings(setting) }) }) t.Run("UndefinedEnvReturnsEmpty", func(t *testing.T) { setting := &types.DSL{ System: &types.System{ Default: "$ENV.UNDEFINED_VAR_12345", }, } resolveEnvStrings(setting) assert.Equal(t, "", setting.System.Default) }) } func TestGlobalPromptsContent(t *testing.T) { prepare(t) defer test.Clean() agent := GetAgent() require.NotNil(t, agent) require.NotNil(t, agent.GlobalPrompts) require.Greater(t, len(agent.GlobalPrompts), 0) t.Run("SystemContextPrompt", func(t *testing.T) { // Find system prompt var systemPrompt string for _, p := range agent.GlobalPrompts { if p.Role == "system" { systemPrompt = p.Content break } } assert.NotEmpty(t, systemPrompt) assert.Contains(t, systemPrompt, "System Context") }) t.Run("VariablesInRawPrompts", func(t *testing.T) { // Raw prompts should contain unparsed variables content := agent.GlobalPrompts[0].Content assert.True(t, strings.Contains(content, "$SYS.") || strings.Contains(content, "$ENV.") || strings.Contains(content, "$CTX."), "Raw prompts should contain variable placeholders") }) } func TestAssistantGlobalPrompts(t *testing.T) { prepare(t) defer test.Clean() t.Run("AssistantModuleReceivesGlobalPrompts", func(t *testing.T) { // Verify assistant module has global prompts prompts := assistant.GetGlobalPrompts(nil) require.NotNil(t, prompts) require.Greater(t, len(prompts), 0) // Should be parsed (no $SYS.* variables) content := prompts[0].Content assert.NotContains(t, content, "$SYS.DATETIME") }) t.Run("AssistantModuleParsesWithContext", func(t *testing.T) { ctx := map[string]string{ "USER_ID": "assistant-test-user", "LOCALE": "en-US", } prompts := assistant.GetGlobalPrompts(ctx) require.NotNil(t, prompts) // $SYS.* should be replaced content := prompts[0].Content assert.NotContains(t, content, "$SYS.") // Should contain current time info now := time.Now() assert.Contains(t, content, now.Format("2006-01-02")) }) } ================================================ FILE: agent/memory/interfaces.go ================================================ package memory import "github.com/yaoapp/gou/store" // Manager defines the interface for managing agent memory type Manager interface { // Memory returns the memory instance for given identifiers Memory(userID, teamID, chatID, contextID string) (*Memory, error) // Close closes all stores and releases resources Close() error } // Accessor defines the interface for accessing memory from agent context // This is the primary interface used by agent hooks and tools type Accessor interface { // User returns the user-level memory namespace User() NamespaceAccessor // Team returns the team-level memory namespace Team() NamespaceAccessor // Chat returns the chat-level memory namespace Chat() NamespaceAccessor // Context returns the context-level memory namespace Context() NamespaceAccessor // Space returns a memory namespace by space type Space(space Space) NamespaceAccessor // Stats returns memory statistics Stats() *Stats } // NamespaceAccessor defines the interface for accessing a single memory namespace // Embeds store.Store for all KV and list operations type NamespaceAccessor interface { store.Store // GetID returns the namespace identifier (user_id, team_id, chat_id, or context_id) GetID() string // GetSpace returns the space type of this namespace GetSpace() Space // Stats returns statistics for this namespace Stats() *NamespaceStats } // Factory defines the interface for creating memory instances type Factory interface { // Create creates a new memory instance with the given configuration Create(config *Config) (Manager, error) // CreateWithDefaults creates a new memory instance with default configuration CreateWithDefaults() (Manager, error) } ================================================ FILE: agent/memory/manager.go ================================================ package memory import ( "sync" ) // Global manager instance var globalManager Manager // Init initializes the global memory manager with the given configuration // Called by agent.Load() after loading agent DSL func Init(config *Config) { globalManager = NewManager(config) } // GetMemory returns a memory instance for the given identifiers using the global manager // This is the main entry point for creating Memory instances from agent/context func GetMemory(userID, teamID, chatID, contextID string) (*Memory, error) { if globalManager == nil { // Initialize with defaults if not configured globalManager = NewManagerWithDefaults() } return globalManager.Memory(userID, teamID, chatID, contextID) } // Close closes the global manager and releases resources func Close() error { if globalManager != nil { err := globalManager.Close() globalManager = nil return err } return nil } // DefaultManager is the default memory manager implementation type DefaultManager struct { config *Config memories sync.Map // map[string]*Memory, key is composite of userID:teamID:chatID:contextID } // NewManager creates a new memory manager with the given configuration func NewManager(config *Config) Manager { if config == nil { config = &Config{} } return &DefaultManager{ config: config, } } // NewManagerWithDefaults creates a new memory manager with default configuration func NewManagerWithDefaults() Manager { return NewManager(&Config{ User: DefaultUserStore, Team: DefaultTeamStore, Chat: DefaultChatStore, Context: DefaultContextStore, }) } // memoryKey generates a unique key for the memory instance func memoryKey(userID, teamID, chatID, contextID string) string { return userID + ":" + teamID + ":" + chatID + ":" + contextID } // Memory returns the memory instance for given identifiers func (m *DefaultManager) Memory(userID, teamID, chatID, contextID string) (*Memory, error) { key := memoryKey(userID, teamID, chatID, contextID) // Check if memory already exists if val, ok := m.memories.Load(key); ok { return val.(*Memory), nil } // Create new memory instance mem, err := New(m.config, userID, teamID, chatID, contextID) if err != nil { return nil, err } // Store and return (use LoadOrStore for thread safety) actual, _ := m.memories.LoadOrStore(key, mem) return actual.(*Memory), nil } // Close closes all stores and releases resources func (m *DefaultManager) Close() error { // Clear all cached memory instances m.memories.Range(func(key, value interface{}) bool { m.memories.Delete(key) return true }) return nil } // Ensure DefaultManager implements Manager var _ Manager = (*DefaultManager)(nil) ================================================ FILE: agent/memory/memory.go ================================================ package memory import ( "fmt" "time" "github.com/yaoapp/gou/store" ) // Default TTL values for each memory space const ( DefaultUserTTL = 0 // No expiration for user-level memory DefaultTeamTTL = 0 // No expiration for team-level memory DefaultChatTTL = 24 * time.Hour // 24 hours for chat-level memory DefaultContextTTL = 30 * time.Minute // 30 minutes for context-level memory ) // New creates a new Memory instance with the given configuration and identifiers func New(cfg *Config, userID, teamID, chatID, contextID string) (*Memory, error) { if cfg == nil { cfg = &Config{} } m := &Memory{ UserID: userID, TeamID: teamID, ChatID: chatID, ContextID: contextID, Config: cfg, } // Initialize user namespace if userID != "" { ns, err := newNamespace(SpaceUser, userID, cfg.User, DefaultUserTTL) if err != nil { return nil, fmt.Errorf("failed to create user namespace: %w", err) } m.User = ns } // Initialize team namespace if teamID != "" { ns, err := newNamespace(SpaceTeam, teamID, cfg.Team, DefaultTeamTTL) if err != nil { return nil, fmt.Errorf("failed to create team namespace: %w", err) } m.Team = ns } // Initialize chat namespace if chatID != "" { ns, err := newNamespace(SpaceChat, chatID, cfg.Chat, DefaultChatTTL) if err != nil { return nil, fmt.Errorf("failed to create chat namespace: %w", err) } m.Chat = ns } // Initialize context namespace if contextID != "" { ns, err := newNamespace(SpaceContext, contextID, cfg.Context, DefaultContextTTL) if err != nil { return nil, fmt.Errorf("failed to create context namespace: %w", err) } m.Context = ns } return m, nil } // newNamespace creates a new Namespace with the given parameters func newNamespace(space Space, id, storeID string, defaultTTL time.Duration) (*Namespace, error) { // Use default store ID if not specified if storeID == "" { switch space { case SpaceUser: storeID = DefaultUserStore case SpaceTeam: storeID = DefaultTeamStore case SpaceChat: storeID = DefaultChatStore case SpaceContext: storeID = DefaultContextStore } } // Get store instance s, err := store.Get(storeID) if err != nil { return nil, fmt.Errorf("failed to get store %s: %w", storeID, err) } return &Namespace{ Space: space, ID: id, Store: s, StoreID: storeID, Prefix: fmt.Sprintf("%s:%s:", space, id), Default: defaultTTL, }, nil } // GetUser returns the user-level memory namespace accessor func (m *Memory) GetUser() NamespaceAccessor { if m.User == nil { return nil } return m.User } // GetTeam returns the team-level memory namespace accessor func (m *Memory) GetTeam() NamespaceAccessor { if m.Team == nil { return nil } return m.Team } // GetChat returns the chat-level memory namespace accessor func (m *Memory) GetChat() NamespaceAccessor { if m.Chat == nil { return nil } return m.Chat } // GetContext returns the context-level memory namespace accessor func (m *Memory) GetContext() NamespaceAccessor { if m.Context == nil { return nil } return m.Context } // GetSpace returns a memory namespace by space type func (m *Memory) GetSpace(space Space) NamespaceAccessor { switch space { case SpaceUser: return m.GetUser() case SpaceTeam: return m.GetTeam() case SpaceChat: return m.GetChat() case SpaceContext: return m.GetContext() default: return nil } } // GetStats returns memory statistics for all namespaces func (m *Memory) GetStats() *Stats { stats := &Stats{} if m.User != nil { stats.User = m.User.Stats() } if m.Team != nil { stats.Team = m.Team.Stats() } if m.Chat != nil { stats.Chat = m.Chat.Stats() } if m.Context != nil { stats.Context = m.Context.Stats() } return stats } // Clear clears all memory in all namespaces for this memory instance func (m *Memory) Clear() { if m.User != nil { m.User.Clear() } if m.Team != nil { m.Team.Clear() } if m.Chat != nil { m.Chat.Clear() } if m.Context != nil { m.Context.Clear() } } // Fork creates a new Memory instance with an independent Context namespace // but sharing the User, Team, and Chat namespaces with the parent. // This is used for parallel agent calls (ctx.agent.All/Any/Race) to prevent // context state from being shared between concurrent sub-agent executions. // // The new Context namespace uses the provided newContextID. // If newContextID is empty, returns a shallow copy with shared Context. func (m *Memory) Fork(newContextID string) (*Memory, error) { if m == nil { return nil, nil } // If no new context ID provided, share everything (shallow copy) if newContextID == "" { return &Memory{ UserID: m.UserID, TeamID: m.TeamID, ChatID: m.ChatID, ContextID: m.ContextID, User: m.User, Team: m.Team, Chat: m.Chat, Context: m.Context, Config: m.Config, }, nil } // Create new Memory with independent Context namespace forked := &Memory{ UserID: m.UserID, TeamID: m.TeamID, ChatID: m.ChatID, ContextID: newContextID, User: m.User, // Shared Team: m.Team, // Shared Chat: m.Chat, // Shared Context: nil, // Will be created below Config: m.Config, } // Create new Context namespace with independent ID storeID := "" if m.Config != nil { storeID = m.Config.Context } ns, err := newNamespace(SpaceContext, newContextID, storeID, DefaultContextTTL) if err != nil { return nil, fmt.Errorf("failed to create forked context namespace: %w", err) } forked.Context = ns return forked, nil } ================================================ FILE: agent/memory/memory_test.go ================================================ package memory_test import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/memory" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func TestMemoryNew(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create memory with default stores mem, err := memory.New(nil, "user1", "team1", "chat1", "ctx1") require.NoError(t, err) require.NotNil(t, mem) // Verify all namespaces are initialized assert.NotNil(t, mem.User) assert.NotNil(t, mem.Team) assert.NotNil(t, mem.Chat) assert.NotNil(t, mem.Context) // Verify IDs assert.Equal(t, "user1", mem.UserID) assert.Equal(t, "team1", mem.TeamID) assert.Equal(t, "chat1", mem.ChatID) assert.Equal(t, "ctx1", mem.ContextID) } func TestMemoryPartialIDs(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create memory with only user and chat mem, err := memory.New(nil, "user1", "", "chat1", "") require.NoError(t, err) require.NotNil(t, mem) // Only user and chat namespaces should be initialized assert.NotNil(t, mem.User) assert.Nil(t, mem.Team) assert.NotNil(t, mem.Chat) assert.Nil(t, mem.Context) } func TestNamespaceBasicOperations(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "team1", "chat1", "ctx1") require.NoError(t, err) // Test User namespace t.Run("User namespace", func(t *testing.T) { ns := mem.GetUser() require.NotNil(t, ns) // Set and Get err := ns.Set("name", "John", 0) require.NoError(t, err) val, ok := ns.Get("name") assert.True(t, ok) assert.Equal(t, "John", val) // Has assert.True(t, ns.Has("name")) assert.False(t, ns.Has("nonexistent")) // Del err = ns.Del("name") require.NoError(t, err) assert.False(t, ns.Has("name")) }) // Test Team namespace t.Run("Team namespace", func(t *testing.T) { ns := mem.GetTeam() require.NotNil(t, ns) err := ns.Set("setting", "value", 0) require.NoError(t, err) val, ok := ns.Get("setting") assert.True(t, ok) assert.Equal(t, "value", val) }) // Test Chat namespace t.Run("Chat namespace", func(t *testing.T) { ns := mem.GetChat() require.NotNil(t, ns) err := ns.Set("topic", "AI", 0) require.NoError(t, err) val, ok := ns.Get("topic") assert.True(t, ok) assert.Equal(t, "AI", val) }) // Test Context namespace t.Run("Context namespace", func(t *testing.T) { ns := mem.GetContext() require.NotNil(t, ns) err := ns.Set("temp", "data", 0) require.NoError(t, err) val, ok := ns.Get("temp") assert.True(t, ok) assert.Equal(t, "data", val) }) } func TestNamespaceIsolation(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() t.Run("User isolation", func(t *testing.T) { // Create two memory instances with different user IDs mem1, err := memory.New(nil, "user1", "", "", "") require.NoError(t, err) mem2, err := memory.New(nil, "user2", "", "", "") require.NoError(t, err) // Set value in user1's namespace err = mem1.GetUser().Set("key", "user1_value", 0) require.NoError(t, err) // Set value in user2's namespace err = mem2.GetUser().Set("key", "user2_value", 0) require.NoError(t, err) // Verify isolation - each user sees their own value val1, ok := mem1.GetUser().Get("key") assert.True(t, ok) assert.Equal(t, "user1_value", val1) val2, ok := mem2.GetUser().Get("key") assert.True(t, ok) assert.Equal(t, "user2_value", val2) // Delete from user1 should not affect user2 err = mem1.GetUser().Del("key") require.NoError(t, err) _, ok = mem1.GetUser().Get("key") assert.False(t, ok, "user1's key should be deleted") val2, ok = mem2.GetUser().Get("key") assert.True(t, ok, "user2's key should still exist") assert.Equal(t, "user2_value", val2) // Clear user1 should not affect user2 mem1.GetUser().Clear() val2, ok = mem2.GetUser().Get("key") assert.True(t, ok, "user2's key should still exist after user1 clear") assert.Equal(t, "user2_value", val2) }) t.Run("Team isolation", func(t *testing.T) { memA, err := memory.New(nil, "", "teamA", "", "") require.NoError(t, err) memB, err := memory.New(nil, "", "teamB", "", "") require.NoError(t, err) // Set same key in different teams memA.GetTeam().Set("config", "teamA_config", 0) memB.GetTeam().Set("config", "teamB_config", 0) // Verify isolation valA, ok := memA.GetTeam().Get("config") assert.True(t, ok) assert.Equal(t, "teamA_config", valA) valB, ok := memB.GetTeam().Get("config") assert.True(t, ok) assert.Equal(t, "teamB_config", valB) }) t.Run("Chat isolation", func(t *testing.T) { mem1, err := memory.New(nil, "", "", "chat1", "") require.NoError(t, err) mem2, err := memory.New(nil, "", "", "chat2", "") require.NoError(t, err) // Set same key in different chats mem1.GetChat().Set("topic", "chat1_topic", 0) mem2.GetChat().Set("topic", "chat2_topic", 0) // Verify isolation val1, ok := mem1.GetChat().Get("topic") assert.True(t, ok) assert.Equal(t, "chat1_topic", val1) val2, ok := mem2.GetChat().Get("topic") assert.True(t, ok) assert.Equal(t, "chat2_topic", val2) }) t.Run("Context isolation", func(t *testing.T) { mem1, err := memory.New(nil, "", "", "", "ctx1") require.NoError(t, err) mem2, err := memory.New(nil, "", "", "", "ctx2") require.NoError(t, err) // Set same key in different contexts mem1.GetContext().Set("temp", "ctx1_temp", 0) mem2.GetContext().Set("temp", "ctx2_temp", 0) // Verify isolation val1, ok := mem1.GetContext().Get("temp") assert.True(t, ok) assert.Equal(t, "ctx1_temp", val1) val2, ok := mem2.GetContext().Get("temp") assert.True(t, ok) assert.Equal(t, "ctx2_temp", val2) }) t.Run("Keys and Len isolation", func(t *testing.T) { mem1, err := memory.New(nil, "userA", "", "", "") require.NoError(t, err) mem2, err := memory.New(nil, "userB", "", "", "") require.NoError(t, err) // Clear first mem1.GetUser().Clear() mem2.GetUser().Clear() // Set keys in userA mem1.GetUser().Set("a", 1, 0) mem1.GetUser().Set("b", 2, 0) mem1.GetUser().Set("c", 3, 0) // Set keys in userB mem2.GetUser().Set("x", 10, 0) mem2.GetUser().Set("y", 20, 0) // Verify Keys isolation keys1 := mem1.GetUser().Keys() assert.Equal(t, 3, len(keys1), "userA should have 3 keys") keys2 := mem2.GetUser().Keys() assert.Equal(t, 2, len(keys2), "userB should have 2 keys") // Verify Len isolation assert.Equal(t, 3, mem1.GetUser().Len(), "userA Len should be 3") assert.Equal(t, 2, mem2.GetUser().Len(), "userB Len should be 2") // Keys should not contain prefix for _, k := range keys1 { assert.NotContains(t, k, "user:", "Key should not contain prefix") } }) t.Run("Incr/Decr isolation", func(t *testing.T) { mem1, err := memory.New(nil, "userX", "", "", "") require.NoError(t, err) mem2, err := memory.New(nil, "userY", "", "", "") require.NoError(t, err) // Incr counter in userX val1, err := mem1.GetUser().Incr("counter", 10) require.NoError(t, err) assert.Equal(t, int64(10), val1) // Incr counter in userY val2, err := mem2.GetUser().Incr("counter", 5) require.NoError(t, err) assert.Equal(t, int64(5), val2) // Incr again - should be independent val1, err = mem1.GetUser().Incr("counter", 1) require.NoError(t, err) assert.Equal(t, int64(11), val1) val2, err = mem2.GetUser().Incr("counter", 1) require.NoError(t, err) assert.Equal(t, int64(6), val2) }) t.Run("List operations isolation", func(t *testing.T) { mem1, err := memory.New(nil, "listUser1", "", "", "") require.NoError(t, err) mem2, err := memory.New(nil, "listUser2", "", "", "") require.NoError(t, err) // Push to user1's list err = mem1.GetUser().Push("items", "a", "b", "c") require.NoError(t, err) // Push to user2's list err = mem2.GetUser().Push("items", "x", "y") require.NoError(t, err) // Verify isolation assert.Equal(t, 3, mem1.GetUser().ArrayLen("items")) assert.Equal(t, 2, mem2.GetUser().ArrayLen("items")) all1, _ := mem1.GetUser().ArrayAll("items") all2, _ := mem2.GetUser().ArrayAll("items") assert.Equal(t, 3, len(all1)) assert.Equal(t, 2, len(all2)) // Pop from user1 should not affect user2 mem1.GetUser().Pop("items", 1) assert.Equal(t, 2, mem1.GetUser().ArrayLen("items")) assert.Equal(t, 2, mem2.GetUser().ArrayLen("items")) }) t.Run("Del pattern isolation", func(t *testing.T) { mem1, err := memory.New(nil, "patternUser1", "", "", "") require.NoError(t, err) mem2, err := memory.New(nil, "patternUser2", "", "", "") require.NoError(t, err) // Set keys with pattern in both users mem1.GetUser().Set("file:1", "data1", 0) mem1.GetUser().Set("file:2", "data2", 0) mem1.GetUser().Set("other", "other1", 0) mem2.GetUser().Set("file:1", "data1", 0) mem2.GetUser().Set("file:2", "data2", 0) mem2.GetUser().Set("other", "other2", 0) // Delete pattern from user1 err = mem1.GetUser().Del("file:*") require.NoError(t, err) // user1's file:* keys should be deleted assert.False(t, mem1.GetUser().Has("file:1")) assert.False(t, mem1.GetUser().Has("file:2")) assert.True(t, mem1.GetUser().Has("other")) // user2's keys should be unaffected assert.True(t, mem2.GetUser().Has("file:1")) assert.True(t, mem2.GetUser().Has("file:2")) assert.True(t, mem2.GetUser().Has("other")) }) } func TestNamespaceIncrDecr(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "", "", "") require.NoError(t, err) ns := mem.GetUser() // Incr on non-existent key val, err := ns.Incr("counter", 1) require.NoError(t, err) assert.Equal(t, int64(1), val) // Incr again val, err = ns.Incr("counter", 5) require.NoError(t, err) assert.Equal(t, int64(6), val) // Decr val, err = ns.Decr("counter", 2) require.NoError(t, err) assert.Equal(t, int64(4), val) } func TestNamespaceListOperations(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "", "", "") require.NoError(t, err) ns := mem.GetUser() // Push values err = ns.Push("list", "a", "b", "c") require.NoError(t, err) // ArrayLen assert.Equal(t, 3, ns.ArrayLen("list")) // ArrayAll all, err := ns.ArrayAll("list") require.NoError(t, err) assert.Len(t, all, 3) // Pop from end val, err := ns.Pop("list", 1) require.NoError(t, err) assert.Equal(t, "c", val) // ArrayLen after pop assert.Equal(t, 2, ns.ArrayLen("list")) } func TestNamespaceSetOperations(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "", "", "") require.NoError(t, err) ns := mem.GetUser() // AddToSet err = ns.AddToSet("tags", "go", "rust", "go") // "go" should only appear once require.NoError(t, err) all, err := ns.ArrayAll("tags") require.NoError(t, err) assert.Len(t, all, 2) // Only "go" and "rust" } func TestNamespaceTTL(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "", "", "", "ctx1") require.NoError(t, err) ns := mem.GetContext() // Set with short TTL err = ns.Set("temp", "value", 100*time.Millisecond) require.NoError(t, err) // Should exist immediately val, ok := ns.Get("temp") assert.True(t, ok) assert.Equal(t, "value", val) // Wait for expiration time.Sleep(150 * time.Millisecond) // Should be expired _, ok = ns.Get("temp") assert.False(t, ok) } func TestMemoryClear(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "team1", "chat1", "ctx1") require.NoError(t, err) // Set values in all namespaces mem.GetUser().Set("key", "user_value", 0) mem.GetTeam().Set("key", "team_value", 0) mem.GetChat().Set("key", "chat_value", 0) mem.GetContext().Set("key", "ctx_value", 0) // Clear all mem.Clear() // All should be empty _, ok := mem.GetUser().Get("key") assert.False(t, ok) _, ok = mem.GetTeam().Get("key") assert.False(t, ok) _, ok = mem.GetChat().Get("key") assert.False(t, ok) _, ok = mem.GetContext().Get("key") assert.False(t, ok) } func TestMemoryStats(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "team1", "chat1", "ctx1") require.NoError(t, err) // Set some values mem.GetUser().Set("k1", "v1", 0) mem.GetUser().Set("k2", "v2", 0) mem.GetTeam().Set("k1", "v1", 0) stats := mem.GetStats() require.NotNil(t, stats) assert.Equal(t, 2, stats.User.KeyCount) assert.Equal(t, 1, stats.Team.KeyCount) assert.Equal(t, 0, stats.Chat.KeyCount) assert.Equal(t, 0, stats.Context.KeyCount) } func TestManager(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mgr := memory.NewManagerWithDefaults() defer mgr.Close() // Get memory instance mem1, err := mgr.Memory("user1", "team1", "chat1", "ctx1") require.NoError(t, err) require.NotNil(t, mem1) // Set a value err = mem1.GetUser().Set("key", "value", 0) require.NoError(t, err) // Get same memory instance again mem2, err := mgr.Memory("user1", "team1", "chat1", "ctx1") require.NoError(t, err) // Should be the same instance (cached) val, ok := mem2.GetUser().Get("key") assert.True(t, ok) assert.Equal(t, "value", val) } func TestGetSpace(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "team1", "chat1", "ctx1") require.NoError(t, err) // Test GetSpace assert.NotNil(t, mem.GetSpace(memory.SpaceUser)) assert.NotNil(t, mem.GetSpace(memory.SpaceTeam)) assert.NotNil(t, mem.GetSpace(memory.SpaceChat)) assert.NotNil(t, mem.GetSpace(memory.SpaceContext)) // Invalid space assert.Nil(t, mem.GetSpace(memory.Space("invalid"))) } func TestNamespaceGetMultiSetMulti(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "", "", "") require.NoError(t, err) ns := mem.GetUser() // SetMulti ns.SetMulti(map[string]interface{}{ "a": 1, "b": 2, "c": 3, }, 0) // GetMulti result := ns.GetMulti([]string{"a", "b", "c"}) assert.Equal(t, 1, result["a"]) assert.Equal(t, 2, result["b"]) assert.Equal(t, 3, result["c"]) // DelMulti ns.DelMulti([]string{"a", "b"}) assert.False(t, ns.Has("a")) assert.False(t, ns.Has("b")) assert.True(t, ns.Has("c")) } func TestNamespaceGetDel(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() mem, err := memory.New(nil, "user1", "", "", "") require.NoError(t, err) ns := mem.GetUser() // Set a value err = ns.Set("key", "value", 0) require.NoError(t, err) // GetDel val, ok := ns.GetDel("key") assert.True(t, ok) assert.Equal(t, "value", val) // Should be deleted _, ok = ns.Get("key") assert.False(t, ok) } ================================================ FILE: agent/memory/namespace.go ================================================ package memory import ( "time" ) // Ensure Namespace implements NamespaceAccessor var _ NamespaceAccessor = (*Namespace)(nil) // GetID returns the namespace identifier func (ns *Namespace) GetID() string { return ns.ID } // GetSpace returns the space type of this namespace func (ns *Namespace) GetSpace() Space { return ns.Space } // prefixKey adds the namespace prefix to a key func (ns *Namespace) prefixKey(key string) string { return ns.Prefix + key } // Get retrieves a value by key func (ns *Namespace) Get(key string) (interface{}, bool) { return ns.Store.Get(ns.prefixKey(key)) } // Set stores a value with the default TTL for this namespace func (ns *Namespace) Set(key string, value interface{}, ttl time.Duration) error { if ttl == 0 { ttl = ns.Default } return ns.Store.Set(ns.prefixKey(key), value, ttl) } // Has checks if a key exists func (ns *Namespace) Has(key string) bool { return ns.Store.Has(ns.prefixKey(key)) } // Del deletes a key (supports wildcards) func (ns *Namespace) Del(key string) error { return ns.Store.Del(ns.prefixKey(key)) } // Keys returns all keys in this namespace // Uses pattern-based query for efficiency func (ns *Namespace) Keys(pattern ...string) []string { // Build pattern with namespace prefix var storePattern string if len(pattern) > 0 && pattern[0] != "" { storePattern = ns.Prefix + pattern[0] } else { storePattern = ns.Prefix + "*" } allKeys := ns.Store.Keys(storePattern) prefixLen := len(ns.Prefix) // Remove prefix from keys result := make([]string, 0, len(allKeys)) for _, key := range allKeys { if len(key) >= prefixLen { result = append(result, key[prefixLen:]) } } return result } // Len returns the number of keys in this namespace // Uses pattern-based query for efficiency func (ns *Namespace) Len(pattern ...string) int { // Build pattern with namespace prefix var storePattern string if len(pattern) > 0 && pattern[0] != "" { storePattern = ns.Prefix + pattern[0] } else { storePattern = ns.Prefix + "*" } return ns.Store.Len(storePattern) } // Clear deletes all keys in this namespace func (ns *Namespace) Clear() { ns.Store.Del(ns.Prefix + "*") } // GetSet retrieves a value and sets a new value if not exists func (ns *Namespace) GetSet(key string, ttl time.Duration, getValue func(key string) (interface{}, error)) (interface{}, error) { if ttl == 0 { ttl = ns.Default } return ns.Store.GetSet(ns.prefixKey(key), ttl, getValue) } // GetDel retrieves a value and deletes it atomically func (ns *Namespace) GetDel(key string) (interface{}, bool) { return ns.Store.GetDel(ns.prefixKey(key)) } // GetMulti retrieves multiple values by keys func (ns *Namespace) GetMulti(keys []string) map[string]interface{} { prefixedKeys := make([]string, len(keys)) for i, key := range keys { prefixedKeys[i] = ns.prefixKey(key) } result := ns.Store.GetMulti(prefixedKeys) // Remove prefix from result keys unprefixed := make(map[string]interface{}) prefixLen := len(ns.Prefix) for k, v := range result { if len(k) > prefixLen { unprefixed[k[prefixLen:]] = v } else { unprefixed[k] = v } } return unprefixed } // SetMulti stores multiple values func (ns *Namespace) SetMulti(values map[string]interface{}, ttl time.Duration) { if ttl == 0 { ttl = ns.Default } prefixed := make(map[string]interface{}) for k, v := range values { prefixed[ns.prefixKey(k)] = v } ns.Store.SetMulti(prefixed, ttl) } // DelMulti deletes multiple keys func (ns *Namespace) DelMulti(keys []string) { prefixedKeys := make([]string, len(keys)) for i, key := range keys { prefixedKeys[i] = ns.prefixKey(key) } ns.Store.DelMulti(prefixedKeys) } // GetSetMulti retrieves multiple values and sets new values if not exists func (ns *Namespace) GetSetMulti(keys []string, ttl time.Duration, getValue func(key string) (interface{}, error)) map[string]interface{} { if ttl == 0 { ttl = ns.Default } prefixedKeys := make([]string, len(keys)) for i, key := range keys { prefixedKeys[i] = ns.prefixKey(key) } result := ns.Store.GetSetMulti(prefixedKeys, ttl, getValue) // Remove prefix from result keys unprefixed := make(map[string]interface{}) prefixLen := len(ns.Prefix) for k, v := range result { if len(k) > prefixLen { unprefixed[k[prefixLen:]] = v } else { unprefixed[k] = v } } return unprefixed } // Incr increments a numeric value func (ns *Namespace) Incr(key string, delta int64) (int64, error) { return ns.Store.Incr(ns.prefixKey(key), delta) } // Decr decrements a numeric value func (ns *Namespace) Decr(key string, delta int64) (int64, error) { return ns.Store.Decr(ns.prefixKey(key), delta) } // Push appends values to a list func (ns *Namespace) Push(key string, values ...interface{}) error { return ns.Store.Push(ns.prefixKey(key), values...) } // Pop removes and returns an element from a list func (ns *Namespace) Pop(key string, position int) (interface{}, error) { return ns.Store.Pop(ns.prefixKey(key), position) } // Pull removes the first occurrence of a value from a list func (ns *Namespace) Pull(key string, value interface{}) error { return ns.Store.Pull(ns.prefixKey(key), value) } // PullAll removes all occurrences of values from a list func (ns *Namespace) PullAll(key string, values []interface{}) error { return ns.Store.PullAll(ns.prefixKey(key), values) } // AddToSet adds values to a set (no duplicates) func (ns *Namespace) AddToSet(key string, values ...interface{}) error { return ns.Store.AddToSet(ns.prefixKey(key), values...) } // ArrayLen returns the length of a list func (ns *Namespace) ArrayLen(key string) int { return ns.Store.ArrayLen(ns.prefixKey(key)) } // ArrayGet retrieves an element from a list by index func (ns *Namespace) ArrayGet(key string, index int) (interface{}, error) { return ns.Store.ArrayGet(ns.prefixKey(key), index) } // ArraySet sets an element in a list by index func (ns *Namespace) ArraySet(key string, index int, value interface{}) error { return ns.Store.ArraySet(ns.prefixKey(key), index, value) } // ArraySlice returns a slice of a list func (ns *Namespace) ArraySlice(key string, skip, limit int) ([]interface{}, error) { return ns.Store.ArraySlice(ns.prefixKey(key), skip, limit) } // ArrayPage returns a page of a list func (ns *Namespace) ArrayPage(key string, page, pageSize int) ([]interface{}, error) { return ns.Store.ArrayPage(ns.prefixKey(key), page, pageSize) } // ArrayAll returns all elements of a list func (ns *Namespace) ArrayAll(key string) ([]interface{}, error) { return ns.Store.ArrayAll(ns.prefixKey(key)) } // Stats returns statistics for this namespace func (ns *Namespace) Stats() *NamespaceStats { return &NamespaceStats{ Space: ns.Space, ID: ns.ID, KeyCount: ns.Len(), StoreID: ns.StoreID, } } // Snapshot returns all key-value pairs in this namespace // Used for recovery/resume functionality func (ns *Namespace) Snapshot() map[string]interface{} { keys := ns.Keys() snapshot := make(map[string]interface{}, len(keys)) for _, key := range keys { if value, ok := ns.Get(key); ok { snapshot[key] = value } } return snapshot } ================================================ FILE: agent/memory/types.go ================================================ package memory import ( "time" "github.com/yaoapp/gou/store" ) // Space defines the memory space type type Space string const ( // SpaceUser user-level memory, persists across all chats for a user // Use case: user preferences, long-term knowledge, personal settings SpaceUser Space = "user" // SpaceTeam team-level memory, shared across all users in a team // Use case: team knowledge, shared settings, collaborative data SpaceTeam Space = "team" // SpaceChat chat-level memory, persists within a single chat session // Use case: conversation context, chat-specific settings, accumulated knowledge SpaceChat Space = "chat" // SpaceContext context-level memory, temporary within a single request context // Use case: intermediate results, temporary variables, request-scoped cache SpaceContext Space = "context" ) // Config represents the memory configuration // Each field is a Store ID referencing gou/store, empty string uses built-in default // All spaces use xun-based storage by default for persistence and reliability type Config struct { User string `json:"user,omitempty" yaml:"user,omitempty"` // Store ID for user-level memory (default: xun-based) Team string `json:"team,omitempty" yaml:"team,omitempty"` // Store ID for team-level memory (default: xun-based) Chat string `json:"chat,omitempty" yaml:"chat,omitempty"` // Store ID for chat-level memory (default: xun-based) Context string `json:"context,omitempty" yaml:"context,omitempty"` // Store ID for context-level memory (default: xun-based, shorter TTL) } // DefaultStoreID constants for built-in stores const ( DefaultUserStore = "__yao.agent.memory.user" DefaultTeamStore = "__yao.agent.memory.team" DefaultChatStore = "__yao.agent.memory.chat" DefaultContextStore = "__yao.agent.memory.context" ) // Entry represents a memory entry type Entry struct { Key string `json:"key"` Value interface{} `json:"value"` Space Space `json:"space"` Metadata map[string]interface{} `json:"metadata,omitempty"` TTL time.Duration `json:"ttl,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` ExpiresAt *time.Time `json:"expires_at,omitempty"` } // Namespace represents a memory namespace for a specific space type Namespace struct { Space Space `json:"space"` ID string `json:"id"` // UserID, TeamID, ChatID, or ContextID depending on space Store store.Store `json:"-"` // Underlying store StoreID string `json:"-"` // Store ID Prefix string `json:"-"` // Computed key prefix (e.g., "user:123:", "team:456:") Default time.Duration `json:"-"` // Default TTL for this namespace } // Memory represents the complete memory system for an agent // It manages four separate namespaces: User, Team, Chat, and Context type Memory struct { UserID string `json:"user_id"` TeamID string `json:"team_id"` ChatID string `json:"chat_id"` ContextID string `json:"context_id"` User *Namespace `json:"-"` // User-level memory namespace Team *Namespace `json:"-"` // Team-level memory namespace Chat *Namespace `json:"-"` // Chat-level memory namespace Context *Namespace `json:"-"` // Context-level memory namespace Config *Config `json:"-"` // Memory configuration } // Stats represents memory statistics type Stats struct { User *NamespaceStats `json:"user,omitempty"` Team *NamespaceStats `json:"team,omitempty"` Chat *NamespaceStats `json:"chat,omitempty"` Context *NamespaceStats `json:"context,omitempty"` } // NamespaceStats represents statistics for a single memory namespace type NamespaceStats struct { Space Space `json:"space"` ID string `json:"id"` KeyCount int `json:"key_count"` StoreID string `json:"store_id"` } ================================================ FILE: agent/output/BUILTIN_TYPES.md ================================================ # Built-in Message Types Built-in message types are standardized types that all adapters must support. These types have predefined Props structures to ensure consistency across different output formats. ## Type Constants Defined in `types.go`: ```go const ( TypeUserInput = "user_input" // User input message (frontend display only) TypeText = "text" // Plain text or Markdown content TypeThinking = "thinking" // Reasoning/thinking process TypeLoading = "loading" // Loading/processing indicator TypeToolCall = "tool_call" // LLM tool/function call TypeRetrieval = "retrieval" // KB/Web search results (for feedback & analytics) TypeError = "error" // Error message TypeImage = "image" // Image content TypeAudio = "audio" // Audio content TypeVideo = "video" // Video content TypeAction = "action" // System action (silent in standard clients) TypeEvent = "event" // Lifecycle event (silent in standard clients) ) ``` ## Standard Props Structures ### 1. User Input (`user_input`) **Purpose:** User input message (for frontend display only) **Props Structure:** ```go type UserInputProps struct { Content interface{} `json:"content"` // User input (text string or multimodal ContentPart[]) Role string `json:"role,omitempty"` // User role: "user", "system", "developer" (default: "user") Name string `json:"name,omitempty"` // Optional participant name } ``` **Example:** ```json { "type": "user_input", "props": { "content": "Hello, can you help me?", "role": "user" } } ``` **Multimodal Example:** ```json { "type": "user_input", "props": { "content": [ { "type": "text", "text": "What's in this image?" }, { "type": "image_url", "image_url": { "url": "https://example.com/photo.jpg" } } ], "role": "user" } } ``` **Helper:** ```go // Simple text input msg := output.NewUserInputMessage("Hello, can you help me?", "user", "") // With name msg := output.NewUserInputMessage("I need assistance", "user", "John") // Multimodal content content := []map[string]interface{}{ { "type": "text", "text": "What's in this image?", }, { "type": "image_url", "image_url": map[string]string{ "url": "https://example.com/photo.jpg", }, }, } msg := output.NewUserInputMessage(content, "user", "") ``` **Important Notes:** - **Frontend display only**: This type is used by the frontend to display user input in the chat UI - **Not sent to backend**: User input is sent to backend as `UserMessage` (OpenAI format), not as `Message` - **Preserves role**: Unlike `text` type, preserves the original user role (`user`, `system`, `developer`) - **Supports multimodal**: Can contain text, images, audio, or files **Data Flow:** ``` User types → UserMessage (sent to API) → Backend processes → Message types (AI response) ↓ UserInputMessage (frontend display) ``` --- ### 2. Text (`text`) **Purpose:** Plain text or Markdown content (AI responses) **Props Structure:** ```go type TextProps struct { Content string `json:"content"` // Text content (supports Markdown) } ``` **Example:** ```json { "type": "text", "props": { "content": "Hello **world**!" } } ``` **Helper:** ```go msg := output.NewTextMessage("Hello **world**!") ``` --- ### 3. Thinking (`thinking`) **Purpose:** Reasoning or thinking process (used by o1 models, DeepSeek R1, etc.) **Props Structure:** ```go type ThinkingProps struct { Content string `json:"content"` // Reasoning/thinking content } ``` **Example:** ```json { "type": "thinking", "props": { "content": "Let me analyze this step by step..." } } ``` **Helper:** ```go msg := output.NewThinkingMessage("Let me analyze this step by step...") ``` --- ### 4. Loading (`loading`) **Purpose:** Loading or processing indicator (preprocessing, knowledge base search, data fetching, etc.) **Props Structure:** ```go type LoadingProps struct { Message string `json:"message"` // Loading message } ``` **Example:** ```json { "type": "loading", "props": { "message": "Searching knowledge base..." } } ``` **Helper:** ```go msg := output.NewLoadingMessage("Searching knowledge base...") ``` **Use Cases:** - Knowledge base search: `"Searching knowledge base..."` - Data preprocessing: `"Processing uploaded file..."` - External API calls: `"Fetching data from API..."` - Database queries: `"Querying database..."` **Example in Hook:** ```go // In Create hook, show preprocessing steps func Create(ctx *context.Context, messages []context.Message) (*context.HookCreateResponse, error) { // Send loading message for knowledge base search output.Send(ctx, output.NewLoadingMessage("Searching knowledge base...")) // Do the actual search results := searchKnowledgeBase(messages) // Send another loading message for processing output.Send(ctx, output.NewLoadingMessage("Processing results...")) // Process and return return &context.HookCreateResponse{ Messages: buildMessages(results), }, nil } ``` **Result in OpenAI Client:** - Shows as thinking/reasoning process - User sees "Searching knowledge base..." and "Processing results..." - Provides transparency into what's happening --- ### 5. Tool Call (`tool_call`) **Purpose:** LLM tool or function call **Props Structure:** ```go type ToolCallProps struct { ID string `json:"id"` // Tool call ID Name string `json:"name"` // Function/tool name Arguments string `json:"arguments,omitempty"` // JSON string of arguments } ``` **Example:** ```json { "type": "tool_call", "props": { "id": "call_abc123", "name": "get_weather", "arguments": "{\"location\": \"San Francisco\"}" } } ``` **Helper:** ```go msg := output.NewToolCallMessage( "call_abc123", "get_weather", "{\"location\": \"San Francisco\"}", ) ``` --- ### 6. Retrieval (`retrieval`) **Purpose:** Knowledge base and web search results (for feedback, analytics, and source attribution) **Props Structure:** ```go type RetrievalProps struct { Query string `json:"query"` // Search query Sources []RetrievalSource `json:"sources"` // Retrieved sources TotalResults int `json:"total_results,omitempty"` // Total matching results QueryTimeMs int64 `json:"query_time_ms,omitempty"` // Query execution time Provider string `json:"provider,omitempty"` // Search provider (e.g., "tavily", "bing") } type RetrievalSource struct { ID string `json:"id"` // Unique source ID within this retrieval Type string `json:"type"` // Source type: "kb", "web", "file", "api", "mcp" Title string `json:"title,omitempty"` // Source title Content string `json:"content"` // Retrieved content/snippet Score float64 `json:"score,omitempty"` // Relevance score URL string `json:"url,omitempty"` // URL for web sources CollectionID string `json:"collection_id,omitempty"` // KB collection ID DocumentID string `json:"document_id,omitempty"` // KB document ID ChunkID string `json:"chunk_id,omitempty"` // KB chunk ID Metadata map[string]interface{} `json:"metadata,omitempty"` // Additional metadata } ``` **Example (Knowledge Base):** ```json { "type": "retrieval", "props": { "query": "How to configure Yao models?", "sources": [ { "id": "src_001", "type": "kb", "collection_id": "col_docs", "document_id": "doc_123", "chunk_id": "chunk_456", "title": "Model Configuration Guide", "content": "To configure a model in Yao, create a .mod.yao file...", "score": 0.92, "metadata": { "file_path": "/docs/model.md", "page": 3 } } ], "total_results": 15, "query_time_ms": 120 } } ``` **Example (Web Search):** ```json { "type": "retrieval", "props": { "query": "latest AI news 2024", "sources": [ { "id": "src_001", "type": "web", "url": "https://example.com/ai-news", "title": "AI Breakthroughs in 2024", "content": "Summary of the article...", "score": 0.95, "metadata": { "domain": "example.com", "published_at": "2024-01-10" } } ], "provider": "tavily", "total_results": 10, "query_time_ms": 850 } } ``` **Helper:** ```go msg := output.NewRetrievalMessage( "How to configure Yao models?", []output.RetrievalSource{ { ID: "src_001", Type: "kb", CollectionID: "col_docs", DocumentID: "doc_123", ChunkID: "chunk_456", Title: "Model Configuration Guide", Content: "To configure a model in Yao...", Score: 0.92, }, }, ) ``` **Source Types:** | Type | Description | Key Fields | | ------ | ----------------------- | ------------------------------------------ | | `kb` | Knowledge base document | `collection_id`, `document_id`, `chunk_id` | | `web` | Web search result | `url` | | `file` | Uploaded file | `file_id`, `file_path` | | `api` | External API result | `api_name`, `endpoint` | | `mcp` | MCP tool result | `server`, `tool` | **Use Cases:** - **Source Attribution**: Display citations in the chat UI - **User Feedback**: Allow users to rate individual sources (👍/👎) - **Analytics**: Track which documents/sources are most useful - **RAG Optimization**: Improve retrieval based on feedback data **Adapter Behavior:** - **CUI**: Renders as expandable source cards with feedback buttons - **OpenAI**: Converts to markdown citations or footnotes --- ### 7. Error (`error`) **Purpose:** Error message **Props Structure:** ```go type ErrorProps struct { Message string `json:"message"` // Error message Code string `json:"code,omitempty"` // Error code Details string `json:"details,omitempty"` // Additional error details } ``` **Example:** ```json { "type": "error", "props": { "message": "Connection timeout", "code": "TIMEOUT", "details": "Failed to connect to database after 30s" } } ``` **Helper:** ```go msg := output.NewErrorMessage("Connection timeout", "TIMEOUT") ``` --- ### 8. Action (`action`) **Purpose:** System-level action/command (not displayed to user, only processed by client) **Props Structure:** ```go type ActionProps struct { Name string `json:"name"` // Action name Payload map[string]interface{} `json:"payload,omitempty"` // Action parameters } ``` **Example:** ```json { "type": "action", "props": { "name": "open_panel", "payload": { "panel_id": "user_profile", "user_id": "123" } } } ``` **Helper:** ```go msg := output.NewActionMessage("open_panel", map[string]interface{}{ "panel_id": "user_profile", "user_id": "123", }) ``` **Use Cases:** - Open sidebar/panel: `"open_panel"` - Navigate to page: `"navigate"` - Trigger UI update: `"refresh_view"` - Close modal: `"close_modal"` - Scroll to element: `"scroll_to"` **Important Notes:** - **Silent in OpenAI clients**: Action messages are NOT sent to standard chat clients - **CUI clients only**: Only CUI clients process action messages - **System-level**: Used for controlling the UI/application, not chat content **Example in Hook:** ```go // Send action to open a panel with user details output.Send(ctx, output.NewActionMessage("open_panel", map[string]interface{}{ "panel_id": "user_details", "user_id": user.ID, })) // Send text message (visible to user) output.Send(ctx, output.NewTextMessage("I've opened the user details panel for you.")) ``` **Result:** - **CUI client**: Panel opens, text message displays - **OpenAI client**: Only text message displays (action is silent) --- ### 9. Event (`event`) **Purpose:** Lifecycle event messages (stream_start, stream_end, connecting, etc.) **Props Structure:** ```go type EventProps struct { Event string `json:"event"` // Event type Message string `json:"message,omitempty"` // Human-readable message Data map[string]interface{} `json:"data,omitempty"` // Additional event data } ``` **Example:** ```json { "type": "event", "props": { "event": "stream_start", "message": "Starting stream...", "data": { "model": "gpt-4", "session_id": "sess_123" } } } ``` **Helper:** ```go msg := output.NewEventMessage("stream_start", "Starting stream...", map[string]interface{}{ "model": "gpt-4", "session_id": "sess_123", }) ``` **Use Cases:** - Stream lifecycle: `"stream_start"`, `"stream_end"` - Connection status: `"connecting"`, `"connected"`, `"disconnected"` - Processing stages: `"preprocessing"`, `"postprocessing"` - Agent state: `"thinking"`, `"executing"`, `"completed"` **Important Notes:** - **Converted in OpenAI clients**: Event messages are typically NOT sent to OpenAI clients, **except** `stream_start`: - `stream_start`: Converted to a clickable trace link in either `reasoning_content` (thinking models) or `content` (regular models) - Other events: Silent (not sent to OpenAI clients) - **CUI clients**: All event messages are processed and may show status indicators - **Lifecycle tracking**: Used for tracking agent/stream lifecycle - **Non-blocking**: Events don't interrupt the main message flow **Example in Hook:** ```go // Send stream start event (automatically generated by assistant) // This is typically handled by the framework, not manually sent startData := message.EventStreamStartData{ RequestID: ctx.RequestID, Timestamp: time.Now().UnixMilli(), TraceID: ctx.Stack.TraceID, ChatID: ctx.ChatID, } output.Send(ctx, output.NewEventMessage("stream_start", "Stream started", startData)) // Do processing processData() // Send stream end event endData := message.EventStreamEndData{ RequestID: ctx.RequestID, Timestamp: time.Now().UnixMilli(), DurationMs: 1500, Status: "completed", } output.Send(ctx, output.NewEventMessage("stream_end", "Stream completed", endData)) ``` **Result:** - **CUI client**: Tracks lifecycle, may show status indicators - **OpenAI client (stream_start only)**: - Reasoning models: Shows as 🔍 with trace link in `reasoning_content` field - Regular models: Shows as 🚀 with trace link in `content` field - Example: "🔍 智能体正在处理 - [查看处理详情](baseURL/trace/traceID/view)" - **OpenAI client (other events)**: Silent (not sent) --- ### 10. Image (`image`) **Purpose:** Image content **Props Structure:** ```go type ImageProps struct { URL string // Required: Image URL or base64 data Alt string // Alternative text Width int // Image width in pixels Height int // Image height in pixels Detail string // OpenAI detail level: "auto", "low", "high" } ``` **Example:** ```json { "type": "image", "props": { "url": "https://example.com/avatar.jpg", "alt": "User avatar", "width": 200, "height": 200 } } ``` **Helper:** ```go msg := output.NewImageMessage("https://example.com/avatar.jpg", "User avatar") ``` **Adapter Behavior:** - **CUI**: Renders image directly with `` tag - **OpenAI**: Converts to Markdown `![alt](url)` - **displays inline** in Markdown-supporting clients --- ### 11. Audio (`audio`) **Purpose:** Audio content **Props Structure:** ```go type AudioProps struct { URL string // Required: Audio URL or base64 data Format string // Audio format: "mp3", "wav", "ogg" Duration float64 // Duration in seconds Transcript string // Audio transcript text Autoplay bool // Whether to autoplay Controls bool // Whether to show controls } ``` **Example:** ```json { "type": "audio", "props": { "url": "https://example.com/audio.mp3", "format": "mp3", "duration": 120.5, "transcript": "This is the audio content...", "controls": true } } ``` **Helper:** ```go msg := output.NewAudioMessage("https://example.com/audio.mp3", "mp3") ``` **Adapter Behavior:** - **CUI**: Renders audio player with controls - **OpenAI**: Converts to link `🔊 [Play Audio](url)` - can't display inline --- ### 12. Video (`video`) **Purpose:** Video content **Props Structure:** ```go type VideoProps struct { URL string // Required: Video URL Format string // Video format: "mp4", "webm" Duration float64 // Duration in seconds Thumbnail string // Thumbnail/poster image URL Width int // Video width in pixels Height int // Video height in pixels Autoplay bool // Whether to autoplay Controls bool // Whether to show controls Loop bool // Whether to loop } ``` **Example:** ```json { "type": "video", "props": { "url": "https://example.com/video.mp4", "format": "mp4", "thumbnail": "https://example.com/poster.jpg", "width": 640, "height": 360, "controls": true } } ``` **Helper:** ```go msg := output.NewVideoMessage("https://example.com/video.mp4") ``` **Adapter Behavior:** - **CUI**: Renders video player with controls - **OpenAI**: Converts to link `🎬 [Watch Video](url)` - can't display inline --- ## Adapter Requirements All adapters (CUI, OpenAI, etc.) **must** support these built-in types with their standard Props structures. ### CUI Adapter CUI adapter passes built-in types through without transformation: ```json { "type": "text", "props": { "content": "Hello world" } } ``` ### OpenAI Adapter OpenAI adapter converts built-in types to OpenAI format: | Type | OpenAI Format | Field | Note | | ------------ | ------------------------- | ----------------------------- | -------------------------------------------------------------------- | | `user_input` | (not sent) | - | Frontend display only - not sent to OpenAI clients | | `text` | `delta.content` | `props.content` | | | `thinking` | `delta.reasoning_content` | `props.content` | Reasoning content (o1 models) | | `loading` | `delta.reasoning_content` | `props.message` | Shows as thinking in OpenAI clients | | `tool_call` | `delta.tool_calls` | `props.{id, name, arguments}` | | | `retrieval` | `delta.content` | `props.sources` | Markdown citations/footnotes with source links | | `error` | `error` | `props.{message, code}` | | | `image` | `delta.content` | `props.{url, alt}` | Markdown: `![alt](url)` - displays inline | | `audio` | `delta.content` | `props.url` | Markdown link (can't display inline) | | `video` | `delta.content` | `props.url` | Markdown link (can't display inline) | | `action` | (not sent) | - | Silent - system actions only | | `event` | (conditional) | `props.{event, data}` | Most events silent; `stream_start` converted to trace link with i18n | --- ## Custom Types Any type **not** in the built-in list is considered a custom type. Adapters may handle custom types differently: - **CUI:** Pass through as-is - **OpenAI:** Convert to Markdown link Example custom type: ```json { "type": "image", "props": { "url": "https://example.com/image.jpg", "alt": "Description" } } ``` --- ## Checking Built-in Types ```go // Check if a type is built-in if output.IsBuiltinType(msg.Type) { // Handle as standard type } else { // Handle as custom type } ``` --- ## Guidelines for New Built-in Types When adding new built-in types: 1. ✅ Add constant to `types.go` 2. ✅ Define Props structure 3. ✅ Add helper function in `builtin.go` 4. ✅ Update all adapters to support it 5. ✅ Document in this file 6. ✅ Add tests **Only add built-in types for:** - Universal concepts (text, errors, etc.) - LLM-specific features (thinking, tool_calls) - Types that need cross-adapter consistency **Do NOT add built-in types for:** - UI components (buttons, forms, etc.) - Application-specific widgets - Domain-specific data types These should remain custom types. ================================================ FILE: agent/output/README.md ================================================ # Output Module The output module provides a unified API for sending messages to different client types (CUI, OpenAI-compatible, etc.) with support for streaming and rich media content. ## Architecture ``` agent/output/ ├── message/ # Core types and interfaces (no dependencies) │ ├── types.go # Message, Group, Props structures │ └── interfaces.go # Writer, Adapter, Factory interfaces ├── adapters/ # Client-specific adapters │ ├── cui/ # CUI adapter (native DSL) │ │ ├── adapter.go │ │ └── writer.go │ └── openai/ # OpenAI adapter (converts to OpenAI format) │ ├── adapter.go │ ├── converter.go │ ├── writer.go │ ├── types.go │ └── factory.go ├── output.go # Main API (Send, GetWriter, etc.) ├── builtin.go # Helper functions for built-in types └── BUILTIN_TYPES.md # Documentation for built-in types ``` ## DSL Structure ### Message Structure The universal message DSL is a JSON structure that supports streaming, rich media, and incremental updates: ```go type Message struct { // Core fields Type string `json:"type"` // Message type (e.g., "text", "image", "action") Props map[string]interface{} `json:"props,omitempty"` // Type-specific properties // Streaming control - Hierarchical structure for Agent/LLM/MCP streaming ChunkID string `json:"chunk_id,omitempty"` // Unique chunk ID (C1, C2, C3...; for dedup/ordering/debugging) MessageID string `json:"message_id,omitempty"` // Logical message ID (M1, M2, M3...; delta merge target; multiple chunks → one message) BlockID string `json:"block_id,omitempty"` // Block ID (B1, B2, B3...; Agent-level grouping for UI sections) ThreadID string `json:"thread_id,omitempty"` // Thread ID (T1, T2, T3...; optional; for concurrent streams) // Delta control Delta bool `json:"delta,omitempty"` // Whether this is an incremental update DeltaPath string `json:"delta_path,omitempty"` // Which field to update (e.g., "content", "items.0.name") DeltaAction string `json:"delta_action,omitempty"` // How to update ("append", "replace", "merge", "set") // Type correction (for streaming type inference) TypeChange bool `json:"type_change,omitempty"` // Marks this as a type correction message // Metadata Metadata *Metadata `json:"metadata,omitempty"` // Timestamp, sequence, trace ID } ``` ### Field Descriptions #### Core Fields - **`Type`** (required): Determines how the message should be rendered - Built-in types: `text`, `thinking`, `loading`, `tool_call`, `error`, `image`, `audio`, `video`, `action`, `event` - Custom types: Any string (frontend must have corresponding component) - **`Props`** (optional): Type-specific properties passed to the rendering component - For `text`: `{"content": "Hello"}` - For `image`: `{"url": "...", "alt": "..."}` - For custom types: Any JSON-serializable data #### Streaming Control Hierarchical structure for fine-grained control over streaming in complex Agent/LLM/MCP scenarios: - **`ChunkID`** (optional): Unique chunk identifier - Auto-generated (C1, C2, C3...) - For deduplication, ordering, and debugging - Each raw stream fragment gets a unique ChunkID - **`MessageID`** (optional): Logical message identifier - Auto-generated (M1, M2, M3...) - Delta merge target - multiple chunks with same MessageID are merged - Represents one complete logical message (e.g., one thinking output, one text response) - Example: `"M1"` - **`BlockID`** (optional): Output block identifier - Auto-generated (B1, B2, B3...) - Agent-level grouping for UI sections - One LLM call, one MCP call, or one Agent sub-task - Used for rendering blocks/sections in the UI - **`ThreadID`** (optional): Thread identifier - Auto-generated (T1, T2, T3...) - For concurrent Agent/LLM/MCP calls - Distinguishes multiple parallel output streams - **`Delta`** (optional): Marks this as an incremental update - `true`: Append/update to existing message with same MessageID - `false`: Complete message (default) - Used for streaming LLM responses #### Delta Update Control For complex, structured messages that need field-level updates: - **`DeltaPath`** (optional): JSON path to the field being updated - Simple: `"content"` (updates `props.content`) - Nested: `"user.name"` (updates `props.user.name`) - Array: `"items.0.title"` (updates `props.items[0].title`) - **`DeltaAction`** (optional): How to apply the delta update - `"append"`: Concatenate to existing string/array - `"replace"`: Replace entire value - `"merge"`: Merge objects (shallow merge) - `"set"`: Set new field (if doesn't exist) #### Type Correction - **`TypeChange`** (optional): Indicates message type was corrected - Used when initial type inference was wrong - Frontend should re-render with new type - Example: Initially sent as `text`, corrected to `thinking` #### Metadata - **`Metadata`** (optional): Additional message metadata ```go type Metadata struct { Timestamp int64 // Unix nanoseconds Sequence int // Message sequence number TraceID string // For debugging/logging } ``` ### Message Examples #### Simple Text Message ```json { "type": "text", "props": { "content": "Hello, world!" } } ``` #### Streaming Text (Delta Updates) ```json // First chunk { "chunk_id": "C1", "message_id": "M1", "type": "text", "delta": true, "props": { "content": "Hello" } } // Second chunk (appends) { "chunk_id": "C2", "message_id": "M1", "type": "text", "delta": true, "props": { "content": ", world" } } // Third chunk { "chunk_id": "C3", "message_id": "M1", "type": "text", "delta": true, "props": { "content": "!" } } // Completion signaled by message_end event (sent separately) { "type": "event", "props": { "event": "message_end", "data": { "message_id": "M1", "type": "text", "chunk_count": 3, "status": "completed" } } } ``` #### Complex Type with Nested Updates ```json // Initial message { "message_id": "M2", "type": "table", "props": { "columns": ["Name", "Age"], "rows": [] } } // Add first row { "chunk_id": "C4", "message_id": "M2", "type": "table", "delta": true, "delta_path": "rows", "delta_action": "append", "props": { "rows": [{"name": "Alice", "age": 30}] } } // Add second row { "chunk_id": "C5", "message_id": "M2", "type": "table", "delta": true, "delta_path": "rows", "delta_action": "append", "props": { "rows": [{"name": "Bob", "age": 25}] } } ``` #### Type Correction ```json // Initial guess (text) { "chunk_id": "C6", "message_id": "M3", "type": "text", "delta": true, "props": { "content": "Let me think..." } } // Correction (actually thinking) { "chunk_id": "C7", "message_id": "M3", "type": "thinking", "type_change": true, "props": { "content": "Let me think..." } } ``` #### Block Grouping (Agent-level) ```json // Block start event { "type": "event", "props": { "event": "block_start", "data": { "block_id": "B1", "type": "llm", "label": "Analyzing image" } } } // Thinking message in block { "message_id": "M4", "block_id": "B1", "type": "thinking", "props": { "content": "Let me analyze this image..." } } // Text message in block { "message_id": "M5", "block_id": "B1", "type": "text", "props": { "content": "This is a beautiful sunset at Golden Gate Bridge" } } // Block end event { "type": "event", "props": { "event": "block_end", "data": { "block_id": "B1", "message_count": 2, "status": "completed" } } } ``` ## Key Design Decisions ### 1. Separate `message` Package To avoid circular dependencies, all core types and interfaces are defined in the `message` sub-package: - `message.Message` - Universal message DSL - `message.Writer` - Interface for writing messages - `message.Adapter` - Interface for format conversion This allows: - `handlers` → `output` → `message` ✅ - `output/adapters` → `message` ✅ - No circular dependencies! ### 2. Adapter Pattern Different clients require different formats: **CUI Clients:** ```json { "type": "text", "props": { "content": "Hello" } } ``` **OpenAI Clients:** ```json { "choices": [ { "delta": { "content": "Hello" } } ] } ``` Adapters handle the transformation automatically based on `ctx.Accept`. ### 3. Built-in Types 10 standardized message types with defined Props structures: | Type | Purpose | CUI | OpenAI | | ----------- | ------------------ | ------- | -------------------------------------------- | | `text` | Text content | Direct | `delta.content` | | `thinking` | LLM reasoning | Direct | `delta.reasoning_content` | | `loading` | Progress indicator | Direct | `delta.reasoning_content` | | `tool_call` | Function calls | Direct | `delta.tool_calls` | | `error` | Error messages | Direct | `error` | | `image` | Images | Render | `![](url)` markdown | | `audio` | Audio | Player | Link | | `video` | Video | Player | Link | | `action` | System commands | Execute | Silent | | `event` | Lifecycle events | Track | Conditional (stream_start converted to link) | ## Usage ### Basic Usage ```go import ( "github.com/yaoapp/yao/agent/output" "github.com/yaoapp/yao/agent/output/message" ) // Send a text message msg := output.NewTextMessage("Hello world") output.Send(ctx, msg) // Send a loading indicator loading := output.NewLoadingMessage("Searching knowledge base...") output.Send(ctx, loading) // Send an image img := output.NewImageMessage("https://example.com/image.jpg", "Description") output.Send(ctx, img) // Send an error err := output.NewErrorMessage("Connection failed", "TIMEOUT") output.Send(ctx, err) ``` ### Streaming Messages ```go // Get ID generator from context idGen := ctx.IDGenerator // Send delta (incremental) updates msg := &message.Message{ ChunkID: idGen.GenerateChunkID(), // C1 MessageID: idGen.GenerateMessageID(), // M1 Type: message.TypeText, Delta: true, // Incremental update Props: map[string]interface{}{ "content": "Hello", }, } output.Send(ctx, msg) // Send more delta updates (same MessageID for merging) msg2 := &message.Message{ ChunkID: idGen.GenerateChunkID(), // C2 MessageID: msg.MessageID, // M1 (same as before) Type: message.TypeText, Delta: true, Props: map[string]interface{}{ "content": " world", }, } output.Send(ctx, msg2) // Mark completion with message_end event endData := message.EventMessageEndData{ MessageID: msg.MessageID, // M1 Type: "text", Status: "completed", ChunkCount: 2, Extra: map[string]interface{}{ "content": "Hello world!", // Full content }, } eventMsg := output.NewEventMessage(message.EventMessageEnd, "Message completed", endData) output.Send(ctx, eventMsg) ``` ### Custom Writers ```go // Register a custom writer factory factory := &MyCustomFactory{} output.SetWriterFactory(factory) // Now all calls to output.Send will use your custom writer ``` ## Integration with Handlers The `handlers` package uses the output module for streaming: ```go func DefaultStreamHandler(ctx *context.Context) context.StreamFunc { return func(chunkType context.StreamChunkType, data []byte) int { switch chunkType { case context.ChunkText: msg := output.NewTextMessage(string(data)) output.Send(ctx, msg) case context.ChunkThinking: msg := output.NewThinkingMessage(string(data)) output.Send(ctx, msg) // ... handle other types } return 0 // Continue } } ``` ## Context-based Routing The output module automatically selects the right writer based on `ctx.Accept`: | `ctx.Accept` | Writer | Format | | ------------- | ------ | --------------------- | | `standard` | OpenAI | OpenAI-compatible SSE | | `cui-web` | CUI | Universal DSL JSON | | `cui-native` | CUI | Universal DSL JSON | | `cui-desktop` | CUI | Universal DSL JSON | ## Writer Caching Writers are cached per context to avoid recreating them: ```go // Get or create writer (cached) writer := output.GetWriter(ctx) // Clear cache when done output.Close(ctx) // Also closes the writer ``` ## See Also - [BUILTIN_TYPES.md](./BUILTIN_TYPES.md) - Complete documentation of built-in message types - [adapters/openai/README.md](./adapters/openai/README.md) - OpenAI adapter documentation ================================================ FILE: agent/output/adapters/cui/adapter.go ================================================ package cui import "github.com/yaoapp/yao/agent/output/message" // Adapter implements the message.Adapter interface for CUI clients. // It performs no conversion and outputs messages as-is, as CUI clients // are designed to directly consume the universal DSL. type Adapter struct{} // NewAdapter creates a new CUI adapter. func NewAdapter() *Adapter { return &Adapter{} } // Adapt converts a universal Message to one or more client-specific chunks. // For CUI, it simply returns the original message as a single chunk. func (a *Adapter) Adapt(msg *message.Message) ([]interface{}, error) { // CUI clients consume the universal DSL directly, so no conversion is needed. // This includes all message types like text, thinking, loading, events, etc. // CUI clients can choose to display or ignore event messages. return []interface{}{msg}, nil } // SupportsType checks if the adapter explicitly supports a given message type. // CUI adapter supports all types as it renders them directly. func (a *Adapter) SupportsType(msgType string) bool { return true } ================================================ FILE: agent/output/adapters/cui/writer.go ================================================ package cui import ( "encoding/json" "net/http" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/agent/output/message" traceTypes "github.com/yaoapp/yao/trace/types" ) // Writer implements the message.Writer interface for CUI clients type Writer struct { Writer http.ResponseWriter Trace traceTypes.Manager Locale string adapter *Adapter } // NewWriter creates a new CUI writer // The Writer should already be wrapped in SafeWriter by the context layer // to ensure thread-safe concurrent writes for SSE streaming. func NewWriter(options message.Options) (*Writer, error) { return &Writer{ Writer: options.Writer, Trace: options.Trace, Locale: options.Locale, adapter: NewAdapter(), }, nil } // Write writes a single message to the output stream func (w *Writer) Write(msg *message.Message) error { // CUI adapter passes messages through as-is chunks, err := w.adapter.Adapt(msg) if err != nil { if w.Trace != nil { w.Trace.Error(i18n.T(w.Locale, "output.cui.writer.adapt_error"), map[string]any{ // "CUI Writer: Failed to adapt message" "error": err.Error(), "message_type": msg.Type, }) } return err } // Send each chunk for _, chunk := range chunks { if err := w.sendChunk(chunk); err != nil { if w.Trace != nil { w.Trace.Error(i18n.T(w.Locale, "output.cui.writer.chunk_error"), map[string]any{"error": err.Error()}) // "CUI Writer: Failed to send chunk" } return err } } return nil } // WriteGroup writes a message group to the output stream func (w *Writer) WriteGroup(group *message.Group) error { // For CUI, we send a group start message, all messages, then a group end message // The group structure itself is also sent for clients that want it // Send the group if err := w.sendChunk(group); err != nil { if w.Trace != nil { w.Trace.Error(i18n.T(w.Locale, "output.cui.writer.group_error"), map[string]any{ // "CUI Writer: Failed to send message group" "error": err.Error(), "group_id": group.ID, }) } return err } return nil } // Flush flushes any buffered data to the output stream func (w *Writer) Flush() error { // For SSE, we don't need explicit flushing // The underlying connection handles it return nil } // Close closes the writer and cleans up resources func (w *Writer) Close() error { // Nothing to clean up for CUI writer // SafeWriter cleanup is handled by the context layer return nil } // sendChunk sends a chunk to the output stream func (w *Writer) sendChunk(chunk interface{}) error { // Convert chunk to JSON data, err := json.Marshal(chunk) if err != nil { if w.Trace != nil { w.Trace.Error(i18n.T(w.Locale, "output.cui.writer.marshal_error"), map[string]any{"error": err.Error()}) // "CUI Writer: Failed to marshal chunk" } return err } // Log outgoing data to trace for debugging if w.Trace != nil { w.Trace.Debug("CUI Writer: Sending chunk to client", map[string]any{ "data": string(data), }) } // Format as SSE (Server-Sent Events) format: "data: {json}\n\n" sseData := []byte("data: ") sseData = append(sseData, data...) sseData = append(sseData, '\n', '\n') // Send via context's writer // The context knows how to send data based on the connection type (SSE, WebSocket, etc.) if err := w.sendData(sseData); err != nil { if w.Trace != nil { w.Trace.Error(i18n.T(w.Locale, "output.cui.writer.send_error"), map[string]any{"error": err.Error()}) // "CUI Writer: Failed to send data to client" } return err } w.flush() return nil } func (w *Writer) flush() error { if w.Writer == nil { return nil // No writer, silently ignore } if flusher, ok := w.Writer.(interface{ Flush() }); ok { flusher.Flush() } return nil } func (w *Writer) sendData(data []byte) error { if w.Writer == nil { return nil // No writer, silently ignore } _, err := w.Writer.Write(data) return err } ================================================ FILE: agent/output/adapters/openai/README.md ================================================ # OpenAI Adapter OpenAI adapter converts universal DSL messages to OpenAI-compatible format. ## Conversion Rules ### Built-in Types (Standard) These types are defined in `output.types.go` and have standardized Props structures that all adapters must support: | Message Type | Constant | Props Structure | OpenAI Format | Description | | ------------ | --------------------- | --------------- | ------------------------- | ----------------------------------------- | | `text` | `output.TypeText` | `TextProps` | `delta.content` | Plain text or Markdown | | `thinking` | `output.TypeThinking` | `ThinkingProps` | `delta.reasoning_content` | Reasoning process (o1 models) | | `loading` | `output.TypeLoading` | `LoadingProps` | `delta.reasoning_content` | Loading indicator (shows as thinking) | | `tool_call` | `output.TypeToolCall` | `ToolCallProps` | `delta.tool_calls` | Tool/function calls | | `error` | `output.TypeError` | `ErrorProps` | `error` | Error messages | | `action` | `output.TypeAction` | `ActionProps` | (not sent) | System actions (silent) | | `event` | `output.TypeEvent` | `EventProps` | (conditional) | Lifecycle events (stream_start converted) | ### Event Type (Lifecycle Events) The `event` type has special handling in the OpenAI adapter: | Event Name | Conversion | Example Output | | -------------- | ------------------------------------------- | --------------------------------------------------- | | `stream_start` | Converted to trace link (with i18n support) | 🔍 智能体正在处理 - [查看处理详情](/trace/xxx/view) | | Other events | Silent (not sent) | - | **Conversion Logic for `stream_start`:** 1. **Extract trace data**: Gets `TraceID` from event data 2. **Check model capabilities**: Determines if model supports reasoning 3. **Format based on capabilities**: - **Reasoning models** (o1, DeepSeek R1): Uses `reasoning_content` field with 🔍 icon - **Regular models**: Uses `content` field with 🚀 icon 4. **Apply i18n**: Uses locale from context for localized text 5. **Generate trace link**: Creates clickable link to `/trace/{traceID}/view` for standalone viewing **Example Conversion:** ```go // Input (event message) { "type": "event", "props": { "event": "stream_start", "message": "Stream started", "data": { "trace_id": "20251122779905354593", "request_id": "ctx-1763779905679380000", "chat_id": "uP4CWZCMHy84nCw7" } } } // Output (reasoning model - Chinese locale) { "choices": [{ "delta": { "reasoning_content": "🔍 智能体正在处理 - [查看处理详情](http://localhost:8000/__yao_admin_root/trace/20251122779905354593/view)\n" } }] } // Output (regular model - English locale) { "choices": [{ "delta": { "content": "🚀 Assistant is processing - [View process](http://localhost:8000/__yao_admin_root/trace/20251122779905354593/view)\n" } }] } ``` **Internationalization:** The adapter uses `i18n.T()` to provide localized text: | Key | English (en-us) | Chinese (zh-cn) | | --------------------- | ----------------------- | --------------- | | `output.stream_start` | Assistant is processing | 智能体正在处理 | | `output.view_trace` | View process | 查看处理详情 | ### Custom Types All other message types (not in the built-in list) are converted to Markdown links: | Format | Example | | ---------------------- | ------------------------------- | | `delta.content` (link) | `"🖼️ [View Image](https://...)` | ## Usage ### Basic Usage ```go import ( "github.com/yaoapp/yao/agent/output/adapters/openai" ) // Create adapter with default config adapter := openai.NewAdapter() // Convert message chunks, err := adapter.Adapt(msg) ``` ### With Custom Configuration ```go // Create adapter with options adapter := openai.NewAdapter( openai.WithBaseURL("https://api.example.com"), openai.WithModel("gpt-4"), openai.WithLinkTemplate("image", "🖼️ [View Image](%s)"), openai.WithLinkTransformer(myOTPTransformer), ) ``` ### With Link Transformer (OTP) ```go // Define OTP transformer func otpTransformer(url string, msgType string, msgID string) (string, error) { // Generate OTP token otp := generateOTP(msgID, 3600) // 1 hour expiry // Create short link with OTP shortURL := fmt.Sprintf("https://api.example.com/s/%s?t=%s", msgID, otp) return shortURL, nil } // Use transformer adapter := openai.NewAdapter( openai.WithLinkTransformer(otpTransformer), ) ``` ### Custom Converter ```go // Register custom converter for a specific type adapter := openai.NewAdapter( openai.WithConverter("my_widget", func(msg *output.Message, config *openai.AdapterConfig) ([]interface{}, error) { // Custom conversion logic return []interface{}{ // OpenAI format chunk }, nil }), ) ``` ## Examples ### Text Message (Built-in Type) **Input (DSL):** ```json { "type": "text", "props": { "content": "Hello world" } } ``` Or using helper: ```go msg := output.NewTextMessage("Hello world") ``` **Output (OpenAI):** ```json { "id": "M1", "object": "chat.completion.chunk", "model": "yao-agent", "choices": [ { "delta": { "content": "Hello world" } } ] } ``` ### Image Message **Input (DSL):** ```json { "message_id": "M2", "type": "image", "props": { "url": "https://example.com/avatar.jpg" } } ``` **Output (OpenAI):** ```json { "id": "M2", "object": "chat.completion.chunk", "model": "yao-agent", "choices": [ { "delta": { "content": "🖼️ [View Image](https://api.example.com/s/M2?t=abc123)" } } ] } ``` ### Button Message **Input (DSL):** ```json { "message_id": "M3", "type": "button", "props": { "text": "Approve", "action": "workflow.approve" } } ``` **Output (OpenAI):** ```json { "id": "M3", "object": "chat.completion.chunk", "model": "yao-agent", "choices": [ { "delta": { "content": "🔘 [Approve](https://api.example.com/s/M3?t=abc123)" } } ] } ``` ## Link Templates Default templates: ```go "image": "🖼️ [View Image](%s)" "audio": "🔊 [Play Audio](%s)" "video": "🎬 [Watch Video](%s)" "file": "📎 [Download File](%s)" "page": "📄 [Open Page](%s)" "table": "📊 [View Table](%s)" "chart": "📈 [View Chart](%s)" "list": "📋 [View List](%s)" "form": "📝 [Fill Form](%s)" "button": "🔘 [%s](%s)" // Special: button text + link ``` Customize templates: ```go adapter := openai.NewAdapter( openai.WithLinkTemplate("image", "📷 Image: %s"), openai.WithLinkTemplate("video", "🎥 Watch: %s"), ) ``` ## Link Transformer (TODO) The link transformer is currently left empty for future implementation of OTP/short link functionality. **Planned features:** - Generate one-time password (OTP) for secure access - Create short URLs for better readability - Set expiration time for links - Track link access for analytics **Example implementation:** ```go func otpTransformer(url string, msgType string, msgID string) (string, error) { // TODO: Implement OTP generation // 1. Generate OTP token with expiry // 2. Store mapping: token -> (url, msgType, msgID, expiry) // 3. Create short URL with token // 4. Return short URL return url, nil // Currently pass-through } ``` ================================================ FILE: agent/output/adapters/openai/adapter.go ================================================ package openai import "github.com/yaoapp/yao/agent/output/message" // Adapter is the OpenAI adapter that converts messages to OpenAI format type Adapter struct { config *AdapterConfig registry *ConverterRegistry } // NewAdapter creates a new OpenAI adapter with default configuration func NewAdapter(options ...Option) *Adapter { adapter := &Adapter{ config: DefaultAdapterConfig(), registry: NewConverterRegistry(), } // Apply options for _, opt := range options { opt(adapter) } return adapter } // Option is a function that configures the adapter type Option func(*Adapter) // WithBaseURL sets the base URL for generating view links func WithBaseURL(baseURL string) Option { return func(a *Adapter) { a.config.BaseURL = baseURL } } // WithLinkTemplate sets a custom link template for a message type func WithLinkTemplate(msgType string, template string) Option { return func(a *Adapter) { a.config.LinkTemplates[msgType] = template } } // WithLinkTransformer sets the link transformer function func WithLinkTransformer(transformer LinkTransformer) Option { return func(a *Adapter) { a.config.LinkTransformer = transformer } } // WithModel sets the model name for OpenAI responses func WithModel(model string) Option { return func(a *Adapter) { a.config.Model = model } } // WithCapabilities sets the model capabilities func WithCapabilities(capabilities *ModelCapabilities) Option { return func(a *Adapter) { a.config.Capabilities = capabilities } } // WithLocale sets the locale for internationalization func WithLocale(locale string) Option { return func(a *Adapter) { a.config.Locale = locale } } // WithConverter registers a custom converter for a message type func WithConverter(msgType string, converter ConverterFunc) Option { return func(a *Adapter) { a.registry.Register(msgType, converter) } } // Adapt converts a universal Message to OpenAI-compatible format func (a *Adapter) Adapt(msg *message.Message) ([]interface{}, error) { // Handle event messages specially if msg.Type == message.TypeEvent { // Check if this is a stream_start event if event, ok := msg.Props["event"].(string); ok && event == message.EventStreamStart { // Use the stream_start converter if converter, exists := a.registry.GetConverter(message.EventStreamStart); exists { return converter(msg, a.config) } } // Other event messages are CUI-only, skip them return []interface{}{}, nil // Return empty array, nothing to send } // Get converter for this message type converter, exists := a.registry.GetConverter(msg.Type) if !exists { // Use default converter for unknown types (convert to link) converter = convertToLink } // Convert the message return converter(msg, a.config) } // SupportsType checks if the adapter explicitly supports a given message type func (a *Adapter) SupportsType(msgType string) bool { _, exists := a.registry.GetConverter(msgType) return exists } // GetConfig returns the adapter configuration func (a *Adapter) GetConfig() *AdapterConfig { return a.config } // GetRegistry returns the converter registry func (a *Adapter) GetRegistry() *ConverterRegistry { return a.registry } ================================================ FILE: agent/output/adapters/openai/converter.go ================================================ package openai import ( "fmt" "time" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/agent/output/message" ) // ConverterRegistry manages message type converters type ConverterRegistry struct { converters map[string]ConverterFunc } // NewConverterRegistry creates a new converter registry with default converters func NewConverterRegistry() *ConverterRegistry { return &ConverterRegistry{ converters: map[string]ConverterFunc{ message.TypeText: convertText, message.TypeThinking: convertThinking, message.TypeLoading: convertLoading, message.TypeToolCall: convertToolCall, message.TypeError: convertError, message.TypeImage: convertImage, message.TypeAudio: convertToLink, message.TypeVideo: convertToLink, message.TypeAction: convertAction, message.EventStreamStart: convertStreamStart, // Handle stream_start events }, } } // Register registers a custom converter for a message type func (r *ConverterRegistry) Register(msgType string, converter ConverterFunc) { r.converters[msgType] = converter } // GetConverter retrieves a converter for a given message type. func (r *ConverterRegistry) GetConverter(msgType string) (ConverterFunc, bool) { converter, exists := r.converters[msgType] return converter, exists } // Convert converts a message using registered converters // If no converter is found, converts to link format func (r *ConverterRegistry) Convert(msg *message.Message, config *AdapterConfig) ([]interface{}, error) { // Check for registered converter if converter, exists := r.converters[msg.Type]; exists { return converter(msg, config) } // Fallback: convert to link format return convertToLink(msg, config) } // convertText converts text messages to OpenAI format func convertText(msg *message.Message, config *AdapterConfig) ([]interface{}, error) { content := getStringProp(msg.Props, "content", "") return []interface{}{ createOpenAIChunk(msg.MessageID, config.Model, map[string]interface{}{ "content": content, }), }, nil } // convertThinking converts thinking messages to OpenAI reasoning format (o1 series) func convertThinking(msg *message.Message, config *AdapterConfig) ([]interface{}, error) { content := getStringProp(msg.Props, "content", "") return []interface{}{ createOpenAIChunk(msg.MessageID, config.Model, map[string]interface{}{ "reasoning_content": content, }), }, nil } // convertLoading converts loading messages to OpenAI reasoning format // This makes loading messages visible in standard OpenAI clients as thinking process func convertLoading(msg *message.Message, config *AdapterConfig) ([]interface{}, error) { message := getStringProp(msg.Props, "message", "Processing...") // Convert loading to reasoning_content so it shows in OpenAI clients return []interface{}{ createOpenAIChunk(msg.MessageID, config.Model, map[string]interface{}{ "reasoning_content": message, }), }, nil } // convertToolCall converts tool_call messages to OpenAI format func convertToolCall(msg *message.Message, config *AdapterConfig) ([]interface{}, error) { // Tool call format varies, pass through the props toolCalls := []map[string]interface{}{} // If props contain tool call data, use it if id, ok := msg.Props["id"].(string); ok { toolCall := map[string]interface{}{ "id": id, "type": "function", } if function, ok := msg.Props["function"].(map[string]interface{}); ok { toolCall["function"] = function } toolCalls = append(toolCalls, toolCall) } return []interface{}{ createOpenAIChunk(msg.MessageID, config.Model, map[string]interface{}{ "tool_calls": toolCalls, }), }, nil } // convertError converts error messages to OpenAI error format func convertError(msg *message.Message, config *AdapterConfig) ([]interface{}, error) { message := getStringProp(msg.Props, "message", "An error occurred") code := getStringProp(msg.Props, "code", "server_error") return []interface{}{ map[string]interface{}{ "error": map[string]interface{}{ "message": message, "type": code, "code": code, }, }, }, nil } // convertAction converts action messages to nothing (silent in OpenAI clients) // Action messages are system-level commands (open panel, navigate, etc.) // and should not be sent to standard chat clients func convertAction(msg *message.Message, config *AdapterConfig) ([]interface{}, error) { // Return empty slice - no output for action messages in OpenAI format return []interface{}{}, nil } // convertStreamStart converts stream_start event to OpenAI format // If model supports reasoning: converts to reasoning_content (thinking) // Otherwise: converts to regular Markdown text with trace link func convertStreamStart(msg *message.Message, config *AdapterConfig) ([]interface{}, error) { // Extract stream_start data from props data, ok := msg.Props["data"] if !ok { // No data, skip this message return []interface{}{}, nil } // Try to convert to EventStreamStartData var startData message.EventStreamStartData switch v := data.(type) { case message.EventStreamStartData: startData = v case map[string]interface{}: // If it's a map, try to extract traceID if traceID, ok := v["trace_id"].(string); ok { startData.TraceID = traceID } if requestID, ok := v["request_id"].(string); ok { startData.RequestID = requestID } default: // Unknown data type, skip return []interface{}{}, nil } // Check if we have a trace ID to link to if startData.TraceID == "" { // No trace ID, skip this message return []interface{}{}, nil } // Generate trace link traceLink := generateTraceLink(startData.TraceID, config) // Check if model supports reasoning supportsReasoning := false if config.Capabilities != nil && config.Capabilities.Reasoning != nil { supportsReasoning = *config.Capabilities.Reasoning } // Get localized text using i18n streamStartText := i18n.T(config.Locale, "output.stream_start") viewTraceText := i18n.T(config.Locale, "output.view_trace") // Convert based on reasoning support if supportsReasoning { // Convert to thinking format (reasoning_content) // Reasoning models display this as part of the thinking process content := fmt.Sprintf("🔍 %s - [%s](%s)\n", streamStartText, viewTraceText, traceLink) chunk := createOpenAIChunk(msg.MessageID, config.Model, map[string]interface{}{ "reasoning_content": content, }) return []interface{}{chunk}, nil } // Convert to regular Markdown text content := fmt.Sprintf("🚀 %s - [%s](%s)\n", streamStartText, viewTraceText, traceLink) chunk := createOpenAIChunk(msg.MessageID, config.Model, map[string]interface{}{ "content": content, }) return []interface{}{chunk}, nil } // generateTraceLink generates a trace link URL // Uses 'view' mode (clean page without sidebar) for better viewing experience in chat func generateTraceLink(traceID string, config *AdapterConfig) string { baseURL := config.BaseURL if baseURL == "" { // If no base URL, return a relative link return fmt.Sprintf("/trace/%s/view", traceID) } return fmt.Sprintf("%s/trace/%s/view", baseURL, traceID) } // convertImage converts image messages to Markdown image format // Uses ![alt](url) which displays inline in Markdown-supporting clients func convertImage(msg *message.Message, config *AdapterConfig) ([]interface{}, error) { // Get URL url, ok := msg.Props["url"].(string) if !ok || url == "" { return nil, fmt.Errorf("image message missing url") } // Transform URL if transformer is provided if config.LinkTransformer != nil { transformedURL, err := config.LinkTransformer(url, msg.Type, msg.MessageID) if err != nil { return nil, err } url = transformedURL } // Get alt text (default to "Image") alt := getStringProp(msg.Props, "alt", "Image") // Format as Markdown image: ![alt](url) template := getLinkTemplate(msg.Type, config) text := fmt.Sprintf(template, alt, url) return []interface{}{ createOpenAIChunk(msg.MessageID, config.Model, map[string]interface{}{ "content": text, }), }, nil } // convertToLink converts any message type to a Markdown link format func convertToLink(msg *message.Message, config *AdapterConfig) ([]interface{}, error) { // Generate link link, err := generateViewLink(msg, config) if err != nil { return nil, err } // Get template template := getLinkTemplate(msg.Type, config) // Format text var text string if msg.Type == "button" { // Button is special: needs button text buttonText := getStringProp(msg.Props, "text", "Button") text = fmt.Sprintf(template, buttonText, link) } else { text = fmt.Sprintf(template, link) } return []interface{}{ createOpenAIChunk(msg.MessageID, config.Model, map[string]interface{}{ "content": text, }), }, nil } // generateViewLink generates a view link for a message func generateViewLink(msg *message.Message, config *AdapterConfig) (string, error) { // If Props contains a URL, use it if url, ok := msg.Props["url"].(string); ok { // Transform URL if transformer is provided if config.LinkTransformer != nil { return config.LinkTransformer(url, msg.Type, msg.MessageID) } return url, nil } // Generate view link: {baseURL}/agent/view/{type}/{id} baseURL := config.BaseURL if baseURL == "" { baseURL = "" // TODO: Get from environment or context } viewURL := fmt.Sprintf("%s/agent/view/%s/%s", baseURL, msg.Type, msg.MessageID) // Transform URL if transformer is provided if config.LinkTransformer != nil { return config.LinkTransformer(viewURL, msg.Type, msg.MessageID) } return viewURL, nil } // getLinkTemplate gets the link template for a message type func getLinkTemplate(msgType string, config *AdapterConfig) string { if template, exists := config.LinkTemplates[msgType]; exists { return template } // Default fallback template return "📎 [View %s](" + msgType + ")" } // createOpenAIChunk creates an OpenAI chat completion chunk func createOpenAIChunk(id string, model string, delta map[string]interface{}) map[string]interface{} { return map[string]interface{}{ "id": id, "object": "chat.completion.chunk", "created": time.Now().Unix(), "model": model, "choices": []map[string]interface{}{ { "index": 0, "delta": delta, "finish_reason": nil, }, }, } } // getStringProp safely gets a string property from props func getStringProp(props map[string]interface{}, key string, defaultValue string) string { if val, ok := props[key].(string); ok { return val } return defaultValue } ================================================ FILE: agent/output/adapters/openai/types.go ================================================ package openai import "github.com/yaoapp/yao/agent/output/message" // ConverterFunc converts a message to OpenAI format chunks type ConverterFunc func(msg *message.Message, config *AdapterConfig) ([]interface{}, error) // LinkTransformer transforms a URL to a secure link (with OTP, short URL, etc.) // Returns the transformed link or error type LinkTransformer func(url string, msgType string, msgID string) (string, error) // AdapterConfig holds the configuration for OpenAI adapter type AdapterConfig struct { // BaseURL is the base URL for generating view links // Example: "https://api.example.com" BaseURL string // LinkTemplates defines the Markdown template for each message type // %s will be replaced with the link // Example: "🖼️ [View Image](%s)" LinkTemplates map[string]string // LinkTransformer transforms URLs to secure links with OTP // If nil, URLs are used as-is LinkTransformer LinkTransformer // Model name to include in OpenAI responses Model string // Capabilities holds the model capabilities // Used to determine how to convert certain message types (e.g., stream_start) Capabilities *ModelCapabilities // Locale for internationalization (e.g., "en-US", "zh-CN") Locale string } // ModelCapabilities is a simplified version of openai.Capabilities // We use a local type to avoid circular dependencies type ModelCapabilities struct { Reasoning *bool // Supports reasoning/thinking mode (o1, DeepSeek R1) } // DefaultLinkTemplates provides default Markdown templates for non-text message types var DefaultLinkTemplates = map[string]string{ "image": "![%s](%s)", // Markdown image: ![alt](url) - displays inline "audio": "🔊 [Play Audio](%s)", // Link (audio can't display inline in Markdown) "video": "🎬 [Watch Video](%s)", // Link (video can't display inline in Markdown) "file": "📎 [Download File](%s)", "page": "📄 [Open Page](%s)", "table": "📊 [View Table](%s)", "chart": "📈 [View Chart](%s)", "list": "📋 [View List](%s)", "form": "📝 [Fill Form](%s)", "button": "🔘 [%s](%s)", // Special: button needs two params (text, link) } // DefaultAdapterConfig returns a default adapter configuration func DefaultAdapterConfig() *AdapterConfig { return &AdapterConfig{ BaseURL: "", // Will be set from environment or context LinkTemplates: copyLinkTemplates(DefaultLinkTemplates), LinkTransformer: nil, // No transformation by default Model: "yao-agent", } } // copyLinkTemplates creates a copy of link templates func copyLinkTemplates(templates map[string]string) map[string]string { copy := make(map[string]string, len(templates)) for k, v := range templates { copy[k] = v } return copy } ================================================ FILE: agent/output/adapters/openai/writer.go ================================================ package openai import ( "encoding/json" "net/http" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/agent/output/message" traceTypes "github.com/yaoapp/yao/trace/types" ) // Writer implements the message.Writer interface for OpenAI-compatible clients type Writer struct { Writer http.ResponseWriter Trace traceTypes.Manager Locale string adapter *Adapter firstChunk bool // Track if this is the first chunk to add role } // NewWriter creates a new OpenAI writer func NewWriter(options message.Options) (*Writer, error) { // Get model capabilities from context (set by assistant) var capabilities *ModelCapabilities if options.Capabilities != nil && options.Capabilities.Reasoning { v := true capabilities = &ModelCapabilities{ Reasoning: &v, } } // Create adapter with capabilities, base URL, and locale adapter := NewAdapter( WithCapabilities(capabilities), WithBaseURL(getBaseURL(options.BaseURL)), WithLocale(options.Locale), ) return &Writer{ adapter: adapter, Writer: options.Writer, Locale: options.Locale, firstChunk: true, // First chunk should include role }, nil } // getBaseURL gets the base URL from context or environment func getBaseURL(baseURL string) string { // @todo: get from context metadata return "http://localhost:8000/__yao_admin_root" // // Try to get from context metadata // if ctx.Metadata != nil { // if baseURL, ok := ctx.Metadata["base_url"].(string); ok && baseURL != "" { // return baseURL // } // } // // TODO: Get from environment variable or config // return "" } // Write writes a single message to the output stream func (w *Writer) Write(msg *message.Message) error { // Convert message to OpenAI format using adapter chunks, err := w.adapter.Adapt(msg) if err != nil { if w.Trace != nil { w.Trace.Error(i18n.T(w.Locale, "output.openai.writer.adapt_error"), map[string]any{ // "OpenAI Writer: Failed to adapt message" "error": err.Error(), "message_type": msg.Type, }) } return err } // Send each chunk for _, chunk := range chunks { // Add role to first text chunk if w.firstChunk && (msg.Type == message.TypeText || msg.Type == message.TypeThinking) { if chunkMap, ok := chunk.(map[string]interface{}); ok { if choices, ok := chunkMap["choices"].([]map[string]interface{}); ok && len(choices) > 0 { if delta, ok := choices[0]["delta"].(map[string]interface{}); ok { delta["role"] = "assistant" w.firstChunk = false } } } } if err := w.sendChunk(chunk); err != nil { if w.Trace != nil { w.Trace.Error(i18n.T(w.Locale, "output.openai.writer.chunk_error"), map[string]any{"error": err.Error()}) // "OpenAI Writer: Failed to send chunk" } return err } } return nil } // WriteGroup writes a message group to the output stream func (w *Writer) WriteGroup(group *message.Group) error { // For OpenAI, we don't send group markers // Just send each message individually for _, msg := range group.Messages { if err := w.Write(msg); err != nil { if w.Trace != nil { w.Trace.Error(i18n.T(w.Locale, "output.openai.writer.group_error"), map[string]any{ // "OpenAI Writer: Failed to write message in group" "error": err.Error(), "group_id": group.ID, "message_type": msg.Type, }) } return err } } return nil } // Flush flushes any buffered data to the output stream func (w *Writer) Flush() error { // For SSE, we don't need explicit flushing // The underlying connection handles it return nil } // Close closes the writer and cleans up resources func (w *Writer) Close() error { // Send final [DONE] message for OpenAI compatibility return w.sendDone() } func (w *Writer) sendData(data []byte) error { if w.Writer == nil { return nil // No writer, silently ignore } _, err := w.Writer.Write(data) return err } func (w *Writer) flush() error { if w.Writer == nil { return nil // No writer, silently ignore } if flusher, ok := w.Writer.(interface{ Flush() }); ok { flusher.Flush() } return nil } // sendChunk sends a chunk to the output stream in SSE format func (w *Writer) sendChunk(chunk interface{}) error { // Convert chunk to JSON data, err := json.Marshal(chunk) if err != nil { if w.Trace != nil { w.Trace.Error(i18n.T(w.Locale, "output.openai.writer.marshal_error"), map[string]any{"error": err.Error()}) // "OpenAI Writer: Failed to marshal chunk" } return err } // Format as SSE: "data: {json}\n\n" sseData := append([]byte("data: "), data...) sseData = append(sseData, []byte("\n\n")...) // Log outgoing data to trace for debugging if w.Trace != nil { w.Trace.Debug("OpenAI Writer: Sending chunk to client", map[string]any{ "data": string(data), }) } // Send via context's writer if err := w.sendData(sseData); err != nil { if w.Trace != nil { w.Trace.Error(i18n.T(w.Locale, "output.openai.writer.send_error"), map[string]any{"error": err.Error()}) // "OpenAI Writer: Failed to send data to client" } return err } // Flush immediately to ensure real-time streaming // Cast to http.ResponseWriter and call Flush if available w.flush() return nil } // sendDone sends the final [DONE] message func (w *Writer) sendDone() error { // Log completion to trace if w.Trace != nil { w.Trace.Debug("OpenAI Writer: Sending [DONE] to client") } // OpenAI SSE format uses "data: [DONE]\n\n" to signal completion doneData := []byte("data: [DONE]\n\n") if err := w.sendData(doneData); err != nil { if w.Trace != nil { w.Trace.Error(i18n.T(w.Locale, "output.openai.writer.done_error"), map[string]any{"error": err.Error()}) // "OpenAI Writer: Failed to send [DONE] to client" } return err } // Flush the final [DONE] message w.flush() return nil } ================================================ FILE: agent/output/builtin.go ================================================ package output import ( "github.com/yaoapp/yao/agent/output/message" ) // Helper functions for creating built-in message types // NewUserInputMessage creates a user input message (for frontend display) // content can be string or []ContentPart for multimodal content func NewUserInputMessage(content interface{}, role, name string) *message.Message { props := map[string]interface{}{ "content": content, } if role != "" { props["role"] = role } if name != "" { props["name"] = name } return &message.Message{ Type: message.TypeUserInput, Props: props, } } // NewTextMessage creates a text message func NewTextMessage(content string) *message.Message { return &message.Message{ Type: message.TypeText, Props: map[string]interface{}{ "content": content, }, } } // NewThinkingMessage creates a thinking message func NewThinkingMessage(content string) *message.Message { return &message.Message{ Type: message.TypeThinking, Props: map[string]interface{}{ "content": content, }, } } // NewLoadingMessage creates a loading message func NewLoadingMessage(msg string) *message.Message { return &message.Message{ Type: message.TypeLoading, Props: map[string]interface{}{ "message": msg, }, } } // NewToolCallMessage creates a tool call message func NewToolCallMessage(id, name, arguments string) *message.Message { return &message.Message{ Type: message.TypeToolCall, Props: map[string]interface{}{ "id": id, "name": name, "arguments": arguments, }, } } // NewErrorMessage creates an error message func NewErrorMessage(msg, code string) *message.Message { return &message.Message{ Type: message.TypeError, Props: map[string]interface{}{ "message": msg, "code": code, }, } } // NewActionMessage creates an action message func NewActionMessage(name string, payload map[string]interface{}) *message.Message { return &message.Message{ Type: message.TypeAction, Props: map[string]interface{}{ "name": name, "payload": payload, }, } } // NewEventMessage creates an event message func NewEventMessage(event string, msg string, data interface{}) *message.Message { return &message.Message{ Type: message.TypeEvent, Props: map[string]interface{}{ "event": event, "message": msg, "data": data, }, } } // NewImageMessage creates an image message func NewImageMessage(url string, alt string) *message.Message { return &message.Message{ Type: message.TypeImage, Props: map[string]interface{}{ "url": url, "alt": alt, }, } } // NewAudioMessage creates an audio message func NewAudioMessage(url string, format string) *message.Message { return &message.Message{ Type: message.TypeAudio, Props: map[string]interface{}{ "url": url, "format": format, }, } } // NewVideoMessage creates a video message func NewVideoMessage(url string) *message.Message { return &message.Message{ Type: message.TypeVideo, Props: map[string]interface{}{ "url": url, }, } } // IsBuiltinType checks if a message type is a built-in type func IsBuiltinType(msgType string) bool { switch msgType { case message.TypeUserInput, message.TypeText, message.TypeThinking, message.TypeLoading, message.TypeToolCall, message.TypeError, message.TypeImage, message.TypeAudio, message.TypeVideo, message.TypeAction, message.TypeEvent: return true default: return false } } // GenerateID generates a unique message ID using nanoid // Deprecated: Use message.GenerateMessageID(), message.GenerateChunkID(), // message.GenerateBlockID(), or message.GenerateThreadID() instead func GenerateID() string { return message.GenerateNanoID() } ================================================ FILE: agent/output/jsapi/README.md ================================================ # Output JSAPI The Output JSAPI provides a JavaScript interface for sending output messages to clients from scripts (e.g., hooks, processes). It wraps the Go `output` package functionality and provides a convenient API for sending messages and message groups. ## Overview The Output object allows you to: - Send individual messages to clients in various formats (text, error, loading, etc.) - Send groups of related messages - Support streaming with delta updates - Handle different message types with custom properties ## Constructor ### `new Output(ctx)` Creates a new Output instance. **Parameters:** - `ctx` (Context): The agent context object **Returns:** - Output instance **Example:** ```javascript function Create(ctx, messages) { const output = new Output(ctx); // Use output methods... } ``` ## Methods ### `Send(message)` Sends a single message to the client. **Parameters:** - `message` (string | object): The message to send - If string: Automatically converted to a text message - If object: Must have a `type` field and optional `props` and other fields **Returns:** - Output instance (for chaining) **Message Object Structure:** ```javascript { type: string, // Required: Message type (e.g., "text", "error", "loading") props: object, // Optional: Message properties (type-specific) id: string, // Optional: Message ID (for streaming) delta: boolean, // Optional: Whether this is a delta update done: boolean, // Optional: Whether the message is complete delta_path: string, // Optional: Path for delta updates (e.g., "content") delta_action: string, // Optional: Delta action ("append", "replace", "merge", "set") type_change: boolean, // Optional: Whether this is a type correction group_id: string, // Optional: Parent message group ID group_start: boolean, // Optional: Marks the start of a message group group_end: boolean, // Optional: Marks the end of a message group metadata: { // Optional: Message metadata timestamp: number, sequence: number, trace_id: string } } ``` **Examples:** Send a simple text message (shorthand): ```javascript output.Send("Hello, world!"); ``` Send a text message (full): ```javascript output.Send({ type: "text", props: { content: "Hello, world!", }, }); ``` Send an error message: ```javascript output.Send({ type: "error", props: { message: "Something went wrong", code: "ERR_001", details: "Additional error details", }, }); ``` Send a loading indicator: ```javascript output.Send({ type: "loading", props: { message: "Searching knowledge base...", }, }); ``` Send streaming text with delta updates: ```javascript // First chunk output.Send({ type: "text", id: "msg-1", props: { content: "Hello" }, delta: true, done: false, }); // Subsequent chunks output.Send({ type: "text", id: "msg-1", props: { content: " world" }, delta: true, delta_path: "content", delta_action: "append", done: false, }); // Final chunk output.Send({ type: "text", id: "msg-1", props: { content: "!" }, delta: true, delta_path: "content", delta_action: "append", done: true, }); ``` Chain multiple sends: ```javascript output .Send("First message") .Send("Second message") .Send({ type: "loading", props: { message: "Processing..." } }); ``` ### `SendGroup(group)` Sends a group of related messages. **Parameters:** - `group` (object): The message group - `id` (string): Required - Group ID - `messages` (array): Required - Array of message objects - `metadata` (object): Optional - Group metadata **Returns:** - Output instance (for chaining) **Group Object Structure:** ```javascript { id: string, // Required: Message group ID messages: [ // Required: Array of messages { type: string, props: object, // ... other message fields } ], metadata: { // Optional: Group metadata timestamp: number, sequence: number, trace_id: string } } ``` **Examples:** Send a simple message group: ```javascript output.SendGroup({ id: "search-results", messages: [ { type: "text", props: { content: "Found 3 results:" } }, { type: "text", props: { content: "Result 1" } }, { type: "text", props: { content: "Result 2" } }, { type: "text", props: { content: "Result 3" } }, ], }); ``` Send a group with metadata: ```javascript output.SendGroup({ id: "analysis-group", messages: [ { type: "loading", props: { message: "Analyzing data..." } }, { type: "text", props: { content: "Analysis complete" } }, ], metadata: { timestamp: Date.now(), sequence: 1, trace_id: "trace-123", }, }); ``` ## Built-in Message Types The Output JSAPI supports all built-in message types defined in the output package: ### User Interaction Types - **`user_input`**: User input message (frontend display only) ```javascript { type: "user_input", props: { content: "User's message", role: "user" } } ``` ### Content Types - **`text`**: Plain text or Markdown content ```javascript { type: "text", props: { content: "Hello **world**" } } ``` - **`thinking`**: Reasoning/thinking process (e.g., o1 models) ```javascript { type: "thinking", props: { content: "Let me think about this..." } } ``` - **`loading`**: Loading/processing indicator ```javascript { type: "loading", props: { message: "Processing..." } } ``` - **`tool_call`**: LLM tool/function call ```javascript { type: "tool_call", props: { id: "call_123", name: "search", arguments: "{\"query\":\"test\"}" } } ``` - **`error`**: Error message ```javascript { type: "error", props: { message: "Error occurred", code: "ERR_001", details: "More info" } } ``` ### Media Types - **`image`**: Image content ```javascript { type: "image", props: { url: "https://example.com/image.jpg", alt: "Description", width: 800, height: 600 } } ``` - **`audio`**: Audio content ```javascript { type: "audio", props: { url: "https://example.com/audio.mp3", format: "mp3", duration: 120.5 } } ``` - **`video`**: Video content ```javascript { type: "video", props: { url: "https://example.com/video.mp4", format: "mp4", duration: 300 } } ``` ### System Types - **`action`**: System action (silent in OpenAI clients) ```javascript { type: "action", props: { name: "open_panel", payload: { panel_id: "settings" } } } ``` - **`event`**: Lifecycle event (CUI only, silent in OpenAI clients) ```javascript { type: "event", props: { event: "stream_start", message: "Starting stream..." } } ``` ## Usage in Hooks ### Create Hook Example ```javascript /** * Create hook - Called before sending messages to the LLM * @param {Context} ctx - Agent context * @param {Array} messages - User messages * @returns {Object} Hook response */ function Create(ctx, messages) { const output = new Output(ctx); // Send a loading indicator output.Send({ type: "loading", props: { message: "Processing your request..." }, }); // Send custom messages to the user output.Send({ type: "text", props: { content: "I'm thinking about your question..." }, }); // Return hook response return { messages: messages, temperature: 0.7, }; } ``` ### Done Hook Example ```javascript /** * Done hook - Called after assistant completes response * @param {Context} ctx - Agent context * @param {Array} messages - Conversation messages * @param {Object} response - Assistant response */ function Done(ctx, messages, response) { const output = new Output(ctx); // Send a completion message output.Send({ type: "text", props: { content: "Response complete!" }, }); // Send an action output.Send({ type: "action", props: { name: "save_conversation", payload: { chat_id: ctx.chat_id }, }, }); } ``` ### Progress Updates Example ```javascript function ProcessData(ctx, data) { const output = new Output(ctx); // Show progress const steps = ["Loading", "Processing", "Analyzing", "Complete"]; for (let i = 0; i < steps.length; i++) { output.Send({ type: "loading", props: { message: `${steps[i]}... (${i + 1}/${steps.length})`, }, }); // Do some work... processStep(i); } // Send final result output.Send({ type: "text", props: { content: "All done!" }, }); } ``` ## Error Handling The Output JSAPI throws exceptions for invalid parameters: ```javascript try { const output = new Output(ctx); // This will throw: message.type is required output.Send({ props: { content: "test" } }); } catch (e) { console.error("Output error:", e.toString()); } ``` Common errors: - `"Output constructor requires a context argument"` - Missing ctx parameter - `"Send requires a message argument"` - Missing message parameter - `"message.type is required and must be a string"` - Missing or invalid type field - `"SendGroup requires a group argument"` - Missing group parameter - `"group.id is required and must be a string"` - Missing group ID - `"group.messages is required and must be an array"` - Missing or invalid messages array ## Notes 1. **Context Requirement**: The Output object must be created with a valid agent context 2. **Writer Required**: The context must have a Writer set (automatically handled in API requests) 3. **Message Format**: Messages are automatically adapted based on the context's Accept type (standard, cui-web, cui-native, cui-desktop) 4. **Streaming**: For streaming responses, use delta updates with proper message IDs 5. **Method Chaining**: All methods return the Output instance for convenient chaining ## See Also - [Output Package Documentation](../README.md) - [Message Types](../BUILTIN_TYPES.md) - [Agent Context](../../context/README.md) - [Hook System](../../assistant/hook/README.md) ================================================ FILE: agent/output/jsapi/output.go ================================================ package jsapi // func init() { // // Auto-register Output JavaScript API when package is imported // v8.RegisterFunction("Output", ExportFunction) // } // // Usage from JavaScript: // // // // const output = new Output(ctx) // // output.Send({ type: "text", props: { content: "Hello" } }) // // output.Send("Hello") // shorthand for text message // // output.SendGroup({ id: "group1", messages: [...] }) // // // // Objects: // // - Output: Output manager (constructor) // // ExportFunction exports the Output constructor function template // // This is used by v8.RegisterFunction // func ExportFunction(iso *v8go.Isolate) *v8go.FunctionTemplate { // return v8go.NewFunctionTemplate(iso, outputConstructor) // } // // outputConstructor is the JavaScript constructor for Output // // Usage: new Output(ctx) // func outputConstructor(info *v8go.FunctionCallbackInfo) *v8go.Value { // v8ctx := info.Context() // args := info.Args() // // Require ctx argument // if len(args) < 1 { // return bridge.JsException(v8ctx, "Output constructor requires a context argument") // } // // Get the context object from JavaScript // ctxObj, err := args[0].AsObject() // if err != nil { // return bridge.JsException(v8ctx, fmt.Sprintf("context must be an object: %s", err)) // } // // Get the goValueID from internal field (index 0) // if ctxObj.InternalFieldCount() < 1 { // return bridge.JsException(v8ctx, "context object is missing internal fields") // } // goValueIDValue := ctxObj.GetInternalField(0) // if goValueIDValue == nil || !goValueIDValue.IsString() { // return bridge.JsException(v8ctx, "context object is missing goValueID") // } // goValueID := goValueIDValue.String() // // Retrieve the Go context object from bridge registry // goObj := bridge.GetGoObject(goValueID) // if goObj == nil { // return bridge.JsException(v8ctx, "context object not found in registry") // } // // Type assert to *agentContext.Context // ctx, ok := goObj.(*agentContext.Context) // if !ok { // return bridge.JsException(v8ctx, fmt.Sprintf("object is not a Context, got %T", goObj)) // } // // Create output object // outputObj, err := NewOutputObject(v8ctx, ctx) // if err != nil { // return bridge.JsException(v8ctx, err.Error()) // } // return outputObj // } // // NewOutputObject creates a JavaScript Output object // func NewOutputObject(v8ctx *v8go.Context, ctx *agentContext.Context) (*v8go.Value, error) { // jsObject := v8go.NewObjectTemplate(v8ctx.Isolate()) // // Set internal field count to 1 to store the __go_id // // Internal fields are not accessible from JavaScript, providing better security // jsObject.SetInternalFieldCount(1) // // Register context in global bridge registry for efficient Go object retrieval // // The goValueID will be stored in internal field (index 0) after instance creation // goValueID := bridge.RegisterGoObject(ctx) // // Set methods // jsObject.Set("Send", outputSendMethod(v8ctx.Isolate(), ctx)) // jsObject.Set("SendGroup", outputSendGroupMethod(v8ctx.Isolate(), ctx)) // // Set release function that will be called when JavaScript object is released // jsObject.Set("__release", outputGoRelease(v8ctx.Isolate())) // // Create instance // instance, err := jsObject.NewInstance(v8ctx) // if err != nil { // // Clean up: release from global registry if instance creation failed // bridge.ReleaseGoObject(goValueID) // return nil, err // } // // Store the goValueID in internal field (index 0) // // This is not accessible from JavaScript, providing better security // obj, err := instance.Value.AsObject() // if err != nil { // bridge.ReleaseGoObject(goValueID) // return nil, err // } // err = obj.SetInternalField(0, goValueID) // if err != nil { // bridge.ReleaseGoObject(goValueID) // return nil, err // } // return instance.Value, nil // } // // outputGoRelease releases the Go object from the global bridge registry // // It retrieves the goValueID from internal field (index 0) and releases the Go object // func outputGoRelease(iso *v8go.Isolate) *v8go.FunctionTemplate { // return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { // // Get the output object (this) // thisObj, err := info.This().AsObject() // if err == nil && thisObj.InternalFieldCount() > 0 { // // Get goValueID from internal field (index 0) // goValueIDValue := thisObj.GetInternalField(0) // if goValueIDValue != nil && goValueIDValue.IsString() { // goValueID := goValueIDValue.String() // // Release from global bridge registry // bridge.ReleaseGoObject(goValueID) // } // } // return v8go.Undefined(info.Context().Isolate()) // }) // } // // outputSendMethod implements the Send method // // Usage: output.Send(message) // // message can be an object with { type: string, props: object, ... } or a simple string (will be converted to text message) // func outputSendMethod(iso *v8go.Isolate, ctx *agentContext.Context) *v8go.FunctionTemplate { // return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { // v8ctx := info.Context() // args := info.Args() // if len(args) < 1 { // return bridge.JsException(v8ctx, "Send requires a message argument") // } // // Parse message argument // msg, err := parseMessage(v8ctx, args[0]) // if err != nil { // return bridge.JsException(v8ctx, fmt.Sprintf("invalid message: %s", err)) // } // // Call output.Send // if err := output.Send(ctx, msg); err != nil { // return bridge.JsException(v8ctx, fmt.Sprintf("Send failed: %s", err)) // } // return info.This().Value // }) // } // // outputSendGroupMethod implements the SendGroup method // // Usage: output.SendGroup(group) // // group must be an object with { id: string, messages: [], ... } // func outputSendGroupMethod(iso *v8go.Isolate, ctx *agentContext.Context) *v8go.FunctionTemplate { // return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { // v8ctx := info.Context() // args := info.Args() // if len(args) < 1 { // return bridge.JsException(v8ctx, "SendGroup requires a group argument") // } // // Parse group argument // group, err := parseGroup(v8ctx, args[0]) // if err != nil { // return bridge.JsException(v8ctx, fmt.Sprintf("invalid group: %s", err)) // } // // Call output.SendGroup // if err := output.SendGroup(ctx, group); err != nil { // return bridge.JsException(v8ctx, fmt.Sprintf("SendGroup failed: %s", err)) // } // return info.This().Value // }) // } // // parseMessage parses a JavaScript value into a message.Message // func parseMessage(v8ctx *v8go.Context, jsValue *v8go.Value) (*message.Message, error) { // // Handle string shorthand: convert to text message // if jsValue.IsString() { // return &message.Message{ // Type: message.TypeText, // Props: map[string]interface{}{ // "content": jsValue.String(), // }, // }, nil // } // // Handle object // if !jsValue.IsObject() { // return nil, fmt.Errorf("message must be a string or object") // } // // Convert to Go map // goValue, err := bridge.GoValue(jsValue, v8ctx) // if err != nil { // return nil, fmt.Errorf("failed to convert message: %w", err) // } // msgMap, ok := goValue.(map[string]interface{}) // if !ok { // return nil, fmt.Errorf("message must be an object") // } // // Build message // msg := &message.Message{} // // Type field (required) // if msgType, ok := msgMap["type"].(string); ok { // msg.Type = msgType // } else { // return nil, fmt.Errorf("message.type is required and must be a string") // } // // Props field (optional) // if props, ok := msgMap["props"].(map[string]interface{}); ok { // msg.Props = props // } // // Optional fields // if id, ok := msgMap["id"].(string); ok { // msg.ID = id // } // if delta, ok := msgMap["delta"].(bool); ok { // msg.Delta = delta // } // if done, ok := msgMap["done"].(bool); ok { // msg.Done = done // } // if deltaPath, ok := msgMap["delta_path"].(string); ok { // msg.DeltaPath = deltaPath // } // if deltaAction, ok := msgMap["delta_action"].(string); ok { // msg.DeltaAction = deltaAction // } // if typeChange, ok := msgMap["type_change"].(bool); ok { // msg.TypeChange = typeChange // } // if groupID, ok := msgMap["group_id"].(string); ok { // msg.GroupID = groupID // } // if groupStart, ok := msgMap["group_start"].(bool); ok { // msg.GroupStart = groupStart // } // if groupEnd, ok := msgMap["group_end"].(bool); ok { // msg.GroupEnd = groupEnd // } // // Metadata (optional) // if metadataMap, ok := msgMap["metadata"].(map[string]interface{}); ok { // metadata := &message.Metadata{} // if timestamp, ok := metadataMap["timestamp"].(float64); ok { // metadata.Timestamp = int64(timestamp) // } // if sequence, ok := metadataMap["sequence"].(float64); ok { // metadata.Sequence = int(sequence) // } // if traceID, ok := metadataMap["trace_id"].(string); ok { // metadata.TraceID = traceID // } // msg.Metadata = metadata // } // return msg, nil // } // // parseGroup parses a JavaScript value into a message.Group // func parseGroup(v8ctx *v8go.Context, jsValue *v8go.Value) (*message.Group, error) { // // Must be an object // if !jsValue.IsObject() { // return nil, fmt.Errorf("group must be an object") // } // // Convert to Go map // goValue, err := bridge.GoValue(jsValue, v8ctx) // if err != nil { // return nil, fmt.Errorf("failed to convert group: %w", err) // } // groupMap, ok := goValue.(map[string]interface{}) // if !ok { // return nil, fmt.Errorf("group must be an object") // } // // Build group // group := &message.Group{} // // ID field (required) // if id, ok := groupMap["id"].(string); ok { // group.ID = id // } else { // return nil, fmt.Errorf("group.id is required and must be a string") // } // // Messages field (required) // if messagesArray, ok := groupMap["messages"].([]interface{}); ok { // group.Messages = make([]*message.Message, 0, len(messagesArray)) // for i, msgInterface := range messagesArray { // // Convert to map // msgMap, ok := msgInterface.(map[string]interface{}) // if !ok { // return nil, fmt.Errorf("group.messages[%d] must be an object", i) // } // // Convert map to Message // msg := &message.Message{} // // Type field (required) // if msgType, ok := msgMap["type"].(string); ok { // msg.Type = msgType // } else { // return nil, fmt.Errorf("group.messages[%d].type is required", i) // } // // Props field (optional) // if props, ok := msgMap["props"].(map[string]interface{}); ok { // msg.Props = props // } // // Optional fields // if id, ok := msgMap["id"].(string); ok { // msg.ID = id // } // if delta, ok := msgMap["delta"].(bool); ok { // msg.Delta = delta // } // if done, ok := msgMap["done"].(bool); ok { // msg.Done = done // } // if deltaPath, ok := msgMap["delta_path"].(string); ok { // msg.DeltaPath = deltaPath // } // if deltaAction, ok := msgMap["delta_action"].(string); ok { // msg.DeltaAction = deltaAction // } // if typeChange, ok := msgMap["type_change"].(bool); ok { // msg.TypeChange = typeChange // } // if groupID, ok := msgMap["group_id"].(string); ok { // msg.GroupID = groupID // } // if groupStart, ok := msgMap["group_start"].(bool); ok { // msg.GroupStart = groupStart // } // if groupEnd, ok := msgMap["group_end"].(bool); ok { // msg.GroupEnd = groupEnd // } // // Metadata (optional) // if metadataMap, ok := msgMap["metadata"].(map[string]interface{}); ok { // metadata := &message.Metadata{} // if timestamp, ok := metadataMap["timestamp"].(float64); ok { // metadata.Timestamp = int64(timestamp) // } // if sequence, ok := metadataMap["sequence"].(float64); ok { // metadata.Sequence = int(sequence) // } // if traceID, ok := metadataMap["trace_id"].(string); ok { // metadata.TraceID = traceID // } // msg.Metadata = metadata // } // group.Messages = append(group.Messages, msg) // } // } else { // return nil, fmt.Errorf("group.messages is required and must be an array") // } // // Metadata (optional) // if metadataMap, ok := groupMap["metadata"].(map[string]interface{}); ok { // metadata := &message.Metadata{} // if timestamp, ok := metadataMap["timestamp"].(float64); ok { // metadata.Timestamp = int64(timestamp) // } // if sequence, ok := metadataMap["sequence"].(float64); ok { // metadata.Sequence = int(sequence) // } // if traceID, ok := metadataMap["trace_id"].(string); ok { // metadata.TraceID = traceID // } // group.Metadata = metadata // } // return group, nil // } ================================================ FILE: agent/output/jsapi/output_test.go ================================================ package jsapi // func TestOutputConstructor(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // tests := []struct { // name string // script string // expectError bool // }{ // { // name: "Create Output with context", // script: ` // function test(ctx) { // const output = new Output(ctx); // return output !== undefined && output !== null; // } // `, // expectError: false, // }, // { // name: "Create Output without context should fail", // script: ` // function test(ctx) { // try { // const output = new Output(); // return false; // } catch (e) { // return e.toString().includes("context argument"); // } // } // `, // expectError: false, // }, // } // for _, tt := range tests { // t.Run(tt.name, func(t *testing.T) { // ctx := agentContext.New(context.Background(), nil, "test-chat-123", "") // ctx.AssistantID = "test-assistant-456" // // Execute test script with v8.Call // res, err := v8.Call(v8.CallOptions{}, tt.script, &ctx) // if tt.expectError { // assert.Error(t, err) // return // } // assert.NoError(t, err) // assert.True(t, res.(bool)) // }) // } // } // func TestOutputSend(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // tests := []struct { // name string // script string // expectError bool // validate func(*testing.T, *agentContext.Context) // }{ // { // name: "Send text message with object", // script: ` // function test(ctx) { // const output = new Output(ctx); // output.Send({ // type: "text", // props: { content: "Hello World" } // }); // return true; // } // `, // expectError: false, // }, // { // name: "Send text message with string shorthand", // script: ` // function test(ctx) { // const output = new Output(ctx); // output.Send("Hello World"); // return true; // } // `, // expectError: false, // }, // { // name: "Send message with all fields", // script: ` // function test(ctx) { // const output = new Output(ctx); // output.Send({ // type: "text", // props: { content: "Test" }, // id: "msg-1", // delta: true, // done: false, // delta_path: "content", // delta_action: "append", // metadata: { // timestamp: 1234567890, // sequence: 1, // trace_id: "trace-123" // } // }); // return true; // } // `, // expectError: false, // }, // { // name: "Send error message", // script: ` // function test(ctx) { // const output = new Output(ctx); // output.Send({ // type: "error", // props: { // message: "Something went wrong", // code: "ERR_001" // } // }); // return true; // } // `, // expectError: false, // }, // { // name: "Send without message should fail", // script: ` // function test(ctx) { // try { // const output = new Output(ctx); // output.Send(); // return false; // } catch (e) { // return e.toString().includes("message argument"); // } // } // `, // expectError: false, // }, // { // name: "Send message without type should fail", // script: ` // function test(ctx) { // try { // const output = new Output(ctx); // output.Send({ props: { content: "test" } }); // return false; // } catch (e) { // return e.toString().includes("type is required"); // } // } // `, // expectError: false, // }, // } // for _, tt := range tests { // t.Run(tt.name, func(t *testing.T) { // // Create context with mock writer // ctx := agentContext.New(context.Background(), nil, "test-chat", "") // ctx.Writer = &mockWriter{} // // Execute test script with v8.Call // res, err := v8.Call(v8.CallOptions{}, tt.script, &ctx) // if tt.expectError { // assert.Error(t, err) // return // } // assert.NoError(t, err) // assert.True(t, res.(bool)) // if tt.validate != nil { // tt.validate(t, &ctx) // } // }) // } // } // func TestOutputSendGroup(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // tests := []struct { // name string // script string // expectError bool // }{ // { // name: "Send message group", // script: ` // function test(ctx) { // const output = new Output(ctx); // output.SendGroup({ // id: "group-1", // messages: [ // { type: "text", props: { content: "Message 1" } }, // { type: "text", props: { content: "Message 2" } } // ] // }); // return true; // } // `, // expectError: false, // }, // { // name: "Send group with metadata", // script: ` // function test(ctx) { // const output = new Output(ctx); // output.SendGroup({ // id: "group-1", // messages: [ // { type: "text", props: { content: "Test" } } // ], // metadata: { // timestamp: 1234567890, // sequence: 1 // } // }); // return true; // } // `, // expectError: false, // }, // { // name: "Send empty group", // script: ` // function test(ctx) { // const output = new Output(ctx); // output.SendGroup({ // id: "group-1", // messages: [] // }); // return true; // } // `, // expectError: false, // }, // { // name: "Send group without id should fail", // script: ` // function test(ctx) { // try { // const output = new Output(ctx); // output.SendGroup({ // messages: [ // { type: "text", props: { content: "Test" } } // ] // }); // return false; // } catch (e) { // return e.toString().includes("id is required"); // } // } // `, // expectError: false, // }, // { // name: "Send group without messages should fail", // script: ` // function test(ctx) { // try { // const output = new Output(ctx); // output.SendGroup({ id: "group-1" }); // return false; // } catch (e) { // return e.toString().includes("messages is required"); // } // } // `, // expectError: false, // }, // } // for _, tt := range tests { // t.Run(tt.name, func(t *testing.T) { // // Create context with mock writer // ctx := agentContext.New(context.Background(), nil, "test-chat", "") // ctx.Writer = &mockWriter{} // // Execute test script with v8.Call // res, err := v8.Call(v8.CallOptions{}, tt.script, &ctx) // if tt.expectError { // assert.Error(t, err) // return // } // assert.NoError(t, err) // assert.True(t, res.(bool)) // }) // } // } // func TestOutputChaining(t *testing.T) { // test.Prepare(t, config.Conf) // defer test.Clean() // script := ` // function test(ctx) { // const output = new Output(ctx); // // Send should return the output object for chaining // const result = output.Send("Message 1"); // // Should be able to chain sends // output.Send("Message 2").Send("Message 3"); // return result !== undefined; // } // ` // ctx := agentContext.New(context.Background(), nil, "test-chat", "") // ctx.Writer = &mockWriter{} // // Execute test script with v8.Call // res, err := v8.Call(v8.CallOptions{}, script, &ctx) // assert.NoError(t, err) // assert.True(t, res.(bool)) // } // // mockWriter is a mock implementation of http.ResponseWriter for testing // type mockWriter struct { // data [][]byte // header http.Header // } // func (w *mockWriter) Header() http.Header { // if w.header == nil { // w.header = make(http.Header) // } // return w.header // } // func (w *mockWriter) Write(p []byte) (n int, err error) { // w.data = append(w.data, p) // return len(p), nil // } // func (w *mockWriter) WriteHeader(statusCode int) {} // func (w *mockWriter) Flush() {} ================================================ FILE: agent/output/message/STREAMING.md ================================================ # Message Streaming Architecture This document explains the hierarchical streaming architecture for Agent/LLM/MCP message delivery. ## Overview The streaming system uses a hierarchical structure to handle complex scenarios including: - Single LLM calls with multiple message types (thinking, tool calls, text) - Agent logic with multiple sequential operations (LLM → MCP → LLM) - Concurrent/parallel calls to multiple LLMs or MCPs - Real-time delta updates for streaming responses ## Hierarchical Structure ``` Agent Stream (entire conversation) └─ ThreadID (concurrent stream, optional: T1, T2, T3...) └─ BlockID (output block/section: B1, B2, B3...) └─ MessageID (logical message: M1, M2, M3...) └─ ChunkID (stream fragment: C1, C2, C3...) ``` ## Field Definitions ### Message Struct Fields ```go type Message struct { // Core fields Type string `json:"type"` Props map[string]interface{} `json:"props,omitempty"` // Streaming control ChunkID string `json:"chunk_id,omitempty"` MessageID string `json:"message_id,omitempty"` BlockID string `json:"block_id,omitempty"` ThreadID string `json:"thread_id,omitempty"` // Delta control Delta bool `json:"delta,omitempty"` DeltaPath string `json:"delta_path,omitempty"` DeltaAction string `json:"delta_action,omitempty"` // ... } ``` ### Field Responsibilities | Field | Generated By | Purpose | Example Values | Required | | ----------- | -------------------- | ---------------------------------- | ------------------------------------ | ----------------------------------- | | `ChunkID` | System (auto) | Deduplication, ordering, debugging | `C1`, `C2`, `C3` | Always | | `MessageID` | LLM Provider/Handler | Delta merge target | `M1`, `M2`, `M3` or `thinking_msg_1` | Required for delta scenarios | | `BlockID` | Agent Logic | UI block/section rendering | `B1`, `B2`, `B3` or `llm_response_1` | Required when Agent controls blocks | | `ThreadID` | Agent Logic | Concurrent stream distinction | `T1`, `T2`, `T3` or `thread_llm1` | Optional (concurrent only) | ### Detailed Field Explanation #### ChunkID (Stream Fragment Identifier) - **Purpose**: Uniquely identifies each chunk in the stream - **Generated**: Automatically by the system (sequential: M1, M2, M3...) - **Used For**: - Deduplication (prevent duplicate chunks) - Ordering (maintain correct sequence) - Debugging (trace message flow) - **Scope**: Unique within entire Agent stream - **Always Present**: Yes **Example:** ```json {"chunk_id": "C1", "type": "text", "props": {"content": "Hello"}} {"chunk_id": "C2", "type": "text", "props": {"content": " World"}} {"chunk_id": "C3", "type": "thinking", "props": {"content": "..."}} ``` #### MessageID (Logical Message Identifier) - **Purpose**: Groups multiple chunks into one logical message via delta merging - **Generated**: By LLM Provider or Stream Handler - **Used For**: - Delta merge target (frontend merges all chunks with same MessageID) - Distinguishing different messages within a group - **Scope**: Unique within a Group - **Present When**: Delta streaming is used **Example:** ```json // Multiple chunks combine into one "thinking" message {"chunk_id": "C1", "message_id": "M1", "type": "thinking", "props": {"content": "Let me"}, "delta": true} {"chunk_id": "C2", "message_id": "M1", "type": "thinking", "props": {"content": " think"}, "delta": true} {"chunk_id": "C3", "message_id": "M1", "type": "thinking", "props": {"content": "..."}, "delta": true} // Another independent message {"chunk_id": "C4", "message_id": "M2", "type": "text", "props": {"content": "Hello"}, "delta": true} ``` #### BlockID (Output Block Identifier) - **Purpose**: Represents one output block/section (e.g., one LLM call, one MCP call) - **Generated**: By Agent logic - **Used For**: - Frontend UI block/section rendering (visual blocks) - Distinguishing different operations (LLM vs MCP vs custom logic) - Organizing related messages together - **Scope**: Unique within entire Agent stream - **Present When**: Agent explicitly controls output blocks **Key Concept**: Block represents a semantic unit of work from Agent's perspective, NOT from LLM's perspective. Each block is rendered as a distinct UI section in the frontend. **Example:** ```json // BLOCK 1: LLM Response (contains thinking + tool_call + text) {"chunk_id": "C1", "block_id": "B1", "message_id": "M1", "type": "thinking", ...} {"chunk_id": "C2", "block_id": "B1", "message_id": "M2", "type": "tool_call", ...} {"chunk_id": "C3", "block_id": "B1", "message_id": "M3", "type": "text", ...} // BLOCK 2: MCP Call {"chunk_id": "C4", "block_id": "B2", "message_id": "M4", "type": "loading", ...} {"chunk_id": "C5", "block_id": "B2", "message_id": "M5", "type": "text", ...} // BLOCK 3: Another LLM Response {"chunk_id": "C6", "block_id": "B3", "message_id": "M6", "type": "text", ...} ``` #### ThreadID (Concurrent Stream Identifier) - **Purpose**: Distinguishes concurrent/parallel output streams - **Generated**: By Agent logic when spawning concurrent operations - **Used For**: - Separating outputs from parallel LLM/MCP calls - Maintaining independent streaming contexts - **Scope**: Unique within entire Agent stream - **Present When**: Agent makes concurrent calls (optional) **Example:** ```json // Main thread {"chunk_id": "C1", "thread_id": "T1", "block_id": "B1", "message_id": "M1", "type": "text", ...} // Parallel MCP calls {"chunk_id": "C2", "thread_id": "T2", "block_id": "B2", "message_id": "M2", "type": "text", ...} {"chunk_id": "C3", "thread_id": "T3", "block_id": "B3", "message_id": "M3", "type": "text", ...} ``` ## Usage Scenarios ### Scenario 1: Simple Text Message **No streaming, no grouping** ```json { "chunk_id": "C1", "type": "text", "props": { "content": "Hello World" } } ``` **Fields Used:** - `chunk_id`: C1 (auto-generated) - No `message_id`, `block_id`, or `thread_id` needed --- ### Scenario 2: LLM Streaming Response (Single Message) **LLM streams one text message** ```json {"chunk_id": "C1", "message_id": "M1", "type": "text", "props": {"content": "Hello"}, "delta": true} {"chunk_id": "C2", "message_id": "M1", "type": "text", "props": {"content": " World"}, "delta": true} {"chunk_id": "C3", "message_id": "M1", "type": "text", "props": {"content": "!"}, "delta": true} ``` **Fields Used:** - `chunk_id`: C1, C2, C3 (unique per chunk) - `message_id`: M1 (same for all, merge target) - `delta`: true **Frontend Behavior:** - Merge all chunks with `message_id: "M1"` into one message - Display: "Hello World!" --- ### Scenario 3: Agent-Controlled LLM Call (One Block) **Agent wraps LLM response in an output block** ```typescript // Agent code starts a block for the LLM response // System generates block_id: "B1" // LLM returns thinking + tool_call + text // Agent ends the block ``` ```json // LLM chunks within block B1 {"chunk_id": "C1", "message_id": "M1", "block_id": "B1", "type": "thinking", "props": {...}, "delta": true} {"chunk_id": "C2", "message_id": "M1", "block_id": "B1", "type": "thinking", "props": {...}, "delta": true} {"chunk_id": "C3", "message_id": "M2", "block_id": "B1", "type": "tool_call", "props": {...}} {"chunk_id": "C4", "message_id": "M3", "block_id": "B1", "type": "text", "props": {...}, "delta": true} {"chunk_id": "C5", "message_id": "M3", "block_id": "B1", "type": "text", "props": {...}, "delta": true} ``` **Fields Used:** - `chunk_id`: C1~C5 (unique per chunk) - `message_id`: M1, M2, M3 (per logical message) - `block_id`: B1 (all belong to same LLM call) - `delta`: true (for streaming messages) **Frontend Behavior:** - Render one block/section for `block_id: "B1"` - Within this block, show 3 messages: - Thinking message (chunks C1+C2 merged into M1) - Tool call message (chunk C3 = M2) - Text message (chunks C4+C5 merged into M3) --- ### Scenario 4: Agent Sequential Operations (Multiple Blocks) **Agent orchestrates: LLM → MCP → LLM** ```typescript // Agent code orchestrates three sequential operations: // 1. Block B1: First LLM call // 2. Block B2: MCP call // 3. Block B3: Second LLM call ``` ```json // BLOCK 1: First LLM call {"chunk_id": "C1", "message_id": "M1", "block_id": "B1", "type": "text", ...} {"chunk_id": "C2", "message_id": "M1", "block_id": "B1", "type": "text", ...} // BLOCK 2: MCP call {"chunk_id": "C3", "message_id": "M2", "block_id": "B2", "type": "loading", ...} {"chunk_id": "C4", "message_id": "M3", "block_id": "B2", "type": "text", ...} // BLOCK 3: Second LLM call {"chunk_id": "C5", "message_id": "M4", "block_id": "B3", "type": "text", ...} {"chunk_id": "C6", "message_id": "M4", "block_id": "B3", "type": "text", ...} ``` **Frontend Behavior:** - Render 3 distinct blocks/sections: 1. Block 1 (B1): LLM response with text 2. Block 2 (B2): MCP call with loading + result 3. Block 3 (B3): LLM response with text --- ### Scenario 5: Concurrent Operations (Blocks + Threads) **Agent uses concurrent handler to make parallel calls** ```typescript // Agent orchestrates parallel operations within one block (B1) // The concurrent handler automatically assigns thread_id to each operation: // - MCP call for weather (thread_id: "T1") // - MCP call for news (thread_id: "T2") // - LLM call for summary (thread_id: "T3") // // Messages from different threads may arrive in any order ``` ```json // Same block, different threads (may arrive in any order) {"chunk_id": "C1", "message_id": "M1", "block_id": "B1", "thread_id": "T1", "type": "text", "props": {"content": "Weather: Sunny"}} {"chunk_id": "C2", "message_id": "M2", "block_id": "B1", "thread_id": "T2", "type": "text", "props": {"content": "News: ..."}} {"chunk_id": "C3", "message_id": "M1", "block_id": "B1", "thread_id": "T1", "type": "text", "props": {"content": ", 25°C"}} {"chunk_id": "C4", "message_id": "M3", "block_id": "B1", "thread_id": "T3", "type": "text", "props": {"content": "Summary..."}} ``` **Fields Used:** - `chunk_id`: C1, C2, C3, C4 (unique per chunk, chronological order) - `message_id`: M1, M2, M3 (per operation/message) - `block_id`: B1 (all belong to same parallel operation block) - `thread_id`: T1, T2, T3 (distinguish concurrent operations) **Frontend Behavior:** - Render one block for `block_id: "B1"` - Within this block, separate messages by `thread_id`: - Thread T1 (Weather): M1 (chunks C1+C3 merged) → "Weather: Sunny, 25°C" - Thread T2 (News): M2 (chunk C2) - Thread T3 (Summary): M3 (chunk C4) - Or interleave by `chunk_id` order (C1, C2, C3, C4) to show real-time arrival --- ## Summary | Field | Level | Purpose | Example | | ----------- | ----------- | ------------------ | ---------- | | `ChunkID` | System | Transport/debug | C1, C2, C3 | | `MessageID` | LLM/Handler | Delta merging | M1, M2, M3 | | `BlockID` | Agent | UI blocks/sections | B1, B2, B3 | | `ThreadID` | Agent | Concurrency | T1, T2, T3 | **Key Insight**: Each field serves a distinct purpose at a specific layer of the architecture. This hierarchical design supports simple single-message scenarios while enabling complex Agent orchestration with concurrent operations. Blocks provide natural UI boundaries for organizing related messages. ================================================ FILE: agent/output/message/interfaces.go ================================================ package message // Writer is the interface for writing output messages // Different writers handle different output formats (SSE, WebSocket, Standard, etc.) type Writer interface { // Write writes a single message Write(msg *Message) error // WriteGroup writes a group of messages WriteGroup(group *Group) error // Flush flushes any buffered data Flush() error // Close closes the writer and releases resources Close() error } // Adapter is the interface for adapting messages to different formats // Adapters transform messages from the universal DSL to specific client formats type Adapter interface { // Adapt transforms a message to the target format // Returns a slice of output chunks (some messages may be split into multiple chunks) Adapt(msg *Message) ([]interface{}, error) // SupportsType checks if this adapter supports a specific message type SupportsType(msgType string) bool } // StreamHandler handles streaming message processing // It bridges between LLM streaming chunks and output messages type StreamHandler interface { // Handle processes a streaming chunk from LLM Handle(chunkType StreamChunkType, data []byte) error // Flush flushes any pending messages Flush() error // Close closes the handler Close() error } ================================================ FILE: agent/output/message/types.go ================================================ package message import ( "net/http" "github.com/yaoapp/gou/llm" traceTypes "github.com/yaoapp/yao/trace/types" ) // Options are the options for the writer type Options struct { BaseURL string Accept string Writer http.ResponseWriter Trace traceTypes.Manager Capabilities *llm.Capabilities Locale string } // Message represents a universal message structure (DSL) // All messages are expressed through Type + Props, without predefining specific types type Message struct { // Core fields Type string `json:"type"` // Message type (frontend decides how to render) Props map[string]interface{} `json:"props,omitempty"` // Message properties (passed to frontend component) // Streaming control - Hierarchical structure for Agent/LLM/MCP streaming // See STREAMING.md for detailed explanation of the streaming architecture ChunkID string `json:"chunk_id,omitempty"` // Unique chunk ID (auto-generated: C1, C2, C3...; for dedup/ordering/debugging) MessageID string `json:"message_id,omitempty"` // Logical message ID (delta merge target; multiple chunks combine into one message) BlockID string `json:"block_id,omitempty"` // Output block ID (Agent-level control: one LLM call, one MCP call, etc.; for UI rendering blocks/sections) ThreadID string `json:"thread_id,omitempty"` // Thread ID (optional; for concurrent Agent/LLM/MCP calls to distinguish output streams) // Delta control Delta bool `json:"delta,omitempty"` // Whether this is an incremental update DeltaPath string `json:"delta_path,omitempty"` // Update path (e.g., "content", "data", "items.0.name") DeltaAction string `json:"delta_action,omitempty"` // Update action (append, replace, merge, set) // Type correction (for streaming scenarios) TypeChange bool `json:"type_change,omitempty"` // Marks this as a type correction message // Metadata Metadata *Metadata `json:"metadata,omitempty"` // Additional metadata } // Metadata represents message metadata type Metadata struct { Timestamp int64 `json:"timestamp,omitempty"` // Timestamp in nanoseconds Sequence int `json:"sequence,omitempty"` // Sequence number (for ordering) TraceID string `json:"trace_id,omitempty"` // Trace ID (for debugging) } // Group represents a semantically complete group of messages type Group struct { ID string `json:"id"` // Message group ID Messages []*Message `json:"messages"` // List of messages Metadata *Metadata `json:"metadata,omitempty"` // Metadata } // Built-in message types that all adapters must support // These types have standardized Props structures const ( // User interaction types TypeUserInput = "user_input" // User input message (frontend display only) // Content types TypeText = "text" // Plain text or Markdown content TypeThinking = "thinking" // Reasoning/thinking process (e.g., o1 models) TypeLoading = "loading" // Loading/processing indicator (preprocessing, knowledge base search, etc.) TypeToolCall = "tool_call" // LLM tool/function call TypeError = "error" // Error message // Media types (with OpenAI support) TypeImage = "image" // Image content TypeAudio = "audio" // Audio content TypeVideo = "video" // Video content // System types (not visible in standard chat clients) TypeAction = "action" // System action (open panel, navigate, etc.) - silent in OpenAI clients TypeEvent = "event" // Lifecycle event (stream_start, stream_end, etc.) - CUI only, silent in OpenAI clients ) // Event types for TypeEvent messages // Hierarchical structure: Stream > Thread > Block > Message > Chunk const ( // Stream level events (Agent layer - overall conversation stream) EventStreamStart = "stream_start" // Stream started event EventStreamEnd = "stream_end" // Stream ended event // Thread level events (optional - for concurrent scenarios) EventThreadStart = "thread_start" // Thread started event EventThreadEnd = "thread_end" // Thread ended event // Block level events (Agent layer - logical output sections) EventBlockStart = "block_start" // Block started event EventBlockEnd = "block_end" // Block ended event // Message level events (LLM layer - individual logical messages) EventMessageStart = "message_start" // Message started event EventMessageEnd = "message_end" // Message ended event ) // Standard Props structures for built-in types // UserInputProps defines the standard structure for user input messages // Type: "user_input" // Props: {"content": string | ContentPart[], "role": string, "name": string} type UserInputProps struct { Content interface{} `json:"content"` // User input (text string or multimodal ContentPart[]) Role string `json:"role,omitempty"` // User role: "user", "system", "developer" (default: "user") Name string `json:"name,omitempty"` // Optional participant name } // TextProps defines the standard structure for text messages // Type: "text" // Props: {"content": string} type TextProps struct { Content string `json:"content"` // Text content (supports Markdown) } // ThinkingProps defines the standard structure for thinking messages // Type: "thinking" // Props: {"content": string} type ThinkingProps struct { Content string `json:"content"` // Reasoning/thinking content } // LoadingProps defines the standard structure for loading messages // Type: "loading" // Props: {"message": string} type LoadingProps struct { Message string `json:"message"` // Loading message (e.g., "Searching knowledge base...") } // ToolCallProps defines the standard structure for tool_call messages // Type: "tool_call" // Props: {"id": string, "name": string, "arguments": string} type ToolCallProps struct { ID string `json:"id"` // Tool call ID Name string `json:"name"` // Function/tool name Arguments string `json:"arguments,omitempty"` // JSON string of arguments } // ErrorProps defines the standard structure for error messages // Type: "error" // Props: {"message": string, "code": string} type ErrorProps struct { Message string `json:"message"` // Error message Code string `json:"code,omitempty"` // Error code Details string `json:"details,omitempty"` // Additional error details } // ActionProps defines the standard structure for action messages // Type: "action" // Props: {"name": string, "payload": map} type ActionProps struct { Name string `json:"name"` // Action name (e.g., "open_panel", "navigate") Payload map[string]interface{} `json:"payload,omitempty"` // Action payload/parameters } // EventProps defines the standard structure for event messages // Type: "event" // Props: {"event": string, "message": string, "data": map} type EventProps struct { Event string `json:"event"` // Event type (e.g., "stream_start", "stream_end", "connecting") Message string `json:"message,omitempty"` // Human-readable message (e.g., "Connecting...") Data map[string]interface{} `json:"data,omitempty"` // Additional event data } // ImageProps defines the standard structure for image messages // Type: "image" // Props: {"url": string, "alt": string, "width": int, "height": int, "detail": string} type ImageProps struct { URL string `json:"url"` // Required: Image URL or base64 encoded data Alt string `json:"alt,omitempty"` // Alternative text Width int `json:"width,omitempty"` // Image width in pixels Height int `json:"height,omitempty"` // Image height in pixels Detail string `json:"detail,omitempty"` // OpenAI detail level: "auto", "low", "high" } // AudioProps defines the standard structure for audio messages // Type: "audio" // Props: {"url": string, "format": string, "duration": float64, "transcript": string, "autoplay": bool} type AudioProps struct { URL string `json:"url"` // Required: Audio URL or base64 encoded data Format string `json:"format,omitempty"` // Audio format: "mp3", "wav", "ogg", etc. Duration float64 `json:"duration,omitempty"` // Duration in seconds Transcript string `json:"transcript,omitempty"` // Audio transcript text Autoplay bool `json:"autoplay,omitempty"` // Whether to autoplay Controls bool `json:"controls,omitempty"` // Whether to show controls (default: true) } // VideoProps defines the standard structure for video messages // Type: "video" // Props: {"url": string, "format": string, "duration": float64, "thumbnail": string, "width": int, "height": int, "autoplay": bool} type VideoProps struct { URL string `json:"url"` // Required: Video URL Format string `json:"format,omitempty"` // Video format: "mp4", "webm", etc. Duration float64 `json:"duration,omitempty"` // Duration in seconds Thumbnail string `json:"thumbnail,omitempty"` // Thumbnail/poster image URL Width int `json:"width,omitempty"` // Video width in pixels Height int `json:"height,omitempty"` // Video height in pixels Autoplay bool `json:"autoplay,omitempty"` // Whether to autoplay Controls bool `json:"controls,omitempty"` // Whether to show controls (default: true) Loop bool `json:"loop,omitempty"` // Whether to loop } // Delta action constants for incremental updates const ( DeltaAppend = "append" // Append (for arrays, strings) DeltaReplace = "replace" // Replace (for any value) DeltaMerge = "merge" // Merge (for objects) DeltaSet = "set" // Set (for new fields) ) // StreamChunkType represents the type of content in a streaming chunk type StreamChunkType string // Stream chunk type constants - indicates what type of content is in the current chunk const ( // Content chunk types - actual data from the LLM ChunkText StreamChunkType = "text" // Regular text content ChunkThinking StreamChunkType = "thinking" // Reasoning/thinking content (o1, DeepSeek R1) ChunkToolCall StreamChunkType = "tool_call" // Tool/function call ChunkRefusal StreamChunkType = "refusal" // Model refusal ChunkMetadata StreamChunkType = "metadata" // Metadata (usage, finish_reason, etc.) ChunkError StreamChunkType = "error" // Error chunk ChunkUnknown StreamChunkType = "unknown" // Unknown/unrecognized chunk type // Lifecycle event types - stream and message boundaries ChunkStreamStart StreamChunkType = "stream_start" // Stream begins (entire request starts) ChunkStreamEnd StreamChunkType = "stream_end" // Stream ends (entire request completes) ChunkMessageStart StreamChunkType = "message_start" // Message begins (text/tool_call/thinking message starts) ChunkMessageEnd StreamChunkType = "message_end" // Message ends (text/tool_call/thinking message completes) ) // StreamFunc the streaming function callback // Parameters: // - chunkType: the type of content in this chunk (text, thinking, tool_call, etc.) // - data: the actual chunk data (could be text, JSON, or other format) // // Returns: // - int: status code (0 = continue, non-zero = stop streaming) type StreamFunc func(chunkType StreamChunkType, data []byte) int // AssistantInfo represents the assistant information structure type AssistantInfo struct { ID string `json:"assistant_id"` // Assistant ID Type string `json:"type,omitempty"` // Assistant Type, default is assistant Name string `json:"name,omitempty"` // Assistant Name Avatar string `json:"avatar,omitempty"` // Assistant Avatar Description string `json:"description,omitempty"` // Assistant Description } // UsageInfo represents token usage statistics // Structure matches OpenAI API: https://platform.openai.com/docs/api-reference/chat/object#chat-object-usage type UsageInfo struct { PromptTokens int `json:"prompt_tokens"` // Number of tokens in the prompt CompletionTokens int `json:"completion_tokens"` // Number of tokens in the generated completion TotalTokens int `json:"total_tokens"` // Total number of tokens used (prompt + completion) // Detailed token breakdown PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"` // Breakdown of tokens used in the prompt CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` // Breakdown of tokens used in the completion } // PromptTokensDetails provides detailed breakdown of tokens used in the prompt type PromptTokensDetails struct { AudioTokens int `json:"audio_tokens,omitempty"` // Audio input tokens present in the prompt CachedTokens int `json:"cached_tokens,omitempty"` // Cached tokens present in the prompt } // CompletionTokensDetails provides detailed breakdown of tokens used in the completion type CompletionTokensDetails struct { AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"` // Tokens from predictions that appeared in the completion AudioTokens int `json:"audio_tokens,omitempty"` // Audio input tokens generated by the model ReasoningTokens int `json:"reasoning_tokens,omitempty"` // Tokens generated by the model for reasoning (o1, o1-mini, DeepSeek R1) RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"` // Tokens from predictions that did not appear in the completion } // ============================================================================ // Stream Lifecycle Event Data Structures // ============================================================================ // These structures define the data format for stream lifecycle events. // They provide a standardized way to communicate stream boundaries and metadata // to the frontend, enabling better UI/UX (progress indicators, timing, etc.). // EventStreamStartData represents the data for stream_start event // Sent when a streaming request begins type EventStreamStartData struct { ContextID string `json:"context_id"` // Context ID for the response RequestID string `json:"request_id"` // Unique identifier for this request Timestamp int64 `json:"timestamp"` // Unix timestamp when stream started ChatID string `json:"chat_id"` // Chat ID being used (e.g., "chat-123") TraceID string `json:"trace_id"` // Trace ID being used (e.g., "trace-123") Assistant *AssistantInfo `json:"assistant,omitempty"` // Assistant information Metadata map[string]interface{} `json:"metadata,omitempty"` // Metadata to pass to the page for CUI context } // EventStreamEndData represents the data for stream_end event // Sent when a streaming request completes (successfully or with error) type EventStreamEndData struct { RequestID string `json:"request_id"` // Corresponding request ID ContextID string `json:"context_id"` // Context ID for the response TraceID string `json:"trace_id"` // Trace ID being used (e.g., "trace-123") Timestamp int64 `json:"timestamp"` // Unix timestamp when stream ended DurationMs int64 `json:"duration_ms"` // Total duration in milliseconds Status string `json:"status"` // "completed" | "error" | "cancelled" Error string `json:"error,omitempty"` // Error message if status is "error" Usage *UsageInfo `json:"usage,omitempty"` // Token usage statistics Metadata map[string]interface{} `json:"metadata,omitempty"` // Metadata to pass to the page for CUI context } // EventMessageStartData represents the data for message_start event // Sent when a logical message begins (text, tool_call, thinking, etc.) // LLM layer: Marks the beginning of a single logical message output type EventMessageStartData struct { MessageID string `json:"message_id"` // Message ID (M1, M2, M3...) Type string `json:"type"` // Message type: "text" | "thinking" | "tool_call" | "refusal" Timestamp int64 `json:"timestamp"` // Unix timestamp when message started ThreadID string `json:"thread_id,omitempty"` // Thread ID (optional; for concurrent streams) ToolCall *EventToolCallInfo `json:"tool_call,omitempty"` // Tool call metadata (if type is "tool_call") Extra map[string]interface{} `json:"extra,omitempty"` // Additional metadata (for custom providers or future extensions) } // EventMessageEndData represents the data for message_end event // Sent when a logical message completes // LLM layer: Signals that all chunks for this message have been sent, client should merge and process type EventMessageEndData struct { MessageID string `json:"message_id"` // Message ID (M1, M2, M3...) Type string `json:"type"` // Message type (same as in message_start) Timestamp int64 `json:"timestamp"` // Unix timestamp when message ended ThreadID string `json:"thread_id,omitempty"` // Thread ID (optional; for concurrent streams) DurationMs int64 `json:"duration_ms"` // Duration of this message in milliseconds ChunkCount int `json:"chunk_count"` // Number of data chunks in this message Status string `json:"status"` // "completed" | "partial" | "error" ToolCall *EventToolCallInfo `json:"tool_call,omitempty"` // Complete tool call info (if type is "tool_call") Extra map[string]interface{} `json:"extra,omitempty"` // Additional metadata (e.g., complete content for direct use) } // EventToolCallInfo contains tool call information for message events // Used in both message_start (partial info) and message_end (complete info) type EventToolCallInfo struct { ID string `json:"id"` // Tool call ID (e.g., "call_abc123") Name string `json:"name"` // Function name (may be partial in message_start) Arguments string `json:"arguments,omitempty"` // Complete arguments (only in message_end) Index int `json:"index"` // Index in the tool calls array } // EventBlockStartData represents the data for block_start event // Sent when an output block begins (one LLM call, one MCP call, one Agent sub-task, etc.) // Agent layer: Groups multiple related messages into a logical section type EventBlockStartData struct { BlockID string `json:"block_id"` // Block ID (B1, B2, B3...) Type string `json:"type"` // Block type: "llm" | "mcp" | "agent" | "tool" | "mixed" Timestamp int64 `json:"timestamp"` // Unix timestamp when block started Label string `json:"label,omitempty"` // Human-readable label (e.g., "Searching knowledge base", "Calling weather API") Extra map[string]interface{} `json:"extra,omitempty"` // Additional metadata } // EventBlockEndData represents the data for block_end event // Sent when an output block completes // Agent layer: Signals that this logical section is complete type EventBlockEndData struct { BlockID string `json:"block_id"` // Block ID (B1, B2, B3...) Type string `json:"type"` // Block type (same as in block_start) Timestamp int64 `json:"timestamp"` // Unix timestamp when block ended DurationMs int64 `json:"duration_ms"` // Duration of this block in milliseconds MessageCount int `json:"message_count"` // Number of messages in this block Status string `json:"status"` // "completed" | "partial" | "error" Extra map[string]interface{} `json:"extra,omitempty"` // Additional metadata } // EventThreadStartData represents the data for thread_start event // Sent when a concurrent thread begins (parallel Agent/LLM/MCP calls) // Used in concurrent scenarios to distinguish multiple parallel output streams type EventThreadStartData struct { ThreadID string `json:"thread_id"` // Thread ID (T1, T2, T3...) Type string `json:"type"` // Thread type: "agent" | "llm" | "mcp" | "tool" Timestamp int64 `json:"timestamp"` // Unix timestamp when thread started Label string `json:"label,omitempty"` // Human-readable label (e.g., "Parallel search 1", "Background task") Extra map[string]interface{} `json:"extra,omitempty"` // Additional metadata } // EventThreadEndData represents the data for thread_end event // Sent when a concurrent thread completes type EventThreadEndData struct { ThreadID string `json:"thread_id"` // Thread ID (T1, T2, T3...) Type string `json:"type"` // Thread type (same as in thread_start) Timestamp int64 `json:"timestamp"` // Unix timestamp when thread ended DurationMs int64 `json:"duration_ms"` // Duration of this thread in milliseconds BlockCount int `json:"block_count"` // Number of blocks in this thread Status string `json:"status"` // "completed" | "partial" | "error" Extra map[string]interface{} `json:"extra,omitempty"` // Additional metadata } ================================================ FILE: agent/output/message/utils.go ================================================ package message import ( "fmt" "sync/atomic" gonanoid "github.com/matoous/go-nanoid/v2" ) // IDGenerator generates unique IDs within a context (e.g., one conversation stream) // Each Context should have its own IDGenerator to ensure IDs are unique within that context type IDGenerator struct { chunkCounter uint64 messageCounter uint64 blockCounter uint64 threadCounter uint64 } // NewIDGenerator creates a new ID generator for a context func NewIDGenerator() *IDGenerator { return &IDGenerator{} } // GenerateChunkID generates a unique chunk ID with prefix C // Format: C1, C2, C3... func (g *IDGenerator) GenerateChunkID() string { id := atomic.AddUint64(&g.chunkCounter, 1) return fmt.Sprintf("C%d", id) } // GenerateMessageID generates a unique message ID with prefix M // Format: M1, M2, M3... func (g *IDGenerator) GenerateMessageID() string { id := atomic.AddUint64(&g.messageCounter, 1) return fmt.Sprintf("M%d", id) } // GenerateBlockID generates a unique block ID with prefix B // Format: B1, B2, B3... func (g *IDGenerator) GenerateBlockID() string { id := atomic.AddUint64(&g.blockCounter, 1) return fmt.Sprintf("B%d", id) } // GenerateThreadID generates a unique thread ID with prefix T // Format: T1, T2, T3... func (g *IDGenerator) GenerateThreadID() string { id := atomic.AddUint64(&g.threadCounter, 1) return fmt.Sprintf("T%d", id) } // Reset resets all counters (useful for testing) func (g *IDGenerator) Reset() { atomic.StoreUint64(&g.chunkCounter, 0) atomic.StoreUint64(&g.messageCounter, 0) atomic.StoreUint64(&g.blockCounter, 0) atomic.StoreUint64(&g.threadCounter, 0) } // GetCounters returns current counter values (for debugging/testing) func (g *IDGenerator) GetCounters() (chunk, message, block, thread uint64) { return atomic.LoadUint64(&g.chunkCounter), atomic.LoadUint64(&g.messageCounter), atomic.LoadUint64(&g.blockCounter), atomic.LoadUint64(&g.threadCounter) } // GenerateNanoID generates a unique ID using nanoid // Returns a 21-character URL-safe string // This is a static function that doesn't depend on the generator's counter func GenerateNanoID() string { id, err := gonanoid.New() if err != nil { // Fallback to timestamp-based ID if nanoid fails return fmt.Sprintf("id_%d", atomic.AddUint64(new(uint64), 1)) } return id } // GenerateCustomID generates a custom ID with prefix and nanoid // Format: prefix_nanoid (e.g., "msg_V1StGXR8_Z5jdHi6B-myT") // This is a static function that doesn't depend on the generator's counter func GenerateCustomID(prefix string) string { id, err := gonanoid.New() if err != nil { // Fallback to timestamp-based ID return fmt.Sprintf("%s_%d", prefix, atomic.AddUint64(new(uint64), 1)) } return fmt.Sprintf("%s_%s", prefix, id) } ================================================ FILE: agent/output/message/utils_test.go ================================================ package message import ( "sync" "testing" ) func TestIDGenerator(t *testing.T) { gen := NewIDGenerator() t.Run("GenerateChunkID", func(t *testing.T) { id1 := gen.GenerateChunkID() id2 := gen.GenerateChunkID() id3 := gen.GenerateChunkID() if id1 != "C1" { t.Errorf("Expected C1, got %s", id1) } if id2 != "C2" { t.Errorf("Expected C2, got %s", id2) } if id3 != "C3" { t.Errorf("Expected C3, got %s", id3) } }) t.Run("GenerateMessageID", func(t *testing.T) { gen := NewIDGenerator() id1 := gen.GenerateMessageID() id2 := gen.GenerateMessageID() id3 := gen.GenerateMessageID() if id1 != "M1" { t.Errorf("Expected M1, got %s", id1) } if id2 != "M2" { t.Errorf("Expected M2, got %s", id2) } if id3 != "M3" { t.Errorf("Expected M3, got %s", id3) } }) t.Run("GenerateBlockID", func(t *testing.T) { gen := NewIDGenerator() id1 := gen.GenerateBlockID() id2 := gen.GenerateBlockID() id3 := gen.GenerateBlockID() if id1 != "B1" { t.Errorf("Expected B1, got %s", id1) } if id2 != "B2" { t.Errorf("Expected B2, got %s", id2) } if id3 != "B3" { t.Errorf("Expected B3, got %s", id3) } }) t.Run("GenerateThreadID", func(t *testing.T) { gen := NewIDGenerator() id1 := gen.GenerateThreadID() id2 := gen.GenerateThreadID() id3 := gen.GenerateThreadID() if id1 != "T1" { t.Errorf("Expected T1, got %s", id1) } if id2 != "T2" { t.Errorf("Expected T2, got %s", id2) } if id3 != "T3" { t.Errorf("Expected T3, got %s", id3) } }) t.Run("Reset", func(t *testing.T) { gen := NewIDGenerator() gen.GenerateChunkID() gen.GenerateMessageID() gen.GenerateBlockID() gen.GenerateThreadID() gen.Reset() chunk, message, block, thread := gen.GetCounters() if chunk != 0 || message != 0 || block != 0 || thread != 0 { t.Errorf("Expected all counters to be 0 after reset, got chunk=%d, message=%d, block=%d, thread=%d", chunk, message, block, thread) } // Verify IDs start from 1 again if id := gen.GenerateChunkID(); id != "C1" { t.Errorf("Expected C1 after reset, got %s", id) } if id := gen.GenerateMessageID(); id != "M1" { t.Errorf("Expected M1 after reset, got %s", id) } }) t.Run("ConcurrentAccess", func(t *testing.T) { gen := NewIDGenerator() var wg sync.WaitGroup count := 100 // Test concurrent chunk ID generation wg.Add(count) for i := 0; i < count; i++ { go func() { defer wg.Done() gen.GenerateChunkID() }() } wg.Wait() chunk, _, _, _ := gen.GetCounters() if chunk != uint64(count) { t.Errorf("Expected chunk counter to be %d, got %d", count, chunk) } }) t.Run("MultipleGenerators", func(t *testing.T) { gen1 := NewIDGenerator() gen2 := NewIDGenerator() id1 := gen1.GenerateMessageID() id2 := gen2.GenerateMessageID() // Both should start from M1 if id1 != "M1" || id2 != "M1" { t.Errorf("Expected both generators to start from M1, got %s and %s", id1, id2) } // Advance gen1 gen1.GenerateMessageID() gen1.GenerateMessageID() // gen2 should still be at M1 id2_next := gen2.GenerateMessageID() if id2_next != "M2" { t.Errorf("Expected gen2 to be at M2, got %s", id2_next) } // gen1 should be at M3 id1_next := gen1.GenerateMessageID() if id1_next != "M4" { t.Errorf("Expected gen1 to be at M4, got %s", id1_next) } }) } func TestGenerateNanoID(t *testing.T) { id1 := GenerateNanoID() id2 := GenerateNanoID() // NanoID should be 21 characters by default if len(id1) != 21 { t.Errorf("Expected NanoID length to be 21, got %d", len(id1)) } // IDs should be unique if id1 == id2 { t.Error("Expected unique NanoIDs, got duplicates") } t.Logf("Generated NanoIDs: %s, %s", id1, id2) } func TestGenerateCustomID(t *testing.T) { id1 := GenerateCustomID("msg") id2 := GenerateCustomID("evt") // Should have prefix if len(id1) < 4 || id1[:4] != "msg_" { t.Errorf("Expected ID to start with 'msg_', got %s", id1) } if len(id2) < 4 || id2[:4] != "evt_" { t.Errorf("Expected ID to start with 'evt_', got %s", id2) } // IDs should be unique if id1 == id2 { t.Error("Expected unique custom IDs, got duplicates") } t.Logf("Generated custom IDs: %s, %s", id1, id2) } ================================================ FILE: agent/output/output.go ================================================ package output import ( "fmt" "github.com/yaoapp/yao/agent/output/adapters/cui" "github.com/yaoapp/yao/agent/output/adapters/openai" "github.com/yaoapp/yao/agent/output/message" ) // Accept type constants const ( AcceptStandard = "standard" AcceptWebCUI = "cui-web" AccepNativeCUI = "cui-native" AcceptDesktopCUI = "cui-desktop" ) // Output are the options for the output type Output struct { Writer message.Writer } // NewOutput creates a new output based on Accept type func NewOutput(options message.Options) (*Output, error) { var writer message.Writer var err error // Create writer based on Accept type switch options.Accept { case AcceptStandard: // OpenAI-compatible format writer, err = openai.NewWriter(options) case AcceptWebCUI, AccepNativeCUI, AcceptDesktopCUI: // CUI format writer, err = cui.NewWriter(options) default: // Default to Standard (OpenAI) writer, err = openai.NewWriter(options) } if err != nil { return nil, err } return &Output{ Writer: writer, }, nil } // Send sends a single message using the appropriate writer for the context func (o *Output) Send(msg *message.Message) error { return o.Writer.Write(msg) } // SendGroup sends a message group using the appropriate writer for the context func (o *Output) SendGroup(group *message.Group) error { return o.Writer.WriteGroup(group) } // Flush flushes the writer for the given context func (o *Output) Flush() error { return o.Writer.Flush() } // Close closes the writer for the given context func (o *Output) Close() error { return o.Writer.Close() } // SendMulti sends multiple messages using the appropriate writer for the context func (o *Output) SendMulti(messages ...*message.Message) error { for _, msg := range messages { if err := o.Writer.Write(msg); err != nil { return fmt.Errorf("failed to send message: %w", err) } } return nil } // // Send sends a single message using the appropriate writer for the context // func Send(ctx *context.Context, msg *message.Message) error { // writer, err := GetWriter(ctx) // if err != nil { // return err // } // return writer.Write(msg) // } // // SendGroup sends a message group using the appropriate writer for the context // func SendGroup(ctx *context.Context, group *message.Group) error { // writer, err := GetWriter(ctx) // if err != nil { // return err // } // return writer.WriteGroup(group) // } // // GetWriter gets or creates a writer for the given context // // Writers are cached per context to avoid recreating them // func GetWriter(ctx *context.Context) (message.Writer, error) { // // Try to get cached writer // writerMutex.RLock() // writer, exists := writerCache[ctx] // writerMutex.RUnlock() // if exists { // return writer, nil // } // // Create new writer // writerMutex.Lock() // defer writerMutex.Unlock() // // Double-check after acquiring write lock // if writer, exists := writerCache[ctx]; exists { // return writer, nil // } // // Create writer based on context.Accept // writer, err := createWriter(ctx) // if err != nil { // return nil, err // } // // Cache the writer // writerCache[ctx] = writer // return writer, nil // } // // createWriter creates a writer based on context.Accept // func createWriter(ctx *context.Context) (message.Writer, error) { // // If global factory is set, use it // if globalFactory != nil { // return globalFactory.NewWriter(ctx, nil) // } // // Default: create based on Accept type // switch ctx.Accept { // case context.AcceptStandard: // // OpenAI-compatible format // return openai.NewWriter(ctx) // case context.AcceptWebCUI, context.AccepNativeCUI, context.AcceptDesktopCUI: // // CUI format // return cui.NewWriter(ctx) // default: // // Default to Standard // return openai.NewWriter(ctx) // } // } // // SetWriterFactory sets a custom writer factory // // This allows applications to provide their own writer implementations // func SetWriterFactory(factory message.WriterFactory) { // globalFactory = factory // } // // ClearWriterCache clears the writer cache // // Should be called when contexts are cleaned up // func ClearWriterCache(ctx *context.Context) { // writerMutex.Lock() // defer writerMutex.Unlock() // delete(writerCache, ctx) // } // // ClearAllWriterCache clears all cached writers // func ClearAllWriterCache() { // writerMutex.Lock() // defer writerMutex.Unlock() // writerCache = make(map[*context.Context]message.Writer) // } // // Flush flushes the writer for the given context // func Flush(ctx *context.Context) error { // writer, err := GetWriter(ctx) // if err != nil { // return err // } // return writer.Flush() // } // // Close closes the writer for the given context and removes it from cache // func Close(ctx *context.Context) error { // writer, err := GetWriter(ctx) // if err != nil { // return err // } // err = writer.Close() // ClearWriterCache(ctx) // return err // } // // SendMulti is a convenience function to send multiple messages // func SendMulti(ctx *context.Context, messages ...*message.Message) error { // writer, err := GetWriter(ctx) // if err != nil { // return err // } // for _, msg := range messages { // if err := writer.Write(msg); err != nil { // return fmt.Errorf("failed to send message: %w", err) // } // } // return nil // } ================================================ FILE: agent/output/safe_writer.go ================================================ package output import ( "context" "net/http" "sync" ) // SafeWriter wraps http.ResponseWriter with a channel-based queue // to serialize concurrent SSE writes and prevent "short write" errors. // // When multiple goroutines (e.g., concurrent sub-agents via ctx.agent.All) // write to the same SSE stream, direct writes can cause data corruption // or "short write" errors. SafeWriter solves this by: // // 1. Accepting write requests via a buffered channel // 2. Processing writes sequentially in a dedicated goroutine // 3. Providing non-blocking writes with overflow protection // 4. Automatic cleanup when context is cancelled (client disconnect) type SafeWriter struct { ch chan writeRequest writer http.ResponseWriter done chan struct{} ctx context.Context // For detecting client disconnection cancel context.CancelFunc // To signal run() to stop closeOnce sync.Once closed bool mu sync.RWMutex } // writeRequest represents a single write request type writeRequest struct { data []byte } // QueueCapacity is the default buffer size for the write queue // Large enough to handle high concurrency without blocking const QueueCapacity = 10000 // NewSafeWriter creates a new SafeWriter that wraps an http.ResponseWriter // and starts a background goroutine to process writes sequentially. // The context should be the HTTP request context to detect client disconnection. func NewSafeWriter(w http.ResponseWriter) *SafeWriter { // Create internal context for graceful shutdown ctx, cancel := context.WithCancel(context.Background()) sw := &SafeWriter{ ch: make(chan writeRequest, QueueCapacity), writer: w, done: make(chan struct{}), ctx: ctx, cancel: cancel, } go sw.run() return sw } // NewSafeWriterWithContext creates a SafeWriter that respects the given context. // When the context is cancelled (e.g., client disconnects), the run() goroutine exits. // This prevents goroutine leaks in enterprise applications with many concurrent requests. func NewSafeWriterWithContext(ctx context.Context, w http.ResponseWriter) *SafeWriter { // Derive a cancellable context from the parent childCtx, cancel := context.WithCancel(ctx) sw := &SafeWriter{ ch: make(chan writeRequest, QueueCapacity), writer: w, done: make(chan struct{}), ctx: childCtx, cancel: cancel, } go sw.run() return sw } // run processes write requests from the channel sequentially // Exits when channel is closed OR context is cancelled (client disconnect) func (sw *SafeWriter) run() { defer close(sw.done) for { select { case req, ok := <-sw.ch: if !ok { // Channel closed, exit gracefully return } if sw.writer != nil { sw.writer.Write(req.data) // Flush after each write to ensure SSE data is sent immediately if flusher, ok := sw.writer.(http.Flusher); ok { flusher.Flush() } } case <-sw.ctx.Done(): // Context cancelled (client disconnected or explicit close) // Continue reading from channel until it's closed to avoid blocking senders // and to process any remaining messages that were already queued sw.drainUntilClosed() return } } } // drainUntilClosed reads from channel until it's closed // This prevents senders from blocking after context cancellation func (sw *SafeWriter) drainUntilClosed() { for range sw.ch { // Discard messages - context is cancelled so we don't write them } } // Write implements io.Writer interface // Queues the data for sequential writing by the background goroutine func (sw *SafeWriter) Write(data []byte) (int, error) { sw.mu.RLock() if sw.closed { sw.mu.RUnlock() return 0, nil // Silently ignore writes after close } sw.mu.RUnlock() // Make a copy of data since the caller may reuse the buffer dataCopy := make([]byte, len(data)) copy(dataCopy, data) // Non-blocking send with overflow protection select { case sw.ch <- writeRequest{data: dataCopy}: return len(data), nil default: // Channel full - this shouldn't happen with 10000 capacity // but if it does, drop the message rather than block // Note: In production, this indicates either: // 1. Extremely high concurrency (>10000 pending writes) // 2. The underlying writer is blocked/slow // Consider increasing QueueCapacity if this occurs frequently return len(data), nil } } // Header returns the header map from the underlying ResponseWriter func (sw *SafeWriter) Header() http.Header { if sw.writer == nil { return http.Header{} } return sw.writer.Header() } // WriteHeader sends an HTTP response header with the provided status code func (sw *SafeWriter) WriteHeader(statusCode int) { if sw.writer != nil { sw.writer.WriteHeader(statusCode) } } // Flush implements http.Flusher interface // Note: Actual flushing happens in the run() goroutine after each write func (sw *SafeWriter) Flush() { // Flushing is handled automatically in run() after each write // This method exists to satisfy the http.Flusher interface } // Close closes the write channel and waits for all pending writes to complete // This is safe to call multiple times (idempotent via sync.Once) func (sw *SafeWriter) Close() error { sw.closeOnce.Do(func() { // First close channel to signal run() to stop and process remaining messages close(sw.ch) // Wait for run() to finish processing all queued messages <-sw.done // Then mark as closed and cancel context sw.mu.Lock() sw.closed = true sw.mu.Unlock() sw.cancel() }) return nil } // IsClosed returns whether the SafeWriter has been closed func (sw *SafeWriter) IsClosed() bool { sw.mu.RLock() defer sw.mu.RUnlock() return sw.closed } // Underlying returns the underlying http.ResponseWriter // Use with caution - direct writes bypass the queue func (sw *SafeWriter) Underlying() http.ResponseWriter { return sw.writer } ================================================ FILE: agent/output/safe_writer_test.go ================================================ package output import ( "bytes" "context" "net/http" "net/http/httptest" "sync" "testing" "time" ) // mockResponseWriter is a thread-safe mock for testing type mockResponseWriter struct { mu sync.Mutex buf bytes.Buffer header http.Header flushed int writeErr error } func newMockResponseWriter() *mockResponseWriter { return &mockResponseWriter{ header: make(http.Header), } } func (m *mockResponseWriter) Header() http.Header { return m.header } func (m *mockResponseWriter) Write(data []byte) (int, error) { m.mu.Lock() defer m.mu.Unlock() if m.writeErr != nil { return 0, m.writeErr } return m.buf.Write(data) } func (m *mockResponseWriter) WriteHeader(statusCode int) {} func (m *mockResponseWriter) Flush() { m.mu.Lock() defer m.mu.Unlock() m.flushed++ } func (m *mockResponseWriter) String() string { m.mu.Lock() defer m.mu.Unlock() return m.buf.String() } func (m *mockResponseWriter) FlushCount() int { m.mu.Lock() defer m.mu.Unlock() return m.flushed } func TestSafeWriter_BasicWrite(t *testing.T) { mock := newMockResponseWriter() sw := NewSafeWriter(mock) defer sw.Close() // Write some data n, err := sw.Write([]byte("hello")) if err != nil { t.Errorf("Write error: %v", err) } if n != 5 { t.Errorf("Expected 5 bytes written, got %d", n) } // Wait for async write to complete time.Sleep(10 * time.Millisecond) // Verify data was written if got := mock.String(); got != "hello" { t.Errorf("Expected 'hello', got '%s'", got) } } func TestSafeWriter_ConcurrentWrites(t *testing.T) { mock := newMockResponseWriter() sw := NewSafeWriter(mock) // Number of concurrent goroutines numGoroutines := 100 // Number of writes per goroutine writesPerGoroutine := 100 var wg sync.WaitGroup wg.Add(numGoroutines) // Launch concurrent writes for i := 0; i < numGoroutines; i++ { go func(id int) { defer wg.Done() for j := 0; j < writesPerGoroutine; j++ { sw.Write([]byte("X")) } }(i) } // Wait for all goroutines to complete wg.Wait() // Close and wait for all writes to be processed sw.Close() // Verify all data was written (no data loss) expectedLen := numGoroutines * writesPerGoroutine if got := len(mock.String()); got != expectedLen { t.Errorf("Expected %d bytes, got %d", expectedLen, got) } // Verify flush was called (at least once per write) if mock.FlushCount() < expectedLen { t.Errorf("Expected at least %d flushes, got %d", expectedLen, mock.FlushCount()) } } func TestSafeWriter_NoDataCorruption(t *testing.T) { mock := newMockResponseWriter() sw := NewSafeWriter(mock) // Use exactly 26 goroutines (one per letter A-Z) to avoid duplicates numGoroutines := 26 // Message to write (with unique content per goroutine) msgLen := 100 var wg sync.WaitGroup wg.Add(numGoroutines) // Launch concurrent writes with different content for i := 0; i < numGoroutines; i++ { go func(id int) { defer wg.Done() // Create a message with repeating character (unique per goroutine) char := byte('A' + id) msg := bytes.Repeat([]byte{char}, msgLen) sw.Write(msg) }(i) } // Wait for all goroutines to complete wg.Wait() // Close and wait for all writes to be processed sw.Close() // Verify total length result := mock.String() expectedLen := numGoroutines * msgLen if len(result) != expectedLen { t.Errorf("Expected %d bytes, got %d", expectedLen, len(result)) } // Verify no interleaving (each message should be contiguous) // Check that we have exactly numGoroutines distinct blocks blocks := make(map[byte]int) for i := 0; i < len(result); i += msgLen { if i+msgLen > len(result) { t.Errorf("Unexpected data at end of result") break } block := result[i : i+msgLen] // Verify block is homogeneous (all same character) firstChar := block[0] for j, c := range []byte(block) { if c != firstChar { t.Errorf("Data corruption detected at position %d: expected %c, got %c", i+j, firstChar, c) break } } blocks[firstChar]++ } // Each character should appear exactly once (one block per goroutine) for char, count := range blocks { if count != 1 { t.Errorf("Character %c appeared %d times, expected 1", char, count) } } // Verify we got all 26 letters if len(blocks) != numGoroutines { t.Errorf("Expected %d distinct blocks, got %d", numGoroutines, len(blocks)) } } func TestSafeWriter_CloseWaitsForPendingWrites(t *testing.T) { mock := newMockResponseWriter() sw := NewSafeWriter(mock) // Write a large number of messages numWrites := 1000 for i := 0; i < numWrites; i++ { sw.Write([]byte("X")) } // Close should wait for all writes to complete sw.Close() // Verify all data was written if got := len(mock.String()); got != numWrites { t.Errorf("Expected %d bytes after close, got %d", numWrites, got) } } func TestSafeWriter_WriteAfterClose(t *testing.T) { mock := newMockResponseWriter() sw := NewSafeWriter(mock) sw.Write([]byte("before")) sw.Close() // Write after close should be silently ignored n, err := sw.Write([]byte("after")) if err != nil { t.Errorf("Write after close should not error: %v", err) } if n != 0 { t.Errorf("Write after close should return 0, got %d", n) } // Verify only "before" was written if got := mock.String(); got != "before" { t.Errorf("Expected 'before', got '%s'", got) } } func TestSafeWriter_ImplementsHTTPInterfaces(t *testing.T) { mock := newMockResponseWriter() sw := NewSafeWriter(mock) defer sw.Close() // Verify it implements http.ResponseWriter var _ http.ResponseWriter = sw // Verify it implements http.Flusher var _ http.Flusher = sw // Test Header() sw.Header().Set("Content-Type", "text/plain") if got := mock.Header().Get("Content-Type"); got != "text/plain" { t.Errorf("Expected Content-Type 'text/plain', got '%s'", got) } } // BenchmarkSafeWriter_ConcurrentWrites benchmarks concurrent write performance func BenchmarkSafeWriter_ConcurrentWrites(b *testing.B) { mock := newMockResponseWriter() sw := NewSafeWriter(mock) defer sw.Close() data := []byte("benchmark data for SSE streaming") b.RunParallel(func(pb *testing.PB) { for pb.Next() { sw.Write(data) } }) } // TestSafeWriter_RealHTTPServer tests SafeWriter with a real HTTP server func TestSafeWriter_RealHTTPServer(t *testing.T) { // This test verifies SafeWriter works correctly with httptest.ResponseRecorder // which is commonly used in testing HTTP handlers recorder := httptest.NewRecorder() sw := NewSafeWriter(recorder) // Simulate concurrent SSE writes from multiple sub-agents var wg sync.WaitGroup numAgents := 10 messagesPerAgent := 10 wg.Add(numAgents) for i := 0; i < numAgents; i++ { go func(agentID int) { defer wg.Done() for j := 0; j < messagesPerAgent; j++ { // Simulate SSE message format msg := []byte("data: {\"agent\":" + string(rune('0'+agentID)) + "}\n\n") sw.Write(msg) } }(i) } wg.Wait() sw.Close() // Verify response contains all messages (no data loss) body := recorder.Body.String() expectedMsgs := numAgents * messagesPerAgent // Count number of "data: " prefixes count := 0 for i := 0; i < len(body); i++ { if i+6 <= len(body) && body[i:i+6] == "data: " { count++ } } if count != expectedMsgs { t.Errorf("Expected %d messages, found %d", expectedMsgs, count) } } // TestSafeWriter_ContextCancellation tests that SafeWriter handles context cancellation // This is critical for enterprise applications to prevent goroutine leaks func TestSafeWriter_ContextCancellation(t *testing.T) { mock := newMockResponseWriter() // Create a cancellable context ctx, cancel := context.WithCancel(context.Background()) sw := NewSafeWriterWithContext(ctx, mock) // Write some data and wait for it to be processed sw.Write([]byte("before")) time.Sleep(20 * time.Millisecond) // Verify "before" was written if got := mock.String(); got != "before" { t.Errorf("Expected 'before' before cancel, got '%s'", got) } // Cancel context (simulates client disconnect) cancel() // Write after context cancellation - these may or may not be written // depending on timing (select may pick ctx.Done() first) sw.Write([]byte("after_cancel")) // Close properly cleans up sw.Close() // After close, run() has exited select { case <-sw.done: // Good - run() has exited default: t.Error("run() should have exited after Close()") } // The key guarantee: run() goroutine exits cleanly, no leak // Data written before cancel is preserved got := mock.String() if len(got) < 6 { // At least "before" should be there t.Errorf("Expected at least 'before', got '%s'", got) } } // TestSafeWriter_GoroutineLeak tests that SafeWriter doesn't leak goroutines func TestSafeWriter_GoroutineLeak(t *testing.T) { // Create many SafeWriters and ensure they all clean up properly numWriters := 100 var wg sync.WaitGroup wg.Add(numWriters) for i := 0; i < numWriters; i++ { go func() { defer wg.Done() mock := newMockResponseWriter() ctx, cancel := context.WithCancel(context.Background()) sw := NewSafeWriterWithContext(ctx, mock) // Write some data sw.Write([]byte("test")) // Randomly either close normally or cancel context if time.Now().UnixNano()%2 == 0 { cancel() time.Sleep(5 * time.Millisecond) sw.Close() } else { sw.Close() cancel() // Cancel after close is safe } }() } // All goroutines should complete done := make(chan struct{}) go func() { wg.Wait() close(done) }() select { case <-done: // All completed successfully case <-time.After(5 * time.Second): t.Error("Timeout waiting for goroutines to complete - possible goroutine leak") } } ================================================ FILE: agent/robot/DESIGN-V2-REVIEW-FINDINGS.md ================================================ # DESIGN-V2 Line-by-Line Review Findings **Review Date:** 2026-02-25 **Files Reviewed:** `runner.go`, `run.go` --- ## File 1: runner.go ### 1. ExecuteTask: Is it truly single-call? No retry loop? No validation? **✅ PASS** — Lines 65–108 - Non-assistant: single call to `executeNonAssistantTask` (L74), no loop - Assistant: single call to `executeAssistantTask` (L89), no loop - No `validator` import or call anywhere in the file - Comment at L62–63: "V2 simplified: single call, no validation loop" --- ### 2. Does it correctly split assistant vs non-assistant at the top? **✅ PASS** — Lines 72–86 vs 88–107 - L72: `if task.ExecutorType != robottypes.ExecutorAssistant` — non-assistant branch first - L88+: assistant branch follows - Clear split at the top of the function --- ### 3. For non-assistant: does it call executeNonAssistantTask which handles MCP and Process? **✅ PASS** — Lines 74, 110–119 - L74: `output, err := r.executeNonAssistantTask(task, taskCtx)` - `executeNonAssistantTask` (L110–119): switch on `ExecutorMCP` and `ExecutorProcess`, delegates to `ExecuteMCPTask` and `ExecuteProcessTask` --- ### 4. For assistant: does it call executeAssistantTask which returns (output, *CallResult, error)? **✅ PASS** — Lines 89, 123–145 - L89: `output, callResult, err := r.executeAssistantTask(task, taskCtx)` - L123: `func (r *Runner) executeAssistantTask(...) (interface{}, *CallResult, error)` - L144: `return output, turnResult.Result, nil` --- ### 5. Does executeAssistantTask use conv.Turn() (single turn, not multi-turn)? **✅ PASS** — Lines 126, 138 - L126: `conv := NewConversation(task.ExecutorID, chatID, 1)` — maxTurns=1 - L138: `turnResult, err := conv.Turn(r.ctx, input)` — single `Turn` call, no loop --- ### 6. Does detectNeedMoreInfo properly check result.Next for map with "status" == "need_input"? **✅ PASS** — Lines 149–166 - L150–151: nil checks for `result` and `result.Next` - L154: type assertion to `map[string]interface{}` - L157–159: `status, _ := m["status"].(string); if status != "need_input" return false` - Matches DESIGN §16.5 protocol --- ### 7. Does it extract "question" from the map? What happens if question is empty? **✅ PASS** — Lines 161–165 - L161: `question, _ := m["question"].(string)` - L163–164: `if question == "" { question = result.GetText() }` — fallback to `CallResult` text - Empty question handled via fallback --- ### 8. Are result.NeedInput and result.InputQuestion set correctly? **✅ PASS** — Lines 102–105 - L102–104: `if needInput, question := detectNeedMoreInfo(callResult); needInput { result.NeedInput = true; result.InputQuestion = question }` - Set only when `detectNeedMoreInfo` returns true --- ### 9. Does result.Duration get set in all paths (success and failure)? **✅ PASS** — Lines 77, 84, 92, 98 - Non-assistant error: L77 - Non-assistant success: L84 - Assistant error: L92 - Assistant success: L98 - All paths set `result.Duration = time.Since(startTime).Milliseconds()` --- ### 10. Does buildResult helper exist or is result construction inline? **✅ PASS (inline)** — Lines 68–107 - No `buildResult` helper; construction is inline - DESIGN §16.4 pseudocode uses `buildResult`; inline construction is acceptable and used here --- ### 11. Any edge cases: what if task.ExecutorType is empty or unknown? **✅ PASS** — Lines 72, 112–118 - Empty/unknown: `!= ExecutorAssistant` is true → non-assistant branch - L116–117: `default` returns `fmt.Errorf("unsupported executor type: %s (expected mcp or process)", task.ExecutorType)` - Error returned and propagated; no silent failure --- ### Additional Finding (runner.go) **⚠️ Minor:** DESIGN §9.1 shows `event.Push("robot.task.failed", ...)` inside `ExecuteTask` when `err != nil`. Implementation pushes `TaskFailed` from `run.go` (L113–120) when `result.Success` is false. Behavior is equivalent; only location differs. --- ## File 2: run.go ### 1. DefaultRunConfig — ContinueOnFailure defaults to true? **✅ PASS** — Lines 21–26 - L23–25: `return &RunConfig{ ContinueOnFailure: true }` - Matches DESIGN §6.3 --- ### 2. RunExecution — does it check exec.ResumeContext for startIndex and PreviousResults? **✅ PASS** — Lines 60–66 - L61–64: `if exec.ResumeContext != nil { startIndex = exec.ResumeContext.TaskIndex; exec.Results = exec.ResumeContext.PreviousResults }` - L72: loop starts at `startIndex` - Matches DESIGN §9.2, §16.3 --- ### 3. Does it NOT reset Results when ResumeContext is present? **✅ PASS** — Lines 62–65 - When `ResumeContext != nil`: `exec.Results = exec.ResumeContext.PreviousResults` — restores, does not reset - When `ResumeContext == nil`: `exec.Results = make(...)` — fresh slice --- ### 4. Does it set task.Status to TaskRunning before execution? **✅ PASS** — Lines 86–89 - L87: `task.Status = robottypes.TaskRunning` - L88–89: `task.StartTime = &now` - Set before `ExecuteTask` (L98) --- ### 5. Does it call e.updateTasksState to persist running state? **✅ PASS** — Line 92 - L92: `e.updateTasksState(ctx, exec)` immediately after setting task status - Persists running state before execution --- ### 6. Does result.NeedInput trigger e.Suspend(ctx, exec, i, result.InputQuestion)? **✅ PASS** — Lines 100–103 - L100–102: `if result.NeedInput { return e.Suspend(ctx, exec, i, result.InputQuestion) }` - Correct parameters and early return --- ### 7. Is the result NOT appended before Suspend (avoiding duplicate results per §16.15)? **✅ PASS** — Lines 100–103, 124 - L100–102: `NeedInput` branch returns before any append - L124: `exec.Results = append(exec.Results, *result)` is after the `NeedInput` check - No append on suspend; matches DESIGN §16.15 --- ### 8. Does it push event.Push for TaskFailed when a task fails? **✅ PASS** — Lines 113–120 - L113–120: `event.Push(ctx.Context, robotevents.TaskFailed, robotevents.NeedInputPayload{...})` when `!result.Success` - Event is pushed on task failure **⚠️ Minor:** Uses `NeedInputPayload` with `Question: result.Error`. DESIGN §7.2 does not define a TaskFailed payload. `ExecPayload` (with `Error`) might be more appropriate; `NeedInputPayload.Question` is reused for the error message. Functionally acceptable but semantically odd. --- ### 9. Does it skip remaining tasks when ContinueOnFailure is false? **✅ PASS** — Lines 129–137 - L129: `if !result.Success && !config.ContinueOnFailure` - L131–134: marks remaining tasks as `TaskSkipped` - L136: `return fmt.Errorf(...)` — stops execution - Matches DESIGN §9.2 --- ### 10. Does it clear exec.Current and exec.ResumeContext after completion? **✅ PASS** — Lines 141–143 - L142–143: `exec.Current = nil; exec.ResumeContext = nil` after loop completes - Only on normal completion (no early return from Suspend or failure) --- ### 11. Is there any event.Push for TaskCompleted? **❌ FINDING** — run.go - No `event.Push(robotevents.TaskCompleted, ...)` when a task succeeds - DESIGN §7.2 defines `EventTaskCompleted = "robot.task.completed"` - DESIGN-V2 §20.4 (B5) notes missing TaskCompleted event constant; implementation also does not push it - **Recommendation:** Add `event.Push(ctx.Context, robotevents.TaskCompleted, payload)` when `result.Success` (e.g. after L110) --- ### 12. Does getRunConfig properly handle nil data? **✅ PASS** — Lines 50–55 - No separate `getRunConfig`; config is obtained inline - L51–55: `if cfg, ok := data.(*RunConfig); ok && cfg != nil { config = cfg } else { config = DefaultRunConfig() }` - Handles: `data == nil`, wrong type, `cfg == nil` → falls back to `DefaultRunConfig()` --- ### Additional Finding (run.go) **⚠️ Order of operations:** Task status update (L109–120) and result append (L124) occur *after* the NeedInput check. Flow is correct: NeedInput → Suspend (return) → no append, no status update for that task. --- ## Summary | Category | runner.go | run.go | |----------|-----------|--------| | **PASS** | 11/11 | 11/12 | | **Minor** | 1 | 1 | | **Finding** | 0 | 1 (TaskCompleted not pushed) | ### Action Items 1. **run.go L109–110:** Add `event.Push(ctx.Context, robotevents.TaskCompleted, payload)` when `result.Success` for per-task completion events. 2. **run.go L113:** Consider introducing a `TaskFailedPayload` (or using `ExecPayload` with `Error`) instead of `NeedInputPayload` for TaskFailed events. ================================================ FILE: agent/robot/DESIGN.md ================================================ # Robot Agent ## 1. What is it? A **Robot Agent** is an AI team member. It works on its own, makes decisions, and runs tasks without waiting for user input. **Key points:** - Belongs to a Team, managed like human members - Has clear job duties (e.g., "Sales Manager: track KPIs, make reports") - Created and deleted via Team API - Runs on schedule, or when triggered by humans or events - Learns from each run, stores knowledge in private KB --- ## 2. Architecture ### 2.1 System Flow > **Architecture Note:** All trigger types flow through Manager. > > - Clock: `Manager.Tick()` (internal ticker) > - Human: `Manager.Intervene()` (API call) > - Event: `Manager.HandleEvent()` (webhook/db trigger) > > The `trigger/` package provides utilities only (validation, clock matching, execution control). ```mermaid flowchart TB subgraph Triggers["Triggers"] WC[/"⏰ Clock"/] HI[/"👤 Human"/] EV[/"📡 Event"/] end subgraph Manager["Manager (Central Orchestrator)"] TC{"Enabled?"} Cache[("Cache")] Dedup{"Dedup?"} Queue["Queue"] end subgraph Pool["Workers"] W1["Worker"] W2["Worker"] W3["Worker"] end subgraph Executor["Executor"] TT{"Trigger?"} P0["P0: Inspiration"] P1["P1: Goals"] P2["P2: Tasks"] P3["P3: Run"] P4["P4: Deliver"] P5["P5: Learn"] end subgraph Storage["Storage"] KB[("KB")] DB[("DB")] end WC --> TC HI & EV --> TC TC -->|Yes| Cache TC -->|No| X[/Skip/] Cache --> Dedup Dedup -->|OK| Queue Dedup -->|Dup| Cache Queue --> W1 & W2 & W3 W1 & W2 & W3 --> TT TT -->|Clock| P0 TT -->|Human/Event| P1 P0 --> P1 --> P2 --> P3 --> P4 --> P5 P5 --> KB & DB KB -.->|History| P0 ``` ### 2.2 Executor Modes Executor supports multiple execution modes for different use cases: | Mode | Use Case | Status | | -------- | --------------------------------------- | ------------------ | | Standard | Production with real Agent calls | ✅ Implemented | | DryRun | Tests, demos, preview without LLM calls | ✅ Implemented | | Sandbox | Container-isolated for untrusted code | ⬜ Not Implemented | **Standard Mode:** Real execution with LLM calls, full phase execution, logging via kun/log. **DryRun Mode:** Simulated execution without LLM calls. Used for: - Unit tests and integration tests - Demo and preview modes - Scheduling and concurrency testing **Sandbox Mode (Future):** Container-level isolation (Docker/gVisor/Firecracker) for: - Untrusted robot configurations - Multi-tenant environments - Resource-limited execution > **⚠️ Sandbox requires infrastructure support.** Current placeholder behaves like DryRun. ### 2.3 Team Structure Uses existing `__yao.member` model (`yao/models/member.mod.yao`): ``` ┌─────────────────────────────────────────────────────────────────┐ │ Team │ │ ┌─────────────────────────────────────────────────────────┐ │ │ │ Robot Members │ │ │ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ │ │ │Sales Manager│ │Data Analyst │ │CS Specialist│ │ │ │ │ │ • Track KPIs│ │ • Analyze │ │ • Tickets │ │ │ │ │ │ • Reports │ │ • Reports │ │ • Inquiries │ │ │ │ │ └─────────────┘ └─────────────┘ └─────────────┘ │ │ │ └─────────────────────────────────────────────────────────┘ │ │ ┌─────────────────────────────────────────────────────────┐ │ │ │ User Members │ │ │ │ ┌─────────────┐ ┌─────────────┐ │ │ │ │ │ John (Owner)│ │ Jane (Admin)│ │ │ │ │ └─────────────┘ └─────────────┘ │ │ │ └─────────────────────────────────────────────────────────┘ │ └─────────────────────────────────────────────────────────────────┘ ``` **Key fields in `__yao.member` for robot agents:** | Field | Type | Description | | ----------------- | ------ | ----------------------------------------------------------- | | `member_type` | enum | `user` \| `robot` | | `autonomous_mode` | bool | Enable robot execution | | `robot_config` | JSON | Agent configuration (see section 5) | | `robot_status` | enum | `idle` \| `working` \| `paused` \| `error` \| `maintenance` | | `system_prompt` | text | Identity & role prompt | | `robot_email` | string | Robot's email address for sending emails (From address) | | `agents` | JSON | Accessible agents list | | `mcp_servers` | JSON | Accessible MCP servers | | `manager_id` | string | Direct manager user ID | --- ## 3. How It Works ### 3.1 Flow: Trigger → Schedule → Run ```mermaid sequenceDiagram autonumber participant T as Trigger participant M as Manager participant S as Scheduler participant W as Worker participant E as Executor participant A as Phase Agents participant KB as KB T->>M: Event M->>M: Check enabled M->>M: Get from cache M->>M: Check dedup M->>S: Submit S->>S: Check quota S->>S: Sort by priority S->>W: Dispatch W->>E: Run alt Clock trigger E->>A: P0: Inspiration (with clock context) A-->>E: Report end loop P1 to P5 E->>A: Call agent A-->>E: Result end E->>KB: Save learning E-->>W: Done ``` ### 3.2 Triggers | Type | What | Config | Handler | | --------- | ----------------------------- | -------------------- | ----------------------- | | **Clock** | Timer (times/interval/daemon) | `triggers.clock` | `Manager.Tick()` | | **Human** | Manual action | `triggers.intervene` | `Manager.Intervene()` | | **Event** | Webhook, DB change | `triggers.event` | `Manager.HandleEvent()` | All on by default. Turn off per agent: ```yaml triggers: clock: { enabled: true } intervene: { enabled: true, actions: ["task.add", "goal.adjust"] } event: { enabled: false } ``` ### 3.3 Concurrency Two levels to prevent one agent from using all resources: ``` ┌─────────────────────────────────────────────────────────────────┐ │ Global Pool (10 workers) │ └─────────────────────────────────────────────────────────────────┘ │ │ │ ▼ ▼ ▼ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │ Sales Manager │ │ Data Analyst │ │ CS Specialist │ │ Limit: 3 │ │ Limit: 2 │ │ Limit: 3 │ │ Now: 2 ✓ │ │ Now: 2 (full) │ │ Now: 1 ✓ │ └─────────────────┘ └─────────────────┘ └─────────────────┘ ``` ### 3.4 Dedup **Fast check** (in memory): ```go key := memberID + ":" + triggerType + ":" + window if has(key) { skip } ``` **Smart check** (for goals/tasks): - Dedup Agent looks at history - Returns: `skip` | `merge` | `proceed` ### 3.5 Cache Keeps agents in memory. No DB query on each tick: ```go type AgentCache struct { agents map[string]*Agent // member_id -> agent byTeam map[string][]string // team_id -> member_ids } // Refresh: on start, on change, every hour ``` --- ## 4. Phases ### 4.1 Overview ``` Clock: P0 → P1 → P2 → P3 → P4 → P5 Human/Event: P1 → P2 → P3 → P4 → P5 ``` | Phase | Agent | In | Out | When | | ----- | ----------- | ------------------- | --------------- | ---------- | | P0 | Inspiration | Clock + Data + News | Report | Clock only | | P1 | Goal Gen | Report + history | Goals | Always | | P2 | Task Plan | Goals + tools | Tasks | Always | | P3 | Run + Valid | Tasks + Experts | TaskResults | Always | | P4 | Delivery | All results | Email/Webhook/Process | Always | | P5 | Learning | Summary | KB entries | Always | ### 4.2 P0: Inspiration (Clock only) **Skipped for Human/Event triggers.** They already have clear intent. Gathers info to help make good goals. **Clock context is key input** - Agent knows what time it is and can decide what to do (e.g., 5pm Friday → write weekly report). ```go type InspirationReport struct { Clock *ClockContext `json:"clock"` // time context Content string `json:"content"` // markdown text for LLM } // Content is markdown like: // ## Summary // ... // ## Highlights // - [High] Sales up 50% // ## Opportunities / Risks / World News / Pending // ... type ClockContext struct { Now time.Time // Current time Hour int // 0-23 DayOfWeek string // Monday, Tuesday... DayOfMonth int // 1-31 IsWeekend bool IsMonthStart bool // 1st-3rd IsMonthEnd bool // last 3 days IsQuarterEnd bool // Agent uses this to decide: "It's 5pm Friday, time for weekly report" } ``` **Sources:** - **Clock**: Current time, day of week, month end, etc. - Internal: Data changes, events, feedback, pending work - External: Web search (news, competitors) ### 4.3 P1: Goals **For Clock:** Uses inspiration report (with clock context) to make goals. Agent decides based on time what's important now. **For Human/Event:** Uses the input directly as goals (or to generate goals). ```go type Goals struct { Content string // markdown text (for LLM) Delivery *DeliveryTarget // where to send results (for P4) } type DeliveryTarget struct { Type DeliveryType // Preferred delivery type (P4 will use Delivery Center) Recipients []string // email addresses, webhook URLs, user IDs Format string // markdown | html | json | text Template string // template name Options map[string]interface{} } ``` **Example prompt:** ``` You are [Sales Manager]. Your job: [track KPIs, make reports]. ## Report ### Key Items - [High] Data: 15 new sales (+50%) - [High] Deadline: Friday report due - [High] News: Competitor launched product ### Chances - Sales up 20% vs last week - Market growing Make today's goals. ``` **Note:** Validation criteria (`ExpectedOutput`, `ValidationRules`) are defined at the **Task level** (P2), not Goals level. This allows each task to have specific validation rules for P3. ### 4.4 P2: Tasks P2 Agent reads Goals markdown and breaks into executable tasks: ```go type Task struct { ID string // unique task ID Description string // human-readable task description (for UI display) Messages []context.Message // original input (text, images, files, audio) GoalRef string // reference to goal (e.g., "Goal 1") Source TaskSource // auto | human | event ExecutorType ExecutorType // assistant | mcp | process ExecutorID string // agent ID or mcp tool name Args []any // arguments for executor Order int // execution order // Validation criteria (used in P3) ExpectedOutput string // what the task should produce ValidationRules []string // specific checks to perform } ``` ### 4.5 P3: Run **Architecture:** P3 uses a modular design with three components: ``` ┌─────────────────────────────────────────────────────────────┐ │ run.go (P3 Entry) │ │ - RunConfig: ContinueOnFailure, ValidationThreshold, │ │ MaxTurnsPerTask │ │ - RunExecution: main loop with task dependency passing │ └─────────────────────┬───────────────────────────────────────┘ │ ┌────────────┴────────────┐ ▼ ▼ ┌─────────────────┐ ┌─────────────────┐ │ runner.go │ │ validator.go │ │ - Runner │ │ - Validator │ │ - Multi-turn │ │ - Two-layer │ │ conversation │ │ - Rule+Semantic│ │ - Task context │ │ - NeedReply │ │ building │ │ - ReplyContent │ └────────┬────────┘ └────────┬────────┘ │ │ │ ▼ │ ┌─────────────────┐ │ │ yao/assert │ │ │ - Asserter │ │ │ - 8 types │ │ │ - Extensible │ │ └─────────────────┘ ▼ ┌─────────────────────────────────────────┐ │ Executor Types │ │ - ExecutorAssistant → Multi-turn AI │ │ - ExecutorMCP → Single-call MCP tool │ │ - ExecutorProcess → Single-call Process │ └─────────────────────────────────────────┘ ``` **Execution Flow:** For each task: 1. **Build Context**: Include previous task results as context 2. **Execute**: Call appropriate executor (Assistant/MCP/Process) 3. **Validate**: Use two-layer validation (rule-based + semantic) 4. **Continue or Complete**: - For Assistant tasks: If `NeedReply`, continue conversation with `ReplyContent` - For MCP/Process tasks: Single-call execution, no multi-turn 5. **Update**: Set task status and store result **Task Dependency**: Previous task results are automatically passed as context to subsequent tasks via `Runner.BuildTaskContext()` and formatted using `FormatPreviousResultsAsContext()`. **Two-Layer Validation:** | Layer | Method | Speed | Use Case | |-------|--------|-------|----------| | 1. Rule-based | `yao/assert` | Fast | Type check, contains, regex, json_path | | 2. Semantic | Validation Agent | Slow | ExpectedOutput, complex criteria | **Executor Types:** | Type | ExecutorID Format | Example | |------|-------------------|---------| | `assistant` | Agent ID | `experts.text-writer` | | `mcp` | `mcp_server.mcp_tool` | `ark.image.text2img.generate` | | `process` | Process name | `models.user.Find` | **MCP Task Fields:** For MCP tasks, three fields are required: - `executor_id`: Combined format `mcp_server.mcp_tool` - `mcp_server`: MCP server/client ID (e.g., `ark.image.text2img`) - `mcp_tool`: Tool name within the server (e.g., `generate`) **Multi-Turn Conversation Flow:** For assistant tasks, P3 uses a multi-turn conversation approach: 1. **Call**: Call assistant and get result 2. **Validate**: Validate result (determines: passed, complete, needReply, replyContent) 3. **Reply**: If needReply, continue conversation with replyContent 4. **Repeat**: Until complete or max turns exceeded The `Validator.ValidateWithContext()` method determines: - `Complete`: Whether the expected result is obtained - `NeedReply`: Whether to continue conversation - `ReplyContent`: What to send in the next turn (validation feedback, clarification request, etc.) This replaces the traditional retry mechanism with intelligent conversation continuation. ```go // RunConfig configures P3 execution behavior type RunConfig struct { ContinueOnFailure bool // continue to next task even if current fails (default: false) ValidationThreshold float64 // minimum score to pass validation (default: 0.6) MaxTurnsPerTask int // max conversation turns per task (default: 10) } // ValidationResult with multi-turn conversation support type ValidationResult struct { // Basic validation result Passed bool // overall validation passed Score float64 // 0-1 confidence score Issues []string // what failed Suggestions []string // how to improve Details string // detailed report (markdown) // Execution state (for multi-turn conversation control) Complete bool // whether expected result is obtained NeedReply bool // whether to continue conversation ReplyContent string // content for next turn (if NeedReply) } ``` **yao/assert Package:** Universal assertion library supporting 8 types: | Type | Description | Example | |------|-------------|---------| | `equals` | Exact match | `{"type": "equals", "value": "success"}` | | `contains` | Substring check | `{"type": "contains", "value": "total"}` | | `not_contains` | Negative check | `{"type": "not_contains", "value": "error"}` | | `json_path` | JSON path extraction | `{"type": "json_path", "path": "data.count", "value": 10}` | | `regex` | Pattern matching | `{"type": "regex", "value": "^[A-Z].*"}` | | `type` | Type checking | `{"type": "type", "value": "array"}` | | `script` | Custom script | `{"type": "script", "script": "scripts.validate"}` | | `agent` | AI validation | `{"type": "agent", "use": "validator"}` | ### 4.6 P4: Deliver P4 generates delivery content and pushes to Delivery Center. **Agent only generates content, Delivery Center decides channels.** **Architecture:** ``` ┌─────────────────────────────────────────────────────────────┐ │ P4 Delivery Agent │ │ Role: Generate content only (Summary, Body, Attachments) │ │ NOT responsible for: Channel selection │ └─────────────────────┬───────────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────┐ │ DeliveryRequest │ │ - Content: Summary, Body, Attachments │ │ - Context: member_id, execution_id, trigger, team │ │ (No Channels - Delivery Center decides) │ └─────────────────────┬───────────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────┐ │ Delivery Center │ │ Role: │ │ 1. Read Robot/User delivery preferences │ │ 2. Decide which channels to use │ │ 3. Execute delivery (email, webhook, process) │ │ 4. Future: auto-notify based on user subscriptions │ │ │ │ (Current: internal, future: yao/delivery) │ └─────────────────────────────────────────────────────────────┘ ``` **Key Design:** - **Separation of concerns**: Agent generates content, Delivery Center handles channels - **User preferences**: Channels decided by Robot/User configuration, not Agent - **Automatic delivery**: If webhook configured, every execution pushes automatically - **Future-ready**: Delivery Center can be extracted to `yao/delivery` package **Delivery Request Structure:** ```go // DeliveryRequest - pushed to Delivery Center // No Channels field - Delivery Center decides based on preferences type DeliveryRequest struct { Content *DeliveryContent `json:"content"` // Agent-generated content Context *DeliveryContext `json:"context"` // Tracking info } // DeliveryContent - content generated by Delivery Agent type DeliveryContent struct { Summary string `json:"summary"` // Brief 1-2 sentence summary Body string `json:"body"` // Full markdown report Attachments []DeliveryAttachment `json:"attachments,omitempty"` // Output artifacts } // DeliveryAttachment - task output attachment with metadata type DeliveryAttachment struct { Title string `json:"title"` // Human-readable title Description string `json:"description,omitempty"` // What this artifact is TaskID string `json:"task_id,omitempty"` // Which task produced this File string `json:"file"` // Wrapper: __:// } // DeliveryContext - tracking and audit info type DeliveryContext struct { MemberID string `json:"member_id"` // Robot member ID (globally unique) ExecutionID string `json:"execution_id"` TriggerType TriggerType `json:"trigger_type"` TeamID string `json:"team_id"` } ``` **File Wrapper Format:** Attachments use the standard `yao/attachment` wrapper format: - Format: `__://` - Example: `__yao.attachment://ccd472d11feb96e03a3fc468f494045c` - Parse: `attachment.Parse(value)` → `(uploader, fileID, isWrapper)` - Read: `attachment.Base64(ctx, value)` → base64 content **Delivery Channels (Delivery Center decides):** | Channel | Description | Multiple Targets | |---------|-------------|------------------| | `email` | Send via yao/messenger | ✅ Multiple recipients/emails | | `webhook` | POST to external URL | ✅ Multiple URLs | | `process` | Yao Process call | ✅ Multiple processes | | `notify` | In-app notification | Future (auto by subscriptions) | **Delivery Agent:** The Delivery Agent **only generates content**, does NOT decide channels: ```go // Delivery Agent Input type DeliveryAgentInput struct { Robot *Robot `json:"robot"` TriggerType TriggerType `json:"trigger"` Inspiration *InspirationReport `json:"inspiration"` // P0 Goals *Goals `json:"goals"` // P1 Tasks []Task `json:"tasks"` // P2 Results []TaskResult `json:"results"` // P3 } // Delivery Agent Output - only content, no channels type DeliveryAgentOutput struct { Content *DeliveryContent `json:"content"` } ``` **Example Agent Output:** ```json { "content": { "summary": "Sales report completed: 15 new leads processed", "body": "## Weekly Sales Report\n\n### Summary\n...", "attachments": [ {"title": "Sales Report.pdf", "file": "__yao.attachment://abc123"}, {"title": "Lead Analysis.xlsx", "file": "__yao.attachment://def456"} ] } } ``` **Delivery Result:** ```go // DeliveryResult - returned by Delivery Center type DeliveryResult struct { RequestID string `json:"request_id"` // Delivery request ID Content *DeliveryContent `json:"content"` // Agent-generated content Results []ChannelResult `json:"results,omitempty"` // Results per channel Success bool `json:"success"` // Overall success Error string `json:"error,omitempty"` // Error if failed SentAt *time.Time `json:"sent_at,omitempty"` // When delivery completed } // ChannelResult - result for a single delivery target type ChannelResult struct { Type DeliveryType `json:"type"` // email | webhook | process Target string `json:"target"` // Target identifier (email, URL, process name) Success bool `json:"success"` // Whether delivery succeeded Recipients []string `json:"recipients,omitempty"` // Who received (for email) Details interface{} `json:"details,omitempty"` // Channel-specific response Error string `json:"error,omitempty"` // Error message if failed SentAt *time.Time `json:"sent_at,omitempty"` // When this target was delivered } ``` **Config (Delivery Preferences):** Robot config defines delivery **preferences** (Delivery Center reads and executes). Each channel supports **multiple targets**: ```yaml delivery: preferences: email: enabled: true targets: # Multiple email targets - to: ["manager@company.com"] cc: ["team@company.com"] - to: ["ceo@company.com"] subject_template: "Executive Summary" webhook: enabled: true targets: # Multiple webhook URLs - url: "https://slack.com/webhook/sales" - url: "https://feishu.cn/webhook/reports" headers: {"X-Custom": "value"} process: enabled: true targets: # Multiple Yao Process calls - name: "orders.UpdateStatus" args: ["completed"] - name: "audit.LogDelivery" # Note: notify handled by Delivery Center based on user subscriptions (future) ``` **Use Cases:** | Scenario | Channels | Description | |----------|----------|-------------| | Event callback | `process` | DB change → Robot → Update data via Process | | Multi-channel notify | `email` + `webhook` | Send to multiple emails and Slack/飞书 | | Data pipeline | `process` | Robot result → Save to DB → Update dashboard | ### 4.7 P5: Learn Save to KB: | Type | Examples | | ----------- | ------------------------ | | `execution` | What worked, what failed | | `feedback` | Errors, fixes | | `insight` | Patterns, tips | --- ## 5. Config ### 5.1 Structure ```go type Config struct { Triggers *Triggers `json:"triggers,omitempty"` Clock *Clock `json:"clock,omitempty"` Identity *Identity `json:"identity"` Quota *Quota `json:"quota"` KB *KB `json:"kb,omitempty"` // shared KB (same as assistant) DB *DB `json:"db,omitempty"` // shared DB (same as assistant) Learn *Learn `json:"learn,omitempty"` // learning for private KB Resources *Resources `json:"resources"` Delivery *DeliveryPreferences `json:"delivery,omitempty"` Events []Event `json:"events,omitempty"` Executor *Executor `json:"executor,omitempty"` // executor mode settings DefaultLocale string `json:"default_locale,omitempty"` // default language for clock/event triggers ("en-US", "zh-CN") } ``` ### 5.2 Types ```go // Phase - execution phase enum type Phase string const ( PhaseInspiration Phase = "inspiration" // P0: Clock only PhaseGoals Phase = "goals" // P1 PhaseTasks Phase = "tasks" // P2 PhaseRun Phase = "run" // P3 (execution + validation) PhaseDelivery Phase = "delivery" // P4 PhaseLearning Phase = "learning" // P5 ) // AllPhases for iteration var AllPhases = []Phase{ PhaseInspiration, PhaseGoals, PhaseTasks, PhaseRun, PhaseDelivery, PhaseLearning, } // ClockMode - clock trigger mode enum type ClockMode string const ( ClockModeTimes ClockMode = "times" // run at specific times ClockModeInterval ClockMode = "interval" // run every X duration ClockModeDaemon ClockMode = "daemon" // run continuously ) // DeliveryType - output delivery type enum type DeliveryType string const ( DeliveryEmail DeliveryType = "email" // Email via yao/messenger DeliveryWebhook DeliveryType = "webhook" // POST to external URL DeliveryProcess DeliveryType = "process" // Yao Process call DeliveryNotify DeliveryType = "notify" // In-app notification (future) ) // ExecStatus - execution status enum type ExecStatus string const ( ExecPending ExecStatus = "pending" ExecRunning ExecStatus = "running" ExecCompleted ExecStatus = "completed" ExecFailed ExecStatus = "failed" ExecCancelled ExecStatus = "cancelled" ) // RobotStatus - matches __yao.member.robot_status enum type RobotStatus string const ( RobotIdle RobotStatus = "idle" // ready to run RobotWorking RobotStatus = "working" // currently executing RobotPaused RobotStatus = "paused" // manually paused RobotError RobotStatus = "error" // encountered error RobotMaintenance RobotStatus = "maintenance" // under maintenance ) // Triggers - all on by default type Triggers struct { Clock *Trigger `json:"clock,omitempty"` Intervene *Trigger `json:"intervene,omitempty"` Event *Trigger `json:"event,omitempty"` } type Trigger struct { Enabled bool `json:"enabled"` Actions []string `json:"actions,omitempty"` // for intervene } // Clock - when to wake up type Clock struct { Mode ClockMode `json:"mode"` Times []string `json:"times"` // for times: ["09:00", "14:00"] Days []string `json:"days"` // ["Mon", "Tue"...] or ["*"] Every string `json:"every"` // for interval: "30m", "1h" TZ string `json:"tz"` // Asia/Shanghai Timeout string `json:"timeout"` // max run time } // Identity type Identity struct { Role string `json:"role"` Duties []string `json:"duties"` Rules []string `json:"rules"` } // Quota type Quota struct { Max int `json:"max"` // max running (default: 2) Queue int `json:"queue"` // queue size (default: 10) Priority int `json:"priority"` // 1-10 (default: 5) } // KB // KB - shared knowledge base (same as assistant) type KB struct { Collections []string `json:"collections,omitempty"` // KB collection IDs Options map[string]interface{} `json:"options,omitempty"` } // DB - shared database (same as assistant) type DB struct { Models []string `json:"models,omitempty"` // database model names Options map[string]interface{} `json:"options,omitempty"` } // Learn - learning config for robot's private KB // Private KB auto-created: robot_{team_id}_{member_id}_kb type Learn struct { On bool `json:"on"` Types []string `json:"types"` // execution, feedback, insight Keep int `json:"keep"` // days, 0 = forever } // Resources type Resources struct { Phases map[Phase]string `json:"phases,omitempty"` // optional, defaults to __yao.{phase} Agents []string `json:"agents"` MCP []MCP `json:"mcp"` } type MCP struct { ID string `json:"id"` Tools []string `json:"tools,omitempty"` // empty = all } // DeliveryPreferences - Robot delivery preferences (read by Delivery Center) // Each channel supports multiple targets type DeliveryPreferences struct { Email *EmailPreference `json:"email,omitempty"` Webhook *WebhookPreference `json:"webhook,omitempty"` Process *ProcessPreference `json:"process,omitempty"` // notify is handled automatically based on user subscriptions } type EmailPreference struct { Enabled bool `json:"enabled"` Targets []EmailTarget `json:"targets"` } type EmailTarget struct { To []string `json:"to"` // Recipient addresses Template string `json:"template,omitempty"` // Email template ID Subject string `json:"subject,omitempty"` // Subject template } type WebhookPreference struct { Enabled bool `json:"enabled"` Targets []WebhookTarget `json:"targets"` } type WebhookTarget struct { URL string `json:"url"` // Webhook URL Method string `json:"method,omitempty"` // HTTP method (default: POST) Headers map[string]string `json:"headers,omitempty"` // Custom headers Secret string `json:"secret,omitempty"` // Signing secret } type ProcessPreference struct { Enabled bool `json:"enabled"` Targets []ProcessTarget `json:"targets"` } type ProcessTarget struct { Process string `json:"process"` // Yao Process name, e.g., "orders.UpdateStatus" Args []any `json:"args,omitempty"` // Additional arguments } // ExecutorMode - executor mode enum type ExecutorMode string const ( ExecutorStandard ExecutorMode = "standard" // real Agent calls (default) ExecutorDryRun ExecutorMode = "dryrun" // simulated, no LLM calls ExecutorSandbox ExecutorMode = "sandbox" // container-isolated (NOT IMPLEMENTED) ) // Executor - executor settings type Executor struct { Mode ExecutorMode `json:"mode,omitempty"` // standard | dryrun | sandbox MaxDuration string `json:"max_duration,omitempty"` // max execution time (e.g., "30m") } // Note: Sandbox mode requires container infrastructure (Docker/gVisor). // Current implementation falls back to DryRun behavior. // Monitor ``` ### 5.3 Example Example record in `__yao.member` table: ```json { "member_id": "mem_abc123", "team_id": "team_xyz", "member_type": "robot", "display_name": "Sales Bot", "autonomous_mode": true, "robot_status": "idle", "system_prompt": "You are a sales analyst...", "robot_config": { "triggers": { "clock": { "enabled": true }, "intervene": { "enabled": true }, "event": { "enabled": false } }, "clock": { "mode": "times", "times": ["09:00", "14:00", "17:00"], "days": ["Mon", "Tue", "Wed", "Thu", "Fri"], "tz": "Asia/Shanghai", "timeout": "30m" }, "identity": { "role": "Sales Analyst", "duties": ["Analyze sales", "Make weekly reports"], "rules": ["Only access sales data"] }, "quota": { "max": 2, "queue": 10, "priority": 5 }, "kb": { "collections": ["sales-policies", "products"] }, "db": { "models": ["sales", "customers"] }, "learn": { "on": true, "types": ["execution", "feedback", "insight"], "keep": 90 }, "resources": { "phases": { "inspiration": "__yao.inspiration", "goals": "__yao.goals", "tasks": "__yao.tasks", "validation": "__yao.validation", "delivery": "__yao.delivery", "learning": "__yao.learning" }, "agents": ["data-analyst", "chart-gen"], "mcp": [{ "id": "database", "tools": ["query"] }] }, "delivery": { "type": "email", "opts": { "to": ["manager@company.com"] } }, "executor": { "mode": "standard", "max_duration": "30m" } }, "agents": ["data-analyst", "chart-gen"], "mcp_servers": ["database"] } ``` --- ## 6. Lifecycle ### 6.1 Agent States ```mermaid stateDiagram-v2 [*] --> Idle: POST create Idle --> Working: trigger Working --> Idle: done Idle --> Paused: PATCH pause Working --> Paused: PATCH pause Paused --> Idle: PATCH resume Idle --> Error: error Working --> Error: error Error --> Idle: PATCH reset Idle --> [*]: DELETE Paused --> [*]: DELETE ``` | From | To | How | | ------- | ------- | --------------------------- | | - | idle | POST create | | idle | working | trigger (clock/human/event) | | working | idle | execution done | | idle | paused | PATCH robot_status="paused" | | paused | idle | PATCH robot_status="idle" | | any | error | execution error | | error | idle | PATCH robot_status="idle" | | any | deleted | DELETE | ### 6.2 On Create 1. Check config 2. Generate member_id if missing 3. Create KB: `robot_{team_id}_{member_id}_kb` 4. Add to cache 5. Set active ### 6.3 On Delete 1. Stop running executions 2. Remove from cache 3. Delete or archive KB 5. Soft delete record ### 6.4 Execution Flow Single execution flow, depends on trigger type: ```mermaid flowchart LR subgraph Trigger T{Trigger} end subgraph Schedule Path P0[P0: Inspiration] end subgraph Common Path P1[P1: Goals] P2[P2: Tasks] P3[P3: Run] P4[P4: Deliver] P5[P5: Learn] end T -->|Clock| P0 T -->|Human/Event| P1 P0 --> P1 P1 --> P2 --> P3 --> P4 --> P5 ``` ```mermaid stateDiagram-v2 [*] --> Triggered Triggered --> P0_Inspiration: Clock Triggered --> P1_Goals: Human/Event P0_Inspiration --> P1_Goals P1_Goals --> P2_Tasks P2_Tasks --> P3_Run P3_Run --> P4_Deliver P4_Deliver --> P5_Learn P5_Learn --> [*] ``` --- ## 7. Integrations ### 7.1 Execution Storage **Relationship:** 1 Robot : N Executions (concurrent) Each trigger creates a new Execution, stored in `ExecutionStore` (`__yao.agent_execution` table). Execution data includes: - Status and phase tracking - All phase outputs (Inspiration, Goals, Tasks, Results, Delivery, Learning) - Error information - Timestamps and progress Logging is handled by `kun/log` package for standard application logging. | List Execs | `job.ListExecutions(param, page, pagesize)` | | Get Exec | `job.GetExecution(execID, param)` | | Save Exec | `job.SaveExecution(exec)` | | List Logs | `job.ListLogs(param, page, pagesize)` | | Save Log | `job.SaveLog(log)` | | Push (start) | `j.Push()` | | Stop | `j.Stop()` | | Destroy | `j.Destroy()` | | Active Jobs | `job.GetActiveJobs()` | | Query by Cat | `job.ListJobs({Wheres: [{Column: "category_id", ...}]})` | ### 7.2 Private KB Made on robot member create: `robot_{team_id}_{member_id}_kb` **What it stores:** - `execution`: What worked, what failed - `feedback`: Errors, fixes - `insight`: Patterns, tips **When:** - Create: On robot member create - Update: After P5 - Clean: Based on `keep` days - Delete: On robot member delete ### 7.3 External Input **Types:** - `clock`: Timer (with time context) - `intervene`: Human action - `event`: Webhook, DB change - `callback`: Async result **Human actions (InterventionAction):** - `task.add`: Add a new task - `task.cancel`: Cancel a task - `task.update`: Update task details - `goal.adjust`: Modify current goal - `goal.add`: Add a new goal - `goal.complete`: Mark goal as complete - `goal.cancel`: Cancel a goal - `plan.add`: Schedule for later - `plan.remove`: Remove from plan queue - `plan.update`: Update planned item - `instruct`: Direct instruction to robot **Plan Queue:** - Holds tasks for later - Runs at next cycle start --- ## 8. API ### 8.1 Manager (Internal) > **Note:** Manager is the central orchestrator, handling all trigger types. ```go type Manager interface { // Lifecycle Start() error Stop() error // Clock trigger (internal, called by ticker) Tick(ctx *Context, now time.Time) error // Manual trigger (for testing/API) TriggerManual(ctx *Context, memberID string, trigger TriggerType, data interface{}) (string, error) // Human intervention (called by API) Intervene(ctx *Context, req *InterveneRequest) (*ExecutionResult, error) // Event trigger (called by webhook/db trigger) HandleEvent(ctx *Context, req *EventRequest) (*ExecutionResult, error) // Execution control PauseExecution(ctx *Context, execID string) error ResumeExecution(ctx *Context, execID string) error StopExecution(ctx *Context, execID string) error // Cache access Cache() Cache } ``` ### 8.2 Trigger (Integrated into Manager) > **Note:** Trigger logic is integrated into Manager, not a separate interface. > The `trigger/` package provides utilities (validation, clock matching, execution control). ```go // TriggerType enum type TriggerType string const ( TriggerClock TriggerType = "clock" TriggerHuman TriggerType = "human" TriggerEvent TriggerType = "event" ) // Manager handles all trigger types: // - Clock: Manager.Tick() called by internal ticker // - Human: Manager.Intervene() called by API // - Event: Manager.HandleEvent() called by webhook/db trigger // trigger/ package provides utilities: // - trigger.ValidateIntervention(req) - validate human intervention request // - trigger.ValidateEvent(req) - validate event request // - trigger.BuildEventInput(req) - build TriggerInput from event // - trigger.ClockMatcher - reusable clock matching logic // - trigger.ExecutionController - pause/resume/stop execution type InterveneRequest struct { TeamID string MemberID string Action InterventionAction // task.add | goal.adjust | task.cancel | plan.add | instruct Messages []context.Message // user input (text, images, files) PlanTime *time.Time // for action=plan.add ExecutorMode ExecutorMode // optional: standard | dryrun (override robot config) } type EventRequest struct { MemberID string Source string // webhook path or table name EventType string // lead.created, etc. Data map[string]interface{} ExecutorMode ExecutorMode // optional: standard | dryrun (override robot config) } type ExecutionResult struct { ExecutionID string // Job execution ID Status ExecStatus // pending | running | completed | failed Message string // status message } type RobotState struct { MemberID string // member_id from __yao.member Status RobotStatus // idle | working | paused | error | maintenance LastRun time.Time NextRun time.Time Running int // current running execution count MaxRunning int // max concurrent executions (from Quota.Max) RunningIDs []string // list of running execution IDs } ``` ### 8.3 Execution (Uses ExecutionStore) Uses dedicated `__yao.agent_execution` table via ExecutionStore. **Each trigger creates a new Execution:** ```go // On each trigger (clock/human/event), create a new Execution exec := &types.Execution{ ID: utils.NewID(), MemberID: memberID, TeamID: teamID, TriggerType: triggerType, Status: types.ExecStatusRunning, Phase: types.PhaseP0Init, StartedAt: time.Now(), } // Save to ExecutionStore execStore.Save(exec) ``` **Query executions for a robot:** ```go // List all executions for a robot member executions, err := execStore.List(memberID, 1, 10) ``` **Query examples:** ```go // Get execution by ID exec, err := execStore.Get(executionID) // List executions for a robot executions, err := execStore.List(memberID, page, pageSize) // Update execution status execStore.UpdateStatus(executionID, types.ExecStatusCompleted) // Logging via kun/log log.With(log.F{"execution_id": exec.ID, "phase": "P1"}).Info("Phase started") ``` --- ## 9. Security 1. **Team only**: Agent sees only its team's data 2. **Role rules**: Uses role_id permissions 3. **Limited tools**: Only what's in `resources` 4. **Timeout**: Stops if runs too long 5. **Logs**: All runs saved --- ## 10. Quick Ref ### Triggers ```yaml triggers: clock: { enabled: true } intervene: { enabled: true, actions: [...] } event: { enabled: false } ``` ### Clock ```yaml # Mode 1: Specific times clock: mode: times times: ["09:00", "14:00", "17:00"] days: ["Mon", "Tue", "Wed", "Thu", "Fri"] tz: Asia/Shanghai timeout: 30m # Mode 2: Interval clock: mode: interval every: 30m # run every 30 minutes timeout: 10m # Mode 3: Daemon (continuous thinking/analysis) clock: mode: daemon # restart immediately after each run timeout: 10m # max time per run # Use case: Research analyst, market monitor ``` ### Phase Agents ```yaml # Optional - defaults to __yao.{phase} if not specified resources: phases: inspiration: "__yao.inspiration" # Clock only goals: "__yao.goals" tasks: "__yao.tasks" validation: "__yao.validation" delivery: "__yao.delivery" learning: "__yao.learning" ``` ### Quota ```yaml quota: max: 2 # max running queue: 10 # queue size priority: 5 # 1-10 ``` ### Executor ```yaml # Standard mode (default) - real Agent calls executor: mode: standard max_duration: 30m # DryRun mode - simulated execution (for testing/demos) executor: mode: dryrun # Sandbox mode (NOT IMPLEMENTED) - container-isolated # Requires Docker/gVisor infrastructure # executor: # mode: sandbox # max_duration: 10m ``` **API Override:** ```javascript // Override executor mode per trigger const result = Process("robot.Trigger", "mem_abc123", { type: "human", action: "task.add", messages: [{ role: "user", content: "Test task" }], executor_mode: "dryrun", // override robot config }); ``` --- ## 11. Examples Each example shows a different trigger mode: | Example | Trigger | Mode | Scenario | | ------- | ------- | --------- | -------------------------------------------- | | 11.1 | Clock | times | SEO/GEO Content - daily content optimization | | 11.2 | Clock | interval | Competitor Monitor - check every 2 hours | | 11.3 | Clock | daemon | Research Analyst - continuous insight mining | | 11.4 | Human | intervene | Sales Assistant - manager assigns tasks | | 11.5 | Event | webhook | Lead Processor - qualify and route new leads | --- ### 11.1 SEO/GEO Content Agent (Clock: times) **Trigger:** Clock - specific times daily **Role:** AI Marketing - auto-generate and optimize SEO/GEO content. ```json // robot_config for SEO Content Agent { "triggers": { "clock": { "enabled": true }, "intervene": { "enabled": true } }, "clock": { "mode": "times", "times": ["06:00", "18:00"], "days": ["Mon", "Tue", "Wed", "Thu", "Fri"], "tz": "Asia/Shanghai" }, "identity": { "role": "SEO/GEO Content Specialist", "duties": [ "Research trending keywords in our industry", "Generate SEO-optimized articles (2-3 per day)", "Optimize existing content for GEO (AI search)", "Track keyword rankings and adjust strategy", "A/B test titles and meta descriptions" ] }, "resources": { "agents": ["keyword-researcher", "content-writer", "seo-optimizer"], "mcp": [ { "id": "google-search", "tools": ["trends", "rankings"] }, { "id": "cms", "tools": ["create", "update", "publish"] } ] }, "delivery": { "type": "notify", "opts": { "channel": "marketing-team" } } } ``` **Example run at 06:00 Monday:** ``` P0 Inspiration: Clock: Monday 06:00, start of week Data: - Keyword "AI app development" trending (+45% this week) - Our article ranks #8, competitor #2 - 3 articles need GEO optimization World: New AI regulation announced last Friday P1 Goals: 1. Write new article targeting "AI app development" 2. Optimize 3 old articles for GEO 3. Update meta descriptions for top 5 pages P2 Tasks: 1. Research "AI app development" keywords → keyword-researcher 2. Write article with SEO structure → content-writer 3. Add FAQ schema for GEO → seo-optimizer 4. Publish to CMS → cms.publish P3 Execute: - Keywords: "AI app development", "build AI apps", "AI dev guide" (12 total) - Article: 2500 words, 8 sections, FAQ schema added - Published to CMS, indexed by Google P4 Delivery: → Notify: "Published: 'Complete Guide to AI App Development' - targeting 12 keywords" P5 Learn: - "AI app development" articles perform well on Monday morning - FAQ schema improves GEO visibility by 30% ``` --- ### 11.2 Competitor Monitor (Clock: interval) **Trigger:** Clock - every 2 hours **Role:** Monitor competitors, track market changes, alert on important updates. ```json // robot_config for Competitor Monitor { "triggers": { "clock": { "enabled": true } }, "clock": { "mode": "interval", "every": "2h" }, "identity": { "role": "Competitor Intelligence Analyst", "duties": [ "Monitor competitor websites for changes", "Track competitor pricing updates", "Watch for new product launches", "Analyze competitor content strategy", "Alert team on significant changes" ] }, "resources": { "agents": ["web-scraper", "diff-analyzer", "report-writer"], "mcp": [{ "id": "web-search", "tools": ["search", "news"] }] }, "delivery": { "type": "webhook", "opts": { "url": "https://slack.com/webhook/competitor-alerts" } } } ``` **Example run detecting competitor change:** ``` P0 Inspiration: Clock: Tuesday 14:00 Data: - Competitor A: pricing page changed - Competitor B: new blog post about "AI agents" - Competitor C: no changes P1 Goals: 1. Analyze Competitor A pricing change 2. Summarize Competitor B's new content 3. Assess impact on our positioning P2 Tasks: 1. Scrape old vs new pricing → web-scraper 2. Compare pricing tiers → diff-analyzer 3. Generate competitive analysis → report-writer P3 Execute: - Competitor A: dropped price 20% on enterprise tier - Competitor B: targeting same keywords as us P4 Delivery: → Slack: "🚨 Competitor A cut enterprise price 20% - review needed" P5 Learn: - Competitor A tends to change pricing on Tuesdays - Price changes often precede feature launches ``` --- ### 11.3 Industry Research Analyst (Clock: daemon) **Trigger:** Clock - continuous daemon mode **Role:** Continuously read industry news, papers, social media; extract insights; build knowledge. ```json // robot_config for Research Analyst { "triggers": { "clock": { "enabled": true } }, "clock": { "mode": "daemon", "timeout": "10m" }, "identity": { "role": "Industry Research Analyst", "duties": [ "Continuously scan industry news and papers", "Analyze trends and extract key insights", "Identify emerging technologies and competitors", "Build and maintain industry knowledge base", "Alert team on significant developments" ] }, "resources": { "agents": ["content-reader", "insight-extractor", "report-writer"], "mcp": [ { "id": "web-search", "tools": ["search", "news"] }, { "id": "arxiv", "tools": ["search", "fetch"] }, { "id": "twitter", "tools": ["search", "trends"] } ] }, "delivery": { "type": "notify", "opts": { "channel": "research-insights" } } } ``` **Example continuous run:** ``` Run #1 (09:00): P0: Scan sources - 15 new AI news articles - 3 new papers on arXiv - Twitter: "AI Agent" trending P1: Goals: 1. Read and analyze new content 2. Extract insights relevant to our business 3. Update knowledge base P2: Tasks: 1. Read articles → content-reader 2. Analyze papers → content-reader 3. Extract insights → insight-extractor P3: Execute: - Article: "OpenAI releases new agent framework" Insight: Validates our direction, watch for API changes - Paper: "Multi-agent collaboration patterns" Insight: Useful for our agent design, save to KB - Twitter: Sentiment positive on AI agents P4: Notify: "📚 3 new insights added to KB" P5: Learn: OpenAI news = high relevance, prioritize → Restart immediately Run #2 (09:12): P0: Scan sources - 2 new articles (low relevance) - No new papers - Twitter: Normal activity P1: Low-value content, skip deep analysis P5: Learn: Mid-morning usually quiet → Restart immediately Run #3 (09:25): P0: Scan sources - Breaking: "Competitor X raises $100M for AI platform" P1: Goals: 1. Deep analyze competitor news 2. Assess impact on our market 3. Alert team immediately P2: Tasks: 1. Gather all competitor X info → web-search 2. Analyze their positioning → insight-extractor 3. Write competitive brief → report-writer P3: Execute: - Competitor X: Focus on enterprise, similar target market - Funding: Will likely expand sales team - Threat level: Medium-High P4: Notify: "🚨 Competitor X raised $100M - brief attached" P5: Learn: Funding news = always high priority → Restart immediately ``` --- ### 11.4 Sales Assistant (Human: intervene) **Trigger:** Human intervention - sales manager assigns tasks **Role:** Help sales team with research, proposals, follow-ups when manager assigns work. ```json // robot_config for Sales Assistant { "triggers": { "clock": { "enabled": false }, "intervene": { "enabled": true, "actions": ["task.add", "goal.adjust", "instruct"] } }, "identity": { "role": "Sales Assistant", "duties": [ "Research assigned prospects and companies", "Prepare customized proposals and presentations", "Draft follow-up emails", "Analyze deal history and suggest strategies", "Prepare meeting briefs" ] }, "resources": { "agents": ["company-researcher", "proposal-writer", "email-drafter"], "mcp": [ { "id": "crm", "tools": ["query", "update"] }, { "id": "linkedin", "tools": ["search", "profile"] }, { "id": "email", "tools": ["draft", "send"] } ] }, "delivery": { "type": "email", "opts": { "to": ["sales-manager@company.com"] } } } ``` **Example: Sales manager assigns task:** ``` Sales Manager Input: Action: task.add Messages: [{ role: "user", content: "Meeting with BigCorp CTO tomorrow. Prepare materials. They do smart manufacturing, $150M revenue, digital transformation." }] Agent Execution (no P0 for human trigger): P1 Goals (from human input): 1. Research BigCorp and their CTO 2. Prepare meeting brief 3. Draft customized proposal P2 Tasks: 1. Research BigCorp → company-researcher - Company background, recent news - Digital transformation status - Potential pain points 2. Research CTO profile → linkedin.profile - Background, interests - Recent posts/articles 3. Prepare meeting brief → proposal-writer 4. Draft proposal → proposal-writer P3 Execute: - BigCorp: Leading smart manufacturing, 3 factories, implementing MES - CTO John: Ex-Google, focused on AI+Manufacturing, recent post on "AI QC" - Pain point: High QC labor cost, 2% defect miss rate - Opportunity: Our AI QC solution can reduce miss rate to 0.1% P4 Delivery: → Email to sales manager: - Attachment 1: BigCorp Research Report (PDF) - Attachment 2: CTO Profile Brief - Attachment 3: Custom Proposal - AI QC Solution - Attachment 4: Meeting Agenda Suggestion Sales Manager Follow-up: Action: task.add Messages: [{ role: "user", content: "Also prepare some similar case studies, manufacturing preferred" }] Agent Continues: P1: Find similar manufacturing case studies P2: Search CRM for manufacturing wins P3: Found 3 cases: Auto parts factory, Electronics plant, Food processing P4: Email: "3 manufacturing case studies attached" P5: Learn: Manufacturing prospects often need QC case studies ``` --- ### 11.5 Lead Processor (Event: webhook) **Trigger:** Event - new lead from website/CRM **Role:** Instantly process and qualify new leads, route to sales. ```json // robot_config for Lead Processor { "triggers": { "clock": { "enabled": false }, "event": { "enabled": true } }, "events": [ { "type": "webhook", "source": "/webhook/leads", "filter": { "event_types": ["lead.created"] } }, { "type": "database", "source": "crm_leads", "filter": { "trigger": "insert" } } ], "identity": { "role": "Lead Qualification Specialist", "duties": [ "Instantly process new leads", "Enrich lead data (company info, LinkedIn)", "Score lead quality (1-100)", "Route hot leads to sales immediately", "Add cold leads to nurture sequence" ] }, "resources": { "agents": ["data-enricher", "lead-scorer"], "mcp": [ { "id": "clearbit", "tools": ["enrich"] }, { "id": "crm", "tools": ["update", "assign"] }, { "id": "email", "tools": ["send"] } ] }, "delivery": { "type": "webhook", "opts": { "url": "https://slack.com/webhook/sales-leads" } } } ``` **Example: New lead event:** ``` Event Received: Type: lead.created Data: { name: "John Smith", email: "john@bigcorp.com", company: "BigCorp", message: "Interested in Enterprise pricing, team of 50" } Agent Execution (no P0 for events): P1 Goals: 1. Enrich lead data 2. Score lead quality 3. Route appropriately P2 Tasks: 1. Lookup company info → clearbit.enrich 2. Calculate lead score → lead-scorer 3. Update CRM → crm.update 4. Notify sales → slack webhook P3 Execute: - Company: BigCorp, 500 employees, Series C - LinkedIn: VP of Engineering - Lead Score: 85/100 (HOT) - Reason: Enterprise inquiry, decision maker, funded company P4 Delivery: → Slack: "🔥 HOT LEAD (85/100): John Smith @ BigCorp - 500 employees, Series C - Interested in Enterprise (50 seats) - Assigned to: Sales Rep A" → CRM: Lead updated, assigned to Sales Rep A → Email to lead: "Thanks for your inquiry. Our sales rep will contact you within 1 hour." P5 Learn: - BigCorp profile saved for future reference - VP-level leads from funded companies = high conversion ``` ================================================ FILE: agent/robot/TECHNICAL.md ================================================ # Robot Agent - Technical Design ## 1. Code Structure ``` yao/agent/robot/ ├── DESIGN.md # Product design doc ├── TECHNICAL.md # This file │ ├── robot.go # Package entry, Init(), Shutdown() │ ├── api/ # All API forms │ ├── api.go # Go API (facade) │ ├── process.go # Yao Process: robot.* │ └── jsapi.go # JS API: robot (global) + Robot (class) │ ├── types/ # Types only (no logic, no external deps) │ ├── enums.go # Phase, ClockMode, TriggerType, etc. │ ├── config.go # Config, Clock, Identity, Quota, etc. │ ├── robot.go # Robot, Execution │ ├── task.go # Goal, Task, TaskResult │ ├── request.go # InterveneRequest, EventRequest, etc. │ ├── inspiration.go # ClockContext, InspirationReport │ ├── interfaces.go # All interfaces (Manager, Trigger, etc.) │ └── errors.go # Error definitions │ ├── manager/ # Manager package (orchestration) │ └── manager.go # Manager struct, Start/Stop, Tick │ ├── pool/ # Worker pool & task dispatch │ ├── pool.go # Pool struct, Submit │ ├── queue.go # Priority queue │ └── worker.go # Worker goroutines │ ├── executor/ # Executor package (pluggable architecture) │ ├── executor.go # Factory functions, unified entry │ ├── types/ │ │ ├── types.go # Executor interface, Config types │ │ └── helpers.go # Shared helper functions │ ├── standard/ │ │ ├── executor.go # Real Agent execution (production) │ │ ├── agent.go # AgentCaller for LLM calls │ │ ├── input.go # InputFormatter for prompts │ │ ├── inspiration.go # P0: Inspiration phase │ │ ├── goals.go # P1: Goals phase │ │ ├── tasks.go # P2: Tasks phase │ │ ├── run.go # P3: Run phase (main entry) │ │ ├── runner.go # P3: Task Runner (execution logic) │ │ ├── validator.go # P3: Validator (two-layer validation) │ │ ├── delivery.go # P4: Delivery phase │ │ └── learning.go # P5: Learning phase │ ├── dryrun/ │ │ └── executor.go # Simulated execution (testing/demo) │ └── sandbox/ │ └── executor.go # Container-isolated (NOT IMPLEMENTED) │ ├── utils/ # Utility functions │ ├── convert.go # Type conversions (JSON, map, struct) │ ├── time.go # Time parsing, formatting, timezone │ ├── id.go # ID generation (nanoid, uuid) │ └── validate.go # Validation helpers │ ├── trigger/ # Trigger utilities (logic in manager/) │ ├── trigger.go # Validation helpers, action utilities │ ├── clock.go # ClockMatcher (reusable clock matching logic) │ └── control.go # ExecutionController (pause/resume/stop) │ ├── cache/ # Cache package │ ├── cache.go # Cache struct, Get/List │ ├── load.go # LoadAll, LoadOne │ └── refresh.go # Refresh logic │ ├── dedup/ # Deduplication package │ ├── dedup.go # Dedup struct │ ├── fast.go # Fast in-memory check │ └── semantic.go # Semantic check via agent │ ├── store/ # Data store package (KB, FS, DB access) │ ├── store.go # Store struct, interface │ ├── kb.go # Knowledge base operations │ ├── fs.go # File system operations │ ├── db.go # Database queries │ └── learning.go # Learning entry save (to KB) │ └── plan/ # Plan queue (deferred tasks) ├── plan.go # Plan queue struct └── schedule.go # Schedule for later yao/assert/ # Universal assertion library (global package) ├── types.go # Assertion, Result, interfaces ├── asserter.go # Asserter implementation (8 assertion types) └── helpers.go # Utility functions (ExtractPath, ToString, etc.) ``` ### Dependency Graph (No Cycles) > **Note:** `trigger/` is a utility package (validation, clock matching, execution control). > All trigger logic flows through `manager/`. ``` ┌──────────┐ │ types/ │ (pure types, no deps) └────┬─────┘ │ ┌───────┬───────┬───────┬──────┼──────┬───────┬───────┬───────┐ │ │ │ │ │ │ │ │ │ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ┌───────┐┌───────┐┌───────┐┌──────┐┌────┐┌──────┐┌───────┐┌─────────┐ │ cache ││ dedup ││ store ││ pool ││ plan ││ utils ││ trigger │ └───┬───┘└───┬───┘└───┬───┘└──┬───┘└──┬─┘└──────┘└───────┘└────┬────┘ │ │ │ │ │ │ └────────┴────────┴───────┴───────┴────────────────────────┘ │ ▼ ┌────────────┐ │ executor/ │ └──────┬─────┘ │ ▼ ┌────────────┐ │ manager/ │ (imports trigger/ for utilities) └──────┬─────┘ │ ┌──────────────┴──────────────┐ │ │ ▼ ▼ ┌─────────────┐ ┌─────────────┐ │ robot.go │ │ api/ │ └─────────────┘ └─────────────┘ ``` ### Package Dependencies | Package | Imports | | ----------- | ----------------------------------------------------------- | | `types/` | stdlib only | | `utils/` | stdlib only | | `cache/` | `types/` | | `dedup/` | `types/` | | `store/` | `types/` | | `pool/` | `types/` | | `trigger/` | `types/` | | `plan/` | `types/` | | `executor/` | `types/`, `cache/`, `dedup/`, `store/`, `pool/`, `yao/assert` | | `manager/` | `types/`, `cache/`, `pool/`, `trigger/`, `executor/` | | | Manager handles all trigger logic (clock, intervene, event) | | `api/` | `types/`, `manager/` | | root | all packages | ### Public API (`api/`) Three API forms, all in `api/` directory. #### Go API (`api/api.go`) ```go package api import ( "github.com/yaoapp/yao/agent/robot/types" ) // ==================== CRUD ==================== // Get returns a robot by member ID func Get(ctx *types.Context, memberID string) (*types.Robot, error) // List returns robots with pagination and filtering func List(ctx *types.Context, query *ListQuery) (*ListResult, error) // Create creates a new robot member func Create(ctx *types.Context, teamID string, req *CreateRequest) (*types.Robot, error) // Update updates robot config func Update(ctx *types.Context, memberID string, req *UpdateRequest) (*types.Robot, error) // Remove deletes a robot member func Remove(ctx *types.Context, memberID string) error // ==================== Status ==================== // Status returns current robot runtime state func Status(ctx *types.Context, memberID string) (*RobotState, error) // UpdateStatus updates robot status (idle, paused, etc.) func UpdateStatus(ctx *types.Context, memberID string, status types.RobotStatus) error // ==================== Trigger ==================== // Trigger starts execution with specified trigger type and request func Trigger(ctx *types.Context, memberID string, req *TriggerRequest) (*TriggerResult, error) // ==================== Execution ==================== // GetExecutions returns execution history func GetExecutions(ctx *types.Context, memberID string, query *ExecutionQuery) (*ExecutionResult, error) // GetExecution returns a specific execution by ID func GetExecution(ctx *types.Context, execID string) (*types.Execution, error) // Pause pauses a running execution func Pause(ctx *types.Context, execID string) error // Resume resumes a paused execution func Resume(ctx *types.Context, execID string) error // Stop stops a running execution func Stop(ctx *types.Context, execID string) error ``` #### API Types ```go // ==================== CRUD Types ==================== // CreateRequest - request for Create() type CreateRequest struct { DisplayName string `json:"display_name"` SystemPrompt string `json:"system_prompt,omitempty"` Config *types.Config `json:"robot_config"` } // UpdateRequest - request for Update() type UpdateRequest struct { DisplayName *string `json:"display_name,omitempty"` SystemPrompt *string `json:"system_prompt,omitempty"` Config *types.Config `json:"robot_config,omitempty"` } // ListQuery - query options for List() type ListQuery struct { TeamID string `json:"team_id,omitempty"` // filter by team Status types.RobotStatus `json:"status,omitempty"` // idle | working | paused | error Keywords string `json:"keywords,omitempty"` // search display_name, role ClockMode types.ClockMode `json:"clock_mode,omitempty"` // times | interval | daemon Page int `json:"page,omitempty"` // default 1 PageSize int `json:"pagesize,omitempty"` // default 20, max 100 Order string `json:"order,omitempty"` // e.g. "created_at desc" } // ListResult - result of List() type ListResult struct { Data []*types.Robot `json:"data"` Total int `json:"total"` Page int `json:"page"` PageSize int `json:"pagesize"` } // RobotState - runtime state from Status() type RobotState struct { MemberID string `json:"member_id"` TeamID string `json:"team_id"` DisplayName string `json:"display_name"` Status types.RobotStatus `json:"status"` // idle | working | paused | error Running int `json:"running"` // current running execution count MaxRunning int `json:"max_running"` // max concurrent allowed LastRun *time.Time `json:"last_run,omitempty"` NextRun *time.Time `json:"next_run,omitempty"` RunningIDs []string `json:"running_ids,omitempty"` // list of running execution IDs } // ==================== Trigger Types ==================== // TriggerRequest - request for Trigger() // Input uses []context.Message to support rich content (text, images, files, audio) type TriggerRequest struct { Type types.TriggerType `json:"type"` // human | event // Human intervention fields (when Type = human) Action types.InterventionAction `json:"action,omitempty"` // task.add | goal.adjust | task.cancel | plan.add Messages []context.Message `json:"messages,omitempty"` // user's input (supports text, images, files) PlanAt *time.Time `json:"plan_at,omitempty"` // for action=plan.add InsertAt InsertPosition `json:"insert_at,omitempty"` // where to insert: first | last | next | at AtIndex int `json:"at_index,omitempty"` // index when insert_at=at // Event fields (when Type = event) Source types.EventSource `json:"source,omitempty"` // webhook | database EventType string `json:"event_type,omitempty"` // lead.created, order.paid, etc. Data map[string]interface{} `json:"data,omitempty"` // event payload // Executor mode (optional, overrides robot config) ExecutorMode types.ExecutorMode `json:"executor_mode,omitempty"` // standard | dryrun } // InsertPosition - where to insert task in queue type InsertPosition string const ( InsertFirst InsertPosition = "first" // insert at beginning (highest priority) InsertLast InsertPosition = "last" // append at end (default) InsertNext InsertPosition = "next" // insert after current task InsertAt InsertPosition = "at" // insert at specific index (use AtIndex) ) // TriggerResult - result of Trigger() type TriggerResult struct { Accepted bool `json:"accepted"` // whether trigger was accepted Queued bool `json:"queued"` // true if queued (quota full) Execution *types.Execution `json:"execution,omitempty"` // execution info if started Message string `json:"message,omitempty"` // status message } // ==================== Execution Types ==================== // ExecutionQuery - query options for GetExecutions() type ExecutionQuery struct { Status types.ExecStatus `json:"status,omitempty"` // pending | running | completed | failed Trigger types.TriggerType `json:"trigger,omitempty"` // clock | human | event Page int `json:"page,omitempty"` // default 1 PageSize int `json:"pagesize,omitempty"`// default 20 } // ExecutionResult - result of GetExecutions() type ExecutionResult struct { Data []*types.Execution `json:"data"` Total int `json:"total"` Page int `json:"page"` PageSize int `json:"pagesize"` } ``` #### Process API (`api/process.go`) Yao Process registration. Naming convention: `robot.` ```go // Process registration func init() { process.Register("robot.Get", processGet) process.Register("robot.List", processList) process.Register("robot.Create", processCreate) process.Register("robot.Update", processUpdate) process.Register("robot.Remove", processRemove) process.Register("robot.Status", processStatus) process.Register("robot.UpdateStatus", processUpdateStatus) process.Register("robot.Trigger", processTrigger) process.Register("robot.Executions", processExecutions) process.Register("robot.Execution", processExecution) process.Register("robot.Pause", processPause) process.Register("robot.Resume", processResume) process.Register("robot.Stop", processStop) } ``` | Process | Args | Returns | Description | | -------------------- | --------------------- | ----------------- | ------------------ | | `robot.Get` | `memberID` | `Robot` | Get robot by ID | | `robot.List` | `query` | `ListResult` | List robots | | `robot.Create` | `teamID`, `data` | `Robot` | Create robot | | `robot.Update` | `memberID`, `data` | `Robot` | Update robot | | `robot.Remove` | `memberID` | `null` | Delete robot | | `robot.Status` | `memberID` | `RobotState` | Get runtime status | | `robot.UpdateStatus` | `memberID`, `status` | `null` | Update status | | `robot.Trigger` | `memberID`, `request` | `TriggerResult` | Trigger execution | | `robot.Executions` | `memberID`, `query` | `ExecutionResult` | List executions | | `robot.Execution` | `execID` | `Execution` | Get execution | | `robot.Pause` | `execID` | `null` | Pause execution | | `robot.Resume` | `execID` | `null` | Resume execution | | `robot.Stop` | `execID` | `null` | Stop execution | **Usage:** ```javascript // In Yao scripts const robot = Process("robot.Get", "mem_abc123"); const list = Process("robot.List", { team_id: "team_xyz", status: "idle", page: 1, pagesize: 20, }); // Trigger with text message const result = Process("robot.Trigger", "mem_abc123", { type: "human", action: "task.add", messages: [ { role: "user", content: "Prepare meeting materials for BigCorp" }, ], insert_at: "first", }); // Trigger with image (multimodal) const imageResult = Process("robot.Trigger", "mem_abc123", { type: "human", action: "task.add", messages: [ { role: "user", content: [ { type: "text", text: "Analyze this chart and summarize key trends" }, { type: "image_url", image_url: { url: "https://example.com/chart.png" }, }, ], }, ], insert_at: "first", }); // Trigger with event const eventResult = Process("robot.Trigger", "mem_abc123", { type: "event", source: "webhook", event_type: "lead.created", data: { name: "John", email: "john@example.com" }, }); const execs = Process("robot.Executions", "mem_abc123", { status: "completed", page: 1, }); ``` #### JSAPI (`api/jsapi.go`) Register to V8 Runtime using constructor pattern, similar to `new FS()`, `new Store()`, `new Query()`. ```go func init() { // Register Robot constructor v8.RegisterFunction("Robot", ExportFunction) } // ExportFunction exports the Robot constructor func ExportFunction(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, robotConstructor) } // robotConstructor: new Robot(memberID) func robotConstructor(info *v8go.FunctionCallbackInfo) *v8go.Value { ctx := info.Context() args := info.Args() if len(args) < 1 { return bridge.JsException(ctx, "Robot requires member ID") } memberID := args[0].String() robotObj, err := RobotNew(ctx, memberID) if err != nil { return bridge.JsException(ctx, err.Error()) } return robotObj } // RobotNew creates a Robot JS object with methods func RobotNew(ctx *v8go.Context, memberID string) (*v8go.Value, error) { iso := ctx.Isolate() obj := v8go.NewObjectTemplate(iso) // Instance methods (operate on this robot) obj.Set("Status", v8go.NewFunctionTemplate(iso, jsStatus)) obj.Set("UpdateStatus", v8go.NewFunctionTemplate(iso, jsUpdateStatus)) obj.Set("Trigger", v8go.NewFunctionTemplate(iso, jsTrigger)) obj.Set("Executions", v8go.NewFunctionTemplate(iso, jsExecutions)) obj.Set("Pause", v8go.NewFunctionTemplate(iso, jsPause)) obj.Set("Resume", v8go.NewFunctionTemplate(iso, jsResume)) obj.Set("Stop", v8go.NewFunctionTemplate(iso, jsStop)) // ... create instance with memberID stored return obj.NewInstance(ctx) } ``` **Global object `robot` (static methods):** ```go func init() { // Register global robot object (lowercase, for static methods) v8.RegisterObject("robot", ExportObject) } // ExportObject exports the robot global object func ExportObject(iso *v8go.Isolate) *v8go.ObjectTemplate { obj := v8go.NewObjectTemplate(iso) obj.Set("List", v8go.NewFunctionTemplate(iso, jsList)) obj.Set("Get", v8go.NewFunctionTemplate(iso, jsGet)) obj.Set("Create", v8go.NewFunctionTemplate(iso, jsCreate)) obj.Set("Update", v8go.NewFunctionTemplate(iso, jsUpdate)) obj.Set("Remove", v8go.NewFunctionTemplate(iso, jsRemove)) obj.Set("Execution", v8go.NewFunctionTemplate(iso, jsExecution)) return obj } ``` **TypeScript Interface:** ```typescript // ==================== Types ==================== interface RobotData { member_id: string; team_id: string; display_name: string; robot_status: "idle" | "working" | "paused" | "error" | "maintenance"; robot_config: RobotConfig; } interface RobotState { member_id: string; team_id: string; display_name: string; status: "idle" | "working" | "paused" | "error" | "maintenance"; running: number; // current running execution count max_running: number; // max concurrent allowed last_run?: string; next_run?: string; running_ids?: string[]; // list of running execution IDs } interface TriggerResult { accepted: boolean; queued: boolean; execution?: Execution; message?: string; } // Message - same as context.Message, supports rich content interface Message { role: "user" | "assistant" | "system" | "tool"; content: string | ContentPart[]; name?: string; tool_call_id?: string; tool_calls?: ToolCall[]; } interface ContentPart { type: "text" | "image_url" | "input_audio" | "file" | "data"; text?: string; image_url?: { url: string; detail?: "auto" | "low" | "high" }; input_audio?: { data: string; format: string }; file?: { url: string; name?: string; mime_type?: string }; data?: { data: string; mime_type: string }; } interface TriggerRequest { type: "human" | "event"; // Human intervention fields action?: | "task.add" | "task.cancel" | "task.update" | "goal.adjust" | "goal.add" | "goal.complete" | "goal.cancel" | "plan.add" | "plan.remove" | "plan.update" | "instruct"; messages?: Message[]; // supports text, images, files, audio insert_at?: "first" | "last" | "next" | "at"; at_index?: number; plan_at?: string; // ISO date for plan.add // Event fields source?: "webhook" | "database"; event_type?: string; // lead.created, etc. data?: Record; // Executor mode (optional, overrides robot config) executor_mode?: "standard" | "dryrun"; // sandbox not implemented } // ExecutorMode - executor mode type type ExecutorMode = "standard" | "dryrun" | "sandbox"; // Note: "sandbox" requires container infrastructure, falls back to "dryrun" interface ListQuery { team_id?: string; status?: "idle" | "working" | "paused" | "error" | "maintenance"; keywords?: string; clock_mode?: "times" | "interval" | "daemon"; page?: number; pagesize?: number; } interface ListResult { data: RobotData[]; total: number; page: number; pagesize: number; } interface ExecutionQuery { status?: "pending" | "running" | "completed" | "failed" | "cancelled"; trigger?: "clock" | "human" | "event"; page?: number; pagesize?: number; } interface ExecutionResult { data: Execution[]; total: number; page: number; pagesize: number; } interface CreateRequest { display_name: string; system_prompt?: string; robot_config: RobotConfig; } interface UpdateRequest { display_name?: string; system_prompt?: string; robot_config?: RobotConfig; } // ==================== Global object: robot ==================== // Static methods, no instance needed interface RobotStatic { List(query?: ListQuery): ListResult; Get(memberID: string): RobotData; Create(teamID: string, data: CreateRequest): RobotData; Update(memberID: string, data: UpdateRequest): RobotData; Remove(memberID: string): void; Execution(execID: string): Execution; } declare const robot: RobotStatic; // ==================== Constructor: Robot ==================== // Instance methods, operate on specific robot declare class Robot { constructor(memberID: string); // Properties readonly memberID: string; // Instance methods Status(): RobotState; UpdateStatus(status: string): void; Trigger(request: TriggerRequest): TriggerResult; Executions(query?: ExecutionQuery): ExecutionResult; Pause(execID: string): void; Resume(execID: string): void; Stop(execID: string): void; } ``` **Usage:** ```javascript // ==================== Global object: robot ==================== // For CRUD and queries (no instance needed) const list = robot.List({ team_id: "team_xyz", status: "idle" }); const data = robot.Get("mem_abc123"); const newRobot = robot.Create("team_xyz", { display_name: "Sales Bot", robot_config: { ... } }); robot.Update("mem_abc123", { display_name: "Updated Bot" }); robot.Remove("mem_abc123"); const exec = robot.Execution("exec_456"); // ==================== Constructor: Robot ==================== // For operating on a specific robot instance const bot = new Robot("mem_abc123"); // Instance methods const state = bot.Status(); if (state.status === "idle") { const result = bot.Trigger({ type: "human", action: "task.add", messages: [{ role: "user", content: "Analyze sales data" }], insert_at: "first", }); console.log("Triggered:", result.accepted); } // Get execution history for this robot const execs = bot.Executions({ status: "completed", page: 1 }); // Control execution bot.Pause("exec_123"); bot.Resume("exec_123"); bot.Stop("exec_123"); // Update status bot.UpdateStatus("paused"); ``` **Usage in Agent Hooks:** ```javascript function Create(ctx, messages) { const bot = new Robot("mem_abc123"); const state = bot.Status(); if (state.status === "working") { ctx.Send({ type: "text", props: { content: "Robot is busy" } }); return null; } const result = bot.Trigger({ type: "human", action: "task.add", messages: [{ role: "user", content: "Analyze this data" }], insert_at: "first", }); if (result.accepted) { ctx.memory.context.Set("robot_exec_id", result.execution.id); } return { messages }; } function Next(ctx, payload) { const execID = ctx.memory.context.Get("robot_exec_id"); if (execID) { const exec = robot.Execution(execID); // use global object if (exec.status === "completed") { ctx.Send({ type: "text", props: { content: `Robot completed: ${exec.delivery?.summary}` }, }); } } return null; } ``` --- ## 2. Type Definitions > All types are in `robot/types/` package. Other files import as: > > ```go > import "github.com/yaoapp/yao/agent/robot/types" > ``` ### 2.1 Enums ```go // types/enums.go package types // Phase - execution phase type Phase string const ( PhaseInspiration Phase = "inspiration" // P0: Clock only PhaseGoals Phase = "goals" // P1 PhaseTasks Phase = "tasks" // P2 PhaseRun Phase = "run" // P3 PhaseDelivery Phase = "delivery" // P4 PhaseLearning Phase = "learning" // P5 ) // AllPhases for iteration var AllPhases = []Phase{ PhaseInspiration, PhaseGoals, PhaseTasks, PhaseRun, PhaseDelivery, PhaseLearning, } // ClockMode - clock trigger mode type ClockMode string const ( ClockTimes ClockMode = "times" // run at specific times ClockInterval ClockMode = "interval" // run every X duration ClockDaemon ClockMode = "daemon" // run continuously ) // TriggerType - trigger source type TriggerType string const ( TriggerClock TriggerType = "clock" TriggerHuman TriggerType = "human" TriggerEvent TriggerType = "event" ) // ExecStatus - execution status type ExecStatus string const ( ExecPending ExecStatus = "pending" ExecRunning ExecStatus = "running" ExecCompleted ExecStatus = "completed" ExecFailed ExecStatus = "failed" ExecCancelled ExecStatus = "cancelled" ) // RobotStatus - matches __yao.member.robot_status type RobotStatus string const ( RobotIdle RobotStatus = "idle" RobotWorking RobotStatus = "working" RobotPaused RobotStatus = "paused" RobotError RobotStatus = "error" RobotMaintenance RobotStatus = "maintenance" ) // InterventionAction - human intervention action // Format: category.action (e.g., "task.add", "goal.adjust") type InterventionAction string const ( // Task operations ActionTaskAdd InterventionAction = "task.add" // add a new task ActionTaskCancel InterventionAction = "task.cancel" // cancel a task ActionTaskUpdate InterventionAction = "task.update" // update task details // Goal operations ActionGoalAdjust InterventionAction = "goal.adjust" // modify current goal ActionGoalAdd InterventionAction = "goal.add" // add a new goal ActionGoalComplete InterventionAction = "goal.complete" // mark goal as complete ActionGoalCancel InterventionAction = "goal.cancel" // cancel a goal // Plan operations (schedule for later) ActionPlanAdd InterventionAction = "plan.add" // add to plan queue ActionPlanRemove InterventionAction = "plan.remove" // remove from plan queue ActionPlanUpdate InterventionAction = "plan.update" // update planned item // Instruction (direct command) ActionInstruct InterventionAction = "instruct" // direct instruction to robot ) // Priority - task/goal priority type Priority string const ( PriorityHigh Priority = "high" PriorityNormal Priority = "normal" PriorityLow Priority = "low" ) // DeliveryType - output delivery type type DeliveryType string const ( DeliveryEmail DeliveryType = "email" // Email via yao/messenger DeliveryWebhook DeliveryType = "webhook" // POST to external URL DeliveryProcess DeliveryType = "process" // Yao Process call DeliveryNotify DeliveryType = "notify" // In-app notification (future) ) // DedupResult - deduplication result type DedupResult string const ( DedupSkip DedupResult = "skip" // skip execution DedupMerge DedupResult = "merge" // merge with existing DedupProceed DedupResult = "proceed" // proceed normally ) // EventSource - event trigger source type EventSource string const ( EventWebhook EventSource = "webhook" // HTTP webhook EventDatabase EventSource = "database" // DB change trigger ) // LearningType - learning entry type type LearningType string const ( LearnExecution LearningType = "execution" // execution record LearnFeedback LearningType = "feedback" // error/fix feedback LearnInsight LearningType = "insight" // pattern/tip insight ) ``` ### 2.2 Context ```go // types/context.go package types import ( "context" "github.com/yaoapp/yao/openapi/oauth/types" ) // Context - robot execution context (lightweight) type Context struct { context.Context // embed standard context Auth *types.AuthorizedInfo `json:"auth,omitempty"` // reuse oauth AuthorizedInfo MemberID string `json:"member_id,omitempty"` // current robot member ID RequestID string `json:"request_id,omitempty"` // request trace ID Locale string `json:"locale,omitempty"` // locale (e.g., "en-US") } // NewContext creates a new robot context func NewContext(parent context.Context, auth *types.AuthorizedInfo) *Context { if parent == nil { parent = context.Background() } return &Context{ Context: parent, Auth: auth, } } // UserID returns user ID from auth func (c *Context) UserID() string { if c.Auth == nil { return "" } return c.Auth.UserID } // TeamID returns team ID from auth func (c *Context) TeamID() string { if c.Auth == nil { return "" } return c.Auth.TeamID } ``` ### 2.3 Config Types ```go // types/config.go package types import "time" // Config - robot_config in __yao.member type Config struct { Triggers *Triggers `json:"triggers,omitempty"` Clock *Clock `json:"clock,omitempty"` Identity *Identity `json:"identity"` Quota *Quota `json:"quota,omitempty"` KB *KB `json:"kb,omitempty"` // shared knowledge base (same as assistant) DB *DB `json:"db,omitempty"` // shared database (same as assistant) Learn *Learn `json:"learn,omitempty"` // learning config for private KB Resources *Resources `json:"resources,omitempty"` Delivery *DeliveryPreferences `json:"delivery,omitempty"` // see section 6.2 Events []Event `json:"events,omitempty"` DefaultLocale string `json:"default_locale,omitempty"` // Default language for clock/event triggers (e.g., "en-US", "zh-CN") } // Validate validates the config func (c *Config) Validate() error { if c.Identity == nil || c.Identity.Role == "" { return ErrMissingIdentity } if c.Clock != nil { if err := c.Clock.Validate(); err != nil { return err } } return nil } // Triggers - trigger enable/disable type Triggers struct { Clock *TriggerSwitch `json:"clock,omitempty"` Intervene *TriggerSwitch `json:"intervene,omitempty"` Event *TriggerSwitch `json:"event,omitempty"` } type TriggerSwitch struct { Enabled bool `json:"enabled"` Actions []string `json:"actions,omitempty"` // for intervene } // IsEnabled checks if trigger is enabled (default: true) func (t *Triggers) IsEnabled(typ TriggerType) bool { if t == nil { return true } switch typ { case TriggerClock: return t.Clock == nil || t.Clock.Enabled case TriggerHuman: return t.Intervene == nil || t.Intervene.Enabled case TriggerEvent: return t.Event == nil || t.Event.Enabled } return false } // Clock - when to wake up type Clock struct { Mode ClockMode `json:"mode"` // times | interval | daemon Times []string `json:"times,omitempty"` // ["09:00", "14:00"] Days []string `json:"days,omitempty"` // ["Mon", "Tue"] or ["*"] Every string `json:"every,omitempty"` // "30m", "1h" TZ string `json:"tz,omitempty"` // "Asia/Shanghai" Timeout string `json:"timeout,omitempty"` // "30m" } // Validate validates clock config func (c *Clock) Validate() error { switch c.Mode { case ClockTimes: if len(c.Times) == 0 { return ErrClockTimesEmpty } case ClockInterval: if c.Every == "" { return ErrClockIntervalEmpty } case ClockDaemon: // no extra validation default: return ErrClockModeInvalid } return nil } // GetTimeout returns parsed timeout duration func (c *Clock) GetTimeout() time.Duration { if c.Timeout == "" { return 30 * time.Minute // default } d, err := time.ParseDuration(c.Timeout) if err != nil { return 30 * time.Minute } return d } // GetLocation returns timezone location func (c *Clock) GetLocation() *time.Location { if c.TZ == "" { return time.Local } loc, err := time.LoadLocation(c.TZ) if err != nil { return time.Local } return loc } // Identity - who is this robot type Identity struct { Role string `json:"role"` Duties []string `json:"duties,omitempty"` Rules []string `json:"rules,omitempty"` } // Quota - concurrency limits type Quota struct { Max int `json:"max"` // max running (default: 2) Queue int `json:"queue"` // queue size (default: 10) Priority int `json:"priority"` // 1-10 (default: 5) } // GetMax returns max with default func (q *Quota) GetMax() int { if q == nil || q.Max <= 0 { return 2 } return q.Max } // GetQueue returns queue size with default func (q *Quota) GetQueue() int { if q == nil || q.Queue <= 0 { return 10 } return q.Queue } // GetPriority returns priority with default func (q *Quota) GetPriority() int { if q == nil || q.Priority <= 0 { return 5 } return q.Priority } // KB - knowledge base config (same as assistant, from store/types) // Shared KB collections accessible by this robot type KB struct { Collections []string `json:"collections,omitempty"` // KB collection IDs Options map[string]interface{} `json:"options,omitempty"` } // DB - database config (same as assistant, from store/types) // Shared database models accessible by this robot type DB struct { Models []string `json:"models,omitempty"` // database model names Options map[string]interface{} `json:"options,omitempty"` } // Learn - learning config for robot's private KB // Private KB is auto-created: robot_{team_id}_{member_id}_kb type Learn struct { On bool `json:"on"` Types []string `json:"types,omitempty"` // execution, feedback, insight Keep int `json:"keep,omitempty"` // days, 0 = forever } // Resources - available agents and tools type Resources struct { Phases map[Phase]string `json:"phases,omitempty"` // phase -> agent ID Agents []string `json:"agents,omitempty"` MCP []MCPConfig `json:"mcp,omitempty"` } // GetPhaseAgent returns agent ID for phase (default: __yao.{phase}) func (r *Resources) GetPhaseAgent(phase Phase) string { if r != nil && r.Phases != nil { if id, ok := r.Phases[phase]; ok && id != "" { return id } } return "__yao." + string(phase) } type MCPConfig struct { ID string `json:"id"` Tools []string `json:"tools,omitempty"` // empty = all } // Note: Delivery preferences moved to DeliveryPreferences (see section 6.2) // Event - event trigger config type Event struct { Type EventSource `json:"type"` // webhook | database Source string `json:"source"` // webhook path or table name Filter map[string]interface{} `json:"filter,omitempty"` } // Monitor - monitoring config ``` ### 2.3 Core Types ```go // types/robot.go package types import ( "context" "sync" "time" ) // Robot - runtime representation of an autonomous robot (from __yao.member) // Relationship: 1 Robot : N Executions (concurrent) // Each trigger creates a new Execution (stored in ExecutionStore) type Robot struct { // From __yao.member MemberID string `json:"member_id"` TeamID string `json:"team_id"` DisplayName string `json:"display_name"` SystemPrompt string `json:"system_prompt"` Status RobotStatus `json:"robot_status"` AutonomousMode bool `json:"autonomous_mode"` RobotEmail string `json:"robot_email"` // Robot's email address for sending emails // Parsed config (from robot_config JSON field) Config *Config `json:"-"` // Runtime state LastRun time.Time `json:"-"` // last execution start time NextRun time.Time `json:"-"` // next scheduled execution (for clock trigger) // Concurrency control // Each Robot can run multiple Executions concurrently (up to Quota.Max) executions map[string]*Execution // execID -> Execution execMu sync.RWMutex } // CanRun checks if robot can accept new execution func (r *Robot) CanRun() bool { r.execMu.RLock() defer r.execMu.RUnlock() return len(r.executions) < r.Config.Quota.GetMax() } // RunningCount returns current running execution count func (r *Robot) RunningCount() int { r.execMu.RLock() defer r.execMu.RUnlock() return len(r.executions) } // AddExecution adds an execution to tracking func (r *Robot) AddExecution(exec *Execution) { r.execMu.Lock() defer r.execMu.Unlock() if r.executions == nil { r.executions = make(map[string]*Execution) } r.executions[exec.ID] = exec } // RemoveExecution removes an execution from tracking func (r *Robot) RemoveExecution(execID string) { r.execMu.Lock() defer r.execMu.Unlock() delete(r.executions, execID) } // GetExecution returns an execution by ID func (r *Robot) GetExecution(execID string) *Execution { r.execMu.RLock() defer r.execMu.RUnlock() return r.executions[execID] } // GetExecutions returns all running executions func (r *Robot) GetExecutions() []*Execution { r.execMu.RLock() defer r.execMu.RUnlock() execs := make([]*Execution, 0, len(r.executions)) for _, exec := range r.executions { execs = append(execs, exec) } return execs } // Execution - single execution instance // Each trigger creates a new Execution, stored in ExecutionStore type Execution struct { ID string `json:"id"` // unique execution ID MemberID string `json:"member_id"` // robot member ID (globally unique) TeamID string `json:"team_id"` TriggerType TriggerType `json:"trigger_type"` // clock | human | event StartTime time.Time `json:"start_time"` EndTime *time.Time `json:"end_time,omitempty"` Status ExecStatus `json:"status"` Phase Phase `json:"phase"` Error string `json:"error,omitempty"` // UI display fields (updated by executor at each phase) // These provide human-readable status for frontend display Name string `json:"name,omitempty"` // Execution title (updated when goals complete) CurrentTaskName string `json:"current_task_name,omitempty"` // Current task description (updated during run phase) // Trigger input (stored for traceability) Input *TriggerInput `json:"input,omitempty"` // original trigger input // Phase outputs Inspiration *InspirationReport `json:"inspiration,omitempty"` // P0: markdown Goals *Goals `json:"goals,omitempty"` // P1: markdown Tasks []Task `json:"tasks,omitempty"` // P2: structured tasks Current *CurrentState `json:"current,omitempty"` // current executing state Results []TaskResult `json:"results,omitempty"` // P3: task results Delivery *DeliveryResult `json:"delivery,omitempty"` Learning []LearningEntry `json:"learning,omitempty"` // Runtime (internal, not serialized) ctx context.Context `json:"-"` cancel context.CancelFunc `json:"-"` robot *Robot `json:"-"` } // TriggerInput - stored trigger input for traceability type TriggerInput struct { // For human intervention Action InterventionAction `json:"action,omitempty"` // task.add, goal.adjust, etc. Messages []context.Message `json:"messages,omitempty"` // user's input (text, images, files) UserID string `json:"user_id,omitempty"` // who triggered Locale string `json:"locale,omitempty"` // language for UI display (e.g., "en-US", "zh-CN") // For event trigger Source EventSource `json:"source,omitempty"` // webhook | database EventType string `json:"event_type,omitempty"` // lead.created, etc. Data map[string]interface{} `json:"data,omitempty"` // event payload // For clock trigger Clock *ClockContext `json:"clock,omitempty"` // time context when triggered } // CurrentState - current executing goal and task type CurrentState struct { Task *Task `json:"task,omitempty"` // current task being executed TaskIndex int `json:"task_index"` // index in Tasks slice Progress string `json:"progress,omitempty"` // human-readable progress (e.g., "2/5 tasks") } // Goals - P1 output (markdown for LLM + structured metadata) // P1 Agent reads InspirationReport and generates goals as markdown // Example: // ## Goals // 1. [High] Analyze sales data and identify trends // - Reason: Sales up 50%, need to understand why // 2. [Normal] Prepare weekly report for manager // - Reason: Friday 5pm, weekly report due // 3. [Low] Update CRM with new leads // - Reason: 3 pending leads from yesterday type Goals struct { Content string `json:"content"` // markdown text Delivery *DeliveryTarget `json:"delivery,omitempty"` // where to send results (for P4) } // DeliveryTarget - where to deliver results (defined in P1, used by P4) // Note: This is a hint from P1 Goals. Actual delivery is handled by Delivery Center // based on Robot/User preferences, not strictly by this target. type DeliveryTarget struct { Type DeliveryType `json:"type"` // Preferred delivery type Recipients []string `json:"recipients,omitempty"` // email addresses, webhook URLs, user IDs Format string `json:"format,omitempty"` // markdown | html | json | text Template string `json:"template,omitempty"` // template name or inline template Options map[string]interface{} `json:"options,omitempty"` // channel-specific options } // Task - planned task (structured, for execution) type Task struct { ID string `json:"id"` Description string `json:"description,omitempty"` // human-readable task description (for UI display) Messages []context.Message `json:"messages"` // original input (text, images, files) GoalRef string `json:"goal_ref,omitempty"` // reference to goal (e.g., "Goal 1") Source TaskSource `json:"source"` // auto | human | event // Executor ExecutorType ExecutorType `json:"executor_type"` ExecutorID string `json:"executor_id"` // unified ID: agent/assistant/process ID, or "mcp_server.mcp_tool" for MCP Args []any `json:"args,omitempty"` // MCP-specific fields (required when executor_type is "mcp") MCPServer string `json:"mcp_server,omitempty"` // MCP server/client ID (e.g., "ark.image.text2img") MCPTool string `json:"mcp_tool,omitempty"` // MCP tool name (e.g., "generate") // Validation (defined in P2, used in P3) ExpectedOutput string `json:"expected_output,omitempty"` // what the task should produce // ValidationRules supports two formats: // 1. Natural language: "output must be valid JSON", "must contain 'field'" // 2. JSON assertions: `{"type": "type", "value": "object"}`, `{"type": "contains", "value": "success"}` ValidationRules []string `json:"validation_rules,omitempty"` // specific checks to perform // Runtime Status TaskStatus `json:"status"` Order int `json:"order"` // execution order (0-based) StartTime *time.Time `json:"start_time,omitempty"` EndTime *time.Time `json:"end_time,omitempty"` } // TaskSource - how task was created type TaskSource string const ( TaskSourceAuto TaskSource = "auto" // generated by P2 (task planning) TaskSourceHuman TaskSource = "human" // added via human intervention TaskSourceEvent TaskSource = "event" // added via event trigger ) // ExecutorType - task executor type type ExecutorType string const ( ExecutorAssistant ExecutorType = "assistant" ExecutorMCP ExecutorType = "mcp" ExecutorProcess ExecutorType = "process" ) // TaskStatus - task execution status type TaskStatus string const ( TaskPending TaskStatus = "pending" TaskRunning TaskStatus = "running" TaskCompleted TaskStatus = "completed" TaskFailed TaskStatus = "failed" TaskSkipped TaskStatus = "skipped" TaskCancelled TaskStatus = "cancelled" ) // TaskResult - task execution result type TaskResult struct { TaskID string `json:"task_id"` Success bool `json:"success"` Output interface{} `json:"output,omitempty"` Error string `json:"error,omitempty"` Duration int64 `json:"duration_ms"` Validation *ValidationResult `json:"validation,omitempty"` // P3 validation result } // ValidationResult - P3 validation result with multi-turn conversation support type ValidationResult struct { // Basic validation result Passed bool `json:"passed"` // overall validation passed Score float64 `json:"score,omitempty"` // 0-1 confidence score Issues []string `json:"issues,omitempty"` // what failed Suggestions []string `json:"suggestions,omitempty"` // how to improve Details string `json:"details,omitempty"` // detailed validation report (markdown) // Execution state (for multi-turn conversation control) Complete bool `json:"complete"` // whether expected result is obtained NeedReply bool `json:"need_reply,omitempty"` // whether to continue conversation ReplyContent string `json:"reply_content,omitempty"` // content for next turn (if NeedReply) } // DeliveryRequest - pushed to Delivery Center // Agent only generates content, Delivery Center decides channels based on preferences type DeliveryRequest struct { Content *DeliveryContent `json:"content"` // Agent-generated content Context *DeliveryContext `json:"context"` // Tracking info // No Channels field - Delivery Center decides based on Robot/User preferences } // DeliveryContent - content generated by Delivery Agent type DeliveryContent struct { Summary string `json:"summary"` // Brief summary (1-2 sentences) Body string `json:"body"` // Full markdown report Attachments []DeliveryAttachment `json:"attachments,omitempty"` // Output artifacts } // DeliveryAttachment - task output attachment with metadata // File uses wrapper format: __:// // Example: __yao.attachment://ccd472d11feb96e03a3fc468f494045c // Parse with attachment.Parse(value) → (uploader, fileID, isWrapper) type DeliveryAttachment struct { Title string `json:"title"` // Human-readable title, e.g., "Market Analysis Report" Description string `json:"description,omitempty"` // Description of what this artifact is TaskID string `json:"task_id,omitempty"` // Which task produced this artifact File string `json:"file"` // Wrapper format: __:// } // DeliveryContext - tracking and audit info type DeliveryContext struct { MemberID string `json:"member_id"` // Robot member ID (globally unique) ExecutionID string `json:"execution_id"` TriggerType TriggerType `json:"trigger_type"` TeamID string `json:"team_id"` } // DeliveryPreferences - Robot/User delivery preferences (read by Delivery Center) // Each channel supports multiple targets type DeliveryPreferences struct { Email *EmailPreference `json:"email,omitempty"` Webhook *WebhookPreference `json:"webhook,omitempty"` Process *ProcessPreference `json:"process,omitempty"` // notify is handled automatically based on user subscriptions } // EmailPreference - multiple email targets type EmailPreference struct { Enabled bool `json:"enabled"` Targets []EmailTarget `json:"targets"` } type EmailTarget struct { To []string `json:"to"` // Recipient addresses Template string `json:"template,omitempty"` // Email template ID Subject string `json:"subject,omitempty"` // Subject template (default: content.Summary) } // WebhookPreference - multiple webhook targets type WebhookPreference struct { Enabled bool `json:"enabled"` Targets []WebhookTarget `json:"targets"` } type WebhookTarget struct { URL string `json:"url"` // Webhook URL Method string `json:"method,omitempty"` // HTTP method (default: POST) Headers map[string]string `json:"headers,omitempty"` // Custom headers Secret string `json:"secret,omitempty"` // Signing secret } // ProcessPreference - multiple Yao Process targets type ProcessPreference struct { Enabled bool `json:"enabled"` Targets []ProcessTarget `json:"targets"` } type ProcessTarget struct { Process string `json:"process"` // Yao Process name, e.g., "orders.UpdateStatus" Args []any `json:"args,omitempty"` // Additional args (DeliveryContent passed as first arg) } // DeliveryResult - P4 delivery output (returned by Delivery Center) type DeliveryResult struct { RequestID string `json:"request_id"` // Delivery request ID Content *DeliveryContent `json:"content"` // Agent-generated content Results []ChannelResult `json:"results,omitempty"` // Results per channel Success bool `json:"success"` // Overall success Error string `json:"error,omitempty"` // Error if failed SentAt *time.Time `json:"sent_at,omitempty"` // When delivery completed } // ChannelResult - result for a single delivery target type ChannelResult struct { Type DeliveryType `json:"type"` // email | webhook | process Target string `json:"target"` // Target identifier (email, URL, process name) Success bool `json:"success"` // Whether delivery succeeded Recipients []string `json:"recipients,omitempty"` // Who received (for email) Details interface{} `json:"details,omitempty"` // Channel-specific response Error string `json:"error,omitempty"` // Error message if failed SentAt *time.Time `json:"sent_at,omitempty"` // When this target was delivered } // LearningEntry - knowledge to save type LearningEntry struct { Type LearningType `json:"type"` // execution | feedback | insight Content string `json:"content"` Tags []string `json:"tags,omitempty"` Meta interface{} `json:"meta,omitempty"` } ``` ### 2.4 Clock Context ```go // types/clock.go package types import "time" // ClockContext - time context for P0 inspiration type ClockContext struct { Now time.Time `json:"now"` Hour int `json:"hour"` // 0-23 DayOfWeek string `json:"day_of_week"` // Monday, Tuesday... DayOfMonth int `json:"day_of_month"` // 1-31 WeekOfYear int `json:"week_of_year"` // 1-52 Month int `json:"month"` // 1-12 Year int `json:"year"` IsWeekend bool `json:"is_weekend"` IsMonthStart bool `json:"is_month_start"` // 1st-3rd IsMonthEnd bool `json:"is_month_end"` // last 3 days IsQuarterEnd bool `json:"is_quarter_end"` IsYearEnd bool `json:"is_year_end"` TZ string `json:"tz"` } // NewClockContext creates clock context from time func NewClockContext(t time.Time, tz string) *ClockContext { loc := time.Local if tz != "" { if l, err := time.LoadLocation(tz); err == nil { loc = l } } t = t.In(loc) _, week := t.ISOWeek() dayOfMonth := t.Day() lastDay := time.Date(t.Year(), t.Month()+1, 0, 0, 0, 0, 0, loc).Day() return &ClockContext{ Now: t, Hour: t.Hour(), DayOfWeek: t.Weekday().String(), DayOfMonth: dayOfMonth, WeekOfYear: week, Month: int(t.Month()), Year: t.Year(), IsWeekend: t.Weekday() == time.Saturday || t.Weekday() == time.Sunday, IsMonthStart: dayOfMonth <= 3, IsMonthEnd: dayOfMonth >= lastDay-2, IsQuarterEnd: (t.Month()%3 == 0) && dayOfMonth >= lastDay-2, IsYearEnd: t.Month() == 12 && dayOfMonth >= 29, TZ: loc.String(), } } ``` ### 2.5 Inspiration Report ```go // types/inspiration.go package types // InspirationReport - P0 output (simple markdown for LLM) type InspirationReport struct { Clock *ClockContext `json:"clock"` // time context Content string `json:"content"` // markdown text for LLM } // Content is markdown like: // ## Summary // ... // ## Highlights // - [High] Sales up 50% // - [Medium] New lead from BigCorp // ## Opportunities // ... // ## Risks // ... // ## World News // ... // ## Pending // ... ``` ### 2.6 Request/Response Types ```go // types/request.go package types import ( "context" "time" ) // InterveneRequest - human intervention request // Processed by Manager.Intervene() type InterveneRequest struct { TeamID string `json:"team_id"` MemberID string `json:"member_id"` Action InterventionAction `json:"action"` // task.add, goal.adjust, etc. Messages []agentcontext.Message `json:"messages,omitempty"` // user input (text, images, files) PlanTime *time.Time `json:"plan_time,omitempty"` // for action=plan.add ExecutorMode ExecutorMode `json:"executor_mode,omitempty"` // optional: standard | dryrun } // EventRequest - event trigger request // Processed by Manager.HandleEvent() type EventRequest struct { MemberID string `json:"member_id"` Source string `json:"source"` // webhook path or table name EventType string `json:"event_type"` // lead.created, etc. Data map[string]interface{} `json:"data,omitempty"` ExecutorMode ExecutorMode `json:"executor_mode,omitempty"` // optional: standard | dryrun } // ExecutorMode - executor mode enum type ExecutorMode string const ( ExecutorStandard ExecutorMode = "standard" // real Agent calls (default) ExecutorDryRun ExecutorMode = "dryrun" // simulated, no LLM calls ExecutorSandbox ExecutorMode = "sandbox" // container-isolated (NOT IMPLEMENTED) ) // ExecutionResult - trigger result type ExecutionResult struct { ExecutionID string `json:"execution_id"` Status ExecStatus `json:"status"` Message string `json:"message,omitempty"` } // RobotState - robot status query result type RobotState struct { MemberID string `json:"member_id"` TeamID string `json:"team_id"` DisplayName string `json:"display_name"` Status RobotStatus `json:"status"` Running int `json:"running"` // current running execution count MaxRunning int `json:"max_running"` // max concurrent allowed LastRun *time.Time `json:"last_run,omitempty"` NextRun *time.Time `json:"next_run,omitempty"` RunningIDs []string `json:"running_ids,omitempty"` // list of running execution IDs } ``` --- ## 3. Interfaces > Interfaces are also in `types/` package to avoid cycles. ### 3.1 Manager Interface ```go // types/interfaces.go package types import "time" // ==================== Internal Interfaces ==================== // These are internal implementation interfaces, not exposed via API. // External API is defined in api/api.go // All interfaces use *Context (not context.Context) for consistency. // Manager - robot lifecycle, scheduling, and all trigger handling // Manager is the central orchestrator, handling: // - Clock triggers (via Tick) // - Human intervention (via Intervene) // - Event triggers (via HandleEvent) // - Execution control (pause/resume/stop) type Manager interface { // Lifecycle Start() error Stop() error // Clock trigger (called by internal ticker) Tick(ctx *Context, now time.Time) error // Manual trigger (for testing/API) TriggerManual(ctx *Context, memberID string, trigger TriggerType, data interface{}) (string, error) // Human intervention Intervene(ctx *Context, req *InterveneRequest) (*ExecutionResult, error) // Event trigger HandleEvent(ctx *Context, req *EventRequest) (*ExecutionResult, error) // Execution control PauseExecution(ctx *Context, execID string) error ResumeExecution(ctx *Context, execID string) error StopExecution(ctx *Context, execID string) error } // Executor - executes robot phases type Executor interface { Execute(ctx *Context, robot *Robot, trigger TriggerType, data interface{}) (*Execution, error) } // Pool - worker pool for concurrent execution type Pool interface { Start() error Stop() error Submit(ctx *Context, robot *Robot, trigger TriggerType, data interface{}) (string, error) Running() int Queued() int } // Cache - in-memory robot cache type Cache interface { Load(ctx *Context) error Get(memberID string) *Robot List(teamID string) []*Robot Refresh(ctx *Context, memberID string) error Add(robot *Robot) Remove(memberID string) } // Dedup - deduplication check type Dedup interface { Check(ctx *Context, memberID string, trigger TriggerType) (DedupResult, error) Mark(memberID string, trigger TriggerType, window time.Duration) } // Store - data storage operations (KB, DB) type Store interface { SaveLearning(ctx *Context, memberID string, entries []LearningEntry) error GetHistory(ctx *Context, memberID string, limit int) ([]LearningEntry, error) SearchKB(ctx *Context, collections []string, query string) ([]interface{}, error) QueryDB(ctx *Context, models []string, query interface{}) ([]interface{}, error) } ``` ### 3.2 Trigger Utilities (`trigger/` package) > **Note:** The `trigger/` package provides utilities, not the main trigger logic. > All trigger handling is done by `Manager`. ```go // trigger/trigger.go - Validation and helper functions // ValidateIntervention validates a human intervention request func ValidateIntervention(req *InterveneRequest) error // ValidateEvent validates an event trigger request func ValidateEvent(req *EventRequest) error // BuildEventInput creates a TriggerInput from an event request func BuildEventInput(req *EventRequest) *TriggerInput // GetActionCategory returns the category of an intervention action // e.g., "task.add" -> "task", "goal.adjust" -> "goal" func GetActionCategory(action InterventionAction) string // GetActionDescription returns a human-readable description of an action func GetActionDescription(action InterventionAction) string ``` ```go // trigger/clock.go - Clock matching logic (reusable) // ClockMatcher provides clock trigger matching logic type ClockMatcher struct{} // ShouldTrigger checks if a robot should be triggered based on its clock config func (cm *ClockMatcher) ShouldTrigger(robot *Robot, now time.Time) bool // ParseTime parses a time string in "HH:MM" format func ParseTime(timeStr string) (hour, minute int, err error) // FormatTime formats hour and minute to "HH:MM" string func FormatTime(hour, minute int) string ``` ```go // trigger/control.go - Execution control (pause/resume/stop) // ExecutionController manages execution lifecycle type ExecutionController struct { executions map[string]*ControlledExecution mu sync.RWMutex } // Track starts tracking an execution func (c *ExecutionController) Track(execID, memberID, teamID string) *ControlledExecution // Untrack stops tracking an execution func (c *ExecutionController) Untrack(execID string) // Pause pauses an execution func (c *ExecutionController) Pause(execID string) error // Resume resumes a paused execution func (c *ExecutionController) Resume(execID string) error // Stop stops an execution func (c *ExecutionController) Stop(execID string) error // ControlledExecution represents an execution that can be controlled type ControlledExecution struct { ID string MemberID string TeamID string Status ExecStatus Phase Phase StartTime time.Time PausedAt *time.Time // ... internal fields for context and channels } // IsPaused returns true if the execution is paused func (e *ControlledExecution) IsPaused() bool // IsCancelled returns true if the execution is cancelled func (e *ControlledExecution) IsCancelled() bool // WaitIfPaused blocks until the execution is resumed or cancelled func (e *ControlledExecution) WaitIfPaused() error // CheckCancelled checks if the execution is cancelled and returns error if so func (e *ControlledExecution) CheckCancelled() error ``` --- ## 4. Errors ```go // types/errors.go package types import "errors" var ( // Config errors ErrMissingIdentity = errors.New("identity.role is required") ErrClockTimesEmpty = errors.New("clock.times is required for times mode") ErrClockIntervalEmpty = errors.New("clock.every is required for interval mode") ErrClockModeInvalid = errors.New("clock.mode must be times, interval, or daemon") // Runtime errors ErrRobotNotFound = errors.New("robot not found") ErrRobotPaused = errors.New("robot is paused") ErrRobotBusy = errors.New("robot has reached max concurrent executions") ErrTriggerDisabled = errors.New("trigger type is disabled for this robot") ErrExecutionCancelled = errors.New("execution was cancelled") ErrExecutionTimeout = errors.New("execution timed out") // Phase errors ErrPhaseAgentNotFound = errors.New("phase agent not found") ErrGoalGenFailed = errors.New("goal generation failed") ErrTaskPlanFailed = errors.New("task planning failed") ErrDeliveryFailed = errors.New("delivery failed") ) ``` --- ## 5. P3 Implementation Details ### 5.1 Multi-Turn Conversation Flow For assistant tasks, P3 uses a validator-driven multi-turn conversation: ``` ┌──────────────────────────────────────────────────────────────┐ │ executeAssistantWithMultiTurn │ ├──────────────────────────────────────────────────────────────┤ │ 1. Create Conversation (single instance for entire task) │ │ 2. Build initial messages with task context │ │ │ │ ┌─────────────────── Turn Loop ───────────────────────────┐ │ │ │ Phase 1: Call assistant via conv.Turn() │ │ │ │ Phase 2: ValidateWithContext() determines: │ │ │ │ - Complete: task done? │ │ │ │ - NeedReply: continue conversation? │ │ │ │ - ReplyContent: what to send next? │ │ │ │ Phase 3: If NeedReply, use ReplyContent as next input │ │ │ │ Break if: Complete && Passed, or !NeedReply │ │ │ └──────────────────────────────────────────────────────────┘ │ │ │ │ 3. Return output, validation, error │ └──────────────────────────────────────────────────────────────┘ ``` Key points: - `ValidateWithContext()` returns `NeedReply` and `ReplyContent` - Conversation continues until `Complete && Passed` or `!NeedReply` - Max turns controlled by `RunConfig.MaxTurnsPerTask` ### 5.2 Validation Rules Format Validation rules support two formats: 1. **Natural language**: `"output must be valid JSON"`, `"must contain 'field'"` 2. **Structured JSON**: `{"type": "type", "path": "field", "value": "array"}` Examples: ```json // Natural language rules (converted to semantic validation) "output must be valid JSON" "must contain product name" // Structured JSON assertions {"type": "equals", "value": "success"} {"type": "contains", "value": "total"} {"type": "regex", "value": "^[A-Z].*"} {"type": "json_path", "path": "data.items", "value": 10} {"type": "type", "path": "result", "value": "object"} ``` ### 5.3 Task Dependencies Task dependencies are handled automatically: 1. `BuildTaskContext()` collects previous task results 2. `FormatPreviousResultsAsContext()` formats them for assistant ```go // Previous results are passed as context func (r *Runner) BuildTaskContext(exec *robottypes.Execution, taskIndex int) *RunnerContext { ctx := &RunnerContext{} if taskIndex > 0 { ctx.PreviousResults = exec.Results[:taskIndex] } return ctx } ``` ### 5.4 Resource Management Agent context is properly released to prevent resource leaks: ```go func (c *AgentCaller) Call(ctx *robottypes.Context, assistantID string, messages []agentcontext.Message) (*CallResult, error) { agentCtx := c.buildAgentContext(ctx) defer agentCtx.Release() // IMPORTANT: Release agent context response, err := ast.Stream(agentCtx, messages, opts) // ... } ``` ### 5.5 yao/assert Package The `yao/assert` package is a standalone universal assertion library that can be used by other modules: ```go import "github.com/yaoapp/yao/assert" // Create asserter with optional callbacks asserter := assert.NewAsserter(assert.AssertionOptions{ AgentValidator: myAgentValidator, // for "agent" type assertions ScriptRunner: myScriptRunner, // for "script" type assertions }) // Run assertions results := asserter.Assert(output, []assert.Assertion{ {Type: "type", Value: "object"}, {Type: "contains", Value: "success"}, {Type: "json_path", Path: "data.count", Value: 10}, }) ``` Supported assertion types: - `equals` - exact match - `contains` - substring check - `not_contains` - negative substring check - `json_path` - JSON path extraction and comparison - `regex` - regex pattern matching - `type` - type checking (with optional path) - `script` - custom script validation - `agent` - AI agent validation --- ## 6. P4 Delivery Implementation ### 6.1 Overview P4 Delivery summarizes P3 execution results and delivers to configured channels. ``` ┌─────────────────────────────────────────────────────────────┐ │ delivery.go (P4 Entry) │ │ - DeliveryExecution: main entry point │ │ - Calls Delivery Agent with full execution context │ │ - Routes DeliveryContent to configured channels │ └─────────────────────┬───────────────────────────────────────┘ │ ┌────────────┴────────────┐ ▼ ▼ ┌─────────────────┐ ┌─────────────────┐ │ Delivery Agent │ │ Delivery Center │ │ - Summarize │ │ - sendEmail() │ │ - Format body │ │ - postWebhook() │ │ - List files │ │ - callProcess() │ └─────────────────┘ └─────────────────┘ ``` ### 6.2 Delivery Request Structure P4 generates a `DeliveryRequest` with **only content** and pushes to Delivery Center. **Delivery Center decides channels** based on Robot/User preferences. ```go // DeliveryRequest - pushed to Delivery Center // No Channels - Delivery Center decides based on preferences type DeliveryRequest struct { Content *DeliveryContent `json:"content"` // Agent-generated content Context *DeliveryContext `json:"context"` // Tracking info } // DeliveryContent - content generated by Delivery Agent type DeliveryContent struct { Summary string `json:"summary"` // Brief 1-2 sentence summary Body string `json:"body"` // Full markdown report Attachments []DeliveryAttachment `json:"attachments,omitempty"` // Output artifacts from P3 } // DeliveryAttachment - file attachment with metadata type DeliveryAttachment struct { Title string `json:"title"` // Human-readable title Description string `json:"description,omitempty"` // What this artifact is TaskID string `json:"task_id,omitempty"` // Which task produced this File string `json:"file"` // Wrapper: __:// } // DeliveryContext - tracking and audit info type DeliveryContext struct { MemberID string `json:"member_id"` // Robot member ID (globally unique) ExecutionID string `json:"execution_id"` TriggerType TriggerType `json:"trigger_type"` TeamID string `json:"team_id"` } ``` **Example DeliveryRequest:** ```json { "content": { "summary": "Sales report completed: 15 new leads", "body": "## Weekly Sales Report\n...", "attachments": [{"title": "Report.pdf", "file": "__yao.attachment://abc123"}] }, "context": { "member_id": "mem_abc123", "execution_id": "exec_xyz789", "trigger_type": "clock", "team_id": "team_123" } } ``` **Channel Decision by Delivery Center:** Delivery Center reads Robot/User preferences and executes delivery to all enabled targets: ```go // DeliveryPreferences - from Robot config (each channel supports multiple targets) type DeliveryPreferences struct { Email *EmailPreference `json:"email,omitempty"` Webhook *WebhookPreference `json:"webhook,omitempty"` Process *ProcessPreference `json:"process,omitempty"` } type EmailPreference struct { Enabled bool `json:"enabled"` Targets []EmailTarget `json:"targets"` // Multiple email targets } type WebhookPreference struct { Enabled bool `json:"enabled"` Targets []WebhookTarget `json:"targets"` // Multiple webhook URLs } type ProcessPreference struct { Enabled bool `json:"enabled"` Targets []ProcessTarget `json:"targets"` // Multiple Yao Process calls } ``` ### 6.3 File Wrapper Format Attachments use the standard `yao/attachment` wrapper format: ```go // Format: __:// // Example: __yao.attachment://ccd472d11feb96e03a3fc468f494045c import "github.com/yaoapp/yao/attachment" // Parse wrapper to get uploader and fileID uploader, fileID, isWrapper := attachment.Parse(wrapper) // uploader: "__yao.attachment" // fileID: "ccd472d11feb96e03a3fc468f494045c" // isWrapper: true // Get file info manager := attachment.Managers[uploader] fileInfo, err := manager.Info(ctx, fileID) // Read file content as base64 base64Content := attachment.Base64(ctx, wrapper) // Read with data URI format dataURI := attachment.Base64(ctx, wrapper, true) // "data:image/png;base64,..." ``` ### 6.4 Delivery Agent The Delivery Agent **only generates content**, does NOT decide channels. Channel decisions are made by Delivery Center based on Robot/User preferences. **Input:** ```go type DeliveryAgentInput struct { Robot *Robot `json:"robot"` // Robot identity and config TriggerType TriggerType `json:"trigger"` // clock | human | event Inspiration *InspirationReport `json:"inspiration"` // P0 (clock only) Goals *Goals `json:"goals"` // P1 Tasks []Task `json:"tasks"` // P2 Results []TaskResult `json:"results"` // P3 } ``` **Output:** ```go // DeliveryAgentOutput - only content, no channels type DeliveryAgentOutput struct { Content *DeliveryContent `json:"content"` // Generated content } ``` **Agent Responsibilities:** The agent focuses on content generation: - **Summary**: Brief 1-2 sentence summary of execution results - **Body**: Full markdown report with details - **Attachments**: Select which P3-generated files to include **Example Output:** ```json { "content": { "summary": "Sales report completed: 15 new leads processed, 3 high-priority", "body": "## Weekly Sales Report\n\n### Summary\n- Total leads: 15\n- High priority: 3\n...", "attachments": [ {"title": "Sales Report.pdf", "task_id": "task_1", "file": "__yao.attachment://abc123"}, {"title": "Lead Analysis.xlsx", "task_id": "task_2", "file": "__yao.attachment://def456"} ] } } ``` ### 6.5 Global Email Configuration Email delivery uses global configuration for channel selection and Robot-specific sender identity: ```go // types/config_global.go // DefaultEmailChannel returns the default messenger channel name // Default: "email" (maps to messengers/channels.yao) func DefaultEmailChannel() string // SetDefaultEmailChannel sets the default channel (call during agent init) func SetDefaultEmailChannel(channel string) ``` **Usage:** - `DefaultEmailChannel()` - returns the messenger channel name for email delivery - `Robot.RobotEmail` - used as the `From` address when sending emails - If `RobotEmail` is empty, falls back to provider's default `from` address ### 6.6 Delivery Center The Delivery Center receives `DeliveryRequest`, reads preferences, and executes delivery to **all enabled targets**. **Current implementation:** Internal to P4 (in `executor/delivery.go`) **Future:** Can be extracted to standalone `yao/delivery` package ```go // DeliveryCenter - handles delivery execution to multiple targets type DeliveryCenter struct { messenger *messenger.Manager } // Deliver - main entry point func (dc *DeliveryCenter) Deliver(ctx context.Context, req *DeliveryRequest) *DeliveryResult { requestID := generateID() prefs := dc.getDeliveryPreferences(ctx, req.Context.MemberID) var results []ChannelResult allSuccess := true // Email - send to all targets (robot passed for From address) if prefs.Email != nil && prefs.Email.Enabled { for _, target := range prefs.Email.Targets { result := dc.sendEmail(ctx, req.Content, target, req.Context, robot) results = append(results, result) if !result.Success { allSuccess = false } } } // Webhook - POST to all targets if prefs.Webhook != nil && prefs.Webhook.Enabled { for _, target := range prefs.Webhook.Targets { result := dc.postWebhook(ctx, req.Content, target) results = append(results, result) if !result.Success { allSuccess = false } } } // Process - call all targets if prefs.Process != nil && prefs.Process.Enabled { for _, target := range prefs.Process.Targets { result := dc.callProcess(ctx, req.Content, target) results = append(results, result) if !result.Success { allSuccess = false } } } // Future: auto-notify based on user subscriptions // dc.sendNotifications(ctx, req) return &DeliveryResult{ RequestID: requestID, Content: req.Content, Success: allSuccess, Results: results, } } ``` ### 6.7 Channel Handlers Each delivery channel is handled by dedicated methods in DeliveryCenter: ```go // sendEmail - send to a single email target // Uses Robot.RobotEmail as From address and global DefaultEmailChannel() func (dc *DeliveryCenter) sendEmail( ctx context.Context, content *DeliveryContent, target EmailTarget, deliveryCtx *DeliveryContext, robot *Robot, ) ChannelResult { // Convert attachments to messenger format var attachments []messenger.Attachment for _, att := range content.Attachments { uploader, fileID, _ := attachment.Parse(att.File) manager := attachment.Managers[uploader] data, _ := manager.Read(ctx, fileID) info, _ := manager.Info(ctx, fileID) attachments = append(attachments, messenger.Attachment{ Filename: att.Title, ContentType: info.ContentType, Content: data, }) } subject := content.Summary if target.Subject != "" { subject = target.Subject } msg := &messenger.Message{ To: target.To, Subject: subject, Body: content.Body, Attachments: attachments, } // Set From address from Robot's email (if configured) if robot != nil && robot.RobotEmail != "" { msg.From = robot.RobotEmail } // Use global default email channel channel := DefaultEmailChannel() // from types/config_global.go err := dc.messenger.Send(ctx, channel, msg) now := time.Now() return ChannelResult{ Type: DeliveryEmail, Target: strings.Join(target.To, ","), Success: err == nil, Recipients: target.To, SentAt: &now, Error: errStr(err), } } // postWebhook - POST to a single webhook target func (dc *DeliveryCenter) postWebhook(ctx context.Context, content *DeliveryContent, target WebhookTarget) ChannelResult { payload, _ := json.Marshal(content) req, _ := http.NewRequestWithContext(ctx, "POST", target.URL, bytes.NewReader(payload)) req.Header.Set("Content-Type", "application/json") // Add custom headers for k, v := range target.Headers { req.Header.Set(k, v) } resp, err := http.DefaultClient.Do(req) now := time.Now() if err != nil { return ChannelResult{ Type: DeliveryWebhook, Target: target.URL, Success: false, Error: err.Error(), SentAt: &now, } } defer resp.Body.Close() success := resp.StatusCode < 400 return ChannelResult{ Type: DeliveryWebhook, Target: target.URL, Success: success, Details: map[string]interface{}{"status_code": resp.StatusCode}, Error: ternary(!success, fmt.Sprintf("HTTP %d", resp.StatusCode), ""), SentAt: &now, } } // callProcess - call a single Yao Process target func (dc *DeliveryCenter) callProcess(ctx context.Context, content *DeliveryContent, target ProcessTarget) ChannelResult { // DeliveryContent as first arg, then additional args args := append([]interface{}{content}, target.Args...) proc := process.Of(target.Process, args...) result, err := proc.Execute() now := time.Now() return ChannelResult{ Type: DeliveryProcess, Target: target.Process, Success: err == nil, Details: map[string]interface{}{ "process": target.Process, "result": result, }, Error: errStr(err), SentAt: &now, } } ``` **Note on Notifications:** `notify` is NOT configured per-Robot. Future Delivery Center will: 1. Check user subscription preferences after receiving DeliveryRequest 2. Automatically send in-app notifications to subscribed users 3. This is transparent to P4 and Delivery Agent ### 6.8 Execution Persistence Robot execution history is stored in `__yao.agent_execution` table for UI display: ```go // Model: yao/models/agent/execution.mod.yao // Table: __yao.agent_execution type ExecutionRecord struct { ID int64 `json:"id,omitempty"` // Auto-increment primary key ExecutionID string `json:"execution_id"` // Unique execution identifier MemberID string `json:"member_id"` // Robot member ID (globally unique) TeamID string `json:"team_id"` // Team ID TriggerType TriggerType `json:"trigger_type"` // clock | human | event // Status tracking (synced with runtime Execution) Status ExecStatus `json:"status"` // pending | running | completed | failed | cancelled Phase Phase `json:"phase"` // Current phase Current *CurrentState `json:"current,omitempty"`// Current executing state (task index, progress) Error string `json:"error,omitempty"` // Error message if failed // Trigger input Input *TriggerInput `json:"input,omitempty"` // Original trigger input // Phase outputs (P0-P5) Inspiration *InspirationReport `json:"inspiration,omitempty"` // P0 result Goals *Goals `json:"goals,omitempty"` // P1 result Tasks []Task `json:"tasks,omitempty"` // P2 result Results []TaskResult `json:"results,omitempty"` // P3 results Delivery *DeliveryResult `json:"delivery,omitempty"` // P4 result Learning []LearningEntry `json:"learning,omitempty"` // P5 entries // Timestamps StartTime *time.Time `json:"start_time,omitempty"` EndTime *time.Time `json:"end_time,omitempty"` CreatedAt *time.Time `json:"created_at,omitempty"` UpdatedAt *time.Time `json:"updated_at,omitempty"` } // CurrentState - current executing state (for JSON storage) type CurrentState struct { TaskIndex int `json:"task_index"` // index in Tasks slice Progress string `json:"progress,omitempty"` // human-readable progress (e.g., "2/5 tasks") } ``` **Store Implementation:** ```go // store/execution.go type ExecutionStore struct { modelID string // "__yao.agent.execution" } func NewExecutionStore() *ExecutionStore // Save creates or updates an execution record func (s *ExecutionStore) Save(ctx context.Context, record *ExecutionRecord) error // Get retrieves an execution by execution_id func (s *ExecutionStore) Get(ctx context.Context, executionID string) (*ExecutionRecord, error) // List retrieves executions with filters func (s *ExecutionStore) List(ctx context.Context, opts *ListOptions) ([]*ExecutionRecord, error) // UpdatePhase updates the current phase and its data func (s *ExecutionStore) UpdatePhase(ctx context.Context, executionID string, phase Phase, data interface{}) error // UpdateStatus updates the execution status func (s *ExecutionStore) UpdateStatus(ctx context.Context, executionID string, status ExecStatus, errorMsg string) error // UpdateCurrent updates the current executing state func (s *ExecutionStore) UpdateCurrent(ctx context.Context, executionID string, current *CurrentState) error // Delete removes an execution record func (s *ExecutionStore) Delete(ctx context.Context, executionID string) error // Conversion helpers func FromExecution(exec *Execution) *ExecutionRecord func (r *ExecutionRecord) ToExecution() *Execution type ListOptions struct { MemberID string // Filter by robot member ID (globally unique) TeamID string // Filter by team Status ExecStatus // Filter by status TriggerType TriggerType // Filter by trigger Limit int // Max records to return (default: 100) Offset int // Skip records for pagination OrderBy string // e.g., "start_time desc" } ``` ================================================ FILE: agent/robot/TODO.md ================================================ # Robot Agent - Implementation TODO > Based on DESIGN.md and TECHNICAL.md > Test environment: `source yao/env.local.sh` > Test assistants: `yao-dev-app/assistants/robot/` --- ## Workflow: Human-AI Collaboration **Important:** Follow this workflow strictly for each sub-task. ``` ┌─────────────────────────────────────────────────────────────────┐ │ Implementation Workflow │ ├─────────────────────────────────────────────────────────────────┤ │ 1. AI: Implement code for current sub-task │ │ 2. AI: Present code for review (DO NOT write tests yet) │ │ 3. Human: Review code, provide feedback │ │ 4. AI: Iterate based on feedback │ │ 5. Human: Confirm "LGTM" or "Approved" │ │ 6. AI: Write tests for the approved code │ │ 7. Human: Review tests │ │ 8. AI: Run tests, fix if needed │ │ 9. Human: Confirm sub-task complete, move to next │ └─────────────────────────────────────────────────────────────────┘ ``` **Rules:** | Rule | Description | | ------------------------ | ------------------------------------------ | | One sub-task at a time | Focus only on current sub-task | | No tests before approval | Wait for human "LGTM" before writing tests | | No jumping ahead | Do not implement future phases | | Ask if unclear | When in doubt, ask before proceeding | --- ## Core Principle - Phase 1-2: Types + Skeleton (code compiles) - Phase 3: Complete scheduling system (Cache + Pool + Trigger + Dedup + Job), executor is stub - Phase 4-9: Implement executor phases one by one (P0 → P5) - Phase 10: API completion, end-to-end tests - Monitoring: Provided by Job system, no separate implementation --- ## Phase 1: Types & Interfaces ✅ **Goal:** Define all types, enums, interfaces. No logic, no external deps. **Status:** Complete - 88.4% test coverage, all tests passing ### 1.1 Enums (`types/enums.go`) - [x] `Phase` - execution phases (inspiration, goals, tasks, run, delivery, learning) - [x] `ClockMode` - clock trigger modes (times, interval, daemon) - [x] `TriggerType` - trigger sources (clock, human, event) - [x] `ExecStatus` - execution status (pending, running, completed, failed, cancelled) - [x] `RobotStatus` - robot status (idle, working, paused, error, maintenance) - [x] `InterventionAction` - human actions (task.add, goal.adjust, etc.) - [x] `Priority` - priority levels (high, normal, low) - [x] `DeliveryType` - delivery types (email, webhook, process, notify) - [x] `DedupResult` - dedup results (skip, merge, proceed) - [x] `EventSource` - event sources (webhook, database) - [x] `LearningType` - learning types (execution, feedback, insight) - [x] `TaskSource` - task sources (auto, human, event) - [x] `ExecutorType` - executor types (assistant, mcp, process) - [x] `TaskStatus` - task status (pending, running, completed, failed, skipped, cancelled) - [x] `InsertPosition` - insert positions (first, last, next, at) ### 1.2 Context (`types/context.go`) - [x] `Context` struct - robot execution context - [x] `NewContext()` - constructor - [x] `UserID()`, `TeamID()` - helper methods ### 1.3 Config Types (`types/config.go`) - [x] `Config` - main config struct - [x] `Triggers`, `TriggerSwitch` - trigger enable/disable - [x] `Clock` - clock config with validation - [x] `Identity` - role, duties, rules - [x] `Quota` - concurrency limits with defaults - [x] `KB`, `DB` - knowledge base and database config - [x] `Learn` - learning config - [x] `Resources`, `MCPConfig` - available agents and tools - [x] `Delivery` - output delivery config - [x] `Event` - event trigger config ### 1.4 Core Types (`types/robot.go`) - [x] `Robot` struct - runtime robot representation - [x] `Robot` methods - `CanRun()`, `RunningCount()`, `AddExecution()`, `RemoveExecution()`, `GetExecution()`, `GetExecutions()` - [x] `Execution` struct - single execution instance - [x] `TriggerInput` - stored trigger input - [x] `CurrentState` - current executing state - [x] `Goals` - P1 output (markdown) - [x] `Task` - planned task (structured) - [x] `TaskResult` - task execution result - [x] `DeliveryResult` - delivery output - [x] `LearningEntry` - knowledge to save ### 1.5 Clock Context (`types/clock.go`) - [x] `ClockContext` struct - time context for P0 - [x] `NewClockContext()` - constructor ### 1.6 Inspiration (`types/inspiration.go`) - [x] `InspirationReport` struct - P0 output ### 1.7 Request/Response (`types/request.go`) - [x] `InterveneRequest` - human intervention request - [x] `EventRequest` - event trigger request - [x] `ExecutionResult` - trigger result - [x] `RobotState` - robot status query result ### 1.8 Interfaces (`types/interfaces.go`) - [x] `Manager` interface - [x] `Executor` interface - [x] `Pool` interface - [x] `Cache` interface - [x] `Dedup` interface - [x] `Store` interface ### 1.9 Errors (`types/errors.go`) - [x] Config errors - [x] Runtime errors - [x] Phase errors ### 1.10 Tests - [x] `types/enums_test.go` - enum validation - [x] `types/config_test.go` - config validation - [x] `types/clock_test.go` - clock context creation - [x] `types/robot_test.go` - robot methods --- ## Phase 2: Skeleton Implementation ✅ **Goal:** Create all packages with empty/stub implementations. Code compiles. **Status:** Complete - All packages compile successfully, no circular dependencies ### 2.1 Utils (`utils/`) ✅ - [x] `utils/convert.go` - JSON, map, struct conversions (implement) - [x] `utils/time.go` - time parsing, formatting, timezone (implement) - [x] `utils/id.go` - ID generation (nanoid) (implement) - [x] `utils/validate.go` - validation helpers (implement) - [x] Test: `utils/utils_test.go` ### 2.2 Package Skeletons ✅ (stubs only, implemented in Phase 3) Create empty structs and stub methods that return nil/empty/success: - [x] `cache/cache.go` - Cache struct, stub methods - [x] `dedup/dedup.go` - Dedup struct, stub methods - [x] `store/store.go` - Store struct, stub methods - [x] `pool/pool.go` - Pool struct, stub methods - [x] `plan/plan.go` - Plan struct, stub methods - [x] `trigger/trigger.go` - trigger dispatcher stub - [x] `executor/executor.go` - Executor struct, stub `Execute()` - [x] `manager/manager.go` - Manager struct, stub methods ### 2.3 API Skeletons ✅ - [x] `api/api.go` - Go API facade (all function signatures, return errors) - [x] `api/process.go` - Yao Process registration (all processes, return errors) - [x] `api/jsapi.go` - JSAPI registration (all methods, return errors) ### 2.4 Root ✅ - [x] `robot.go` - package entry - [x] `Init()` - placeholder - [x] `Shutdown()` - placeholder ### 2.5 Compile Test ✅ - [x] All packages compile without errors - [x] All imports resolve correctly - [x] No circular dependencies --- ## Phase 3: Complete Scheduling System ✅ **Goal:** Implement complete scheduling system. Executor is stub (simulates success). **Status:** Complete - All 7 sub-tasks done, 80+ integration tests passing This phase delivers a fully working scheduling pipeline: ``` Trigger → Manager → Cache → Dedup → Pool → Worker → Executor(stub) → Job ``` ### ✅ 3.1 Cache Implementation (COMPLETE) - [x] `cache/cache.go` - Cache struct with thread-safe map - [x] `cache/load.go` - load robots from `__yao.member` where `member_type='robot'` and `autonomous_mode=true` - [x] Implemented pagination (100 robots per page) - [x] Configurable model name via `SetMemberModel()` - [x] `cache/refresh.go` - refresh single robot, periodic full refresh (every hour) - [x] Test: load/refresh with real DB - [x] Created comprehensive integration tests with real database - [x] Tests cover Load, LoadByID, Refresh, ListByTeam, GetByStatus - [x] All tests passing with proper cleanup ### ✅ 3.2 Pool Implementation (COMPLETE) - [x] `pool/pool.go` - worker pool with configurable size (global limit) - [x] Default config: 10 workers, 100 queue size - [x] Configurable via `pool.NewWithConfig()` - [x] `pool/queue.go` - priority queue (sorted by: robot priority, trigger type, wait time) - [x] Two-level limit: global queue + per-robot queue - [x] Priority: Robot Priority × 1000 + Trigger Priority × 100 - [x] `pool/worker.go` - worker goroutines, dispatch to executor - [x] Non-blocking quota check with re-enqueue - [x] Graceful shutdown support - [x] Test: submit jobs, verify execution order, verify concurrency limits - [x] 15 test cases covering all edge cases - [x] All tests passing ### ✅ 3.3 Manager Implementation (COMPLETE) > **Note:** Manager is the scheduling core, depends on completed Cache and Pool. - [x] `manager/manager.go` - Manager struct - [x] `Start()` - load cache, start pool, start ticker goroutine - [x] `Stop()` - graceful shutdown (wait for running, drain queue) - [x] `Tick()` - main loop: 1. Get all cached robots 2. For each robot with clock trigger enabled 3. Check if should execute (times/interval/daemon modes) 4. Submit to pool - [x] `TriggerManual()` - manual trigger for testing/API - [x] Clock modes: times, interval, daemon - [x] Day matching for times mode - [x] Timezone handling - [x] Skip paused/error/maintenance robots - [x] Test: manager start/stop, tick cycle, manual trigger, clock modes, goroutine leak ### ✅ 3.4 Trigger Implementation (COMPLETE) - [x] `trigger/trigger.go` - validation and helper functions - [x] `ValidateIntervention()` - validate human intervention requests - [x] `ValidateEvent()` - validate event trigger requests - [x] `BuildEventInput()` - build TriggerInput from event request - [x] `GetActionCategory()` / `GetActionDescription()` - action helpers - [x] `trigger/clock.go` - ClockMatcher for clock trigger matching - [x] `times` mode: match specific times (09:00, 14:00) - [x] `interval` mode: run every X duration (30m, 1h) - [x] `daemon` mode: restart immediately after completion - [x] Timezone handling - [x] Day-of-week filtering - [x] `trigger/control.go` - ExecutionController for pause/resume/stop - [x] Track/Untrack executions - [x] Pause/Resume execution - [x] Stop execution (cancel context) - [x] WaitIfPaused() for executor integration - [x] `manager/manager.go` - integrated trigger handling - [x] `Intervene()` - human intervention handler - [x] `HandleEvent()` - event trigger handler - [x] `PauseExecution()` / `ResumeExecution()` / `StopExecution()` - [x] `ListExecutions()` / `ListExecutionsByMember()` - [x] Tests: `trigger/trigger_test.go`, `trigger/clock_test.go`, `trigger/control_test.go` - [x] Validation tests for intervention and event requests - [x] Clock matching tests for all modes - [x] ExecutionController lifecycle tests - [x] Manager integration tests for Intervene/HandleEvent ### ✅ 3.5 Execution Storage (COMPLETE) - [x] ExecutionStore - execution record persistence - [x] Execution data stored in `__yao.agent_execution` table - [x] All phase outputs (Inspiration, Goals, Tasks, Results, Delivery, Learning) - [x] Status and phase tracking - [x] Logging via `kun/log` package - [x] Localization support (en-US, zh-CN) - [x] Test: execution storage, status tracking - [x] `store/execution_test.go` - execution store tests - [x] All tests passing with real database ### ✅ 3.6 Executor Architecture (COMPLETE) Pluggable executor architecture with multiple execution modes: ``` executor/ ├── types/ │ ├── types.go # Executor interface, Config types │ └── helpers.go # Shared helper functions ├── standard/ │ ├── executor.go # Real Agent execution (production) │ ├── agent.go # AgentCaller for LLM calls │ ├── input.go # InputFormatter for prompts │ ├── inspiration.go # P0: Inspiration phase │ ├── goals.go # P1: Goals phase │ ├── tasks.go # P2: Tasks phase │ ├── run.go # P3: Run phase │ ├── delivery.go # P4: Delivery phase │ └── learning.go # P5: Learning phase ├── dryrun/ │ └── executor.go # Simulated execution (testing/demo) ├── sandbox/ │ └── executor.go # Container-isolated (NOT IMPLEMENTED) └── executor.go # Factory functions ``` **Execution Modes:** | Mode | Use Case | Status | | -------- | -------------------------------- | ------------------ | | Standard | Production with real Agent calls | ✅ Implemented | | DryRun | Tests, demos, scheduling tests | ✅ Implemented | | Sandbox | Container-isolated execution | ⬜ Not Implemented | > **⚠️ Sandbox Mode:** Requires container-level isolation (Docker/gVisor/Firecracker) > for true security. Current placeholder behaves like DryRun. Future feature. - [x] `executor/types/types.go` - `Executor` interface, `PhaseExecutor` interface - [x] `executor/types/helpers.go` - `BuildTriggerInput()` shared helper - [x] `executor/executor.go` - Factory functions (`New`, `NewDryRun`, `NewWithMode`) - [x] `executor/standard/executor.go` - Real execution with Job integration - [x] `executor/standard/phases.go` - Phase implementations (P0-P5) - [x] `executor/dryrun/executor.go` - Simulated execution with callbacks - [x] `executor/sandbox/executor.go` - Placeholder (NOT IMPLEMENTED) - [x] Manager integration - accepts `Executor` interface via config - [x] Tests use DryRun mode for scheduling/concurrency tests ### 3.7 Integration Test (End-to-End Scheduling) ✅ - [x] Create test robot in `__yao.member` with clock config - [x] Start manager - [x] Wait for clock trigger - [x] Verify: - [x] Robot loaded to cache - [x] Clock trigger matched - [x] Job submitted to pool - [x] Worker picked up job - [x] Executor stub called - [x] Job execution recorded - [x] Logs written - [x] Test human intervention trigger - [x] Test event trigger - [x] Test concurrent executions (multiple robots) - [x] Test quota enforcement (per-robot limit) - [x] Test pause/resume/stop **Test Files Created:** - `manager/integration_test.go` - Core scheduling flow (Cache→Pool→Executor) - `manager/integration_clock_test.go` - Clock trigger modes (times/interval/daemon) - `manager/integration_human_test.go` - Human intervention trigger tests - `manager/integration_event_test.go` - Event trigger tests - `manager/integration_concurrent_test.go` - Concurrent execution & quota tests - `manager/integration_control_test.go` - Pause/Resume/Stop tests **Test Coverage:** - 27 top-level test functions - 80+ sub-tests covering all verification points - 3x run stability verified --- ## Phase 4: Agent Call Infrastructure ✅ **Goal:** Implement unified Agent/Assistant calling mechanism. This is the foundation for all phase implementations (P0-P5). **Architecture Note:** - **Prompt construction is handled by Assistant layer** (`prompts.yml` in each assistant) - **Executor only prepares input data** (ClockContext, InspirationReport, etc.) and calls Assistant - **Assistant framework handles** prompt rendering, LLM API calls, streaming **Implemented:** 1. A unified way to call assistants with streaming support 2. Input data formatting for each phase 3. Response parsing (markdown and structured data via `gou/text`) 4. Multi-turn conversation support ### 4.1 Agent Caller Implementation ✅ - [x] `executor/agent.go` - `AgentCaller` struct with `SkipOutput`, `SkipHistory`, `SkipSearch`, `ChatID` - [x] `executor/agent.go` - `Call(ctx, assistantID, messages)` - basic call with full response - [x] `executor/agent.go` - `CallWithMessages(ctx, assistantID, userContent)` - convenience method - [x] `executor/agent.go` - `CallWithSystemAndUser(ctx, assistantID, systemContent, userContent)` - [x] `executor/agent.go` - handle assistant not found error - [x] `executor/agent.go` - handle LLM API errors gracefully - [x] `executor/agent.go` - `CallResult.GetJSON()` / `GetJSONArray()` - parse JSON response using `gou/text` - [x] `executor/agent.go` - `Conversation` struct for multi-turn dialogues - [x] `executor/agent.go` - `Conversation.Turn()`, `RunUntil()`, `Reset()`, `WithSystemPrompt()` - [x] `executor/agent.go` - Use `agentcontext.Noop()` logger to suppress debug output ### 4.2 Input Formatters ✅ - [x] `executor/input.go` - `FormatClockContext(clockCtx, robot)` - format clock context as message content - [x] `executor/input.go` - `FormatInspirationReport(report)` - format P0 output for P1 input - [x] `executor/input.go` - `FormatTriggerInput(input)` - format Human/Event trigger for P1 input - [x] `executor/input.go` - `FormatGoals(goals, robot)` - format P1 output for P2 input - [x] `executor/input.go` - `FormatTasks(tasks)` - format P2 output for P3 input - [x] `executor/input.go` - `FormatTaskResults(results)` - format P3 output for P4/P5 input - [x] `executor/input.go` - `FormatExecutionSummary(exec)` - format full execution for P5 input - [x] `executor/input.go` - `BuildMessages()`, `BuildMessagesWithSystem()` - helper methods ### 4.3 Test Assistants ✅ - [x] `yao-dev-app/assistants/tests/robot-single/` - Single-turn test assistant - [x] `yao-dev-app/assistants/tests/robot-conversation/` - Multi-turn conversation test assistant ### 4.4 Tests ✅ - [x] `executor/agent_test.go` - 22 test cases for AgentCaller and Conversation - [x] `executor/input_test.go` - 20 test cases for InputFormatter - [x] Verify: assistant can be called and returns response - [x] Verify: multi-turn conversation maintains state - [x] Verify: input data is well-formatted for assistant prompts - [x] Verify: JSON/YAML extraction from LLM output works correctly --- ## Phase 5: Test Scenario & Assistants Setup ✅ **Goal:** Create realistic test scenarios with all required assistants. **Architecture:** ``` ┌─────────────────────────────────────────────────────────────────────────────┐ │ 6 Generic Phase Agents (P0-P5) │ ├─────────────────────────────────────────────────────────────────────────────┤ │ inspiration │ goals │ tasks │ validation │ delivery │ learning │ │ (P0) │ (P1) │ (P2) │ (P3) │ (P4) │ (P5) │ └───────────────┴─────────┴─────────┴──────────────┴────────────┴─────────────┘ ↓ P2 assigns tasks to ┌─────────────────────────────────────────────────────────────────────────────┐ │ Expert Agents (Task Executors) │ ├─────────────────────────────────────────────────────────────────────────────┤ │ text-writer │ web-reader │ data-analyst │ summarizer │ ... │ │ (Generate) │ (Fetch URL) │ (Analyze) │ (Summarize) │ │ └───────────────┴──────────────┴────────────────┴──────────────┴─────────────┘ ``` **Test Strategy:** - Phase Agents (P0-P5) are **generic** and reusable across all robot types - Expert Agents are **specialized** for specific tasks (text, web, data, etc.) - Each P0-P5 test uses **different expert combinations** to cover real scenarios - Tests use `interval: 1s` or `TriggerManual()` for easy triggering (no time dependency) ### 5.1 Directory Structure ``` yao-dev-app/assistants/ ├── robot/ # Generic Phase Agents │ ├── inspiration/ # P0: Analyze clock context, generate insights │ │ ├── package.yao │ │ └── prompts.yml │ ├── goals/ # P1: Generate prioritized goals │ │ ├── package.yao │ │ └── prompts.yml │ ├── tasks/ # P2: Split goals into executable tasks │ │ ├── package.yao │ │ └── prompts.yml │ ├── validation/ # P3: Validate task results │ │ ├── package.yao │ │ └── prompts.yml │ ├── delivery/ # P4: Format and deliver results │ │ ├── package.yao │ │ └── prompts.yml │ └── learning/ # P5: Summarize execution, extract insights │ ├── package.yao │ └── prompts.yml │ └── experts/ # Expert Agents (Task Executors) ├── text-writer/ # Generate text content (reports, emails, summaries) │ ├── package.yao │ └── prompts.yml ├── web-reader/ # Fetch and parse web page content │ ├── package.yao │ └── prompts.yml ├── data-analyst/ # Analyze data, generate insights │ ├── package.yao │ └── prompts.yml └── summarizer/ # Summarize long text into key points ├── package.yao └── prompts.yml ``` ### 5.2 Generic Phase Agents #### 5.2.1 Inspiration Agent (P0) - [x] `robot/inspiration/package.yao` - config with model, temperature - [x] `robot/inspiration/prompts.yml` - system prompt: - Input: Clock context (time, day, markers), robot identity - Output: Markdown report with Summary, Highlights, Opportunities, Risks - Style: Analytical, context-aware #### 5.2.2 Goals Agent (P1) - [x] `robot/goals/package.yao` - config - [x] `robot/goals/prompts.yml` - system prompt: - Input: Inspiration report OR trigger input (human/event) - Output: Prioritized goals in markdown (High/Normal/Low) - Style: Strategic, actionable #### 5.2.3 Tasks Agent (P2) - [x] `robot/tasks/package.yao` - config - [x] `robot/tasks/prompts.yml` - system prompt: - Input: Goals, available expert agents list - Output: Structured task list (JSON) with executor assignments - Style: Detailed, executable #### 5.2.4 Validation Agent (P3) - [x] `robot/validation/package.yao` - config - [x] `robot/validation/prompts.yml` - system prompt: - Input: Task result, expected outcome - Output: Validation result (pass/fail, issues, suggestions) - Style: Critical, thorough #### 5.2.5 Delivery Agent (P4) - [x] `robot/delivery/package.yao` - config - [x] `robot/delivery/prompts.yml` - system prompt: - Input: Full execution context (P0-P3 results) - Output: Formatted delivery content - Style: Clear, professional #### 5.2.6 Learning Agent (P5) - [x] `robot/learning/package.yao` - config - [x] `robot/learning/prompts.yml` - system prompt: - Input: Full execution summary - Output: Insights, patterns, improvement suggestions - Style: Reflective, insightful ### 5.3 Expert Agents (Task Executors) #### 5.3.1 Text Writer - [x] `experts/text-writer/package.yao` - config - [x] `experts/text-writer/prompts.yml` - system prompt: - Input: Topic, key points, style (formal/casual), length - Output: Generated text content - Use cases: Weekly reports, email drafts, summaries #### 5.3.2 Web Reader - [x] `experts/web-reader/package.yao` - config with hooks - [x] `experts/web-reader/prompts.yml` - system prompt: - Input: URL or topic to search - Output: Extracted content, key information - Use cases: News fetching, competitor monitoring, research - [x] `experts/web-reader/src/fetch.ts` - HTTP fetching utilities - [x] `experts/web-reader/src/fetch_test.ts` - 19 test cases (100% pass) - [x] `experts/web-reader/src/index.ts` - Create/Next hooks #### 5.3.3 Data Analyst - [x] `experts/data-analyst/package.yao` - config - [x] `experts/data-analyst/prompts.yml` - system prompt: - Input: Data description, analysis goal - Output: Analysis report, trends, insights - Use cases: Sales analysis, performance review #### 5.3.4 Summarizer - [x] `experts/summarizer/package.yao` - config - [x] `experts/summarizer/prompts.yml` - system prompt: - Input: Long text content - Output: Concise summary with key points - Use cases: Document summarization, meeting notes ### 5.4 Test Scenarios Each phase test uses different expert combinations: | Test | Phase | Trigger | Expert Agents Used | Verification | | ---- | ----- | ---------------- | ------------------------ | ----------------------------- | | T1 | P0 | Clock (interval) | - | Clock → Inspiration report | | T2 | P1 | Clock | - | Inspiration → Goals | | T3 | P1 | Human | - | User input → Goals | | T4 | P2 | Clock | text-writer, web-reader | Goals → Tasks with executors | | T5 | P3 | Clock | text-writer | Task exec → Result validation | | T6 | P3 | Human | summarizer | Task exec → Result validation | | T7 | P4 | Clock | - | Results → Delivery format | | T8 | P5 | Clock | - | Full execution → Insights | | T9 | E2E | Clock | text-writer, summarizer | Full P0→P5 flow | | T10 | E2E | Human | web-reader, data-analyst | Full P1→P5 flow | ### 5.5 Verification - [x] All 6 Phase Agents load correctly (`robot.inspiration`, `robot.goals`, etc.) - [x] All 4 Expert Agents load correctly (`experts.text-writer`, `experts.web-reader`, etc.) - [x] Web Reader `fetch.ts` utilities tested (19 tests, 100% pass) --- ## Phase 6: P0 Inspiration Implementation ✅ **Goal:** Implement P0 (Inspiration Agent). Clock trigger → P0 → stub P1-P5. **Depends on:** Phase 4 (Agent Call Infrastructure), Phase 5 (Assistants Setup) **Status:** COMPLETED ### 6.1 P0 Implementation - [x] `executor/inspiration.go` - `RunInspiration(ctx, exec, data)` - real implementation - [x] `executor/inspiration.go` - build prompt using `InputFormatter.FormatClockContext()` - [x] `executor/inspiration.go` - call Inspiration Agent using `AgentCaller` - [x] `executor/inspiration.go` - parse response to `InspirationReport` (markdown content) - [x] `types/robot.go` - added `GetRobot()`/`SetRobot()` methods for Execution - [x] `executor/executor.go` - set robot reference on execution creation ### 6.2 Tests - [x] `executor/inspiration_test.go` - P0 with real LLM call (8 test cases) - [x] Test: clock context correctly formatted in prompt - [x] Test: robot identity included in prompt - [x] Test: markdown report generated with expected sections - [x] Test: handles LLM errors gracefully (robot nil, agent not found) - [x] Test: uses clock from trigger input or creates new one - [x] `InputFormatter.FormatClockContext()` unit tests (4 test cases) ### 6.3 Notes - `executor_test.go` temporarily moved to `.bak` - will restore when all phases implemented - P0 uses `robot.inspiration` test agent from `yao-dev-app/assistants/robot/inspiration/` --- ## Phase 7: P1 Goals Implementation ✅ **Goal:** Implement P1 (Goal Generation Agent). P0 → P1 → stub P2-P5. **Depends on:** Phase 6 (P0 Inspiration) **Status:** COMPLETED ### 7.1 P1 Implementation - [x] `executor/goals.go` - `RunGoals(ctx, exec, data)` - real implementation - [x] `executor/goals.go` - build prompt with inspiration report (Clock trigger) - [x] `executor/goals.go` - build prompt with trigger input (Human/Event trigger) - [x] `executor/goals.go` - call Goals Agent using `AgentCaller` - [x] `executor/goals.go` - parse response to `Goals` struct (JSON with content + delivery) - [x] `executor/goals.go` - handle Human/Event trigger (skip P0, use input directly) - [x] `executor/goals.go` - include robot identity in prompt - [x] `executor/goals.go` - include available resources in prompt - [x] `executor/goals.go` - `ParseDelivery()` - parse delivery target from JSON - [x] `executor/goals.go` - `IsValidDeliveryType()` - validate delivery types ### 7.2 Tests - [x] `executor/goals_test.go` - P1 with real LLM call (14 test cases) - [x] Test: inspiration report in prompt (Clock trigger) - [x] Test: user input in prompt (Human trigger) - [x] Test: event data in prompt (Event trigger) - [x] Test: goals markdown generated with priorities - [x] Test: delivery parsing from agent response - [x] Test: error handling (robot nil, agent not found, empty input) - [x] Test: fallback behavior (no inspiration → clock context) - [x] `ParseDelivery()` unit tests (8 test cases covering edge cases) - [x] `IsValidDeliveryType()` unit tests ### 7.3 Notes - P1 uses `robot.goals` test agent from `yao-dev-app/assistants/robot/goals/` - Goals Agent returns JSON: `{ "content": "...", "delivery": {...} }` - Delivery is optional; if not present or invalid, `Goals.Delivery` is nil - Available resources (agents, MCP, KB, DB) are passed to agent for achievable goal generation --- ## Phase 8: P2 Tasks Implementation ✅ **Goal:** Implement P2 (Task Planning Agent). P1 → P2 → stub P3-P5. **Depends on:** Phase 7 (P1 Goals) **Status:** COMPLETED ### 8.1 Validation Agent Setup (Prerequisite for P3) ✅ > **Note:** Validation Agent was already set up in Phase 5. - [x] `robot/validation/package.yao` - Validation Agent config (DeepSeek V3, temperature 0.2) - [x] `robot/validation/prompts.yml` - validation prompts - Input: Task result, expected outcome, validation rules - Output: Validation result (pass/fail, score, issues, suggestions) ### 8.2 P2 Implementation ✅ - [x] `executor/tasks.go` - `RunTasks(ctx, exec, data)` - real implementation - [x] `executor/tasks.go` - build prompt with goals (using `FormatGoals`) - [x] `executor/tasks.go` - include available tools/agents in prompt - [x] `executor/tasks.go` - include delivery target in prompt (for task output format) - [x] `executor/tasks.go` - call Tasks Agent using `AgentCaller` - [x] `executor/tasks.go` - parse response to `[]Task` (structured JSON) - [x] `executor/tasks.go` - validate task structure (executor type, ID, messages) - [x] `executor/tasks.go` - `ParseTasks()`, `ParseTask()`, `ParseMessages()` helpers - [x] `executor/tasks.go` - `SortTasksByOrder()` - ensure correct execution sequence - [x] `executor/tasks.go` - `ValidateExecutorExists()` - optional executor existence check - [x] `executor/tasks.go` - `ValidateTasksWithResources()` - validation with warnings - [x] `executor/input.go` - `FormatGoals()` updated to include Delivery Target ### 8.3 Tests ✅ - [x] `executor/tasks_test.go` - P2 with real LLM call (7 integration tests) - [x] Test: goals included in prompt - [x] Test: available tools listed in prompt - [x] Test: delivery target included in prompt - [x] Test: structured tasks generated - [x] Test: each task has valid executor type and ID - [x] Test: each task has expected output and validation rules - [x] `ParseTasks` unit tests (5 tests) - [x] `ValidateTasks` unit tests (5 tests) - [x] `SortTasksByOrder` unit tests (4 tests) - [x] `ValidateExecutorExists` unit tests (7 tests) - [x] `ValidateTasksWithResources` unit tests (3 tests) - [x] `ParseExecutorType` unit tests (5 tests) - [x] `IsValidExecutorType` unit tests (2 tests) - [x] `FormatGoals` with delivery target tests (4 tests) ### 8.4 Notes - Tasks Agent returns JSON: `{ "tasks": [...] }` - Each task includes: id, executor_type, executor_id, messages, expected_output, validation_rules, order - Tasks are sorted by `order` field after parsing - Executor existence is optionally validated (warnings only, doesn't block) - Delivery target from P1 is passed to P2 so tasks can produce appropriate output format --- ## Phase 9: P3 Run Implementation ✅ **Goal:** Implement P3 (Task Execution + Validation). P2 → P3 → stub P4-P5. **Depends on:** Phase 8 (P2 Tasks + Validation Agent) **Status:** Complete ### 9.1 Implementation ✅ - [x] `executor/run.go` - `RunExecution(ctx, exec, data)` - real implementation - [x] `RunConfig` - configuration (ContinueOnFailure, ValidationThreshold, MaxTurnsPerTask) - [x] Sequential task execution with progress tracking - [x] Task status updates (Running → Completed/Failed/Skipped) - [x] `ContinueOnFailure` option for graceful failure handling - [x] Previous task results passed as context to subsequent tasks - [x] `executor/runner.go` - `Runner` struct for task execution - [x] `ExecuteWithRetry()` - multi-turn conversation flow for assistant tasks - [x] `executeNonAssistantTask()` - single-call execution for MCP/Process - [x] `executeAssistantWithMultiTurn()` - AI assistant with conversation support - [x] `ExecuteMCPTask()` - MCP tool execution (format: `clientID.toolName`) - [x] `ExecuteProcessTask()` - Yao process execution - [x] `BuildTaskContext()` - context with previous results - [x] `BuildAssistantMessages()` - build messages for assistant - [x] `FormatPreviousResultsAsContext()` - format previous results as context - [x] `extractOutput()` - extract output from CallResult - [x] `generateDefaultReply()` - fallback reply generation - [x] `executor/validator.go` - Two-layer validation system - [x] Layer 1: Rule-based validation using `yao/assert` - [x] Layer 2: Semantic validation using Validation Agent - [x] `ValidateWithContext()` - validation with multi-turn support - [x] `isComplete()` - determine if expected result is obtained - [x] `checkNeedReply()` - determine if conversation should continue - [x] `generateFeedbackReply()` - generate validation feedback for next turn - [x] `detectNeedMoreInfo()` - detect if assistant needs clarification - [x] `convertStringRule()` - natural language rules to assertions - [x] `parseRules()` - JSON and string rule parsing - [x] `mergeResults()` - combine rule and semantic results ### 9.2 Assert Package ✅ Created new `yao/assert` package for universal assertion/validation: - [x] `assert/types.go` - `Assertion`, `Result`, `AssertionOptions` types - [x] `assert/asserter.go` - `Asserter` with 8 assertion types: - [x] `equals` - exact match - [x] `contains` - substring check - [x] `not_contains` - negative substring check - [x] `json_path` - JSON path extraction and comparison - [x] `regex` - regex pattern matching - [x] `type` - type checking (with optional path) - [x] `script` - custom script validation - [x] `agent` - AI agent validation - [x] `assert/helpers.go` - `ValidateOutput()`, `ExtractPath()`, `ToString()`, `GetType()` - [x] `assert/asserter_test.go` - 98.7% test coverage ### 9.3 Tests **Completed:** - [x] `assert/asserter_test.go` - 40+ test cases (98.7% coverage) - [x] `types/robot_test.go` - Task structure tests with validation rules - [x] `tasks_test.go` - ParseTasks with validation rules format - [x] Validation rules format aligned with `prompts.yml` guidelines **Completed Tests:** - [x] `executor/standard/run_test.go` - P3 RunExecution tests ✅ - [x] Test: tasks executed in order (`TestRunExecutionBasic`) - [x] Test: task status updates (`TestRunExecutionTaskStatus`) - [x] Test: remaining tasks marked as skipped on failure - [x] Test: error handling (robot nil, no tasks, non-existent assistant) - [x] Test: rule-based and semantic validation (`TestRunExecutionValidation`) - [x] Test: previous results passed as context to subsequent tasks - [x] `executor/standard/runner_test.go` - Runner tests ✅ - [x] Test: ExecuteWithRetry with multi-turn conversation flow - [x] Test: max turns limit enforcement - [x] Test: BuildTaskContext with previous results - [x] Test: FormatPreviousResultsAsContext formatting - [x] Test: BuildAssistantMessages with task content - [x] Test: FormatMessagesAsText (string, multipart, map) - [x] Test: MCP and Process tasks (skipped - requires runtime) - [x] `executor/standard/validator_test.go` - Validator tests ✅ - [x] Test: ValidateWithContext with multi-turn state - [x] Test: isComplete determination logic - [x] Test: checkNeedReply scenarios - [x] Test: convertStringRule for natural language rules - [x] Test: parseRules for JSON assertions (equals, regex, json_path, type) - [x] Test: validateSemantic with Validation Agent - [x] Test: mergeResults logic (rule + semantic) **Completed:** - [x] Test: ContinueOnFailure option (run_test.go) ✅ - [x] `stops_on_first_failure_when_ContinueOnFailure_is_false` - [x] `continues_execution_when_ContinueOnFailure_is_true` - [x] `multiple_failures_with_ContinueOnFailure` --- ## Phase 10: P4 Delivery Implementation **Goal:** Implement P4 (Delivery). P3 → P4 → stub P5. **Depends on:** Phase 9 (P3 Run) ### 10.1 Execution Persistence (Prerequisite) ✅ > **Background:** Each Robot execution (P0-P5) needs persistent storage for UI history queries. - [x] `yao/models/agent/execution.mod.yao` - Execution record model (`agent_execution` table) - [x] id, execution_id (unique) - [x] member_id (globally unique), team_id - [x] trigger_type (enum: clock, human, event) - [x] **Status tracking** (synced with runtime Execution): - [x] status (enum: pending, running, completed, failed, cancelled) - [x] phase (enum: inspiration, goals, tasks, run, delivery, learning) - [x] current (JSON) - current executing state (task_index, progress) - [x] error - error message if failed - [x] input (JSON) - trigger input - [x] **Phase outputs** (P0-P5): - [x] inspiration (JSON) - P0 output - [x] goals (JSON) - P1 output - [x] tasks (JSON) - P2 output - [x] results (JSON) - P3 output - [x] delivery (JSON) - P4 output - [x] learning (JSON) - P5 output - [x] **Timestamps**: start_time, end_time, created_at, updated_at - [x] Relations: member (hasOne __yao.member) - [x] `agent/robot/store/execution.go` - Execution record storage - [x] `Save(ctx, record)` - create or update execution record - [x] `Get(ctx, execID)` - get execution by ID - [x] `List(ctx, opts)` - query execution history with filters - [x] `UpdatePhase(ctx, execID, phase, data)` - update current phase and data - [x] `UpdateStatus(ctx, execID, status, error)` - update execution status - [x] `UpdateCurrent(ctx, execID, current)` - update current executing state - [x] `Delete(ctx, execID)` - delete execution record - [x] `FromExecution(exec, robotID)` - convert runtime Execution to record - [x] `ToExecution()` - convert record to runtime Execution - [x] Tests: `agent/robot/store/execution_test.go` (9 test groups, all passing) - [x] Integrate into Executor - call `UpdatePhase()` after each phase completes - [x] Added `SkipPersistence` config option to `executor/types/Config` - [x] Added `ExecutionStore` to `executor/standard/Executor` - [x] Save execution record at start of `Execute()` - [x] Call `UpdatePhase()` after each phase completes in `runPhase()` - [x] Call `UpdateStatus()` on status changes (running, completed, failed) ### 10.2 Messenger Attachment Support ✅ > **Conclusion:** All email providers now support attachments. **Implementation Status:** | Provider | Attachment Support | Implementation | |----------|-------------------|----------------| | Twilio/SendGrid | ✅ Supported | `buildAttachments()` - base64 encoded | | Mailgun | ✅ Supported | `sendEmailWithAttachments()` - multipart/form-data | | SMTP (mailer) | ✅ Supported | `buildMessageWithAttachments()` - MIME multipart/mixed | **Features Supported:** - Regular attachments (Content-Disposition: attachment) - Inline attachments (Content-Disposition: inline) with Content-ID for HTML embedding - Multiple attachments per email - Automatic content type detection - Base64 encoding for SMTP (RFC 2045 compliant, 76-char line wrapping) **Tests Added:** - `messenger/providers/mailgun/mailgun_test.go`: - `TestSend_EmailWithAttachments_MockServer` - `TestSend_EmailWithInlineAttachment_MockServer` - `TestSend_EmailWithAttachments_RealAPI` - `messenger/providers/mailer/mailer_test.go`: - `TestBuildMessage_WithAttachments` (single, multiple, inline, no attachments) - `TestSend_EmailWithAttachments_RealAPI` ```go // messenger/types/types.go type Attachment struct { Filename string `json:"filename"` ContentType string `json:"content_type"` Content []byte `json:"content"` Inline bool `json:"inline,omitempty"` CID string `json:"cid,omitempty"` } ``` Supported channels: - [x] Email - Full attachment support - [x] SMS - No attachment (text only) - [x] WhatsApp - TBD ### 10.3 Type Updates (Prerequisite) ✅ - [x] Update `types/enums.go` - Update `DeliveryType` enum - [x] Remove `DeliveryFile` - [x] Add `DeliveryProcess` - [x] Update `types/robot.go` - Delivery types for new architecture - [x] `DeliveryResult` - update to new structure (RequestID, Content, Results[]) - [x] Add `DeliveryContent` struct - [x] Add `DeliveryAttachment` struct - [x] Add `DeliveryRequest` struct - [x] Add `DeliveryContext` struct - [x] Add `DeliveryPreferences` struct (with Email, Webhook, Process) - [x] Add `EmailPreference`, `EmailTarget` structs - [x] Add `WebhookPreference`, `WebhookTarget` structs - [x] Add `ProcessPreference`, `ProcessTarget` structs - [x] Add `ChannelResult` struct (with Target field) - [x] Update `types/enums_test.go` - Update DeliveryType tests - [x] Update `types/robot_test.go` - Update delivery result tests ### 10.4 Delivery Agent Setup - [x] `robot/delivery/package.yao` - Delivery Agent config - [x] `robot/delivery/prompts.yml` - delivery prompts - [x] Input: Full execution context (P0-P3 results) - [x] Output: DeliveryContent (Summary, Body, Attachments) - **only content, no channels** - [x] Agent focuses on content generation, NOT channel selection ### 10.5 Delivery Content Structure ```go // DeliveryRequest - pushed to Delivery Center // No Channels - Delivery Center decides based on preferences type DeliveryRequest struct { Content *DeliveryContent `json:"content"` // Agent-generated content Context *DeliveryContext `json:"context"` // Tracking info } // DeliveryContent - Content generated by Delivery Agent (only content) type DeliveryContent struct { Summary string `json:"summary"` // Brief 1-2 sentence summary Body string `json:"body"` // Full markdown report Attachments []DeliveryAttachment `json:"attachments,omitempty"` // Output artifacts from P3 } // DeliveryAttachment - Task output attachment with metadata type DeliveryAttachment struct { Title string `json:"title"` // Human-readable title Description string `json:"description,omitempty"` // What this artifact is TaskID string `json:"task_id,omitempty"` // Which task produced this File string `json:"file"` // Wrapper: __:// } // DeliveryContext - tracking info type DeliveryContext struct { MemberID string `json:"member_id"` // Robot member ID (globally unique) ExecutionID string `json:"execution_id"` TriggerType TriggerType `json:"trigger_type"` // clock | human | event TeamID string `json:"team_id"` } ``` **Key Design:** - **Agent only generates content** (Summary, Body, Attachments) - **Delivery Center decides channels** based on Robot/User preferences - If webhook configured, every execution pushes automatically **File Wrapper:** - Format: `__://` - Parse: `attachment.Parse(value)` → `(uploader, fileID, isWrapper)` - Read: `attachment.Base64(ctx, value)` → base64 content **Delivery Channels (each supports multiple targets):** | Channel | Description | Multiple Targets | |---------|-------------|------------------| | `email` | Send via yao/messenger | ✅ Multiple recipients | | `webhook` | POST to external URL | ✅ Multiple URLs | | `process` | Yao Process call | ✅ Multiple processes | | `notify` | In-app notification | Future (auto by subscriptions) | ### 10.6 Implementation **P4 Entry (executor/delivery.go):** - [x] `RunDelivery(ctx, exec, data)` - P4 entry point - [x] Call Delivery Agent to generate content (only content, no channels) - [x] Build DeliveryRequest (Content + Context) - [x] Push to Delivery Center - [x] Store DeliveryResult in exec.Delivery **Delivery Center (executor/delivery_center.go):** - [x] `DeliveryCenter.Deliver(ctx, request)` - main entry - [x] Read Robot/User delivery preferences - [x] Iterate through all enabled targets for each channel - [x] Aggregate ChannelResults into DeliveryResult **Channel Handlers (each supports multiple targets):** - [x] `sendEmail()` - uses yao/messenger - [x] Convert DeliveryAttachment to messenger.Attachment - [x] Support multiple EmailTarget - [x] Support custom subject_template per target - [x] Use `Robot.RobotEmail` as From address (if configured) - [x] Use global `DefaultEmailChannel()` for messenger channel selection - [x] `postWebhook()` - POST JSON - [x] POST DeliveryContent as JSON payload - [x] Support multiple WebhookTarget - [x] Support custom headers per target - [x] `callProcess()` - Yao Process call - [x] DeliveryContent as first arg - [x] Support multiple ProcessTarget - [x] Support additional args per target ### 10.7 Tests - [x] `executor/delivery_test.go` - P4 delivery - [x] Test: Delivery Agent generates content (only content) - [x] Test: DeliveryCenter reads preferences - [x] Test: Multiple email targets (TestDeliveryCenterEmail) - [x] Test: Multiple webhook targets - [x] Test: Multiple process targets (TestDeliveryCenterProcess) - [x] Test: Mixed channels (email + webhook + process) (TestDeliveryCenterAllChannels) - [x] Test: sendEmail with attachments (TestDeliveryCenterEmail) - [x] Test: postWebhook with custom headers - [x] Test: callProcess with args (TestDeliveryCenterProcess) - [x] Test: Partial success (some targets fail) - [x] Test: DeliveryResult aggregation --- ## Phase 11: API & Integration (MVP) **Goal:** Complete Go API and end-to-end tests. Main flow: P0 → P1 → P2 → P3 → P4. **Depends on:** Phase 10 (P4 Delivery) > **Note:** Process handlers and JSAPI are optional wrappers, moved to Phase 12. > Go API is sufficient for MVP integration. ### 11.1 Go API Implementation - [x] `api/types.go` - API request/response types - [x] `api/lifecycle.go` - manager lifecycle - [x] `Start()` / `Stop()` - manager lifecycle - [x] `StartWithConfig(config)` - start with custom config - [x] `IsRunning()` - check if system is running - [x] `api/robot.go` - robot query functions - [x] `GetRobot(memberID)` - get robot by member ID - [x] `ListRobots(query)` - list robots with filtering - [x] `GetRobotStatus(memberID)` - get robot runtime status - [x] `api/trigger.go` - trigger functions - [x] `Trigger(memberID, request)` - main trigger entry point - [x] `TriggerManual(memberID, triggerType, data)` - manual trigger for testing - [x] `Intervene(memberID, request)` - human intervention - [x] `HandleEvent(memberID, request)` - event trigger - [x] `api/execution.go` - execution query and control - [x] `GetExecution(execID)` - get execution by ID - [x] `ListExecutions(memberID, query)` - list executions - [x] `GetExecutionStatus(execID)` - get execution with runtime status - [x] `PauseExecution(execID)` / `ResumeExecution(execID)` / `StopExecution(execID)` - [x] `api/api.go` - package documentation - [x] Tests: `api/*_test.go` (black-box tests, 16 test cases) ### 11.2 End-to-End Tests - [x] Full clock trigger flow (P0 → P1 → P2 → P3 → P4) - `e2e_clock_test.go` - [x] Human intervention flow (P1 → P2 → P3 → P4) - `e2e_human_test.go` - [x] Event trigger flow (P1 → P2 → P3 → P4) - `e2e_event_test.go` - [x] Concurrent execution test - `e2e_concurrent_test.go` - [x] Pause/Resume/Stop test - `e2e_control_test.go` --- ## Phase 12: OpenAPI Integration **Goal:** HTTP endpoints for Robot Agent management and triggers. **Depends on:** Phase 11 (API & Integration), Frontend UI Design > **Note:** This phase will be planned in detail after frontend UI design is complete. > The API endpoints will be designed based on actual UI requirements. ### 12.1 Planned Features - [ ] HTTP endpoints for robot management (CRUD) - [ ] HTTP endpoints for human intervention triggers - [ ] Webhook endpoints for external events - [ ] WebSocket for real-time execution status updates - [ ] Authentication and authorization integration ### 12.2 Design Dependencies - Frontend dashboard design (robot list, status, controls) - Execution history UI design - Human intervention UI design - Real-time notification requirements --- ## Phase 13: Advanced Features **Goal:** P5 Learning, Process/JSAPI wrappers, dedup, plan queue. > **Note:** These are optional features. Main flow works without them. ### 13.1 Process & JSAPI Wrappers > **Note:** These are convenience wrappers around Go API for Yao ecosystem integration. - [ ] `api/process.go` - implement Process handlers - [ ] `robot.Start` / `robot.Stop` - [ ] `robot.Trigger` / `robot.Intervene` / `robot.HandleEvent` - [ ] `robot.Pause` / `robot.Resume` / `robot.Stop` - [ ] `robot.Get` / `robot.List` - [ ] `robot.GetExecution` / `robot.ListExecutions` - [ ] `api/jsapi.go` - implement JSAPI for JavaScript runtime - [ ] Tests for Process and JSAPI ### 13.2 P5 Learning Implementation > **Background:** P5 Learning is async, runs after P4 Delivery completes. > User doesn't wait for it. Results stored in private KB for future reference. #### 13.2.1 Learning Agent Setup - [ ] `robot/learning/package.yao` - Learning Agent config - [ ] `robot/learning/prompts.yml` - learning prompts #### 13.2.2 Store Implementation - [ ] `store/store.go` - Store interface and struct - [ ] `store/kb.go` - KB operations (create, save, search) - [ ] `store/learning.go` - save learning entries to private KB #### 13.2.3 Implementation - [ ] `executor/learning.go` - `RunLearning(ctx, exec, data)` - real implementation - [ ] `executor/learning.go` - extract learnings from execution - [ ] `executor/learning.go` - call Learning Agent - [ ] `executor/learning.go` - save to private KB #### 13.2.4 Tests - [ ] `executor/learning_test.go` - P5 learning - [ ] Test: learnings extracted from execution - [ ] Test: learnings saved to KB - [ ] Test: KB can be queried for past learnings ### 13.3 Fast Dedup (Time-Window) > **Note:** Manager has `// TODO: dedup check` comment placeholder. Integrate after implementation. - [ ] `dedup/dedup.go` - Dedup struct - [ ] `dedup/fast.go` - fast in-memory time-window dedup - [ ] Key: `memberID:triggerType:window` - [ ] Check before submit - [ ] Mark after submit - [ ] Integrate into Manager.Tick() - [ ] Test: dedup check/mark, window expiry ### 13.4 Semantic Dedup - [ ] `dedup/semantic.go` - call Dedup Agent for goal/task level dedup - [ ] Dedup Agent setup (`assistants/robot/dedup/`) - [ ] Test: semantic dedup with real LLM ### 13.5 Plan Queue - [ ] `plan/plan.go` - plan queue implementation - [ ] Store planned tasks/goals - [ ] Execute at next cycle or specified time - [ ] `plan/schedule.go` - schedule for later - [ ] Test: plan queue operations > **Note:** Monitoring is provided by Job system (Activity Monitor UI). No separate implementation needed. --- ## Test Assistants Structure ``` yao-dev-app/assistants/robot/ ├── inspiration/ # P0: Inspiration Agent │ ├── package.yao │ └── prompts.yml ├── goals/ # P1: Goal Generation Agent │ ├── package.yao │ └── prompts.yml ├── tasks/ # P2: Task Planning Agent │ ├── package.yao │ └── prompts.yml ├── validation/ # P3: Validation Agent │ ├── package.yao │ └── prompts.yml ├── delivery/ # P4: Delivery Agent │ ├── package.yao │ └── prompts.yml ├── learning/ # P5: Learning Agent │ ├── package.yao │ └── prompts.yml └── dedup/ # Deduplication Agent ├── package.yao └── prompts.yml ``` --- ## Notes ### Test Environment Setup 1. **Environment Variables:** Run `source yao/env.local.sh` before tests 2. **Test Preparation:** Use `testutils.Prepare(t)` to load config, KB, and agents ```go package robot_test import ( "testing" "github.com/yaoapp/yao/agent/testutils" ) func TestExample(t *testing.T) { // Load environment config (from YAO_TEST_APPLICATION) // This loads: config, connectors, KB, agents, models, etc. testutils.Prepare(t) defer testutils.Clean(t) // Your test code here } ``` ### Test Conventions 1. **Black-box Tests:** All tests in `*_test` package (external package) 2. **Real LLM Calls:** Use `gpt-4o` or `deepseek` connectors for agent tests 3. **Incremental:** Each phase builds on previous, all tests must pass before next phase 4. **No Skip:** Do NOT use `t.Skip()` except for `testing.Short()` (CI mode) 5. **Must Assert:** Every test MUST have result validation assertions ```go func TestWithLLM(t *testing.T) { // Only allowed Skip: testing.Short() for CI if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Your test code... result, err := SomeFunction() // MUST have assertions - no empty tests! assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, expected, result.Field) } ``` ### Test Rules | Rule | Description | | ----------------- | ----------------------------------------- | | No arbitrary Skip | Only `testing.Short()` skip allowed | | Must assert | Every test must validate results | | No empty tests | Tests without assertions will fail review | | Real calls | LLM tests use real API calls, not mocks | ### Key Environment Variables | Variable | Description | | ---------------------- | ------------------------------- | | `YAO_TEST_APPLICATION` | Test app path (`yao-dev-app`) | | `OPENAI_TEST_KEY` | OpenAI API key | | `DEEPSEEK_API_KEY` | DeepSeek API key | | `YAO_DB_DRIVER` | Database driver (mysql/sqlite3) | | `YAO_DB_PRIMARY` | Database connection string | --- ## Progress Tracking | Phase | Status | Description | | --------------------- | ------ | ---------------------------------------------------------------------------- | | 1. Types & Interfaces | ✅ | All types, enums, interfaces | | 2. Skeleton | ✅ | Empty stubs, code compiles | | 3. Scheduling System | ✅ | Cache + Pool + Trigger + Job + Executor architecture | | 4. Agent Infra | ✅ | AgentCaller, InputFormatter, test assistants | | 5. Test Scenarios | ✅ | Phase agents (P0-P5), expert agents | | 6. P0 Inspiration | ✅ | Inspiration Agent integration | | 7. P1 Goals | ✅ | Goal Generation Agent integration | | 8. P2 Tasks | ✅ | Task Planning Agent integration | | 9. P3 Run | ✅ | Task execution + validation + yao/assert + multi-turn conversation | | 10. P4 Delivery | ✅ | Output delivery (email/webhook/process, notify future) | | 11. API & Integration | ✅ | Go API, end-to-end tests (main flow: P0→P1→P2→P3→P4) | | 12. OpenAPI | ⬜ | HTTP endpoints (depends on frontend UI design) | | 13. Advanced | ⬜ | Process/JSAPI, P5 Learning, dedup, plan queue, Sandbox | Legend: ⬜ Not started | 🟡 In progress | ✅ Complete **Main Flow (MVP):** P0 Inspiration → P1 Goals → P2 Tasks → P3 Run → P4 Delivery ✅ **OpenAPI (Phase 12):** HTTP endpoints - planned after frontend UI design **Advanced (Phase 13):** P5 Learning (async), Process/JSAPI, Dedup, Plan Queue, Sandbox --- ## Quick Commands ```bash # Setup environment source yao/env.local.sh # Run all robot tests go test -v ./agent/robot/... # Run specific phase tests go test -v ./agent/robot/types/... go test -v ./agent/robot/cache/... go test -v ./agent/robot/pool/... go test -v ./agent/robot/executor/... # Run with coverage go test -cover ./agent/robot/... ``` ================================================ FILE: agent/robot/V2-IMPROVEMENTS.md ================================================ # Robot Agent V2 — Improvement Plan > Generated: 2026-02-25 > Based on: DESIGN-V2.md deep review against implementation code > Scope: Bug fixes, missing unit tests, code quality improvements --- ## Auth Context Clarification Robot is a legitimate team member in `__yao.member`. Auth is always present: | Trigger Path | Auth Source | Code | |-------------|------------|------| | Clock | `manager.buildRobotAuth(robot)` → `{UserID: robot.MemberID, TeamID: robot.TeamID}` | `manager.go:270` | | Human / Event | Caller's `ctx.Auth` passthrough from HTTP middleware | `openapi/agent/robot/*.go` | | Resume | Loaded from execution record → `buildRobotAuth` or caller passthrough | `executor.go:764+` | The `openapi/agent/robot/` layer constructs `ctx := &robottypes.Context{}` without Auth — this is the **existing V1 pattern** across ALL openapi handlers (trigger.go, execution.go, list.go, etc). Auth checking is done via `authorized.GetInfo(c)` at the Gin middleware level; `robottypes.Context` is a downstream execution context. **However**, `manager/interact.go:createConfirmingExecution` calls `ctx.UserID()` which returns `""` when openapi passes an empty Context. This is a V2-specific issue since V1 handlers don't need `ctx.UserID()`. --- ## 1. Bugs ### BUG-1 [P0] `advanceExecution` discards confirmed Goals/Tasks **File**: `manager/interact.go:415-431` **Problem**: After multi-round Host Agent confirmation (which may have generated Goals and Tasks stored in `record.Goals` / `record.Tasks`), `advanceExecution()` submits to Pool via `m.pool.SubmitWithID(...)`. The Pool Worker then calls `ExecuteWithControl()` which starts from P1 (Goals) and re-generates everything — the confirmed plan is lost. **Design intent (§10.1)**: Confirmation → use confirmed Goals/Tasks → skip P1/P2 → directly execute P3. **Fix**: `advanceExecution` must inject `record.Goals` and `record.Tasks` into the `TriggerInput` or use a dedicated `Resume`-like path that skips P1/P2 when Goals/Tasks already exist. **Test required**: - Confirm with pre-existing Goals/Tasks → verify P3 uses those Goals/Tasks, not re-generated ones - Confirm without Goals/Tasks → verify normal P1→P2→P3 flow --- ### BUG-2 [P0] `standard.New()` creates orphan Executor instances **Files**: `manager/interact.go:501, 525, 544` **Problem**: `skipWaitingTask()`, `resumeWithContext()`, and `directResume()` all call `standard.New()`, creating a fresh Executor with independent counters. Consequences: 1. `currentCount` / `execCount` not shared — monitoring inaccurate 2. No `execController.Untrack()` after Resume completes — memory leak 3. Separate `store` / `robotStore` instances (less critical, stateless) **Fix**: Manager should hold a reference to the live Executor (obtained via Pool) and expose a `Resume` method, or provide the Executor as a constructor parameter. **Tests required**: - Resume via `skipWaitingTask` → verify `execController.Untrack()` called - Resume via `resumeWithContext` → verify executor `currentCount` incremented/decremented correctly --- ### BUG-3 [P1] `buildRobotStatusSnapshot` returns near-empty snapshot **File**: `manager/interact.go:266-278` **Problem**: Only populates `ActiveCount` and `MaxQuota`. Missing: `WaitingCount`, `QueuedCount`, `ActiveExecs`, `RecentExecs`. Host Agent cannot make informed decisions about robot workload. **Fix**: Query `robot.Executions` to compute `WaitingCount`, collect `ActiveExecs` briefs, and optionally query recent completed executions from store. **Tests required**: - Robot with 2 running + 1 waiting execution → snapshot reflects correct counts - Robot with no executions → all counts zero --- ### BUG-4 [P1] `openapi/agent/robot/interact.go` passes empty Context to Manager **File**: `openapi/agent/robot/interact.go:67, 152, 209` **Problem**: `ctx := &robottypes.Context{}` — no `Auth`, no `context.Context`. When `HandleInteract` → `createConfirmingExecution` calls `ctx.UserID()`, returns `""`. The `TriggerInput.UserID` in the DB record is empty. **Note**: This is NOT about Robot's own Auth (which is always set via `buildRobotAuth` in execution paths). This is about tracking **which human user** initiated the interaction. **Fix**: In V2 interact handlers, construct Context properly: ```go ctx := robottypes.NewContext(c.Request.Context(), &oauthtypes.AuthorizedInfo{ UserID: authInfo.UserID, TeamID: authInfo.TeamID, }) ``` **Tests required**: - InteractRobot handler → verify ctx.UserID() returns the authenticated user's ID - CreateConfirmingExecution → verify record.Input.UserID is populated --- ### BUG-5 [P2] `HostContext.Goals` type mismatch with design **File**: `types/host.go:15` **Problem**: Design §5.7 defines `Goals string`, implementation uses `*Goals` (struct with `Content` field). Host Agent receives `{"goals": {"content": "..."}}` instead of `{"goals": "..."}`. **Fix**: Either update the Host Agent prompt to expect the struct format, or flatten to `string` in `buildHostContext`: ```go if record.Goals != nil { hostCtx.GoalsContent = record.Goals.Content // string } ``` **Tests required**: - `buildHostContext` with Goals → verify JSON output matches Host Agent prompt expectations --- ## 2. Missing Unit Tests All tests should be **black-box** tests (test exported APIs only), must **verify return values and side effects**, and must **not require real LLM calls**. ### 2.1 `executor/standard/host.go` — CallHostAgent **Current coverage**: 0 tests | # | Test Case | Verify | |---|-----------|--------| | H1 | Robot is nil | Returns error "robot cannot be nil" | | H2 | No Host Agent configured (empty Resources) | Returns error "no Host Agent configured" | | H3 | Valid Host Agent call returns JSON | Parsed `HostOutput` with correct Action and Reply | | H4 | Host Agent returns non-JSON text | Fallback to `HostActionConfirm` with text as Reply | | H5 | Host Agent returns invalid JSON structure | Fallback to `HostActionConfirm` | | H6 | Host Agent call fails (network error) | Returns wrapped error | | H7 | Input marshalling (verify HostInput fields) | Correct JSON sent to agent | **Status**: ✅ All tests implemented. H1-H2, H7 are pure unit tests. H3-H5 use real LLM integration via `yao-dev-app` test assistants (`tests.host-json`, `tests.host-plaintext`, `tests.host-badjson`). H6 uses real assistant framework. --- ### 2.2 `manager/interact.go` — processHostAction (all branches) **Current coverage**: 2/7 branches (WaitForMore, default) | # | Test Case | Action | Verify | |---|-----------|--------|--------| | PA1 | HostActionConfirm | `confirm` | `resp.Status == "confirmed"`, `advanceExecution` called | | PA2 | HostActionAdjust with goals data | `adjust` | Record Goals updated, `resp.Status == "adjusted"` | | PA3 | HostActionAdjust with tasks data | `adjust` | Record Tasks updated | | PA4 | HostActionAdjust with nil data | `adjust` | No error, noop | | PA5 | HostActionAddTask | `add_task` | New task appended to record.Tasks with generated ID | | PA6 | HostActionAddTask with nil data | `add_task` | Returns error "task data is required" | | PA7 | HostActionSkip with waiting task | `skip` | Task status = skipped | | PA8 | HostActionSkip without waiting task | `skip` | Returns error "no task is waiting" | | PA9 | HostActionInjectCtx with string reply | `inject_context` | Resume called with correct reply | | PA10 | HostActionInjectCtx → re-suspend | `inject_context` | `resp.Status == "waiting"` | | PA11 | HostActionCancel | `cancel` | Execution status = cancelled, event pushed | | PA12 | WaitForMore = true | — | `resp.Status == "waiting_for_more"`, `resp.WaitForMore == true` | | PA13 | Unknown action | — | `resp.Status == "acknowledged"` | **Note**: PA1, PA7, PA9, PA11 require mocking Executor.Resume and Pool.SubmitWithID. --- ### 2.3 `manager/interact.go` — HandleInteract routing **Current coverage**: Parameter validation only | # | Test Case | Verify | |---|-----------|--------| | HI1 | No execution_id → creates confirming execution | Record saved with status=confirming, Host Agent called with "assign" | | HI2 | execution_id with status=confirming | Host Agent called with "assign" scenario | | HI3 | execution_id with status=waiting | Host Agent called with "clarify" scenario | | HI4 | execution_id with status=running | Host Agent called with "guide" scenario | | HI5 | execution_id with status=completed | Returns error "cannot interact" | | HI6 | execution_id not found | Returns error "execution not found" | | HI7 | Host Agent unavailable → direct assign fallback | Execution started without Host Agent | | HI8 | Host Agent unavailable → direct resume fallback | Execution resumed directly | --- ### 2.4 `manager/interact.go` — CancelExecution **Current coverage**: "manager not started" only | # | Test Case | Verify | |---|-----------|--------| | CE1 | Cancel waiting execution | Status → cancelled, `Untrack` called, event pushed | | CE2 | Cancel confirming execution | Status → cancelled | | CE3 | Cancel running execution | Returns error (only waiting/confirming allowed) | | CE4 | Cancel non-existent execution | Returns error "execution not found" | | CE5 | Cancel already cancelled | Returns error | --- ### 2.5 `executor/standard/executor.go` — Resume method **Current coverage**: Only via E2E tests (requires real LLM) | # | Test Case | Verify | |---|-----------|--------| | R1 | Resume non-waiting execution | Returns error "not in waiting status" | | R2 | Resume non-existent execution | Returns error "execution not found" | | R3 | Resume with nil store | Returns error "store is required" | | R4 | Resume injects reply into task messages | `exec.Tasks[i].Messages` contains `[Human reply]` prefixed message | | R5 | Resume clears waiting fields | `WaitingTaskID`, `WaitingQuestion`, `WaitingSince` all empty after resume | | R6 | Resume updates status to running | `exec.Status == ExecRunning` | | R7 | Resume → re-suspend | Returns `ErrExecutionSuspended`, execution stays tracked | | R8 | Resume → complete → P4 → P5 | Status == ExecCompleted, `ResumeContext` cleared | | R9 | Resume → P3 error | Status == ExecFailed with error message | | R10 | Resume maintains executor currentCount | `currentCount +1 before, -1 after` | **Note**: R4-R10 require mocking store.Get, store.UpdateResumeState, RunExecution, runPhase. --- ### 2.6 `manager/interact.go` — Helper methods **Current coverage**: buildRobotStatusSnapshot (3), findWaitingTask (3), buildHostContext (2) | # | Missing Test Case | Verify | |---|-------------------|--------| | HL1 | `createConfirmingExecution` | Record has correct fields (execID, chatID, status=confirming, input) | | HL2 | `adjustExecution` with goals string | `record.Goals.Content` updated | | HL3 | `adjustExecution` with tasks array | `record.Tasks` replaced | | HL4 | `adjustExecution` with non-map data | Graceful handling | | HL5 | `injectTask` with valid task | Task appended with auto-generated ID | | HL6 | `injectTask` preserves existing tasks | len(tasks) == original + 1 | | HL7 | `callHostAgentForScenario` — no host agent | Returns error | | HL8 | `directAssign` | Returns "confirmed" status | | HL9 | `directResume` — re-suspend | Returns "waiting" status | | HL10 | `directResume` — complete | Returns "resumed" status | --- ### 2.7 `api/interact.go` — Interact/Reply/Confirm/CancelExecution **Current coverage**: 0 tests | # | Test Case | Verify | |---|-----------|--------| | AI1 | `Interact` with manager available | Delegates to `managerInteract` | | AI2 | `Interact` without manager, with execution_id | Delegates to `legacyResume` | | AI3 | `Interact` without manager, without execution_id | Returns error | | AI4 | `Interact` with empty member_id | Returns error | | AI5 | `Interact` with nil request | Returns error | | AI6 | `Reply` shortcut | Calls Interact with correct TaskID and Source | | AI7 | `Confirm` shortcut | Calls Interact with correct Action | | AI8 | `CancelExecution` with manager | Delegates correctly | | AI9 | `CancelExecution` without manager | Returns error | | AI10 | `legacyResume` → success | Returns "resumed" status | | AI11 | `legacyResume` → re-suspend | Returns "waiting" status | | AI12 | `legacyResume` → error | Returns wrapped error | --- ### 2.8 `events/events.go` + `events/handlers.go` — Event integration **Current coverage**: DeliveryHandler basic (3 tests) | # | Missing Test Case | Verify | |---|-------------------|--------| | EV1 | DeliveryHandler — payload deserialization | All fields (`ExecutionID`, `MemberID`, `Content`, `Preferences`) correctly parsed | | EV2 | Verify event constants match design §7.2 | All 9 constants present and correctly named | | EV3 | `NeedInputPayload` marshalling | Correct JSON roundtrip | | EV4 | `TaskPayload` marshalling | Correct JSON roundtrip with optional Error field | | EV5 | `ExecPayload` marshalling | Correct JSON roundtrip | --- ### 2.9 `openapi/agent/robot/interact.go` — HTTP handlers **Current coverage**: 0 tests | # | Test Case | Verify | |---|-----------|--------| | OH1 | `InteractRobot` — valid request | 200 with InteractResponse | | OH2 | `InteractRobot` — missing robot ID | 400 error | | OH3 | `InteractRobot` — missing message | 400 error | | OH4 | `InteractRobot` — robot not found | 404 error | | OH5 | `InteractRobot` — forbidden (no write permission) | 403 error | | OH6 | `ReplyToTask` — valid request | 200 with response | | OH7 | `ReplyToTask` — missing params | 400 error | | OH8 | `ConfirmExecution` — valid request | 200 with response | | OH9 | `ConfirmExecution` — empty body allowed | 200 (confirm without message) | --- ### 2.10 Event push verification in execution flow **Current coverage**: 0 (events are pushed but never verified in tests) | # | Test Case | File | Verify | |---|-----------|------|--------| | EP1 | Task completes → TaskCompleted event | `run.go:111` | Event type + payload fields | | EP2 | Task fails → TaskFailed event | `run.go:120` | Event type + error in payload | | EP3 | Execution suspends → ExecWaiting event | `executor.go:750` | Event type + question in payload | | EP4 | Execution resumes → ExecResumed event | `executor.go:856` | Event type + chatID | | EP5 | Execution completes → ExecCompleted event | `executor.go:287` | Event type + status | | EP6 | Execution cancelled → ExecCancelled event | `manager/interact.go:66` | Event type + status | | EP7 | Delivery → Delivery event | `delivery.go:102` | Content + Preferences in payload | **Approach**: Use `event.Subscribe` in test to capture pushed events, or mock `event.Push`. --- ## 3. Code Quality Improvements ### CQ1 — Extract common Executor resume logic `skipWaitingTask`, `resumeWithContext`, `directResume` all duplicate: create executor → call Resume → handle ErrExecutionSuspended. Extract to a private helper: ```go func (m *Manager) executeResume(ctx *types.Context, execID, reply string) error { // Use shared executor reference, not standard.New() return m.getExecutor().Resume(ctx, execID, reply) } ``` ### CQ2 — `processHostAction` needs explicit `store.Save()` after Confirm `advanceExecution` changes execution status but doesn't save the Goals/Tasks that may have been set during confirming flow. Needs explicit persist before Pool submit. ### CQ3 — `RobotStatusSnapshot` should include `MemberID` and `Status` Add back `MemberID` and `Status` fields to match design §5.7. These help Host Agent identify which robot it's serving. --- ## 4. Implementation Priority | Priority | Items | Est. Effort | |----------|-------|-------------| | **P0** | BUG-1 (advanceExecution), BUG-2 (standard.New) | 1 day | | **P1** | BUG-3 (snapshot), BUG-4 (context auth) | 0.5 day | | **P1** | Tests §2.2 (processHostAction), §2.3 (HandleInteract), §2.5 (Resume) | 1.5 days | | **P2** | BUG-5 (Goals type), CQ1-CQ3 | 0.5 day | | **P2** | Tests §2.1 (CallHostAgent), §2.4 (Cancel), §2.6-2.10 | 2 days | | | **Total** | **~5.5 days** | --- ## 5. Test Infrastructure Notes 1. **Source env before test**: `source yao/env.local.sh` 2. **Test app**: `yao-dev-app` — all test assistants live there 3. **No recompile needed**: `yao-dev` runs from Go source directly 4. **Mock strategy**: For unit tests not requiring real LLM, create interfaces for `ConversationCaller`, `ExecutionStore`, `Pool` to enable mock injection. Alternatively, use `SkipPersistence: true` config + in-memory stubs. 5. **Event verification**: Wrap `event.Push` calls with a test interceptor or use `event.Subscribe` to capture events during test. ================================================ FILE: agent/robot/api/README.md ================================================ # Robot Agent Go API Go API for managing autonomous robot agents. ## Quick Start ```go import "github.com/yaoapp/yao/agent/robot/api" // Start system api.Start() defer api.Stop() // Trigger execution result, _ := api.Trigger(ctx, "member_123", &api.TriggerRequest{ Type: types.TriggerHuman, Action: types.ActionTaskAdd, Messages: []agentcontext.Message{{Role: "user", Content: "Analyze sales"}}, }) // Check status exec, _ := api.GetExecution(ctx, result.ExecutionID) ``` ## Lifecycle ```go api.Start() // Start with defaults api.StartWithConfig(config) // Start with custom config api.Stop() // Graceful shutdown api.IsRunning() // Check if running ``` ## Robot Query ```go // Get single robot robot, err := api.GetRobot(ctx, "member_123") // List robots with filters result, err := api.ListRobots(ctx, &api.ListQuery{ TeamID: "team_1", Status: types.RobotIdle, Keywords: "sales", ClockMode: types.ClockInterval, Page: 1, PageSize: 20, Order: "created_at desc", }) // Get runtime status state, err := api.GetRobotStatus(ctx, "member_123") // state.Running, state.MaxRunning, state.RunningIDs, state.LastRun, state.NextRun ``` ## Triggers ### Human Intervention ```go result, err := api.Trigger(ctx, "member_123", &api.TriggerRequest{ Type: types.TriggerHuman, Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: "user", Content: "Generate weekly report"}, }, }) // Or use shorthand: result, err := api.Intervene(ctx, "member_123", req) ``` ### Event Trigger ```go result, err := api.Trigger(ctx, "member_123", &api.TriggerRequest{ Type: types.TriggerEvent, Source: types.EventWebhook, EventType: "order.created", Data: map[string]interface{}{"order_id": "12345"}, }) // Or use shorthand: result, err := api.HandleEvent(ctx, "member_123", req) ``` ### Manual Trigger (Testing) ```go result, err := api.TriggerManual(ctx, "member_123", types.TriggerClock, nil) ``` ## Execution Management ```go // Get execution by ID exec, err := api.GetExecution(ctx, "exec_abc123") // List executions with filters result, err := api.ListExecutions(ctx, "member_123", &api.ExecutionQuery{ Status: types.ExecRunning, Trigger: types.TriggerClock, Page: 1, PageSize: 10, }) // Get execution with runtime status exec, err := api.GetExecutionStatus(ctx, "exec_abc123") // Control execution api.PauseExecution(ctx, "exec_abc123") api.ResumeExecution(ctx, "exec_abc123") api.StopExecution(ctx, "exec_abc123") ``` ## Types ### ListQuery ```go type ListQuery struct { TeamID string // Filter by team Status types.RobotStatus // Filter by status (idle|working|paused|error) Keywords string // Search in display_name ClockMode types.ClockMode // Filter by clock mode (times|interval|daemon) Page int // Page number (default: 1) PageSize int // Page size (default: 20, max: 100) Order string // Order by column (default: "created_at desc") } ``` ### TriggerRequest ```go type TriggerRequest struct { Type types.TriggerType // human | event | clock Action types.InterventionAction // task.add, goal.adjust, etc. Messages []agentcontext.Message // User input PlanAt *time.Time // Schedule for later InsertPosition InsertPosition // first | last | next | at AtIndex int // When InsertPosition = "at" Source types.EventSource // webhook | database EventType string // Event name Data map[string]interface{} // Event payload ExecutorMode types.ExecutorMode // standard | dryrun | sandbox } ``` ### TriggerResult ```go type TriggerResult struct { Accepted bool // Whether trigger was accepted Queued bool // Whether queued (vs immediate) Execution *types.Execution // Execution details ExecutionID string // Execution ID for tracking Message string // Status message } ``` ### ExecutionQuery ```go type ExecutionQuery struct { Status types.ExecStatus // Filter by status Trigger types.TriggerType // Filter by trigger type Page int // Page number (default: 1) PageSize int // Page size (default: 20, max: 100) } ``` ### RobotState ```go type RobotState struct { MemberID string // Robot member ID TeamID string // Team ID DisplayName string // Display name Status types.RobotStatus // idle | working | paused | error Running int // Current running count MaxRunning int // Max concurrent limit LastRun *time.Time // Last execution time NextRun *time.Time // Next scheduled time RunningIDs []string // IDs of running executions } ``` ## Files | File | Functions | |------|-----------| | `lifecycle.go` | `Start`, `StartWithConfig`, `Stop`, `IsRunning` | | `robot.go` | `GetRobot`, `ListRobots`, `GetRobotStatus` | | `trigger.go` | `Trigger`, `TriggerManual`, `Intervene`, `HandleEvent` | | `execution.go` | `GetExecution`, `ListExecutions`, `GetExecutionStatus`, `PauseExecution`, `ResumeExecution`, `StopExecution` | | `types.go` | Type definitions | ================================================ FILE: agent/robot/api/activities.go ================================================ package api import ( "context" "fmt" "time" "github.com/yaoapp/yao/agent/robot/store" "github.com/yaoapp/yao/agent/robot/types" ) // ==================== Activity Types ==================== // ActivityQuery - query parameters for listing activities type ActivityQuery struct { TeamID string `json:"team_id,omitempty"` // Filter by team ID Limit int `json:"limit,omitempty"` Since *time.Time `json:"since,omitempty"` // Only activities after this time Type string `json:"type,omitempty"` // Filter by activity type: execution.started, execution.completed, execution.failed, execution.cancelled } // Activity - activity item for feed type Activity struct { Type store.ActivityType `json:"type"` RobotID string `json:"robot_id"` RobotName string `json:"robot_name,omitempty"` // Display name from robot ExecutionID string `json:"execution_id"` Message string `json:"message"` Timestamp time.Time `json:"timestamp"` } // ActivityListResponse - response with activities type ActivityListResponse struct { Data []*Activity `json:"data"` } // ==================== Activity API Functions ==================== // ListActivities returns recent activities for a team // Activities are derived from execution status changes func ListActivities(ctx *types.Context, query *ActivityQuery) (*ActivityListResponse, error) { if query == nil { query = &ActivityQuery{} } query.applyDefaults() // Build store options opts := &store.ActivityListOptions{ Limit: query.Limit, Since: query.Since, } if query.TeamID != "" { opts.TeamID = query.TeamID } // Pass type filter if provided if query.Type != "" { opts.Type = store.ActivityType(query.Type) } // Query from store storeActivities, err := getExecutionStore().ListActivities(context.Background(), opts) if err != nil { return nil, fmt.Errorf("failed to list activities: %w", err) } // Transform to Activity slice // Also enrich with robot display names activities := make([]*Activity, 0, len(storeActivities)) robotNames := make(map[string]string) // Cache robot names for _, sa := range storeActivities { activity := &Activity{ Type: sa.Type, RobotID: sa.RobotID, ExecutionID: sa.ExecutionID, Message: sa.Message, Timestamp: sa.Timestamp, } // Try to get robot name (with caching) if name, ok := robotNames[sa.RobotID]; ok { activity.RobotName = name } else { // Try to get robot display name robotResp, err := GetRobotResponse(ctx, sa.RobotID) if err == nil && robotResp != nil { activity.RobotName = robotResp.DisplayName robotNames[sa.RobotID] = robotResp.DisplayName } } activities = append(activities, activity) } return &ActivityListResponse{ Data: activities, }, nil } // ==================== Helper Functions ==================== // applyDefaults applies default values to ActivityQuery func (q *ActivityQuery) applyDefaults() { if q.Limit <= 0 { q.Limit = 20 } if q.Limit > 100 { q.Limit = 100 } } ================================================ FILE: agent/robot/api/api_test.go ================================================ package api_test // Integration tests for the Robot Agent API // These tests verify the complete API functionality with real database operations. // // Test Structure: // - api_test.go: Core API integration tests (this file) // - lifecycle_test.go: Start/Stop lifecycle tests // - robot_test.go: Robot query tests // - trigger_test.go: Trigger tests // - execution_test.go: Execution query/control tests // // Test Data: // All tests use real database records in __yao.member and agent_execution tables // Test robot IDs are prefixed with "robot_api_" for easy cleanup // Test execution IDs are prefixed with "exec_api_" for easy cleanup import ( "context" "encoding/json" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" "github.com/yaoapp/yao/agent/robot/api" "github.com/yaoapp/yao/agent/robot/store" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ==================== Full Lifecycle Integration Tests ==================== // TestAPIFullLifecycle tests the complete API workflow: // Start → Create Robot → Query Robot → Trigger → Query Execution → Stop func TestAPIFullLifecycle(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Cleanup test data cleanupAPITestRobots(t) cleanupAPITestExecutions(t) defer cleanupAPITestRobots(t) defer cleanupAPITestExecutions(t) t.Run("complete workflow", func(t *testing.T) { // 1. Setup: Create test robot in database setupAPITestRobot(t, "robot_api_lifecycle_001", "team_api_001") // 2. Start the API system err := api.Start() require.NoError(t, err) assert.True(t, api.IsRunning()) defer api.Stop() // 3. Query the robot via API ctx := types.NewContext(context.Background(), nil) robot, err := api.GetRobot(ctx, "robot_api_lifecycle_001") require.NoError(t, err) require.NotNil(t, robot) assert.Equal(t, "robot_api_lifecycle_001", robot.MemberID) assert.Equal(t, "team_api_001", robot.TeamID) assert.Equal(t, "API Test Robot robot_api_lifecycle_001", robot.DisplayName) // 4. Get robot status status, err := api.GetRobotStatus(ctx, "robot_api_lifecycle_001") require.NoError(t, err) require.NotNil(t, status) assert.Equal(t, "robot_api_lifecycle_001", status.MemberID) assert.Equal(t, types.RobotIdle, status.Status) assert.Equal(t, 0, status.Running) assert.Equal(t, 5, status.MaxRunning) // 5. List robots listResult, err := api.ListRobots(ctx, &api.ListQuery{ TeamID: "team_api_001", Page: 1, PageSize: 10, }) require.NoError(t, err) require.NotNil(t, listResult) assert.GreaterOrEqual(t, listResult.Total, 1) // Find our robot in the list found := false for _, r := range listResult.Data { if r.MemberID == "robot_api_lifecycle_001" { found = true break } } assert.True(t, found, "Robot should be in list") // 6. Trigger manual execution triggerResult, err := api.TriggerManual(ctx, "robot_api_lifecycle_001", types.TriggerClock, nil) require.NoError(t, err) require.NotNil(t, triggerResult) assert.True(t, triggerResult.Accepted) assert.NotEmpty(t, triggerResult.ExecutionID) // 7. Wait for execution to complete time.Sleep(500 * time.Millisecond) // 8. Stop the system err = api.Stop() require.NoError(t, err) assert.False(t, api.IsRunning()) }) } // TestAPIRobotQueryWithData tests robot query APIs with real data func TestAPIRobotQueryWithData(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupAPITestRobots(t) defer cleanupAPITestRobots(t) // Setup: Create multiple robots setupAPITestRobot(t, "robot_api_query_001", "team_api_query") setupAPITestRobot(t, "robot_api_query_002", "team_api_query") setupAPITestRobot(t, "robot_api_query_003", "team_api_other") ctx := types.NewContext(context.Background(), nil) t.Run("GetRobot returns correct robot", func(t *testing.T) { robot, err := api.GetRobot(ctx, "robot_api_query_001") require.NoError(t, err) require.NotNil(t, robot) assert.Equal(t, "robot_api_query_001", robot.MemberID) assert.Equal(t, "team_api_query", robot.TeamID) assert.Equal(t, "API Test Robot robot_api_query_001", robot.DisplayName) assert.True(t, robot.AutonomousMode) assert.Equal(t, types.RobotIdle, robot.Status) }) t.Run("ListRobots filters by team", func(t *testing.T) { result, err := api.ListRobots(ctx, &api.ListQuery{ TeamID: "team_api_query", Page: 1, PageSize: 10, }) require.NoError(t, err) require.NotNil(t, result) // Should have at least 2 robots from team_api_query // (might have more if other tests created robots in this team) assert.GreaterOrEqual(t, result.Total, 2) assert.GreaterOrEqual(t, len(result.Data), 2) // Verify all returned robots are from the correct team for _, robot := range result.Data { assert.Equal(t, "team_api_query", robot.TeamID) } }) t.Run("ListRobots pagination works", func(t *testing.T) { // Page 1 with size 1 result1, err := api.ListRobots(ctx, &api.ListQuery{ TeamID: "team_api_query", Page: 1, PageSize: 1, }) require.NoError(t, err) require.GreaterOrEqual(t, len(result1.Data), 1, "Should have at least 1 robot on page 1") // Page 2 with size 1 result2, err := api.ListRobots(ctx, &api.ListQuery{ TeamID: "team_api_query", Page: 2, PageSize: 1, }) require.NoError(t, err) require.GreaterOrEqual(t, len(result2.Data), 1, "Should have at least 1 robot on page 2") // Should be different robots assert.NotEqual(t, result1.Data[0].MemberID, result2.Data[0].MemberID) }) t.Run("ListRobots filters by keywords", func(t *testing.T) { result, err := api.ListRobots(ctx, &api.ListQuery{ Keywords: "robot_api_query_001", Page: 1, PageSize: 10, }) require.NoError(t, err) require.NotNil(t, result) // Should find at least 1 robot matching keywords assert.GreaterOrEqual(t, result.Total, 1) for _, robot := range result.Data { assert.Contains(t, robot.DisplayName, "robot_api_query_001") } }) } // TestListRobotsAutonomousModeFilter tests the autonomous_mode filter func TestListRobotsAutonomousModeFilter(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupAPITestRobots(t) defer cleanupAPITestRobots(t) // Setup: Create robots with different autonomous_mode settings setupAPITestRobotWithMode(t, "robot_api_auto_001", "team_api_mode", true) // autonomous setupAPITestRobotWithMode(t, "robot_api_auto_002", "team_api_mode", true) // autonomous setupAPITestRobotWithMode(t, "robot_api_demand_001", "team_api_mode", false) // on-demand ctx := types.NewContext(context.Background(), nil) t.Run("ListRobots returns all robots when autonomous_mode is nil", func(t *testing.T) { result, err := api.ListRobots(ctx, &api.ListQuery{ TeamID: "team_api_mode", Page: 1, PageSize: 10, }) require.NoError(t, err) require.NotNil(t, result) // Should have all 3 robots assert.Equal(t, 3, result.Total) }) t.Run("ListRobots filters by autonomous_mode=true", func(t *testing.T) { autonomousMode := true result, err := api.ListRobots(ctx, &api.ListQuery{ TeamID: "team_api_mode", AutonomousMode: &autonomousMode, Page: 1, PageSize: 10, }) require.NoError(t, err) require.NotNil(t, result) // Should have only 2 autonomous robots assert.Equal(t, 2, result.Total) for _, robot := range result.Data { assert.True(t, robot.AutonomousMode, "All returned robots should be autonomous") } }) t.Run("ListRobots filters by autonomous_mode=false", func(t *testing.T) { autonomousMode := false result, err := api.ListRobots(ctx, &api.ListQuery{ TeamID: "team_api_mode", AutonomousMode: &autonomousMode, Page: 1, PageSize: 10, }) require.NoError(t, err) require.NotNil(t, result) // Should have only 1 on-demand robot assert.Equal(t, 1, result.Total) for _, robot := range result.Data { assert.False(t, robot.AutonomousMode, "All returned robots should be on-demand") } }) } // TestAPIExecutionQueryWithData tests execution query APIs with real data func TestAPIExecutionQueryWithData(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupAPITestExecutions(t) defer cleanupAPITestExecutions(t) // Setup: Create test executions setupAPITestExecution(t, "exec_api_query_001", "member_api_exec", types.TriggerClock, types.ExecCompleted) setupAPITestExecution(t, "exec_api_query_002", "member_api_exec", types.TriggerHuman, types.ExecRunning) setupAPITestExecution(t, "exec_api_query_003", "member_api_exec", types.TriggerClock, types.ExecFailed) setupAPITestExecution(t, "exec_api_query_004", "member_api_other", types.TriggerEvent, types.ExecCompleted) ctx := types.NewContext(context.Background(), nil) t.Run("GetExecution returns correct execution", func(t *testing.T) { exec, err := api.GetExecution(ctx, "exec_api_query_001") require.NoError(t, err) require.NotNil(t, exec) assert.Equal(t, "exec_api_query_001", exec.ID) assert.Equal(t, "member_api_exec", exec.MemberID) assert.Equal(t, types.TriggerClock, exec.TriggerType) assert.Equal(t, types.ExecCompleted, exec.Status) }) t.Run("ListExecutions filters by member", func(t *testing.T) { result, err := api.ListExecutions(ctx, "member_api_exec", &api.ExecutionQuery{ Page: 1, PageSize: 10, }) require.NoError(t, err) require.NotNil(t, result) // Should have 3 executions for member_api_exec assert.Equal(t, 3, result.Total) assert.Len(t, result.Data, 3) // Verify all returned executions are for the correct member for _, exec := range result.Data { assert.Equal(t, "member_api_exec", exec.MemberID) } }) t.Run("ListExecutions filters by status", func(t *testing.T) { result, err := api.ListExecutions(ctx, "member_api_exec", &api.ExecutionQuery{ Status: types.ExecCompleted, Page: 1, PageSize: 10, }) require.NoError(t, err) require.NotNil(t, result) // Should have only completed executions assert.Equal(t, 1, result.Total) for _, exec := range result.Data { assert.Equal(t, types.ExecCompleted, exec.Status) } }) t.Run("ListExecutions filters by trigger type", func(t *testing.T) { result, err := api.ListExecutions(ctx, "member_api_exec", &api.ExecutionQuery{ Trigger: types.TriggerClock, Page: 1, PageSize: 10, }) require.NoError(t, err) require.NotNil(t, result) // Should have only clock trigger executions assert.Equal(t, 2, result.Total) for _, exec := range result.Data { assert.Equal(t, types.TriggerClock, exec.TriggerType) } }) t.Run("ListExecutions pagination works", func(t *testing.T) { // Page 1 with size 2 result1, err := api.ListExecutions(ctx, "member_api_exec", &api.ExecutionQuery{ Page: 1, PageSize: 2, }) require.NoError(t, err) assert.GreaterOrEqual(t, len(result1.Data), 2, "Should have at least 2 executions on page 1") // Page 2 with size 2 result2, err := api.ListExecutions(ctx, "member_api_exec", &api.ExecutionQuery{ Page: 2, PageSize: 2, }) require.NoError(t, err) assert.GreaterOrEqual(t, len(result2.Data), 1, "Should have at least 1 execution on page 2") }) } // TestAPITriggerWithData tests trigger APIs with real robots func TestAPITriggerWithData(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupAPITestRobots(t) defer cleanupAPITestRobots(t) // Setup: Create test robot setupAPITestRobot(t, "robot_api_trigger_001", "team_api_trigger") // Start manager err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), nil) t.Run("TriggerManual accepts valid robot", func(t *testing.T) { result, err := api.TriggerManual(ctx, "robot_api_trigger_001", types.TriggerClock, nil) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.Accepted) assert.NotEmpty(t, result.ExecutionID) assert.Contains(t, result.Message, "submitted") }) t.Run("Trigger with human type", func(t *testing.T) { result, err := api.Trigger(ctx, "robot_api_trigger_001", &api.TriggerRequest{ Type: types.TriggerHuman, Action: types.ActionTaskAdd, }) require.NoError(t, err) require.NotNil(t, result) // Result should be returned (accepted or not depends on robot state) // The important thing is that the API doesn't error t.Logf("Trigger result: accepted=%v, message=%s", result.Accepted, result.Message) }) t.Run("Trigger with event type", func(t *testing.T) { result, err := api.Trigger(ctx, "robot_api_trigger_001", &api.TriggerRequest{ Type: types.TriggerEvent, Source: types.EventWebhook, EventType: "test.event", Data: map[string]interface{}{"key": "value"}, }) require.NoError(t, err) require.NotNil(t, result) // Should be accepted (robot exists and has event enabled) assert.True(t, result.Accepted) }) t.Run("Trigger rejects non-existent robot", func(t *testing.T) { result, err := api.Trigger(ctx, "robot_api_nonexistent", &api.TriggerRequest{ Type: types.TriggerHuman, }) // Should not return error, but result should show not accepted assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.Accepted) }) } // ==================== Helper Functions ==================== // setupAPITestRobotWithMode creates a test robot with specific autonomous_mode setting func setupAPITestRobotWithMode(t *testing.T, memberID, teamID string, autonomousMode bool) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "API Test Robot", "duties": []string{"Testing API functions"}, }, "quota": map[string]interface{}{ "max": 5, "queue": 20, "priority": 5, }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "API Test Robot " + memberID, "system_prompt": "You are an API test robot.", "status": "active", "role_id": "member", "autonomous_mode": autonomousMode, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot %s: %v", memberID, err) } } // setupAPITestRobot creates a test robot in the database func setupAPITestRobot(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "API Test Robot", "duties": []string{"Testing API functions"}, }, "quota": map[string]interface{}{ "max": 5, "queue": 20, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, "intervene": map[string]interface{}{"enabled": true}, "event": map[string]interface{}{"enabled": true}, }, "clock": map[string]interface{}{ "mode": "interval", "every": "1h", "timeout": "30m", }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "API Test Robot " + memberID, "system_prompt": "You are an API test robot.", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot %s: %v", memberID, err) } } // setupAPITestExecution creates a test execution in the database func setupAPITestExecution(t *testing.T, execID, memberID string, triggerType types.TriggerType, status types.ExecStatus) { s := store.NewExecutionStore() ctx := context.Background() startTime := time.Now().Add(-1 * time.Hour) record := &store.ExecutionRecord{ ExecutionID: execID, MemberID: memberID, TeamID: "team_api_exec", TriggerType: triggerType, Status: status, Phase: types.PhaseDelivery, StartTime: &startTime, } if status == types.ExecCompleted || status == types.ExecFailed { endTime := time.Now() record.EndTime = &endTime } err := s.Save(ctx, record) if err != nil { t.Fatalf("Failed to insert execution %s: %v", execID, err) } } // cleanupAPITestRobots removes all API test robots func cleanupAPITestRobots(t *testing.T) { m := model.Select("__yao.member") if m == nil { return } tableName := m.MetaData.Table.Name qb := capsule.Query() // Delete all robots with member_id starting with "robot_api_" or "api_robot_" _, err := qb.Table(tableName).Where("member_id", "like", "robot_api_%").Delete() if err != nil { t.Logf("Warning: cleanup robots error: %v", err) } // Also delete "api_robot_" prefixed robots (new tests) _, err = qb.Table(tableName).Where("member_id", "like", "api_robot_%").Delete() if err != nil { t.Logf("Warning: cleanup robots error: %v", err) } } // cleanupAPITestExecutions removes all API test executions func cleanupAPITestExecutions(t *testing.T) { m := model.Select("__yao.agent.execution") if m == nil { t.Logf("Warning: model __yao.agent.execution not found, skipping cleanup") return } tableName := m.MetaData.Table.Name qb := capsule.Query() // Delete all executions with execution_id starting with "exec_api_" _, err := qb.Table(tableName).Where("execution_id", "like", "exec_api_%").Delete() if err != nil { t.Logf("Warning: cleanup executions error: %v", err) } // Also delete executions for API test members _, err = qb.Table(tableName).Where("member_id", "like", "member_api_%").Delete() if err != nil { t.Logf("Warning: cleanup executions error: %v", err) } } ================================================ FILE: agent/robot/api/e2e_clock_test.go ================================================ package api_test // End-to-end tests for Clock trigger flow // These tests use REAL LLM calls via Standard executor (not DryRun) // // Test Flow: Clock Trigger → P0 (Inspiration) → P1 (Goals) → P2 (Tasks) → P3 (Run) → P4 (Delivery) // // Prerequisites: // - Valid LLM API keys (OPENAI_TEST_KEY or DEEPSEEK_API_KEY) // - Test assistants in yao-dev-app/assistants/robot/ // - Database connection (YAO_DB_PRIMARY) import ( "context" "encoding/json" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" "github.com/yaoapp/yao/agent/robot/api" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) // testAuth returns test auth info for E2E tests func testAuth() *oauthtypes.AuthorizedInfo { return &oauthtypes.AuthorizedInfo{ UserID: "e2e-test-user", TeamID: "e2e-test-team", } } // TestE2EClockTriggerFullFlow tests the complete clock trigger flow with real LLM calls // Flow: Clock → P0 (Inspiration) → P1 (Goals) → P2 (Tasks) → P3 (Run) → P4 (Delivery) func TestE2EClockTriggerFullFlow(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("complete_P0_to_P4_flow", func(t *testing.T) { // Setup: Create a robot configured for clock trigger memberID := "robot_e2e_clock_001" setupE2ERobotForClock(t, memberID, "team_e2e_clock") // Start the API system err := api.Start() require.NoError(t, err) defer api.Stop() // Verify robot is loaded ctx := types.NewContext(context.Background(), testAuth()) robot, err := api.GetRobot(ctx, memberID) require.NoError(t, err) require.NotNil(t, robot) assert.Equal(t, memberID, robot.MemberID) // Trigger execution via clock trigger type result, err := api.TriggerManual(ctx, memberID, types.TriggerClock, nil) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.Accepted, "Clock trigger should be accepted: %s", result.Message) assert.NotEmpty(t, result.ExecutionID, "Should return execution ID") t.Logf("Execution started: ExecutionID=%s", result.ExecutionID) // Wait for execution to complete (real LLM calls take time) // P0→P4 typically takes 30-60 seconds with real LLM var exec *types.Execution maxWait := 180 * time.Second pollInterval := 2 * time.Second deadline := time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(pollInterval) // Query all executions and find a completed one executions, err := api.ListExecutions(ctx, memberID, &api.ExecutionQuery{ Page: 1, PageSize: 10, }) if err != nil { t.Logf("Query error (retrying): %v", err) continue } // Look for a completed execution for _, e := range executions.Data { t.Logf("Execution %s: status=%s, phase=%s", e.ID, e.Status, e.Phase) if e.Status == types.ExecCompleted { exec = e break } } if exec != nil { break } // Also check if there's any running execution hasRunning := false for _, e := range executions.Data { if e.Status == types.ExecRunning || e.Status == types.ExecPending { hasRunning = true break } } if !hasRunning && len(executions.Data) > 0 { // All executions finished but none completed - take the first one for error reporting exec = executions.Data[0] break } } // Verify execution completed successfully - ALL phases must pass require.NotNil(t, exec, "Execution should exist") if exec.Status == types.ExecFailed { t.Fatalf("Execution failed: %s", exec.Error) } // Strict assertion: execution MUST complete successfully assert.Equal(t, types.ExecCompleted, exec.Status, "Execution must complete successfully") // Verify P0 (Inspiration) output exists require.NotNil(t, exec.Inspiration, "P0 Inspiration output must exist") t.Logf("P0 Inspiration: %+v", exec.Inspiration) // Verify P1 (Goals) output exists require.NotNil(t, exec.Goals, "P1 Goals output must exist") t.Logf("P1 Goals content length: %d", len(exec.Goals.Content)) // Verify P2 (Tasks) output exists require.NotNil(t, exec.Tasks, "P2 Tasks output must exist") require.Greater(t, len(exec.Tasks), 0, "P2 must have at least 1 task") t.Logf("P2 Tasks count: %d", len(exec.Tasks)) // Verify P3 (Results) output exists - THIS IS CRITICAL require.NotNil(t, exec.Results, "P3 Results output must exist") require.Greater(t, len(exec.Results), 0, "P3 must have at least 1 result") t.Logf("P3 Results count: %d", len(exec.Results)) // Verify P4 (Delivery) output exists require.NotNil(t, exec.Delivery, "P4 Delivery output must exist") t.Logf("P4 Delivery: RequestID=%s, Success=%v", exec.Delivery.RequestID, exec.Delivery.Success) t.Logf("✅ Clock trigger E2E: ALL PHASES (P0-P4) completed successfully") }) } // TestE2EClockTriggerPhaseProgression tests that phases execute in correct order func TestE2EClockTriggerPhaseProgression(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("phases_execute_P0_P1_P2_P3_P4", func(t *testing.T) { memberID := "robot_e2e_clock_phases" setupE2ERobotForClock(t, memberID, "team_e2e_clock") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuth()) // Trigger execution result, err := api.TriggerManual(ctx, memberID, types.TriggerClock, nil) require.NoError(t, err) assert.True(t, result.Accepted) // Track phase progression phasesObserved := make([]types.Phase, 0) lastPhase := types.Phase("") maxWait := 120 * time.Second pollInterval := 1 * time.Second deadline := time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(pollInterval) executions, err := api.ListExecutions(ctx, memberID, &api.ExecutionQuery{ Page: 1, PageSize: 1, }) if err != nil || len(executions.Data) == 0 { continue } exec := executions.Data[0] // Record phase changes if exec.Phase != lastPhase { phasesObserved = append(phasesObserved, exec.Phase) lastPhase = exec.Phase t.Logf("Phase changed to: %s", exec.Phase) } if exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed { break } } // Verify phase order (should include at least P0, P1, P2, P3, P4) t.Logf("Phases observed: %v", phasesObserved) assert.GreaterOrEqual(t, len(phasesObserved), 1, "Should observe at least one phase") // The final phase should be delivery or learning if len(phasesObserved) > 0 { lastObserved := phasesObserved[len(phasesObserved)-1] assert.True(t, lastObserved == types.PhaseDelivery || lastObserved == types.PhaseLearning, "Last phase should be delivery or learning, got: %s", lastObserved) } }) } // TestE2EClockTriggerDataPersistence tests that execution data is persisted to database func TestE2EClockTriggerDataPersistence(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("execution_data_persisted_to_database", func(t *testing.T) { memberID := "robot_e2e_clock_persist" setupE2ERobotForClock(t, memberID, "team_e2e_clock") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuth()) // Trigger and wait for completion result, err := api.TriggerManual(ctx, memberID, types.TriggerClock, nil) require.NoError(t, err) assert.True(t, result.Accepted) // Wait for completion var execID string maxWait := 120 * time.Second deadline := time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(2 * time.Second) executions, err := api.ListExecutions(ctx, memberID, nil) if err != nil || len(executions.Data) == 0 { continue } exec := executions.Data[0] execID = exec.ID if exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed { break } } require.NotEmpty(t, execID, "Should have execution ID") // Query execution by ID to verify persistence exec, err := api.GetExecution(ctx, execID) require.NoError(t, err) require.NotNil(t, exec) // Verify all fields are persisted assert.Equal(t, execID, exec.ID) assert.Equal(t, memberID, exec.MemberID) assert.Equal(t, types.TriggerClock, exec.TriggerType) assert.NotNil(t, exec.StartTime, "StartTime should be set") if exec.Status == types.ExecCompleted { assert.NotNil(t, exec.EndTime, "EndTime should be set for completed execution") } t.Logf("Persisted execution: ID=%s, Status=%s, Phase=%s", exec.ID, exec.Status, exec.Phase) }) } // ==================== Helper Functions ==================== // setupE2ERobotForClock creates a robot configured for clock trigger E2E tests // Uses extremely simple tasks to ensure quick completion through all phases func setupE2ERobotForClock(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() // Robot config optimized for E2E testing - tasks must complete quickly robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Greeter Bot", "duties": []string{"Output greeting message"}, // Extremely simple "rules": []string{"Always complete in one step", "No tools needed", "Just output text directly"}, }, "quota": map[string]interface{}{ "max": 5, "queue": 20, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, "intervene": map[string]interface{}{"enabled": true}, "event": map[string]interface{}{"enabled": true}, }, "resources": map[string]interface{}{ "phases": map[string]interface{}{ "inspiration": "robot.inspiration", "goals": "robot.goals", "tasks": "tests.e2e-tasks", // Use simple E2E test task planner "run": "robot.validation", "validation": "tests.e2e-validation", // Use lenient E2E test validator "delivery": "robot.delivery", "learning": "robot.learning", }, "agents": []string{}, }, "delivery": map[string]interface{}{ "email": map[string]interface{}{"enabled": false}, "webhook": map[string]interface{}{"enabled": false}, "process": map[string]interface{}{"enabled": false}, }, } configJSON, _ := json.Marshal(robotConfig) // Critical: System prompt must guide LLM to generate simple, completable tasks systemPrompt := `You are a simple greeter bot for E2E testing. CRITICAL RULES FOR ALL PHASES: 1. Goals: Generate exactly ONE goal: "Output a greeting" 2. Tasks: Generate exactly ONE task with description "Say Hello World" 3. Run: Complete the task by directly outputting "Hello World!" - NO tools needed 4. The task is complete when you output the greeting text You do NOT need any tools or external calls. Just output text directly. When asked to execute a task, immediately respond with the greeting and mark complete.` err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "E2E Clock Test Robot " + memberID, "system_prompt": systemPrompt, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot %s: %v", memberID, err) } } // cleanupE2ERobots removes all E2E test robots func cleanupE2ERobots(t *testing.T) { m := model.Select("__yao.member") if m == nil { return } tableName := m.MetaData.Table.Name qb := capsule.Query() _, err := qb.Table(tableName).Where("member_id", "like", "robot_e2e_%").Delete() if err != nil { t.Logf("Warning: cleanup robots error: %v", err) } } // cleanupE2EExecutions removes all E2E test executions func cleanupE2EExecutions(t *testing.T) { m := model.Select("__yao.agent.execution") if m == nil { return } tableName := m.MetaData.Table.Name qb := capsule.Query() _, err := qb.Table(tableName).Where("member_id", "like", "robot_e2e_%").Delete() if err != nil { t.Logf("Warning: cleanup executions error: %v", err) } } ================================================ FILE: agent/robot/api/e2e_concurrent_test.go ================================================ package api_test // End-to-end tests for Concurrent execution // These tests verify that multiple robots can execute simultaneously // and that quota limits are enforced correctly. // // Prerequisites: // - Valid LLM API keys (OPENAI_TEST_KEY or DEEPSEEK_API_KEY) // - Test assistants in yao-dev-app/assistants/robot/ // - Database connection (YAO_DB_PRIMARY) import ( "context" "encoding/json" "sync" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" "github.com/yaoapp/yao/agent/robot/api" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) // testAuthConcurrent returns test auth info for concurrent E2E tests func testAuthConcurrent() *oauthtypes.AuthorizedInfo { return &oauthtypes.AuthorizedInfo{ UserID: "e2e-concurrent-user", TeamID: "e2e-concurrent-team", } } // TestE2EConcurrentMultipleRobots tests concurrent execution of multiple different robots func TestE2EConcurrentMultipleRobots(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("multiple_robots_execute_concurrently", func(t *testing.T) { // Create 3 different robots robots := []string{ "robot_e2e_concurrent_001", "robot_e2e_concurrent_002", "robot_e2e_concurrent_003", } for _, memberID := range robots { setupE2ERobotForConcurrent(t, memberID, "team_e2e_concurrent") } err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthConcurrent()) // Trigger all robots concurrently var wg sync.WaitGroup var acceptedCount atomic.Int32 results := make([]*api.TriggerResult, len(robots)) var mu sync.Mutex for i, memberID := range robots { wg.Add(1) go func(idx int, id string) { defer wg.Done() result, err := api.TriggerManual(ctx, id, types.TriggerClock, nil) if err != nil { t.Logf("Robot %s trigger error: %v", id, err) return } mu.Lock() results[idx] = result mu.Unlock() if result.Accepted { acceptedCount.Add(1) t.Logf("Robot %s accepted: ExecutionID=%s", id, result.ExecutionID) } }(i, memberID) } wg.Wait() // All 3 should be accepted (different robots, no quota conflict) assert.Equal(t, int32(3), acceptedCount.Load(), "All 3 robots should be accepted") // Wait for all executions to complete maxWait := 180 * time.Second // Longer timeout for concurrent deadline := time.Now().Add(maxWait) completedCount := 0 for time.Now().Before(deadline) { time.Sleep(3 * time.Second) completedCount = 0 for _, memberID := range robots { executions, err := api.ListExecutions(ctx, memberID, nil) if err != nil || len(executions.Data) == 0 { continue } exec := executions.Data[0] if exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed { completedCount++ } } t.Logf("Completed: %d/%d", completedCount, len(robots)) if completedCount == len(robots) { break } } // Verify all completed assert.Equal(t, len(robots), completedCount, "All robots should complete execution") }) } // TestE2EConcurrentSameRobotMultipleTriggers tests multiple triggers on the same robot func TestE2EConcurrentSameRobotMultipleTriggers(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("same_robot_handles_multiple_triggers", func(t *testing.T) { memberID := "robot_e2e_concurrent_same" // Create robot with high quota to allow multiple concurrent executions setupE2ERobotHighQuota(t, memberID, "team_e2e_concurrent") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthConcurrent()) // Trigger 3 executions on the same robot triggerCount := 3 var wg sync.WaitGroup var acceptedCount atomic.Int32 for i := 0; i < triggerCount; i++ { wg.Add(1) go func(idx int) { defer wg.Done() result, err := api.TriggerManual(ctx, memberID, types.TriggerClock, nil) if err != nil { t.Logf("Trigger %d error: %v", idx, err) return } if result.Accepted { acceptedCount.Add(1) t.Logf("Trigger %d accepted: ExecutionID=%s", idx, result.ExecutionID) } else { t.Logf("Trigger %d rejected: %s", idx, result.Message) } }(i) // Small delay between triggers to avoid race conditions time.Sleep(100 * time.Millisecond) } wg.Wait() // With high quota (max=5), all 3 should be accepted assert.GreaterOrEqual(t, acceptedCount.Load(), int32(1), "At least 1 trigger should be accepted") t.Logf("Accepted triggers: %d/%d", acceptedCount.Load(), triggerCount) // Wait for executions to complete maxWait := 180 * time.Second deadline := time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(3 * time.Second) executions, err := api.ListExecutions(ctx, memberID, nil) if err != nil { continue } completedCount := 0 for _, exec := range executions.Data { if exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed { completedCount++ } } t.Logf("Completed: %d/%d", completedCount, int(acceptedCount.Load())) if completedCount >= int(acceptedCount.Load()) { break } } // Verify execution count executions, err := api.ListExecutions(ctx, memberID, nil) require.NoError(t, err) assert.GreaterOrEqual(t, len(executions.Data), 1, "Should have at least 1 execution") }) } // TestE2EConcurrentQuotaEnforcement tests that quota limits are enforced func TestE2EConcurrentQuotaEnforcement(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("quota_limit_enforced", func(t *testing.T) { memberID := "robot_e2e_concurrent_quota" // Create robot with low quota (max=2) setupE2ERobotLowQuota(t, memberID, "team_e2e_concurrent") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthConcurrent()) // Try to trigger 5 executions on robot with max=2 triggerCount := 5 var acceptedCount atomic.Int32 var rejectedCount atomic.Int32 for i := 0; i < triggerCount; i++ { result, err := api.TriggerManual(ctx, memberID, types.TriggerClock, nil) if err != nil { t.Logf("Trigger %d error: %v", i, err) continue } if result.Accepted { acceptedCount.Add(1) t.Logf("Trigger %d accepted", i) } else { rejectedCount.Add(1) t.Logf("Trigger %d rejected: %s", i, result.Message) } // Small delay to allow execution to start time.Sleep(200 * time.Millisecond) } // With max=2, only 2 should be accepted at a time // Some may be rejected due to quota t.Logf("Accepted: %d, Rejected: %d", acceptedCount.Load(), rejectedCount.Load()) // At least some should be accepted assert.GreaterOrEqual(t, acceptedCount.Load(), int32(1), "At least 1 should be accepted") // Wait for completion time.Sleep(120 * time.Second) // Query final execution count executions, err := api.ListExecutions(ctx, memberID, nil) require.NoError(t, err) t.Logf("Total executions: %d", len(executions.Data)) }) } // TestE2EConcurrentMixedTriggerTypes tests concurrent execution with different trigger types func TestE2EConcurrentMixedTriggerTypes(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("mixed_trigger_types_execute_concurrently", func(t *testing.T) { // Create robots for different trigger types clockRobot := "robot_e2e_concurrent_clock" humanRobot := "robot_e2e_concurrent_human" eventRobot := "robot_e2e_concurrent_event" setupE2ERobotForConcurrent(t, clockRobot, "team_e2e_concurrent") setupE2ERobotForConcurrent(t, humanRobot, "team_e2e_concurrent") setupE2ERobotForConcurrent(t, eventRobot, "team_e2e_concurrent") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthConcurrent()) // Trigger all three types concurrently var wg sync.WaitGroup var acceptedCount atomic.Int32 // Clock trigger wg.Add(1) go func() { defer wg.Done() result, err := api.TriggerManual(ctx, clockRobot, types.TriggerClock, nil) if err == nil && result.Accepted { acceptedCount.Add(1) t.Logf("Clock trigger accepted") } }() // Human trigger wg.Add(1) go func() { defer wg.Done() result, err := api.Trigger(ctx, humanRobot, &api.TriggerRequest{ Type: types.TriggerHuman, Action: types.ActionTaskAdd, }) if err == nil && (result.Accepted || result.Queued) { acceptedCount.Add(1) t.Logf("Human trigger accepted/queued") } }() // Event trigger wg.Add(1) go func() { defer wg.Done() result, err := api.Trigger(ctx, eventRobot, &api.TriggerRequest{ Type: types.TriggerEvent, Source: types.EventWebhook, EventType: "test.concurrent", Data: map[string]interface{}{"test": true}, }) if err == nil && result.Accepted { acceptedCount.Add(1) t.Logf("Event trigger accepted") } }() wg.Wait() // All should be accepted (different robots) assert.GreaterOrEqual(t, acceptedCount.Load(), int32(2), "At least 2 triggers should be accepted") // Wait for executions time.Sleep(120 * time.Second) // Verify executions exist for each robot for _, memberID := range []string{clockRobot, humanRobot, eventRobot} { executions, err := api.ListExecutions(ctx, memberID, nil) if err == nil && len(executions.Data) > 0 { t.Logf("Robot %s has %d executions", memberID, len(executions.Data)) } } }) } // ==================== Helper Functions ==================== // setupE2ERobotForConcurrent creates a robot for concurrent execution tests func setupE2ERobotForConcurrent(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() // Simple config for E2E testing - minimal tasks robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Simple E2E Test Robot", "duties": []string{"Say hello"}, // Very simple duty "rules": []string{"Keep responses under 50 words"}, }, "quota": map[string]interface{}{ "max": 3, "queue": 10, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, "intervene": map[string]interface{}{"enabled": true}, "event": map[string]interface{}{"enabled": true}, }, "clock": map[string]interface{}{ "mode": "interval", "every": "1h", }, "resources": map[string]interface{}{ "phases": map[string]interface{}{ "inspiration": "robot.inspiration", "goals": "robot.goals", "tasks": "tests.e2e-tasks", // Use simple E2E test task planner "run": "robot.validation", "validation": "tests.e2e-validation", // Use lenient E2E test validator "delivery": "robot.delivery", "learning": "robot.learning", }, "agents": []string{"experts.text-writer"}, }, "delivery": map[string]interface{}{ "email": map[string]interface{}{"enabled": false}, "webhook": map[string]interface{}{"enabled": false}, "process": map[string]interface{}{"enabled": false}, }, } configJSON, _ := json.Marshal(robotConfig) systemPrompt := `You are a simple E2E test robot. Your job is to say hello. When generating goals: create exactly 1 simple goal. When generating tasks: create exactly 1 simple task. Keep all outputs brief. No complex analysis needed.` err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "E2E Concurrent Robot " + memberID, "system_prompt": systemPrompt, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot %s: %v", memberID, err) } } // setupE2ERobotHighQuota creates a robot with high quota for concurrent tests func setupE2ERobotHighQuota(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "E2E Test Robot - High Quota", }, "quota": map[string]interface{}{ "max": 5, // High quota "queue": 20, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, "intervene": map[string]interface{}{"enabled": true}, "event": map[string]interface{}{"enabled": true}, }, // Resources: phase agents and expert agents from yao-dev-app/assistants/ "resources": map[string]interface{}{ "phases": map[string]interface{}{ "inspiration": "robot.inspiration", "goals": "robot.goals", "tasks": "tests.e2e-tasks", // Use simple E2E test task planner "run": "robot.validation", "validation": "tests.e2e-validation", // Use lenient E2E test validator "delivery": "robot.delivery", "learning": "robot.learning", }, "agents": []string{"experts.text-writer"}, }, "delivery": map[string]interface{}{ "email": map[string]interface{}{"enabled": false}, "webhook": map[string]interface{}{"enabled": false}, "process": map[string]interface{}{"enabled": false}, }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "E2E High Quota Robot " + memberID, "system_prompt": "You are a high quota test robot.", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot %s: %v", memberID, err) } } // setupE2ERobotLowQuota creates a robot with low quota for quota enforcement tests func setupE2ERobotLowQuota(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "E2E Test Robot - Low Quota", }, "quota": map[string]interface{}{ "max": 2, // Low quota for testing limits "queue": 5, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, "intervene": map[string]interface{}{"enabled": true}, "event": map[string]interface{}{"enabled": true}, }, // Resources: phase agents and expert agents from yao-dev-app/assistants/ "resources": map[string]interface{}{ "phases": map[string]interface{}{ "inspiration": "robot.inspiration", "goals": "robot.goals", "tasks": "tests.e2e-tasks", // Use simple E2E test task planner "run": "robot.validation", "validation": "tests.e2e-validation", // Use lenient E2E test validator "delivery": "robot.delivery", "learning": "robot.learning", }, "agents": []string{"experts.text-writer"}, }, "delivery": map[string]interface{}{ "email": map[string]interface{}{"enabled": false}, "webhook": map[string]interface{}{"enabled": false}, "process": map[string]interface{}{"enabled": false}, }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "E2E Low Quota Robot " + memberID, "system_prompt": "You are a low quota test robot.", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot %s: %v", memberID, err) } } ================================================ FILE: agent/robot/api/e2e_control_test.go ================================================ package api_test // End-to-end tests for Execution Control (Pause/Resume/Stop) // These tests verify that executions can be controlled during runtime // // Prerequisites: // - Valid LLM API keys (OPENAI_TEST_KEY or DEEPSEEK_API_KEY) // - Test assistants in yao-dev-app/assistants/robot/ // - Database connection (YAO_DB_PRIMARY) import ( "context" "encoding/json" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" "github.com/yaoapp/yao/agent/robot/api" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) // testAuthControl returns test auth info for control E2E tests func testAuthControl() *oauthtypes.AuthorizedInfo { return &oauthtypes.AuthorizedInfo{ UserID: "e2e-control-user", TeamID: "e2e-control-team", } } // TestE2EControlPauseResume tests pausing and resuming an execution func TestE2EControlPauseResume(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("pause_and_resume_execution", func(t *testing.T) { memberID := "robot_e2e_control_pause" setupE2ERobotForControl(t, memberID, "team_e2e_control") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthControl()) // Start execution result, err := api.TriggerManual(ctx, memberID, types.TriggerClock, nil) require.NoError(t, err) require.True(t, result.Accepted) t.Logf("Execution started: ExecutionID=%s", result.ExecutionID) // Wait for execution to start running var execID string maxWait := 30 * time.Second deadline := time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(500 * time.Millisecond) executions, err := api.ListExecutions(ctx, memberID, nil) if err != nil || len(executions.Data) == 0 { continue } exec := executions.Data[0] if exec.Status == types.ExecRunning { execID = exec.ID t.Logf("Execution running: ID=%s, Phase=%s", execID, exec.Phase) break } } if execID == "" { t.Skip("Execution did not start in time - may have completed too quickly") return } // Pause the execution err = api.PauseExecution(ctx, execID) if err != nil { t.Logf("Pause error (may be expected if execution completed): %v", err) } else { t.Logf("Execution paused") // Verify paused state time.Sleep(1 * time.Second) status, err := api.GetExecutionStatus(ctx, execID) if err == nil && status != nil { t.Logf("Status after pause: %s", status.Status) } // Resume the execution err = api.ResumeExecution(ctx, execID) if err != nil { t.Logf("Resume error: %v", err) } else { t.Logf("Execution resumed") } } // Wait for completion maxWait = 120 * time.Second deadline = time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(2 * time.Second) exec, err := api.GetExecution(ctx, execID) if err != nil { continue } t.Logf("Execution status: %s, phase: %s", exec.Status, exec.Phase) if exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed || exec.Status == types.ExecCancelled { // Execution finished (completed, failed, or cancelled) t.Logf("Execution finished with status: %s", exec.Status) return } } t.Logf("Execution did not complete in time") }) } // TestE2EControlStop tests stopping an execution func TestE2EControlStop(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("stop_running_execution", func(t *testing.T) { memberID := "robot_e2e_control_stop" setupE2ERobotForControl(t, memberID, "team_e2e_control") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthControl()) // Start execution result, err := api.TriggerManual(ctx, memberID, types.TriggerClock, nil) require.NoError(t, err) require.True(t, result.Accepted) t.Logf("Execution started: ExecutionID=%s", result.ExecutionID) // Wait for execution to start running var execID string maxWait := 30 * time.Second deadline := time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(500 * time.Millisecond) executions, err := api.ListExecutions(ctx, memberID, nil) if err != nil || len(executions.Data) == 0 { continue } exec := executions.Data[0] if exec.Status == types.ExecRunning { execID = exec.ID t.Logf("Execution running: ID=%s, Phase=%s", execID, exec.Phase) break } } if execID == "" { t.Skip("Execution did not start in time - may have completed too quickly") return } // Stop the execution err = api.StopExecution(ctx, execID) if err != nil { t.Logf("Stop error (may be expected if execution completed): %v", err) } else { t.Logf("Stop signal sent") } // Wait and verify stopped/cancelled state (with retry) maxWait = 30 * time.Second deadline = time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(2 * time.Second) exec, err := api.GetExecution(ctx, execID) if err != nil { t.Logf("Get execution error: %v", err) continue } t.Logf("Current status: %s", exec.Status) // Execution should eventually be cancelled, completed, or failed if exec.Status == types.ExecCancelled || exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed { t.Logf("Final status: %s", exec.Status) return } } // If we get here, check final state exec, err := api.GetExecution(ctx, execID) if err != nil { t.Logf("Get execution error: %v", err) return } // Allow running state if stop didn't take effect in time (execution may have already completed) t.Logf("Final status after wait: %s", exec.Status) assert.True(t, exec.Status == types.ExecCancelled || exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed || exec.Status == types.ExecRunning, // Allow running if stop didn't take effect "Execution should be in terminal state or still running, got: %s", exec.Status) }) } // TestE2EControlStopBeforeStart tests stopping an execution before it starts func TestE2EControlStopBeforeStart(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("stop_queued_execution", func(t *testing.T) { memberID := "robot_e2e_control_stop_early" setupE2ERobotForControl(t, memberID, "team_e2e_control") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthControl()) // Start execution result, err := api.TriggerManual(ctx, memberID, types.TriggerClock, nil) require.NoError(t, err) require.True(t, result.Accepted) // Immediately try to get execution ID and stop time.Sleep(100 * time.Millisecond) executions, err := api.ListExecutions(ctx, memberID, nil) if err != nil || len(executions.Data) == 0 { t.Skip("No execution found") return } execID := executions.Data[0].ID // Try to stop immediately err = api.StopExecution(ctx, execID) if err != nil { t.Logf("Stop error: %v", err) } else { t.Logf("Stop signal sent for execution %s", execID) } // Wait and check status time.Sleep(5 * time.Second) exec, err := api.GetExecution(ctx, execID) if err != nil { t.Logf("Get execution error: %v", err) return } t.Logf("Final status: %s", exec.Status) }) } // TestE2EControlMultipleOperations tests a sequence of control operations func TestE2EControlMultipleOperations(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("pause_resume_pause_stop_sequence", func(t *testing.T) { memberID := "robot_e2e_control_multi" setupE2ERobotForControl(t, memberID, "team_e2e_control") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthControl()) // Start execution result, err := api.TriggerManual(ctx, memberID, types.TriggerClock, nil) require.NoError(t, err) require.True(t, result.Accepted) // Wait for running state var execID string maxWait := 30 * time.Second deadline := time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(500 * time.Millisecond) executions, err := api.ListExecutions(ctx, memberID, nil) if err != nil || len(executions.Data) == 0 { continue } exec := executions.Data[0] if exec.Status == types.ExecRunning { execID = exec.ID break } } if execID == "" { t.Skip("Execution did not start in time") return } // Sequence: Pause → Resume → Pause → Stop operations := []struct { name string fn func() error }{ {"Pause", func() error { return api.PauseExecution(ctx, execID) }}, {"Resume", func() error { return api.ResumeExecution(ctx, execID) }}, {"Pause", func() error { return api.PauseExecution(ctx, execID) }}, {"Stop", func() error { return api.StopExecution(ctx, execID) }}, } for _, op := range operations { err := op.fn() if err != nil { t.Logf("%s error (may be expected): %v", op.name, err) // If execution already completed, stop the sequence exec, _ := api.GetExecution(ctx, execID) if exec != nil && (exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed || exec.Status == types.ExecCancelled) { t.Logf("Execution already finished: %s", exec.Status) return } } else { t.Logf("%s successful", op.name) } time.Sleep(2 * time.Second) } // Verify final state exec, err := api.GetExecution(ctx, execID) if err != nil { t.Logf("Get execution error: %v", err) return } t.Logf("Final status after operations: %s", exec.Status) }) } // TestE2EControlStatusQuery tests querying execution status during control func TestE2EControlStatusQuery(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("query_status_during_execution", func(t *testing.T) { memberID := "robot_e2e_control_status" setupE2ERobotForControl(t, memberID, "team_e2e_control") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthControl()) // Start execution result, err := api.TriggerManual(ctx, memberID, types.TriggerClock, nil) require.NoError(t, err) require.True(t, result.Accepted) // Track status changes statusHistory := make([]types.ExecStatus, 0) phaseHistory := make([]types.Phase, 0) maxWait := 120 * time.Second deadline := time.Now().Add(maxWait) lastStatus := types.ExecStatus("") lastPhase := types.Phase("") for time.Now().Before(deadline) { time.Sleep(1 * time.Second) executions, err := api.ListExecutions(ctx, memberID, nil) if err != nil || len(executions.Data) == 0 { continue } exec := executions.Data[0] // Track status changes if exec.Status != lastStatus { statusHistory = append(statusHistory, exec.Status) lastStatus = exec.Status t.Logf("Status changed: %s", exec.Status) } // Track phase changes if exec.Phase != lastPhase { phaseHistory = append(phaseHistory, exec.Phase) lastPhase = exec.Phase t.Logf("Phase changed: %s", exec.Phase) } // Also test GetExecutionStatus status, err := api.GetExecutionStatus(ctx, exec.ID) if err == nil && status != nil { // Status query should return valid data assert.NotEmpty(t, status.ID) assert.Equal(t, exec.Status, status.Status) } if exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed || exec.Status == types.ExecCancelled { break } } t.Logf("Status history: %v", statusHistory) t.Logf("Phase history: %v", phaseHistory) // Should have observed at least pending → running transition assert.GreaterOrEqual(t, len(statusHistory), 1, "Should observe at least one status") }) } // ==================== Helper Functions ==================== // setupE2ERobotForControl creates a robot for control tests func setupE2ERobotForControl(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() // Simple config for E2E testing - minimal tasks robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Simple E2E Test Robot", "duties": []string{"Say hello"}, // Very simple duty "rules": []string{"Keep responses under 50 words"}, }, "quota": map[string]interface{}{ "max": 3, "queue": 10, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, "intervene": map[string]interface{}{"enabled": true}, "event": map[string]interface{}{"enabled": true}, }, "clock": map[string]interface{}{ "mode": "interval", "every": "1h", }, "resources": map[string]interface{}{ "phases": map[string]interface{}{ "inspiration": "robot.inspiration", "goals": "robot.goals", "tasks": "tests.e2e-tasks", // Use simple E2E test task planner "run": "robot.validation", "validation": "tests.e2e-validation", // Use lenient E2E test validator "delivery": "robot.delivery", "learning": "robot.learning", }, "agents": []string{"experts.text-writer"}, }, "delivery": map[string]interface{}{ "email": map[string]interface{}{"enabled": false}, "webhook": map[string]interface{}{"enabled": false}, "process": map[string]interface{}{"enabled": false}, }, } configJSON, _ := json.Marshal(robotConfig) systemPrompt := `You are a simple E2E test robot. Your job is to say hello. When generating goals: create exactly 1 simple goal. When generating tasks: create exactly 1 simple task. Keep all outputs brief. No complex analysis needed.` err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "E2E Control Robot " + memberID, "system_prompt": systemPrompt, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot %s: %v", memberID, err) } } ================================================ FILE: agent/robot/api/e2e_event_test.go ================================================ package api_test // End-to-end tests for Event trigger flow // These tests use REAL LLM calls via Standard executor (not DryRun) // // Test Flow: Event Trigger → P1 (Goals) → P2 (Tasks) → P3 (Run) → P4 (Delivery) // Note: Event trigger SKIPS P0 (Inspiration) - event data provides the context // // Prerequisites: // - Valid LLM API keys (OPENAI_TEST_KEY or DEEPSEEK_API_KEY) // - Test assistants in yao-dev-app/assistants/robot/ // - Database connection (YAO_DB_PRIMARY) import ( "context" "encoding/json" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" "github.com/yaoapp/yao/agent/robot/api" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) // testAuthEvent returns test auth info for event E2E tests func testAuthEvent() *oauthtypes.AuthorizedInfo { return &oauthtypes.AuthorizedInfo{ UserID: "e2e-event-user", TeamID: "e2e-event-team", } } // TestE2EEventTriggerFullFlow tests the complete event trigger flow with real LLM calls // Flow: Event → P1 (Goals) → P2 (Tasks) → P3 (Run) → P4 (Delivery) func TestE2EEventTriggerFullFlow(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("complete_P1_to_P4_flow_with_webhook_event", func(t *testing.T) { memberID := "robot_e2e_event_001" setupE2ERobotForEvent(t, memberID, "team_e2e_event") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthEvent()) // Verify robot is loaded robot, err := api.GetRobot(ctx, memberID) require.NoError(t, err) require.NotNil(t, robot) // Trigger with webhook event - simulating external system notification result, err := api.Trigger(ctx, memberID, &api.TriggerRequest{ Type: types.TriggerEvent, Source: types.EventWebhook, EventType: "order.created", Data: map[string]interface{}{ "order_id": "ORD-2025-001", "customer": "John Doe", "total": 299.99, "items_count": 3, "priority": "high", "created_at": time.Now().Format(time.RFC3339), }, }) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.Accepted, "Event trigger should be accepted") t.Logf("Event trigger result: Accepted=%v, ExecutionID=%s", result.Accepted, result.ExecutionID) // Wait for execution to complete var exec *types.Execution maxWait := 120 * time.Second pollInterval := 2 * time.Second deadline := time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(pollInterval) executions, err := api.ListExecutions(ctx, memberID, &api.ExecutionQuery{ Page: 1, PageSize: 1, }) if err != nil || len(executions.Data) == 0 { continue } exec = executions.Data[0] t.Logf("Execution status: %s, phase: %s", exec.Status, exec.Phase) if exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed { break } } require.NotNil(t, exec, "Execution should exist") // E2E test validates the flow executes correctly isFinished := exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed assert.True(t, isFinished, "Execution should finish (completed or failed), got: %s", exec.Status) if exec.Status == types.ExecFailed { t.Logf("Execution finished with status=failed (acceptable for E2E): %s", exec.Error) } else { t.Logf("Execution finished with status=completed") } // Verify trigger type assert.Equal(t, types.TriggerEvent, exec.TriggerType, "Should be event trigger") // Event trigger skips P0, so Inspiration should be nil assert.Nil(t, exec.Inspiration, "P0 Inspiration should be nil for event trigger") // P1 Goals should always exist for event trigger assert.NotNil(t, exec.Goals, "P1 Goals should exist") // P2-P4 may or may not exist depending on where failure occurred if exec.Tasks != nil { t.Logf("P2 Tasks count: %d", len(exec.Tasks)) } if exec.Results != nil { t.Logf("P3 Results count: %d", len(exec.Results)) } if exec.Delivery != nil { t.Logf("P4 Delivery: RequestID=%s", exec.Delivery.RequestID) } t.Logf("Event trigger E2E completed") }) } // TestE2EEventTriggerDatabaseEvent tests event trigger from database changes func TestE2EEventTriggerDatabaseEvent(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("handles_database_event_source", func(t *testing.T) { memberID := "robot_e2e_event_db" setupE2ERobotForEvent(t, memberID, "team_e2e_event") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthEvent()) // Trigger with database event - simulating record change notification result, err := api.Trigger(ctx, memberID, &api.TriggerRequest{ Type: types.TriggerEvent, Source: types.EventDatabase, EventType: "user.updated", Data: map[string]interface{}{ "table": "users", "operation": "UPDATE", "record_id": 12345, "changes": map[string]interface{}{ "status": map[string]interface{}{ "old": "pending", "new": "active", }, }, }, }) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.Accepted) // Wait for execution maxWait := 120 * time.Second deadline := time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(2 * time.Second) executions, err := api.ListExecutions(ctx, memberID, nil) if err != nil || len(executions.Data) == 0 { continue } exec := executions.Data[0] if exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed { assert.Equal(t, types.ExecCompleted, exec.Status) t.Logf("Database event E2E completed") return } } t.Fatal("Execution did not complete in time") }) } // TestE2EEventTriggerVariousEventTypes tests different event types // Optimized: Only tests one representative event type to reduce CI time // The event handling logic is the same for all event types, so testing one is sufficient func TestE2EEventTriggerVariousEventTypes(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) // Test only one representative event type (notification) // All event types use the same code path, so one test is sufficient t.Run("webhook_event", func(t *testing.T) { memberID := "robot_e2e_event_webhook" setupE2ERobotForEvent(t, memberID, "team_e2e_event") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthEvent()) result, err := api.Trigger(ctx, memberID, &api.TriggerRequest{ Type: types.TriggerEvent, Source: types.EventWebhook, EventType: "notification.received", Data: map[string]interface{}{ "message": "Test notification", "priority": "normal", }, }) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.Accepted, "Event should be accepted") t.Logf("Event triggered: ExecutionID=%s", result.ExecutionID) // Wait for execution maxWait := 120 * time.Second deadline := time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(2 * time.Second) executions, err := api.ListExecutions(ctx, memberID, nil) if err != nil || len(executions.Data) == 0 { continue } exec := executions.Data[0] t.Logf("Event status: %s, phase: %s", exec.Status, exec.Phase) if exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed { if exec.Status == types.ExecFailed { t.Logf("Event execution failed: %s", exec.Error) } else { t.Logf("Event execution completed") } return } } t.Logf("Event execution did not complete in time (may be CI latency)") }) } // TestE2EEventTriggerWithComplexData tests event with nested/complex data structures func TestE2EEventTriggerWithComplexData(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("handles_complex_nested_data", func(t *testing.T) { memberID := "robot_e2e_event_complex" setupE2ERobotForEvent(t, memberID, "team_e2e_event") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthEvent()) // Complex nested event data result, err := api.Trigger(ctx, memberID, &api.TriggerRequest{ Type: types.TriggerEvent, Source: types.EventWebhook, EventType: "report.generated", Data: map[string]interface{}{ "report": map[string]interface{}{ "id": "RPT-2025-001", "type": "sales_summary", "period": "monthly", "generated": time.Now().Format(time.RFC3339), "department": "Sales", }, "metrics": []map[string]interface{}{ {"name": "total_sales", "value": 150000, "unit": "USD"}, {"name": "orders_count", "value": 450, "unit": "orders"}, {"name": "avg_order_value", "value": 333.33, "unit": "USD"}, }, "comparison": map[string]interface{}{ "previous_period": map[string]interface{}{ "total_sales": 140000, "orders_count": 420, "change_percent": 7.14, }, }, "highlights": []string{ "Sales increased by 7.14% compared to last month", "Top performing product: Widget Pro", "New customer acquisition up 15%", }, }, }) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.Accepted) // Wait for execution maxWait := 120 * time.Second deadline := time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(2 * time.Second) executions, err := api.ListExecutions(ctx, memberID, nil) if err != nil || len(executions.Data) == 0 { continue } exec := executions.Data[0] if exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed { assert.Equal(t, types.ExecCompleted, exec.Status, "Complex data event should complete") t.Logf("Complex data event E2E completed") return } } t.Fatal("Execution did not complete in time") }) } // ==================== Helper Functions ==================== // setupE2ERobotForEvent creates a robot configured for event trigger E2E tests func setupE2ERobotForEvent(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() // Simple config for E2E testing - minimal tasks robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Simple E2E Test Robot", "duties": []string{"Acknowledge events"}, // Very simple duty "rules": []string{"Keep responses under 50 words"}, }, "quota": map[string]interface{}{ "max": 5, "queue": 20, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": false}, "intervene": map[string]interface{}{"enabled": true}, "event": map[string]interface{}{"enabled": true}, }, "event": map[string]interface{}{ "types": []string{"*"}, }, "resources": map[string]interface{}{ "phases": map[string]interface{}{ "inspiration": "robot.inspiration", "goals": "robot.goals", "tasks": "tests.e2e-tasks", // Use simple E2E test task planner "run": "robot.validation", "validation": "tests.e2e-validation", // Use lenient E2E test validator "delivery": "robot.delivery", "learning": "robot.learning", }, "agents": []string{"experts.text-writer"}, }, "delivery": map[string]interface{}{ "email": map[string]interface{}{"enabled": false}, "webhook": map[string]interface{}{"enabled": false}, "process": map[string]interface{}{"enabled": false}, }, } configJSON, _ := json.Marshal(robotConfig) systemPrompt := `You are a simple E2E test robot. Your job is to acknowledge events. When generating goals: create exactly 1 simple goal. When generating tasks: create exactly 1 simple task. Keep all outputs brief. No complex analysis needed.` err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "E2E Event Test Robot " + memberID, "system_prompt": systemPrompt, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot %s: %v", memberID, err) } } ================================================ FILE: agent/robot/api/e2e_human_test.go ================================================ package api_test // End-to-end tests for Human intervention trigger flow // These tests use REAL LLM calls via Standard executor (not DryRun) // // Test Flow: Human Trigger → P1 (Goals) → P2 (Tasks) → P3 (Run) → P4 (Delivery) // Note: Human trigger SKIPS P0 (Inspiration) - user provides the input directly // // Prerequisites: // - Valid LLM API keys (OPENAI_TEST_KEY or DEEPSEEK_API_KEY) // - Test assistants in yao-dev-app/assistants/robot/ // - Database connection (YAO_DB_PRIMARY) import ( "context" "encoding/json" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/api" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) // testAuthHuman returns test auth info for human E2E tests func testAuthHuman() *oauthtypes.AuthorizedInfo { return &oauthtypes.AuthorizedInfo{ UserID: "e2e-human-user", TeamID: "e2e-human-team", } } // TestE2EHumanTriggerFullFlow tests the complete human intervention flow with real LLM calls // Flow: Human Input → P1 (Goals) → P2 (Tasks) → P3 (Run) → P4 (Delivery) func TestE2EHumanTriggerFullFlow(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("complete_P1_to_P4_flow_with_user_input", func(t *testing.T) { memberID := "robot_e2e_human_001" setupE2ERobotForHuman(t, memberID, "team_e2e_human") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthHuman()) // Verify robot is loaded robot, err := api.GetRobot(ctx, memberID) require.NoError(t, err) require.NotNil(t, robot) // Trigger with human input - user requesting a specific task result, err := api.Trigger(ctx, memberID, &api.TriggerRequest{ Type: types.TriggerHuman, Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ { Role: "user", Content: "Please write a brief summary of today's key tasks and priorities.", }, }, }) require.NoError(t, err) require.NotNil(t, result) // Human trigger returns Queued=true (goes through Intervene) t.Logf("Trigger result: Accepted=%v, Queued=%v, Message=%s", result.Accepted, result.Queued, result.Message) // Wait for execution to complete var exec *types.Execution maxWait := 120 * time.Second pollInterval := 2 * time.Second deadline := time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(pollInterval) executions, err := api.ListExecutions(ctx, memberID, &api.ExecutionQuery{ Page: 1, PageSize: 1, }) if err != nil || len(executions.Data) == 0 { continue } exec = executions.Data[0] t.Logf("Execution status: %s, phase: %s", exec.Status, exec.Phase) if exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed { break } } require.NotNil(t, exec, "Execution should exist") // E2E test validates the flow executes correctly isFinished := exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed assert.True(t, isFinished, "Execution should finish (completed or failed), got: %s", exec.Status) if exec.Status == types.ExecFailed { t.Logf("Execution finished with status=failed (acceptable for E2E): %s", exec.Error) } else { t.Logf("Execution finished with status=completed") } // Verify trigger type assert.Equal(t, types.TriggerHuman, exec.TriggerType, "Should be human trigger") // Human trigger skips P0, so Inspiration should be nil assert.Nil(t, exec.Inspiration, "P0 Inspiration should be nil for human trigger") // P1 Goals should always exist for human trigger assert.NotNil(t, exec.Goals, "P1 Goals should exist") // P2-P4 may or may not exist depending on where failure occurred if exec.Tasks != nil { t.Logf("P2 Tasks count: %d", len(exec.Tasks)) } if exec.Results != nil { t.Logf("P3 Results count: %d", len(exec.Results)) } if exec.Delivery != nil { t.Logf("P4 Delivery: RequestID=%s", exec.Delivery.RequestID) } t.Logf("Human trigger E2E completed") }) } // TestE2EHumanTriggerWithMultimodalInput tests human trigger with rich content func TestE2EHumanTriggerWithMultimodalInput(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) t.Run("handles_multipart_message_input", func(t *testing.T) { memberID := "robot_e2e_human_multi" setupE2ERobotForHuman(t, memberID, "team_e2e_human") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthHuman()) // Trigger with multipart message (text parts) result, err := api.Trigger(ctx, memberID, &api.TriggerRequest{ Type: types.TriggerHuman, Action: types.ActionGoalAdjust, Messages: []agentcontext.Message{ { Role: "user", Content: []map[string]interface{}{ { "type": "text", "text": "I need you to focus on the following priorities:", }, { "type": "text", "text": "1. Review pending tasks\n2. Summarize progress\n3. Identify blockers", }, }, }, }, }) require.NoError(t, err) require.NotNil(t, result) t.Logf("Multipart trigger result: Queued=%v", result.Queued) // Wait for execution maxWait := 120 * time.Second deadline := time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(2 * time.Second) executions, err := api.ListExecutions(ctx, memberID, nil) if err != nil || len(executions.Data) == 0 { continue } exec := executions.Data[0] if exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed { assert.Equal(t, types.ExecCompleted, exec.Status) t.Logf("Multipart input E2E completed") return } } t.Fatal("Execution did not complete in time") }) } // TestE2EHumanTriggerAllActions tests different intervention actions func TestE2EHumanTriggerAllActions(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ERobots(t) cleanupE2EExecutions(t) defer cleanupE2ERobots(t) defer cleanupE2EExecutions(t) actions := []struct { name string action types.InterventionAction input string }{ { name: "task_add", action: types.ActionTaskAdd, input: "Add a new task: Review system logs for errors", }, { name: "goal_adjust", action: types.ActionGoalAdjust, input: "Adjust goal: Focus on performance optimization instead of new features", }, { name: "instruct", action: types.ActionInstruct, input: "Please prioritize security review as the top task", }, } for i, tc := range actions { t.Run(tc.name, func(t *testing.T) { memberID := "robot_e2e_human_action_" + tc.name setupE2ERobotForHuman(t, memberID, "team_e2e_human") // Start fresh for each action test if i == 0 { err := api.Start() require.NoError(t, err) } ctx := types.NewContext(context.Background(), testAuthHuman()) result, err := api.Trigger(ctx, memberID, &api.TriggerRequest{ Type: types.TriggerHuman, Action: tc.action, Messages: []agentcontext.Message{ {Role: "user", Content: tc.input}, }, }) require.NoError(t, err) require.NotNil(t, result) t.Logf("Action %s: Queued=%v", tc.action, result.Queued) // Wait for execution (shorter timeout for action tests) maxWait := 90 * time.Second deadline := time.Now().Add(maxWait) for time.Now().Before(deadline) { time.Sleep(2 * time.Second) executions, err := api.ListExecutions(ctx, memberID, nil) if err != nil || len(executions.Data) == 0 { continue } exec := executions.Data[0] if exec.Status == types.ExecCompleted || exec.Status == types.ExecFailed { if exec.Status == types.ExecFailed { t.Logf("Action %s failed: %s", tc.action, exec.Error) } else { t.Logf("Action %s completed successfully", tc.action) } return } } t.Logf("Action %s: execution did not complete in time (may still be running)", tc.action) }) } // Stop after all action tests api.Stop() } // ==================== Helper Functions ==================== // setupE2ERobotForHuman creates a robot configured for human intervention E2E tests func setupE2ERobotForHuman(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() // Simple config for E2E testing - minimal tasks robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Simple E2E Test Robot", "duties": []string{"Echo user input"}, // Very simple duty "rules": []string{"Keep responses under 50 words"}, }, "quota": map[string]interface{}{ "max": 5, "queue": 20, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": false}, "intervene": map[string]interface{}{"enabled": true}, "event": map[string]interface{}{"enabled": true}, }, "resources": map[string]interface{}{ "phases": map[string]interface{}{ "inspiration": "robot.inspiration", "goals": "robot.goals", "tasks": "tests.e2e-tasks", // Use simple E2E test task planner "run": "robot.validation", "validation": "tests.e2e-validation", // Use lenient E2E test validator "delivery": "robot.delivery", "learning": "robot.learning", }, "agents": []string{"experts.text-writer"}, }, "delivery": map[string]interface{}{ "email": map[string]interface{}{"enabled": false}, "webhook": map[string]interface{}{"enabled": false}, "process": map[string]interface{}{"enabled": false}, }, } configJSON, _ := json.Marshal(robotConfig) systemPrompt := `You are a simple E2E test robot. Your job is to echo user requests. When generating goals: create exactly 1 simple goal. When generating tasks: create exactly 1 simple task. Keep all outputs brief. No complex analysis needed.` err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "E2E Human Test Robot " + memberID, "system_prompt": systemPrompt, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot %s: %v", memberID, err) } } ================================================ FILE: agent/robot/api/e2e_interact_test.go ================================================ package api_test import ( "context" "encoding/json" "strings" "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" "github.com/yaoapp/yao/agent/robot/api" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // TestE2EInteractNewAssignment tests the full Interact flow for a new task assignment. // With the conversational Host Agent, the first turn may return natural language // (waiting_for_more) or an action decision depending on request clarity. func TestE2EInteractNewAssignment(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupInteractRobots(t) cleanupInteractExecutions(t) defer cleanupInteractRobots(t) defer cleanupInteractExecutions(t) t.Run("assign_via_interact_creates_execution_and_gets_host_reply", func(t *testing.T) { memberID := "robot_e2e_interact_assign" setupInteractRobot(t, memberID, "team_e2e_interact") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuth()) robot, err := api.GetRobot(ctx, memberID) require.NoError(t, err) require.NotNil(t, robot) result, err := api.Interact(ctx, memberID, &api.InteractRequest{ Source: types.InteractSourceUI, Message: "Please write a short greeting email for our team meeting tomorrow morning.", }) require.NoError(t, err) require.NotNil(t, result) t.Logf("Interact result: status=%s, message=%s, reply=%s, exec_id=%s, wait_for_more=%v", result.Status, result.Message, result.Reply, result.ExecutionID, result.WaitForMore) assert.NotEmpty(t, result.ExecutionID, "should create an execution") assert.NotEmpty(t, result.ChatID, "should have a chat session") assert.NotEmpty(t, result.Reply, "Host Agent should provide a reply") validStatuses := []string{"confirmed", "waiting_for_more", "adjusted", "acknowledged"} assert.Contains(t, validStatuses, result.Status, "status should be one of the valid Host Agent action outcomes") if result.Status == "confirmed" { time.Sleep(2 * time.Second) executions, err := api.ListExecutions(ctx, memberID, &api.ExecutionQuery{Page: 1, PageSize: 5}) require.NoError(t, err) assert.Greater(t, len(executions.Data), 0, "confirmed execution should exist in store") } }) } // TestE2EInteractStream tests the streaming version end-to-end. func TestE2EInteractStream(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupInteractRobots(t) cleanupInteractExecutions(t) defer cleanupInteractRobots(t) defer cleanupInteractExecutions(t) t.Run("stream_assign_returns_chunks_and_valid_result", func(t *testing.T) { memberID := "robot_e2e_interact_stream" setupInteractRobot(t, memberID, "team_e2e_interact") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuth()) var mu sync.Mutex var chunks []*standard.StreamChunk streamFn := func(chunk *standard.StreamChunk) int { mu.Lock() defer mu.Unlock() chunks = append(chunks, chunk) return 0 } result, err := api.InteractStream(ctx, memberID, &api.InteractRequest{ Source: types.InteractSourceUI, Message: "Help me draft a brief status update email about completing the Q4 report.", }, streamFn) require.NoError(t, err) require.NotNil(t, result) mu.Lock() chunkCount := len(chunks) var textChunks []string for _, c := range chunks { if c.Type == "text" && c.Delta { textChunks = append(textChunks, c.Content) } } mu.Unlock() combined := strings.Join(textChunks, "") t.Logf("Stream received %d total chunks, %d text chunks, combined length: %d", chunkCount, len(textChunks), len(combined)) t.Logf("Result: status=%s, exec_id=%s, reply_len=%d, wait_for_more=%v", result.Status, result.ExecutionID, len(result.Reply), result.WaitForMore) assert.Greater(t, len(textChunks), 0, "should receive streaming text chunks from Host Agent") assert.NotEmpty(t, combined, "combined text should not be empty") assert.NotEmpty(t, result.ExecutionID, "should create an execution") assert.NotEmpty(t, result.Reply, "final result should contain reply") validStatuses := []string{"confirmed", "waiting_for_more", "adjusted"} assert.Contains(t, validStatuses, result.Status) }) } // TestE2EInteractMultiTurn tests a multi-turn conversation: // Turn 1: Send vague message -> Host Agent replies conversationally (waiting_for_more) // Turn 2: Send clear confirmation -> Host Agent returns action JSON (confirmed or other action) func TestE2EInteractMultiTurn(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupInteractRobots(t) cleanupInteractExecutions(t) defer cleanupInteractRobots(t) defer cleanupInteractExecutions(t) t.Run("multi_turn_assign_conversation", func(t *testing.T) { memberID := "robot_e2e_interact_multiturn" setupInteractRobot(t, memberID, "team_e2e_interact") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuth()) // Turn 1: Send vague message — expect conversational reply result1, err := api.Interact(ctx, memberID, &api.InteractRequest{ Source: types.InteractSourceUI, Message: "Do something with emails.", }) require.NoError(t, err) require.NotNil(t, result1) t.Logf("Turn 1: status=%s, reply=%s, exec_id=%s, wait_for_more=%v", result1.Status, result1.Reply, result1.ExecutionID, result1.WaitForMore) assert.NotEmpty(t, result1.ExecutionID) assert.NotEmpty(t, result1.Reply) // Turn 2: Clarify/confirm with the same execution_id result2, err := api.Interact(ctx, memberID, &api.InteractRequest{ ExecutionID: result1.ExecutionID, Source: types.InteractSourceUI, Message: "Yes, please write a brief thank-you email to the design team for their Q4 work. Go ahead and confirm.", }) require.NoError(t, err) require.NotNil(t, result2) t.Logf("Turn 2: status=%s, reply=%s, exec_id=%s, wait_for_more=%v", result2.Status, result2.Reply, result2.ExecutionID, result2.WaitForMore) assert.NotEmpty(t, result2.Reply) assert.Equal(t, result1.ExecutionID, result2.ExecutionID, "should be same execution") validStatuses := []string{"confirmed", "waiting_for_more", "adjusted", "acknowledged"} assert.Contains(t, validStatuses, result2.Status, "second turn should produce a valid outcome") }) } // TestE2EInteractStreamMultiTurn tests multi-turn with streaming. func TestE2EInteractStreamMultiTurn(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupInteractRobots(t) cleanupInteractExecutions(t) defer cleanupInteractRobots(t) defer cleanupInteractExecutions(t) t.Run("stream_multi_turn", func(t *testing.T) { memberID := "robot_e2e_interact_stream_mt" setupInteractRobot(t, memberID, "team_e2e_interact") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuth()) // Turn 1 var mu1 sync.Mutex var chunks1 []*standard.StreamChunk result1, err := api.InteractStream(ctx, memberID, &api.InteractRequest{ Source: types.InteractSourceUI, Message: "I need help with something.", }, func(chunk *standard.StreamChunk) int { mu1.Lock() chunks1 = append(chunks1, chunk) mu1.Unlock() return 0 }) require.NoError(t, err) require.NotNil(t, result1) mu1.Lock() t.Logf("Turn 1 stream: %d chunks, status=%s, reply=%s, wait_for_more=%v", len(chunks1), result1.Status, result1.Reply, result1.WaitForMore) mu1.Unlock() assert.NotEmpty(t, result1.ExecutionID) assert.NotEmpty(t, result1.Reply) // Turn 2: Clarify with same execution_id var mu2 sync.Mutex var chunks2 []*standard.StreamChunk result2, err := api.InteractStream(ctx, memberID, &api.InteractRequest{ ExecutionID: result1.ExecutionID, Source: types.InteractSourceUI, Message: "Please compose a short farewell message for a colleague leaving the team. Yes, go ahead.", }, func(chunk *standard.StreamChunk) int { mu2.Lock() chunks2 = append(chunks2, chunk) mu2.Unlock() return 0 }) require.NoError(t, err) require.NotNil(t, result2) mu2.Lock() t.Logf("Turn 2 stream: %d chunks, status=%s, reply=%s, wait_for_more=%v", len(chunks2), result2.Status, result2.Reply, result2.WaitForMore) mu2.Unlock() assert.NotEmpty(t, result2.Reply) assert.Equal(t, result1.ExecutionID, result2.ExecutionID) }) } // ==================== Helper Functions ==================== func setupInteractRobot(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Email Assistant", "duties": []string{"Write and manage emails"}, "rules": []string{"Always confirm before sending", "Keep emails professional"}, }, "quota": map[string]interface{}{ "max": 5, "queue": 20, "priority": 5, }, "triggers": map[string]interface{}{ "intervene": map[string]interface{}{"enabled": true}, }, "resources": map[string]interface{}{ "phases": map[string]interface{}{ "inspiration": "robot.inspiration", "goals": "robot.goals", "tasks": "robot.tasks", "run": "robot.validation", "validation": "robot.validation", "delivery": "robot.delivery", "learning": "robot.learning", "host": "robot.host", }, "agents": []string{}, }, } configJSON, _ := json.Marshal(robotConfig) systemPrompt := `You are an email assistant for E2E testing of the Interact API. When asked to write an email, confirm the task and generate a brief email draft.` err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "E2E Interact Test Robot " + memberID, "system_prompt": systemPrompt, "status": "active", "role_id": "member", "autonomous_mode": false, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert interact robot %s: %v", memberID, err) } } func cleanupInteractRobots(t *testing.T) { m := model.Select("__yao.member") if m == nil { return } qb := capsule.Query() _, err := qb.Table(m.MetaData.Table.Name).Where("member_id", "like", "robot_e2e_interact%").Delete() if err != nil { t.Logf("Warning: cleanup interact robots: %v", err) } } func cleanupInteractExecutions(t *testing.T) { m := model.Select("__yao.agent.execution") if m == nil { return } qb := capsule.Query() _, err := qb.Table(m.MetaData.Table.Name).Where("member_id", "like", "robot_e2e_interact%").Delete() if err != nil { t.Logf("Warning: cleanup interact executions: %v", err) } } ================================================ FILE: agent/robot/api/e2e_suspend_test.go ================================================ package api_test // End-to-end tests for V2 Suspend/Resume flow // Tests the complete lifecycle: execution → need_input → suspend → reply → resume → complete/re-suspend // // Prerequisites: // - Valid LLM API keys // - Test assistants: tests.robot-need-input, experts.text-writer // - Database connection (YAO_DB_PRIMARY) import ( "context" "encoding/json" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/api" "github.com/yaoapp/yao/agent/robot/store" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) func testAuthSuspend() *oauthtypes.AuthorizedInfo { return &oauthtypes.AuthorizedInfo{ UserID: "e2e-suspend-user", TeamID: "e2e-suspend-team", } } // triggerSuspendRobot triggers a robot via the Trigger API (human trigger path) func triggerSuspendRobot(t *testing.T, ctx *types.Context, memberID string, message string) *api.TriggerResult { t.Helper() result, err := api.Trigger(ctx, memberID, &api.TriggerRequest{ Type: types.TriggerHuman, Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: "user", Content: message}, }, }) require.NoError(t, err) require.NotNil(t, result) if !result.Accepted { t.Fatalf("Trigger not accepted: %s", result.Message) } return result } // waitForStatus polls execution status until it matches one of the expected statuses func waitForStatus(t *testing.T, execID string, statuses []types.ExecStatus, timeout time.Duration) *types.Execution { t.Helper() deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { time.Sleep(time.Second) exec := getExecution(t, execID) if exec == nil { continue } for _, s := range statuses { if exec.Status == s { return exec } } } return nil } // TestE2ENormalExecutionNoSuspend verifies that a normal execution (no need_input) // completes without entering the suspend path. func TestE2ENormalExecutionNoSuspend(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ESuspendRobots(t) defer cleanupE2ESuspendRobots(t) memberID := "robot_e2e_suspend_001" setupE2ESuspendRobotWithTasksPlanner(t, memberID, "team_e2e_suspend", []string{"experts.text-writer"}, "tests.e2e-tasks") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthSuspend()) result := triggerSuspendRobot(t, ctx, memberID, "Write a one-sentence greeting") exec := waitForStatus(t, result.ExecutionID, []types.ExecStatus{types.ExecCompleted, types.ExecFailed}, 60*time.Second) require.NotNil(t, exec, "Execution should exist and reach terminal state") if exec.Status == types.ExecFailed { t.Logf("Execution failed with error: %s", exec.Error) } assert.Equal(t, types.ExecCompleted, exec.Status, "Normal execution should complete") assert.NotEmpty(t, exec.ChatID, "ChatID should be set") assert.Nil(t, exec.ResumeContext, "No resume context for normal execution") assert.Empty(t, exec.WaitingTaskID, "No waiting task for normal execution") } // TestE2ESuspendResumeFlow tests the full suspend-resume lifecycle: // 1. Trigger execution with robot-need-input assistant (signals need_input) // 2. Verify execution enters waiting status // 3. Reply to resume execution via api.Interact // 4. Verify execution re-suspends (since robot-need-input always signals need_input) func TestE2ESuspendResumeFlow(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ESuspendRobots(t) defer cleanupE2ESuspendRobots(t) memberID := "robot_e2e_suspend_002" setupE2ESuspendRobot(t, memberID, "team_e2e_suspend", []string{"tests.robot-need-input"}) err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthSuspend()) // Step 1: Trigger execution — the robot-need-input assistant always returns need_input result := triggerSuspendRobot(t, ctx, memberID, "Analyze sales data") execID := result.ExecutionID // Step 2: Wait for the execution to reach waiting status exec := waitForStatus(t, execID, []types.ExecStatus{types.ExecWaiting, types.ExecCompleted, types.ExecFailed}, 60*time.Second) require.NotNil(t, exec, "Execution should exist") require.Equal(t, types.ExecWaiting, exec.Status, "Execution should be in waiting status") assert.NotEmpty(t, exec.WaitingTaskID, "WaitingTaskID should be set") assert.NotEmpty(t, exec.WaitingQuestion, "WaitingQuestion should be set") assert.NotNil(t, exec.WaitingSince, "WaitingSince should be set") assert.NotNil(t, exec.ResumeContext, "ResumeContext should be set") assert.NotEmpty(t, exec.ChatID, "ChatID should be set") t.Logf("Execution suspended: execID=%s task=%s question=%s", execID, exec.WaitingTaskID, exec.WaitingQuestion) // Step 3: Resume via api.Interact (reply to the waiting execution) interactResult, err := api.Interact(ctx, memberID, &api.InteractRequest{ ExecutionID: execID, Message: "Use the last 30 days for analysis", }) require.NoError(t, err) require.NotNil(t, interactResult) // Since robot-need-input always signals need_input, the resumed execution // will re-suspend. The Interact API returns "waiting" status in this case. assert.Equal(t, "waiting", interactResult.Status, "Should re-suspend since assistant always signals need_input") t.Logf("Interact result: status=%s message=%s", interactResult.Status, interactResult.Message) // Step 4: Verify the execution is in waiting status again (re-suspended) exec = getExecution(t, execID) require.NotNil(t, exec) assert.Equal(t, types.ExecWaiting, exec.Status, "Execution should be waiting again after re-suspend") assert.NotNil(t, exec.ResumeContext, "ResumeContext should be set after re-suspend") } // TestE2EReplyShortcut tests the Reply semantic shortcut func TestE2EReplyShortcut(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ESuspendRobots(t) defer cleanupE2ESuspendRobots(t) memberID := "robot_e2e_suspend_004" setupE2ESuspendRobot(t, memberID, "team_e2e_suspend", []string{"tests.robot-need-input"}) err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthSuspend()) result := triggerSuspendRobot(t, ctx, memberID, "Check inventory levels") exec := waitForStatus(t, result.ExecutionID, []types.ExecStatus{types.ExecWaiting}, 60*time.Second) require.NotNil(t, exec, "Execution should reach waiting status") require.Equal(t, types.ExecWaiting, exec.Status) // Use Reply shortcut replyResult, err := api.Reply(ctx, memberID, result.ExecutionID, exec.WaitingTaskID, "Use warehouse A data") require.NoError(t, err) require.NotNil(t, replyResult) assert.Contains(t, []string{"waiting", "resumed"}, replyResult.Status) t.Logf("Reply result: status=%s", replyResult.Status) } // TestE2EResumeContextPersistence verifies that suspend state is properly persisted // and can be loaded back from the database. func TestE2EResumeContextPersistence(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ESuspendRobots(t) defer cleanupE2ESuspendRobots(t) memberID := "robot_e2e_suspend_003" setupE2ESuspendRobot(t, memberID, "team_e2e_suspend", []string{"tests.robot-need-input"}) err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthSuspend()) result := triggerSuspendRobot(t, ctx, memberID, "Analyze user behavior") exec := waitForStatus(t, result.ExecutionID, []types.ExecStatus{types.ExecWaiting, types.ExecCompleted, types.ExecFailed}, 60*time.Second) require.NotNil(t, exec) if exec.Status != types.ExecWaiting { t.Skipf("Execution did not reach waiting status (status=%s), skipping persistence test", exec.Status) } // Load from DB directly using store to verify persistence execStore := store.NewExecutionStore() record, err := execStore.Get(context.Background(), result.ExecutionID) require.NoError(t, err) require.NotNil(t, record) assert.Equal(t, types.ExecWaiting, record.Status) assert.NotEmpty(t, record.WaitingTaskID) assert.NotEmpty(t, record.WaitingQuestion) assert.NotNil(t, record.WaitingSince) assert.NotNil(t, record.ResumeContext) assert.Equal(t, exec.ChatID, record.ChatID) // Verify resume context deserialization restored := record.ToExecution() assert.NotNil(t, restored.ResumeContext) assert.GreaterOrEqual(t, restored.ResumeContext.TaskIndex, 0) t.Logf("Persisted resume context: TaskIndex=%d, PreviousResults=%d", restored.ResumeContext.TaskIndex, len(restored.ResumeContext.PreviousResults)) } // TestE2EInteractRequiresExecutionID tests that Interact API returns error when // execution_id is not provided (Host Agent deferred). func TestE2EInteractRequiresExecutionID(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuthSuspend()) _, err := api.Interact(ctx, "some-member", &api.InteractRequest{ Message: "hello", }) assert.Error(t, err) assert.Contains(t, err.Error(), "execution_id is required") } // TestE2EInteractWithNonWaitingExecution tests that Interact API returns error // when trying to resume an execution that is not in waiting status. func TestE2EInteractWithNonWaitingExecution(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test - requires real LLM calls") } testutils.Prepare(t) defer testutils.Clean(t) cleanupE2ESuspendRobots(t) defer cleanupE2ESuspendRobots(t) memberID := "robot_e2e_suspend_005" setupE2ESuspendRobotWithTasksPlanner(t, memberID, "team_e2e_suspend", []string{"experts.text-writer"}, "tests.e2e-tasks") err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), testAuthSuspend()) result := triggerSuspendRobot(t, ctx, memberID, "Say hello") // Wait for completion exec := waitForStatus(t, result.ExecutionID, []types.ExecStatus{types.ExecCompleted, types.ExecFailed}, 60*time.Second) require.NotNil(t, exec, "Execution should reach terminal state") // Try to interact with the completed execution _, err = api.Interact(ctx, memberID, &api.InteractRequest{ ExecutionID: result.ExecutionID, Message: "This should fail", }) assert.Error(t, err) assert.Contains(t, err.Error(), "cannot interact") } // ============================================================================ // Helper Functions // ============================================================================ func setupE2ESuspendRobotWithTasksPlanner(t *testing.T, memberID, teamID string, agents []string, tasksPlanner string) { t.Helper() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "V2 Suspend Test Robot", "duties": []string{"Execute test tasks"}, "rules": []string{"Keep responses under 50 words"}, }, "resources": map[string]interface{}{ "phases": map[string]interface{}{ "inspiration": "robot.inspiration", "goals": "robot.goals", "tasks": tasksPlanner, "run": "robot.validation", "delivery": "robot.delivery", "learning": "robot.learning", }, "agents": agents, }, "quota": map[string]interface{}{ "max": 5, "queue": 20, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": false}, "intervene": map[string]interface{}{"enabled": true}, "event": map[string]interface{}{"enabled": true}, }, "delivery": map[string]interface{}{ "email": map[string]interface{}{"enabled": false}, "webhook": map[string]interface{}{"enabled": false}, "process": map[string]interface{}{"enabled": false}, }, } configJSON, err := json.Marshal(robotConfig) require.NoError(t, err) m := model.Select("__yao.member") require.NotNil(t, m) tableName := m.MetaData.Table.Name qb := capsule.Query() err = qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "E2E Suspend Test Robot " + memberID, "system_prompt": "You are a simple E2E test robot. Your job is to execute tasks.\nWhen generating goals: create exactly 1 simple goal.\nWhen generating tasks: create exactly 1 simple task.\nKeep all outputs brief.", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) require.NoError(t, err) } func setupE2ESuspendRobot(t *testing.T, memberID, teamID string, agents []string) { t.Helper() setupE2ESuspendRobotWithTasksPlanner(t, memberID, teamID, agents, "tests.e2e-suspend-tasks") } func cleanupE2ESuspendRobots(t *testing.T) { t.Helper() mod := model.Select("__yao.member") if mod == nil { return } mod.DeleteWhere(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "member_id", OP: "like", Value: "robot_e2e_suspend_%"}, }, }) } func getExecution(t *testing.T, execID string) *types.Execution { t.Helper() execStore := store.NewExecutionStore() record, err := execStore.Get(context.Background(), execID) if err != nil || record == nil { return nil } return record.ToExecution() } ================================================ FILE: agent/robot/api/execution.go ================================================ package api import ( "context" "fmt" "sync" "github.com/yaoapp/yao/agent/robot/store" "github.com/yaoapp/yao/agent/robot/types" ) // executionStore singleton var ( execStore *store.ExecutionStore execStoreOnce sync.Once ) // getExecutionStore returns the singleton execution store func getExecutionStore() *store.ExecutionStore { execStoreOnce.Do(func() { execStore = store.NewExecutionStore() }) return execStore } // ResetExecutionStore resets the singleton for testing purposes // This should only be called in tests func ResetExecutionStore() { execStoreOnce = sync.Once{} execStore = nil } // ==================== Execution Query API ==================== // These functions query and manage execution history // GetExecution returns a specific execution by ID func GetExecution(ctx *types.Context, execID string) (*types.Execution, error) { if execID == "" { return nil, fmt.Errorf("execution_id is required") } // Try to get from execution store record, err := getExecutionStore().Get(context.Background(), execID) if err != nil { return nil, fmt.Errorf("failed to get execution: %w", err) } if record == nil { return nil, fmt.Errorf("execution not found: %s", execID) } return record.ToExecution(), nil } // ListExecutions returns execution history for a robot func ListExecutions(ctx *types.Context, memberID string, query *ExecutionQuery) (*ExecutionResult, error) { if memberID == "" { return nil, fmt.Errorf("member_id is required") } if query == nil { query = &ExecutionQuery{} } query.applyDefaults() opts := &store.ListOptions{ MemberID: memberID, Page: query.Page, PageSize: query.PageSize, OrderBy: "start_time desc", } if query.Status != "" { opts.Status = query.Status } if len(query.ExcludeStatuses) > 0 { opts.ExcludeStatuses = query.ExcludeStatuses } if query.Trigger != "" { opts.TriggerType = query.Trigger } result, err := getExecutionStore().List(context.Background(), opts) if err != nil { return nil, fmt.Errorf("failed to list executions: %w", err) } executions := make([]*types.Execution, 0, len(result.Data)) for _, record := range result.Data { executions = append(executions, record.ToExecution()) } return &ExecutionResult{ Data: executions, Total: result.Total, Page: result.Page, PageSize: result.PageSize, }, nil } // ==================== Execution Control API ==================== // These functions control running executions // PauseExecution pauses a running execution func PauseExecution(ctx *types.Context, execID string) error { if execID == "" { return fmt.Errorf("execution_id is required") } mgr, err := getManager() if err != nil { return err } if err := mgr.PauseExecution(ctx, execID); err != nil { return err } // Update database status to paused return getExecutionStore().UpdateStatus(context.Background(), execID, types.ExecPaused, "") } // ResumeExecution resumes a paused execution func ResumeExecution(ctx *types.Context, execID string) error { if execID == "" { return fmt.Errorf("execution_id is required") } mgr, err := getManager() if err != nil { return err } if err := mgr.ResumeExecution(ctx, execID); err != nil { return err } // Update database status back to running return getExecutionStore().UpdateStatus(context.Background(), execID, types.ExecRunning, "") } // StopExecution stops a running execution func StopExecution(ctx *types.Context, execID string) error { if execID == "" { return fmt.Errorf("execution_id is required") } mgr, err := getManager() if err != nil { return err } if err := mgr.StopExecution(ctx, execID); err != nil { return err } // Update database status to cancelled return getExecutionStore().UpdateStatus(context.Background(), execID, types.ExecCancelled, "User cancelled") } // ==================== Execution Status API ==================== // GetExecutionStatus returns the current status of an execution // This combines stored data with runtime state func GetExecutionStatus(ctx *types.Context, execID string) (*types.Execution, error) { if execID == "" { return nil, fmt.Errorf("execution_id is required") } // Get from store first exec, err := GetExecution(ctx, execID) if err != nil { return nil, err } // If manager is running, check for runtime state mgr, mgrErr := getManager() if mgrErr == nil { // Check if execution is being tracked (running) ctrlExec, ctrlErr := mgr.GetExecutionStatus(execID) if ctrlErr == nil && ctrlExec != nil { // Update with runtime state exec.Status = ctrlExec.Status exec.Phase = ctrlExec.Phase } } return exec, nil } ================================================ FILE: agent/robot/api/execution_test.go ================================================ package api_test import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/robot/api" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // TestGetExecutionValidation tests parameter validation for GetExecution func TestGetExecutionValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) t.Run("returns error for empty execution_id", func(t *testing.T) { exec, err := api.GetExecution(ctx, "") assert.Error(t, err) assert.Nil(t, exec) assert.Contains(t, err.Error(), "execution_id is required") }) t.Run("returns error for non-existent execution", func(t *testing.T) { exec, err := api.GetExecution(ctx, "non_existent_exec_id_xyz") assert.Error(t, err) assert.Nil(t, exec) }) } // TestListExecutionsValidation tests parameter validation for ListExecutions func TestListExecutionsValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) t.Run("returns error for empty member_id", func(t *testing.T) { result, err := api.ListExecutions(ctx, "", nil) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("applies default pagination when query is nil", func(t *testing.T) { result, err := api.ListExecutions(ctx, "test_member", nil) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, 1, result.Page) assert.Equal(t, 20, result.PageSize) }) t.Run("caps pagesize at 100", func(t *testing.T) { result, err := api.ListExecutions(ctx, "test_member", &api.ExecutionQuery{ Page: 1, PageSize: 200, }) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, 100, result.PageSize) }) } // TestPauseExecutionValidation tests parameter validation for PauseExecution func TestPauseExecutionValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) t.Run("returns error for empty execution_id", func(t *testing.T) { err := api.PauseExecution(ctx, "") assert.Error(t, err) assert.Contains(t, err.Error(), "execution_id is required") }) t.Run("returns error when manager not started", func(t *testing.T) { err := api.PauseExecution(ctx, "test_exec_id") assert.Error(t, err) assert.Contains(t, err.Error(), "not started") }) } // TestResumeExecutionValidation tests parameter validation for ResumeExecution func TestResumeExecutionValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) t.Run("returns error for empty execution_id", func(t *testing.T) { err := api.ResumeExecution(ctx, "") assert.Error(t, err) assert.Contains(t, err.Error(), "execution_id is required") }) t.Run("returns error when manager not started", func(t *testing.T) { err := api.ResumeExecution(ctx, "test_exec_id") assert.Error(t, err) assert.Contains(t, err.Error(), "not started") }) } // TestStopExecutionValidation tests parameter validation for StopExecution func TestStopExecutionValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) t.Run("returns error for empty execution_id", func(t *testing.T) { err := api.StopExecution(ctx, "") assert.Error(t, err) assert.Contains(t, err.Error(), "execution_id is required") }) t.Run("returns error when manager not started", func(t *testing.T) { err := api.StopExecution(ctx, "test_exec_id") assert.Error(t, err) assert.Contains(t, err.Error(), "not started") }) } // TestGetExecutionStatusValidation tests parameter validation for GetExecutionStatus func TestGetExecutionStatusValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) t.Run("returns error for empty execution_id", func(t *testing.T) { exec, err := api.GetExecutionStatus(ctx, "") assert.Error(t, err) assert.Nil(t, exec) assert.Contains(t, err.Error(), "execution_id is required") }) t.Run("returns error for non-existent execution", func(t *testing.T) { exec, err := api.GetExecutionStatus(ctx, "non_existent_exec_id_xyz") assert.Error(t, err) assert.Nil(t, exec) }) } // TestExecutionControlWithManagerStarted tests execution control APIs when manager is running func TestExecutionControlWithManagerStarted(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Start manager err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), nil) t.Run("pause returns error for non-existent execution", func(t *testing.T) { err := api.PauseExecution(ctx, "non_existent_exec_id_xyz") assert.Error(t, err) }) t.Run("resume returns error for non-existent execution", func(t *testing.T) { err := api.ResumeExecution(ctx, "non_existent_exec_id_xyz") assert.Error(t, err) }) t.Run("stop returns error for non-existent execution", func(t *testing.T) { err := api.StopExecution(ctx, "non_existent_exec_id_xyz") assert.Error(t, err) }) } ================================================ FILE: agent/robot/api/interact.go ================================================ package api import ( "fmt" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/manager" "github.com/yaoapp/yao/agent/robot/types" ) // InteractRequest represents a unified interaction with a robot. type InteractRequest struct { ExecutionID string `json:"execution_id,omitempty"` TaskID string `json:"task_id,omitempty"` Source types.InteractSource `json:"source,omitempty"` Message string `json:"message"` Action string `json:"action,omitempty"` } // InteractResult is the response from an interaction. type InteractResult struct { ExecutionID string `json:"execution_id,omitempty"` Status string `json:"status"` Message string `json:"message,omitempty"` ChatID string `json:"chat_id,omitempty"` Reply string `json:"reply,omitempty"` WaitForMore bool `json:"wait_for_more,omitempty"` } // Interact handles all human-robot interactions through a unified entry point. // // Routing logic: // - If manager is running, delegate to Manager.HandleInteract (full V2 flow with Host Agent) // - Otherwise, use legacy direct-executor path for backward compatibility func Interact(ctx *types.Context, memberID string, req *InteractRequest) (*InteractResult, error) { if memberID == "" { return nil, fmt.Errorf("member_id is required") } if req == nil { return nil, fmt.Errorf("interact request is required") } // Try V2 path via manager mgr, err := getManager() if err == nil && mgr != nil { return managerInteract(ctx, mgr, memberID, req) } // V1 fallback: require execution_id for direct resume if req.ExecutionID == "" { return nil, fmt.Errorf("execution_id is required for current version (Host Agent deferred)") } return legacyResume(ctx, req) } // managerInteract delegates to the manager's HandleInteract. func managerInteract(ctx *types.Context, mgr *manager.Manager, memberID string, req *InteractRequest) (*InteractResult, error) { mgrReq := &manager.InteractRequest{ ExecutionID: req.ExecutionID, TaskID: req.TaskID, Source: req.Source, Message: req.Message, Action: req.Action, } resp, err := mgr.HandleInteract(ctx, memberID, mgrReq) if err != nil { return nil, err } return &InteractResult{ ExecutionID: resp.ExecutionID, Status: resp.Status, Message: resp.Message, ChatID: resp.ChatID, Reply: resp.Reply, WaitForMore: resp.WaitForMore, }, nil } // legacyResume handles the direct executor resume path (backward compatible). func legacyResume(ctx *types.Context, req *InteractRequest) (*InteractResult, error) { executor := standard.New() err := executor.Resume(ctx, req.ExecutionID, req.Message) if err != nil { if err == types.ErrExecutionSuspended { return &InteractResult{ ExecutionID: req.ExecutionID, Status: "waiting", Message: "Execution suspended again: needs more input", }, nil } return nil, fmt.Errorf("failed to resume execution: %w", err) } return &InteractResult{ ExecutionID: req.ExecutionID, Status: "resumed", Message: "Execution resumed and completed successfully", }, nil } // Reply is a semantic shortcut for replying to a specific waiting task. func Reply(ctx *types.Context, memberID string, execID string, taskID string, message string) (*InteractResult, error) { return Interact(ctx, memberID, &InteractRequest{ ExecutionID: execID, TaskID: taskID, Source: types.InteractSourceUI, Message: message, }) } // Confirm is a semantic shortcut for confirming a pending execution. func Confirm(ctx *types.Context, memberID string, execID string, message string) (*InteractResult, error) { return Interact(ctx, memberID, &InteractRequest{ ExecutionID: execID, Source: types.InteractSourceUI, Message: message, Action: "confirm", }) } // InteractStream is the streaming version of Interact. // It streams Host Agent text tokens via streamFn while still returning the final InteractResult. // V1 fallback does not support streaming and returns an error. func InteractStream(ctx *types.Context, memberID string, req *InteractRequest, streamFn standard.StreamCallback) (*InteractResult, error) { if memberID == "" { return nil, fmt.Errorf("member_id is required") } if req == nil { return nil, fmt.Errorf("interact request is required") } mgr, err := getManager() if err != nil || mgr == nil { return nil, fmt.Errorf("streaming requires V2 manager (not available)") } mgrReq := &manager.InteractRequest{ ExecutionID: req.ExecutionID, TaskID: req.TaskID, Source: req.Source, Message: req.Message, Action: req.Action, } resp, err := mgr.HandleInteractStream(ctx, memberID, mgrReq, streamFn) if err != nil { return nil, err } return &InteractResult{ ExecutionID: resp.ExecutionID, Status: resp.Status, Message: resp.Message, ChatID: resp.ChatID, Reply: resp.Reply, WaitForMore: resp.WaitForMore, }, nil } // InteractStreamRaw is the CUI-protocol-aligned streaming version of Interact. // It passes raw message.Message objects to the onMessage callback, preserving all CUI // protocol fields for direct SSE passthrough to the frontend. func InteractStreamRaw(ctx *types.Context, memberID string, req *InteractRequest, onMessage agentcontext.OnMessageFunc) (*InteractResult, error) { if memberID == "" { return nil, fmt.Errorf("member_id is required") } if req == nil { return nil, fmt.Errorf("interact request is required") } mgr, err := getManager() if err != nil || mgr == nil { return nil, fmt.Errorf("raw streaming requires V2 manager (not available)") } mgrReq := &manager.InteractRequest{ ExecutionID: req.ExecutionID, TaskID: req.TaskID, Source: req.Source, Message: req.Message, Action: req.Action, } resp, err := mgr.HandleInteractStreamRaw(ctx, memberID, mgrReq, onMessage) if err != nil { return nil, err } return &InteractResult{ ExecutionID: resp.ExecutionID, Status: resp.Status, Message: resp.Message, ChatID: resp.ChatID, Reply: resp.Reply, WaitForMore: resp.WaitForMore, }, nil } // CancelExecution cancels a waiting/confirming execution via the manager. func CancelExecution(ctx *types.Context, execID string) error { mgr, err := getManager() if err != nil { return fmt.Errorf("cancel not available: %w", err) } return mgr.CancelExecution(ctx, execID) } ================================================ FILE: agent/robot/api/interact_test.go ================================================ package api import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/robot/types" ) // AI1-AI3: Interact routing func TestInteract(t *testing.T) { t.Run("empty member_id returns error", func(t *testing.T) { ctx := types.NewContext(nil, nil) _, err := Interact(ctx, "", &InteractRequest{Message: "test"}) assert.Error(t, err) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("nil request returns error", func(t *testing.T) { ctx := types.NewContext(nil, nil) _, err := Interact(ctx, "member-1", nil) assert.Error(t, err) assert.Contains(t, err.Error(), "interact request is required") }) t.Run("no manager and no execution_id returns error", func(t *testing.T) { ctx := types.NewContext(nil, nil) _, err := Interact(ctx, "member-1", &InteractRequest{Message: "test"}) assert.Error(t, err) assert.Contains(t, err.Error(), "execution_id is required") }) } // AI6: Reply shortcut func TestReply(t *testing.T) { t.Run("empty member_id returns error", func(t *testing.T) { ctx := types.NewContext(nil, nil) _, err := Reply(ctx, "", "exec-1", "task-1", "hello") assert.Error(t, err) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("routes through Interact", func(t *testing.T) { ctx := types.NewContext(nil, nil) // legacyResume accesses the DB model which panics if not initialized. // Verify the routing reaches legacyResume by catching the expected panic. assert.Panics(t, func() { Reply(ctx, "member-1", "exec-1", "task-1", "hello") }, "should reach legacyResume which requires DB model") }) } // AI7: Confirm shortcut func TestConfirm(t *testing.T) { t.Run("empty member_id returns error", func(t *testing.T) { ctx := types.NewContext(nil, nil) _, err := Confirm(ctx, "", "exec-1", "yes") assert.Error(t, err) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("routes through Interact", func(t *testing.T) { ctx := types.NewContext(nil, nil) assert.Panics(t, func() { Confirm(ctx, "member-1", "exec-1", "yes") }, "should reach legacyResume which requires DB model") }) } // AI8-AI9: CancelExecution func TestCancelExecution(t *testing.T) { t.Run("no manager returns error", func(t *testing.T) { ctx := types.NewContext(nil, nil) err := CancelExecution(ctx, "exec-1") assert.Error(t, err) assert.Contains(t, err.Error(), "cancel not available") }) } // AI10-AI12: legacyResume func TestLegacyResume(t *testing.T) { t.Run("non-existent execution panics without DB", func(t *testing.T) { ctx := types.NewContext(nil, nil) assert.Panics(t, func() { legacyResume(ctx, &InteractRequest{ ExecutionID: "nonexistent-exec", Message: "test", }) }, "should panic because DB model is not initialized") }) } // AI1: managerInteract delegates correctly func TestManagerInteract(t *testing.T) { t.Run("converts request fields correctly", func(t *testing.T) { // This would require a running Manager; test the field mapping logic req := &InteractRequest{ ExecutionID: "exec-ai1", TaskID: "task-ai1", Source: types.InteractSourceUI, Message: "do it", Action: "confirm", } // Verify InteractRequest has all expected fields assert.Equal(t, "exec-ai1", req.ExecutionID) assert.Equal(t, "task-ai1", req.TaskID) assert.Equal(t, types.InteractSourceUI, req.Source) assert.Equal(t, "do it", req.Message) assert.Equal(t, "confirm", req.Action) }) } // AI2: Interact with execution_id and no manager falls back to legacy func TestInteractLegacyFallback(t *testing.T) { t.Run("with execution_id delegates to legacyResume", func(t *testing.T) { ctx := types.NewContext(nil, nil) assert.Panics(t, func() { Interact(ctx, "member-1", &InteractRequest{ ExecutionID: "exec-1", Message: "resume this", }) }, "should reach legacyResume which requires DB model") }) } // Test InteractResult field mapping func TestInteractResultFields(t *testing.T) { result := &InteractResult{ ExecutionID: "exec-test", Status: "confirmed", Message: "Done", ChatID: "chat-test", Reply: "I'll do it", WaitForMore: true, } assert.Equal(t, "exec-test", result.ExecutionID) assert.Equal(t, "confirmed", result.Status) assert.Equal(t, "Done", result.Message) assert.Equal(t, "chat-test", result.ChatID) assert.Equal(t, "I'll do it", result.Reply) assert.True(t, result.WaitForMore) // Verify zero-value result empty := &InteractResult{} assert.Empty(t, empty.ExecutionID) assert.Empty(t, empty.Status) assert.False(t, empty.WaitForMore) } // Test that legacyResume returns "waiting" on ErrExecutionSuspended func TestLegacyResumeStatusMapping(t *testing.T) { // ErrExecutionSuspended handling is tested via the suspend E2E tests. // Here we verify the InteractResult field structure. result := &InteractResult{ ExecutionID: "exec-lr", Status: "waiting", Message: "Execution suspended again: needs more input", } assert.Equal(t, "waiting", result.Status) assert.Contains(t, result.Message, "suspended") resultOK := &InteractResult{ ExecutionID: "exec-lr2", Status: "resumed", Message: "Execution resumed and completed successfully", } require.Equal(t, "resumed", resultOK.Status) } ================================================ FILE: agent/robot/api/lifecycle.go ================================================ package api import ( "context" "fmt" "sync" robotevents "github.com/yaoapp/yao/agent/robot/events" "github.com/yaoapp/yao/agent/robot/events/integrations" dtadapter "github.com/yaoapp/yao/agent/robot/events/integrations/dingtalk" dcadapter "github.com/yaoapp/yao/agent/robot/events/integrations/discord" fsadapter "github.com/yaoapp/yao/agent/robot/events/integrations/feishu" "github.com/yaoapp/yao/agent/robot/events/integrations/telegram" "github.com/yaoapp/yao/agent/robot/logger" "github.com/yaoapp/yao/agent/robot/manager" "github.com/yaoapp/yao/agent/robot/types" ) var log = logger.New("robot") func init() { robotevents.RegisterTriggerFunc(func(ctx *types.Context, memberID string, triggerType types.TriggerType, data interface{}) (string, bool, error) { result, err := TriggerManual(ctx, memberID, triggerType, data) if err != nil { return "", false, err } return result.ExecutionID, result.Accepted, nil }) } // ==================== Lifecycle API ==================== // These functions manage the robot agent system lifecycle var ( globalManager *manager.Manager globalDispatcher *integrations.Dispatcher managerMu sync.RWMutex ) // Start starts the robot agent system // This initializes and starts the manager which handles: // - Robot cache loading // - Worker pool // - Clock ticker for scheduled triggers func Start() error { managerMu.Lock() defer managerMu.Unlock() if globalManager != nil && globalManager.IsStarted() { return fmt.Errorf("robot agent system already started") } // Create new manager if not exists if globalManager == nil { globalManager = manager.New() } if err := globalManager.Start(); err != nil { return err } // Start integration dispatcher (Telegram polling, webhook subscriptions, etc.) adapters := map[string]integrations.Adapter{ "telegram": telegram.NewAdapter(), "feishu": fsadapter.NewAdapter(), "dingtalk": dtadapter.NewAdapter(), "discord": dcadapter.NewAdapter(), } globalDispatcher = integrations.NewDispatcher(globalManager.Cache(), adapters) if err := globalDispatcher.Start(context.Background()); err != nil { log.Error("failed to start integration dispatcher: %v", err) } return nil } // StartWithConfig starts the robot agent system with custom configuration func StartWithConfig(config *manager.Config) error { managerMu.Lock() defer managerMu.Unlock() if globalManager != nil && globalManager.IsStarted() { return fmt.Errorf("robot agent system already started") } globalManager = manager.NewWithConfig(config) return globalManager.Start() } // Stop stops the robot agent system gracefully // This will: // - Stop the clock ticker // - Stop cache auto-refresh // - Wait for running jobs to complete // - Stop the worker pool func Stop() error { managerMu.Lock() defer managerMu.Unlock() if globalManager == nil { return nil } if globalDispatcher != nil { globalDispatcher.Stop() globalDispatcher = nil } err := globalManager.Stop() if err != nil { return err } globalManager = nil return nil } // IsRunning returns true if the robot agent system is running func IsRunning() bool { managerMu.RLock() defer managerMu.RUnlock() return globalManager != nil && globalManager.IsStarted() } // getManager returns the global manager instance // Returns error if manager is not started func getManager() (*manager.Manager, error) { managerMu.RLock() defer managerMu.RUnlock() if globalManager == nil || !globalManager.IsStarted() { return nil, fmt.Errorf("robot agent system not started") } return globalManager, nil } // SetManager sets the global manager instance (for testing) func SetManager(m *manager.Manager) { managerMu.Lock() defer managerMu.Unlock() globalManager = m } ================================================ FILE: agent/robot/api/lifecycle_test.go ================================================ package api_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/robot/api" "github.com/yaoapp/yao/agent/testutils" ) // TestLifecycle tests the Start/Stop lifecycle APIs func TestLifecycle(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) t.Run("start and stop cycle", func(t *testing.T) { // Initially not running assert.False(t, api.IsRunning()) // Start err := api.Start() require.NoError(t, err) assert.True(t, api.IsRunning()) // Start again should fail err = api.Start() assert.Error(t, err) assert.Contains(t, err.Error(), "already started") // Stop err = api.Stop() require.NoError(t, err) assert.False(t, api.IsRunning()) // Stop again should be no-op (not error) err = api.Stop() assert.NoError(t, err) }) t.Run("can restart after stop", func(t *testing.T) { // Start err := api.Start() require.NoError(t, err) assert.True(t, api.IsRunning()) // Stop err = api.Stop() require.NoError(t, err) assert.False(t, api.IsRunning()) // Start again should work err = api.Start() require.NoError(t, err) assert.True(t, api.IsRunning()) // Cleanup api.Stop() }) } ================================================ FILE: agent/robot/api/results.go ================================================ package api import ( "context" "fmt" "time" "github.com/yaoapp/yao/agent/robot/store" "github.com/yaoapp/yao/agent/robot/types" ) // ==================== Result Types ==================== // ResultQuery - query parameters for listing results type ResultQuery struct { TriggerType types.TriggerType `json:"trigger_type,omitempty"` // clock | human | event Keyword string `json:"keyword,omitempty"` // Search in name/summary Page int `json:"page,omitempty"` PageSize int `json:"pagesize,omitempty"` } // ResultItem - result list item (subset of execution) type ResultItem struct { ID string `json:"id"` MemberID string `json:"member_id"` TriggerType types.TriggerType `json:"trigger_type"` Status types.ExecStatus `json:"status"` Name string `json:"name"` Summary string `json:"summary"` StartTime time.Time `json:"start_time"` EndTime *time.Time `json:"end_time,omitempty"` HasAttachments bool `json:"has_attachments"` } // ResultDetail - full result with delivery content type ResultDetail struct { ID string `json:"id"` MemberID string `json:"member_id"` TriggerType types.TriggerType `json:"trigger_type"` Status types.ExecStatus `json:"status"` Name string `json:"name"` Delivery *types.DeliveryResult `json:"delivery,omitempty"` StartTime time.Time `json:"start_time"` EndTime *time.Time `json:"end_time,omitempty"` } // ResultListResponse - paginated response type ResultListResponse struct { Data []*ResultItem `json:"data"` Total int `json:"total"` Page int `json:"page"` PageSize int `json:"pagesize"` } // ==================== Result API Functions ==================== // ListResults returns completed executions with delivery content for a robot func ListResults(ctx *types.Context, memberID string, query *ResultQuery) (*ResultListResponse, error) { if memberID == "" { return nil, fmt.Errorf("member_id is required") } if query == nil { query = &ResultQuery{} } query.applyDefaults() opts := &store.ResultListOptions{ MemberID: memberID, Page: query.Page, PageSize: query.PageSize, } if query.TriggerType != "" { opts.TriggerType = query.TriggerType } if query.Keyword != "" { opts.Keyword = query.Keyword } // Query from store result, err := getExecutionStore().ListResults(context.Background(), opts) if err != nil { return nil, fmt.Errorf("failed to list results: %w", err) } // Transform to ResultItem slice items := make([]*ResultItem, 0, len(result.Data)) for _, record := range result.Data { item := recordToResultItem(record) if item != nil { items = append(items, item) } } return &ResultListResponse{ Data: items, Total: result.Total, Page: result.Page, PageSize: result.PageSize, }, nil } // GetResult returns a single result by execution ID func GetResult(ctx *types.Context, execID string) (*ResultDetail, error) { if execID == "" { return nil, fmt.Errorf("execution_id is required") } // Get from store record, err := getExecutionStore().Get(context.Background(), execID) if err != nil { return nil, fmt.Errorf("failed to get result: %w", err) } if record == nil { return nil, fmt.Errorf("result not found: %s", execID) } // Verify it has delivery content if record.Delivery == nil || record.Delivery.Content == nil { return nil, fmt.Errorf("result not found: %s (no delivery content)", execID) } return recordToResultDetail(record), nil } // ==================== Helper Functions ==================== // applyDefaults applies default values to ResultQuery func (q *ResultQuery) applyDefaults() { if q.Page <= 0 { q.Page = 1 } if q.PageSize <= 0 { q.PageSize = 20 } if q.PageSize > 100 { q.PageSize = 100 } } // recordToResultItem converts ExecutionRecord to ResultItem func recordToResultItem(record *store.ExecutionRecord) *ResultItem { if record == nil { return nil } item := &ResultItem{ ID: record.ExecutionID, MemberID: record.MemberID, TriggerType: record.TriggerType, Status: record.Status, Name: record.Name, } // Set times if record.StartTime != nil { item.StartTime = *record.StartTime } item.EndTime = record.EndTime // Extract summary and attachments from delivery if record.Delivery != nil && record.Delivery.Content != nil { item.Summary = record.Delivery.Content.Summary item.HasAttachments = len(record.Delivery.Content.Attachments) > 0 } return item } // recordToResultDetail converts ExecutionRecord to ResultDetail func recordToResultDetail(record *store.ExecutionRecord) *ResultDetail { if record == nil { return nil } detail := &ResultDetail{ ID: record.ExecutionID, MemberID: record.MemberID, TriggerType: record.TriggerType, Status: record.Status, Name: record.Name, Delivery: record.Delivery, } // Set times if record.StartTime != nil { detail.StartTime = *record.StartTime } detail.EndTime = record.EndTime return detail } ================================================ FILE: agent/robot/api/robot.go ================================================ package api import ( "context" "fmt" "time" gonanoid "github.com/matoous/go-nanoid/v2" "github.com/yaoapp/gou/model" "github.com/yaoapp/kun/maps" robotevents "github.com/yaoapp/yao/agent/robot/events" "github.com/yaoapp/yao/agent/robot/store" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/event" ) // ==================== Robot Query API ==================== // These functions query robot information // memberModel is the model name for member table const memberModel = "__yao.member" // robotStore is the shared robot store instance var robotStore = store.NewRobotStore() // executionStore is the shared execution store instance var executionStore = store.NewExecutionStore() // GetRobot returns a robot by member ID // Returns the robot from cache if available, otherwise loads from database func GetRobot(ctx *types.Context, memberID string) (*types.Robot, error) { if memberID == "" { return nil, fmt.Errorf("member_id is required") } mgr, err := getManager() if err != nil { // Manager not started, try to load directly from database return loadRobotFromDB(memberID) } // Try cache first robot := mgr.Cache().Get(memberID) if robot != nil { return robot, nil } // Not in cache, try to load from database robot, err = mgr.Cache().LoadByID(ctx, memberID) if err != nil { return nil, err } return robot, nil } // ListRobots returns robots with pagination and filtering func ListRobots(ctx *types.Context, query *ListQuery) (*ListResult, error) { if query == nil { query = &ListQuery{} } query.applyDefaults() mgr, err := getManager() if err != nil { // Manager not started, load directly from database return listRobotsFromDB(query) } // If only teamID specified AND explicitly filtering for autonomous_mode=true, use cache // Cache only contains autonomous_mode=true robots // When autonomous_mode is not specified or false, must query database to include all robots if query.TeamID != "" && query.Status == "" && query.Keywords == "" && query.ClockMode == "" && query.AutonomousMode != nil && *query.AutonomousMode == true { robots := mgr.Cache().List(query.TeamID) return paginateRobots(robots, query), nil } // For complex queries, load from database return listRobotsFromDB(query) } // GetRobotStatus returns the runtime status of a robot func GetRobotStatus(ctx *types.Context, memberID string) (*RobotState, error) { if memberID == "" { return nil, fmt.Errorf("member_id is required") } robot, err := GetRobot(ctx, memberID) if err != nil { return nil, err } // Get permission fields from store (for access control) record, _ := robotStore.Get(context.Background(), memberID) state := &RobotState{ MemberID: robot.MemberID, TeamID: robot.TeamID, DisplayName: robot.DisplayName, Bio: robot.Bio, Status: robot.Status, MaxRunning: 2, // default } // Add permission fields if available if record != nil { state.YaoCreatedBy = record.YaoCreatedBy state.YaoTeamID = record.YaoTeamID } if robot.Config != nil && robot.Config.Quota != nil { state.MaxRunning = robot.Config.Quota.GetMax() } // Get running execution IDs from ExecutionStore (more reliable than in-memory) // This ensures we get accurate status even when robot is loaded from database runningResult, err := executionStore.List(context.Background(), &store.ListOptions{ MemberID: memberID, Status: types.ExecRunning, PageSize: 100, }) if err == nil && runningResult != nil && len(runningResult.Data) > 0 { state.Running = len(runningResult.Data) state.RunningIDs = make([]string, 0, len(runningResult.Data)) for _, exec := range runningResult.Data { state.RunningIDs = append(state.RunningIDs, exec.ExecutionID) } // Update status based on running count state.Status = types.RobotWorking } else { // No running executions from store, check in-memory executions := robot.GetExecutions() state.Running = len(executions) state.RunningIDs = make([]string, 0, len(executions)) for _, exec := range executions { state.RunningIDs = append(state.RunningIDs, exec.ID) } // If there are running executions in memory, update status if state.Running > 0 { state.Status = types.RobotWorking } } // Set last run time if !robot.LastRun.IsZero() { state.LastRun = &robot.LastRun } // Set next run time if !robot.NextRun.IsZero() { state.NextRun = &robot.NextRun } return state, nil } // ==================== Helper Functions ==================== // loadRobotFromDB loads a robot directly from database func loadRobotFromDB(memberID string) (*types.Robot, error) { m := model.Select(memberModel) if m == nil { return nil, fmt.Errorf("model %s not found", memberModel) } records, err := m.Get(model.QueryParam{ Select: []interface{}{ "id", "member_id", "team_id", "display_name", "bio", "system_prompt", "robot_status", "autonomous_mode", "robot_config", "robot_email", "agents", "mcp_servers", "manager_id", "language_model", }, Wheres: []model.QueryWhere{ {Column: "member_id", Value: memberID}, {Column: "member_type", Value: "robot"}, }, Limit: 1, }) if err != nil { return nil, fmt.Errorf("failed to load robot: %w", err) } if len(records) == 0 { return nil, types.ErrRobotNotFound } return types.NewRobotFromMap(map[string]interface{}(records[0])) } // listRobotsFromDB loads robots from database with filtering func listRobotsFromDB(query *ListQuery) (*ListResult, error) { m := model.Select(memberModel) if m == nil { return nil, fmt.Errorf("model %s not found", memberModel) } // Build where conditions wheres := []model.QueryWhere{ {Column: "member_type", Value: "robot"}, {Column: "status", Value: "active"}, } if query.TeamID != "" { wheres = append(wheres, model.QueryWhere{Column: "team_id", Value: query.TeamID}) } if query.Status != "" { wheres = append(wheres, model.QueryWhere{Column: "robot_status", Value: string(query.Status)}) } if query.Keywords != "" { wheres = append(wheres, model.QueryWhere{ Column: "display_name", OP: "like", Value: "%" + query.Keywords + "%", }) } if query.AutonomousMode != nil { wheres = append(wheres, model.QueryWhere{Column: "autonomous_mode", Value: *query.AutonomousMode}) } // Build order orders := []model.QueryOrder{} if query.Order != "" { orders = append(orders, model.QueryOrder{Column: query.Order}) } else { orders = append(orders, model.QueryOrder{Column: "created_at", Option: "desc"}) } // Execute paginated query result, err := m.Paginate(model.QueryParam{ Select: []interface{}{ "id", "member_id", "team_id", "display_name", "bio", "system_prompt", "robot_status", "autonomous_mode", "robot_config", "robot_email", "agents", "mcp_servers", "language_model", }, Wheres: wheres, Orders: orders, }, query.Page, query.PageSize) if err != nil { return nil, fmt.Errorf("failed to list robots: %w", err) } // Parse result listResult := &ListResult{ Data: []*types.Robot{}, Page: query.Page, PageSize: query.PageSize, } // Get total count if total, ok := result.Get("total").(int); ok { listResult.Total = total } // Parse robot records - handle both []maps.MapStr and []map[string]interface{} data := result.Get("data") switch records := data.(type) { case []maps.MapStr: for _, record := range records { robot, err := types.NewRobotFromMap(map[string]interface{}(record)) if err != nil { continue // skip invalid records } listResult.Data = append(listResult.Data, robot) } case []map[string]interface{}: for _, record := range records { robot, err := types.NewRobotFromMap(record) if err != nil { continue // skip invalid records } listResult.Data = append(listResult.Data, robot) } } return listResult, nil } // paginateRobots applies pagination to a slice of robots func paginateRobots(robots []*types.Robot, query *ListQuery) *ListResult { total := len(robots) // Calculate offset offset := (query.Page - 1) * query.PageSize if offset >= total { return &ListResult{ Data: []*types.Robot{}, Total: total, Page: query.Page, PageSize: query.PageSize, } } // Calculate end index end := offset + query.PageSize if end > total { end = total } return &ListResult{ Data: robots[offset:end], Total: total, Page: query.Page, PageSize: query.PageSize, } } // ==================== Robot CRUD API ==================== // These functions create, update, and delete robots // They call store layer for persistence and manage cache // Request/Response types are defined in types.go // CreateRobot creates a new robot member // Calls store.RobotStore.Save() and refreshes cache // If member_id is not provided, it will be auto-generated func CreateRobot(ctx *types.Context, req *CreateRobotRequest) (*RobotResponse, error) { // Validate required fields if req.TeamID == "" { return nil, fmt.Errorf("team_id is required") } if req.DisplayName == "" { return nil, fmt.Errorf("display_name is required") } // Generate member_id if not provided if req.MemberID == "" { generatedID, err := generateMemberID(context.Background()) if err != nil { return nil, fmt.Errorf("failed to generate member_id: %w", err) } req.MemberID = generatedID } // Check if robot already exists existing, err := robotStore.Get(context.Background(), req.MemberID) if err != nil { return nil, fmt.Errorf("failed to check existing robot: %w", err) } if existing != nil { return nil, fmt.Errorf("robot with member_id '%s' already exists", req.MemberID) } // Determine autonomous_mode value autonomousMode := false if req.AutonomousMode != nil { autonomousMode = *req.AutonomousMode } // Determine status values status := "active" if req.Status != "" { status = req.Status } robotStatus := "idle" if req.RobotStatus != "" { robotStatus = req.RobotStatus } // Create store record with all fields now := time.Now() record := &store.RobotRecord{ // Required MemberID: req.MemberID, TeamID: req.TeamID, MemberType: "robot", Status: status, RobotStatus: robotStatus, AutonomousMode: autonomousMode, // Profile DisplayName: req.DisplayName, Bio: req.Bio, Avatar: req.Avatar, // Identity & Role SystemPrompt: req.SystemPrompt, RoleID: req.RoleID, ManagerID: req.ManagerID, // Communication RobotEmail: req.RobotEmail, AuthorizedSenders: req.AuthorizedSenders, EmailFilterRules: req.EmailFilterRules, // Capabilities RobotConfig: req.RobotConfig, Agents: req.Agents, MCPServers: req.MCPServers, LanguageModel: req.LanguageModel, // Limits CostLimit: req.CostLimit, // Timestamps JoinedAt: &now, } // Apply Yao permission fields if provided if req.AuthScope != nil { record.YaoCreatedBy = req.AuthScope.CreatedBy record.YaoTeamID = req.AuthScope.TeamID record.YaoTenantID = req.AuthScope.TenantID // Set invited_by from CreatedBy if not explicitly set if record.InvitedBy == "" && req.AuthScope.CreatedBy != "" { record.InvitedBy = req.AuthScope.CreatedBy } } // Save to database err = robotStore.Save(context.Background(), record) if err != nil { return nil, fmt.Errorf("failed to create robot: %w", err) } // Refresh cache if manager is running // Use Refresh() which handles autonomous_mode correctly: // - If autonomous_mode=true: adds to cache for scheduling // - If autonomous_mode=false: does not add to cache mgr, err := getManager() if err == nil && mgr != nil { _ = mgr.Cache().Refresh(ctx, req.MemberID) } // Notify integrations of new robot config event.Push(context.Background(), robotevents.RobotConfigCreated, robotevents.RobotConfigPayload{ MemberID: req.MemberID, TeamID: req.TeamID, }) // Return the created robot as response return GetRobotResponse(ctx, req.MemberID) } // UpdateRobot updates an existing robot member // Calls store.RobotStore.Save() and refreshes cache func UpdateRobot(ctx *types.Context, memberID string, req *UpdateRobotRequest) (*RobotResponse, error) { if memberID == "" { return nil, fmt.Errorf("member_id is required") } // Get existing record existing, err := robotStore.Get(context.Background(), memberID) if err != nil { return nil, fmt.Errorf("failed to get robot: %w", err) } if existing == nil { return nil, types.ErrRobotNotFound } // Apply updates - only non-nil fields are updated // Profile if req.DisplayName != nil { existing.DisplayName = *req.DisplayName } if req.Bio != nil { existing.Bio = *req.Bio } if req.Avatar != nil { existing.Avatar = *req.Avatar } // Identity & Role if req.SystemPrompt != nil { existing.SystemPrompt = *req.SystemPrompt } if req.RoleID != nil { existing.RoleID = *req.RoleID } if req.ManagerID != nil { existing.ManagerID = *req.ManagerID } // Status if req.Status != nil { existing.Status = *req.Status } if req.RobotStatus != nil { existing.RobotStatus = *req.RobotStatus } if req.AutonomousMode != nil { existing.AutonomousMode = *req.AutonomousMode } // Communication if req.RobotEmail != nil { existing.RobotEmail = *req.RobotEmail } if req.AuthorizedSenders != nil { existing.AuthorizedSenders = req.AuthorizedSenders } if req.EmailFilterRules != nil { existing.EmailFilterRules = req.EmailFilterRules } // Capabilities if req.RobotConfig != nil { existing.RobotConfig = req.RobotConfig } if req.Agents != nil { existing.Agents = req.Agents } if req.MCPServers != nil { existing.MCPServers = req.MCPServers } if req.LanguageModel != nil { existing.LanguageModel = *req.LanguageModel } // Limits if req.CostLimit != nil { existing.CostLimit = *req.CostLimit } // Apply Yao permission fields if provided (update scope) if req.AuthScope != nil { existing.YaoUpdatedBy = req.AuthScope.UpdatedBy // Team and Tenant are typically set on create, not update // But allow override if explicitly provided if req.AuthScope.TeamID != "" { existing.YaoTeamID = req.AuthScope.TeamID } if req.AuthScope.TenantID != "" { existing.YaoTenantID = req.AuthScope.TenantID } } // Save to database err = robotStore.Save(context.Background(), existing) if err != nil { return nil, fmt.Errorf("failed to update robot: %w", err) } // Refresh cache if manager is running // Use Refresh() which handles autonomous_mode correctly: // - If autonomous_mode=true: adds to cache for scheduling // - If autonomous_mode=false: removes from cache mgr, err := getManager() if err == nil && mgr != nil { _ = mgr.Cache().Refresh(ctx, memberID) // Ignore error, database is already saved } // Notify integrations of updated robot config event.Push(context.Background(), robotevents.RobotConfigUpdated, robotevents.RobotConfigPayload{ MemberID: memberID, TeamID: existing.TeamID, }) // Return the updated robot as response return GetRobotResponse(ctx, memberID) } // RemoveRobot deletes a robot member // Calls store.RobotStore.Delete() and invalidates cache func RemoveRobot(ctx *types.Context, memberID string) error { if memberID == "" { return fmt.Errorf("member_id is required") } // Check if robot exists existing, err := robotStore.Get(context.Background(), memberID) if err != nil { return fmt.Errorf("failed to get robot: %w", err) } if existing == nil { return types.ErrRobotNotFound } // Check if robot has running executions mgr, err := getManager() if err == nil && mgr != nil { robot := mgr.Cache().Get(memberID) if robot != nil && robot.RunningCount() > 0 { return fmt.Errorf("cannot delete robot with running executions") } } // Delete from database err = robotStore.Delete(context.Background(), memberID) if err != nil { return fmt.Errorf("failed to delete robot: %w", err) } // Invalidate cache if manager is running if mgr != nil { mgr.Cache().Remove(memberID) } // Notify integrations of deleted robot config event.Push(context.Background(), robotevents.RobotConfigDeleted, robotevents.RobotConfigPayload{ MemberID: memberID, TeamID: existing.TeamID, }) return nil } // GetRobotResponse retrieves a robot and converts to API response format func GetRobotResponse(ctx *types.Context, memberID string) (*RobotResponse, error) { record, err := robotStore.Get(context.Background(), memberID) if err != nil { return nil, fmt.Errorf("failed to get robot: %w", err) } if record == nil { return nil, types.ErrRobotNotFound } return recordToResponse(record), nil } // recordToResponse converts a store.RobotRecord to API RobotResponse func recordToResponse(record *store.RobotRecord) *RobotResponse { return &RobotResponse{ ID: record.ID, MemberID: record.MemberID, TeamID: record.TeamID, Status: record.Status, RobotStatus: record.RobotStatus, AutonomousMode: record.AutonomousMode, DisplayName: record.DisplayName, Bio: record.Bio, Avatar: record.Avatar, SystemPrompt: record.SystemPrompt, RoleID: record.RoleID, ManagerID: record.ManagerID, RobotEmail: record.RobotEmail, AuthorizedSenders: record.AuthorizedSenders, EmailFilterRules: record.EmailFilterRules, RobotConfig: record.RobotConfig, Agents: record.Agents, MCPServers: record.MCPServers, LanguageModel: record.LanguageModel, CostLimit: record.CostLimit, InvitedBy: record.InvitedBy, JoinedAt: record.JoinedAt, YaoCreatedBy: record.YaoCreatedBy, YaoTeamID: record.YaoTeamID, CreatedAt: record.CreatedAt, UpdatedAt: record.UpdatedAt, } } // ==================== Member ID Generation ==================== // generateMemberID generates a unique member_id with collision detection // Uses 12-digit numeric ID to match existing pattern in openapi/oauth/providers/user func generateMemberID(ctx context.Context) (string, error) { const maxRetries = 10 for i := 0; i < maxRetries; i++ { // Generate 12-digit numeric ID id, err := gonanoid.Generate("0123456789", 12) if err != nil { return "", fmt.Errorf("failed to generate member_id: %w", err) } // Check if ID already exists exists, err := memberIDExists(ctx, id) if err != nil { return "", fmt.Errorf("failed to check member_id existence: %w", err) } if !exists { return id, nil } // ID exists, retry } return "", fmt.Errorf("failed to generate unique member_id after %d retries", maxRetries) } // memberIDExists checks if a member_id already exists in the database func memberIDExists(ctx context.Context, memberID string) (bool, error) { m := model.Select(memberModel) if m == nil { return false, fmt.Errorf("model %s not found", memberModel) } members, err := m.Get(model.QueryParam{ Select: []interface{}{"id"}, Wheres: []model.QueryWhere{ {Column: "member_id", Value: memberID}, }, Limit: 1, }) if err != nil { return false, err } return len(members) > 0, nil } ================================================ FILE: agent/robot/api/robot_test.go ================================================ package api_test import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/robot/api" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // TestGetRobotValidation tests parameter validation for GetRobot func TestGetRobotValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) t.Run("returns error for empty member_id", func(t *testing.T) { ctx := types.NewContext(context.Background(), nil) robot, err := api.GetRobot(ctx, "") assert.Error(t, err) assert.Nil(t, robot) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("returns error for non-existent robot", func(t *testing.T) { ctx := types.NewContext(context.Background(), nil) robot, err := api.GetRobot(ctx, "non_existent_member_id_xyz") assert.Error(t, err) assert.Nil(t, robot) }) } // TestListRobotsValidation tests parameter validation for ListRobots func TestListRobotsValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) t.Run("applies default pagination when query is nil", func(t *testing.T) { result, err := api.ListRobots(ctx, nil) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, 1, result.Page) assert.Equal(t, 20, result.PageSize) }) t.Run("applies default pagination when values are zero", func(t *testing.T) { result, err := api.ListRobots(ctx, &api.ListQuery{ Page: 0, PageSize: 0, }) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, 1, result.Page) assert.Equal(t, 20, result.PageSize) }) t.Run("caps pagesize at 100", func(t *testing.T) { result, err := api.ListRobots(ctx, &api.ListQuery{ Page: 1, PageSize: 500, }) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, 100, result.PageSize) }) } // TestGetRobotStatusValidation tests parameter validation for GetRobotStatus func TestGetRobotStatusValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) t.Run("returns error for empty member_id", func(t *testing.T) { ctx := types.NewContext(context.Background(), nil) status, err := api.GetRobotStatus(ctx, "") assert.Error(t, err) assert.Nil(t, status) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("returns error for non-existent robot", func(t *testing.T) { ctx := types.NewContext(context.Background(), nil) status, err := api.GetRobotStatus(ctx, "non_existent_member_id_xyz") assert.Error(t, err) assert.Nil(t, status) }) } // ==================== Robot CRUD API Tests ==================== // TestCreateRobotValidation tests parameter validation for CreateRobot func TestCreateRobotValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) t.Run("auto_generates_member_id_when_empty", func(t *testing.T) { req := &api.CreateRobotRequest{ MemberID: "", TeamID: "team_001", DisplayName: "Test Robot Auto ID", } result, err := api.CreateRobot(ctx, req) require.NoError(t, err) require.NotNil(t, result) // Verify member_id was auto-generated (12-digit numeric) assert.NotEmpty(t, result.MemberID) assert.Len(t, result.MemberID, 12, "Auto-generated member_id should be 12 digits") // Cleanup _ = api.RemoveRobot(ctx, result.MemberID) }) t.Run("returns_error_for_empty_team_id", func(t *testing.T) { req := &api.CreateRobotRequest{ MemberID: "robot_test_001", TeamID: "", DisplayName: "Test Robot", } result, err := api.CreateRobot(ctx, req) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "team_id is required") }) t.Run("returns_error_for_empty_display_name", func(t *testing.T) { req := &api.CreateRobotRequest{ MemberID: "robot_test_001", TeamID: "team_001", DisplayName: "", } result, err := api.CreateRobot(ctx, req) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "display_name is required") }) } // TestCreateRobot tests the CreateRobot API function func TestCreateRobot(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Cleanup before and after cleanupAPITestRobots(t) defer cleanupAPITestRobots(t) ctx := types.NewContext(context.Background(), nil) t.Run("creates_robot_with_required_fields", func(t *testing.T) { req := &api.CreateRobotRequest{ MemberID: "api_robot_create_001", TeamID: "api_team_001", DisplayName: "API Test Robot", } result, err := api.CreateRobot(ctx, req) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, "api_robot_create_001", result.MemberID) assert.Equal(t, "api_team_001", result.TeamID) assert.Equal(t, "API Test Robot", result.DisplayName) assert.Equal(t, "active", result.Status) assert.Equal(t, "idle", result.RobotStatus) }) t.Run("creates_robot_with_all_fields", func(t *testing.T) { autonomousMode := true req := &api.CreateRobotRequest{ MemberID: "api_robot_create_002", TeamID: "api_team_002", DisplayName: "Full Robot", Bio: "A fully configured robot", SystemPrompt: "You are a helpful assistant", Avatar: "https://example.com/avatar.png", RoleID: "admin", ManagerID: "user_001", AutonomousMode: &autonomousMode, RobotEmail: "fullrobot@test.com", LanguageModel: "gpt-4", CostLimit: 100.0, RobotConfig: map[string]interface{}{ "clock_mode": "on", "max_concurrent": 3, }, } result, err := api.CreateRobot(ctx, req) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, "api_robot_create_002", result.MemberID) assert.Equal(t, "Full Robot", result.DisplayName) assert.Equal(t, "A fully configured robot", result.Bio) assert.Equal(t, "You are a helpful assistant", result.SystemPrompt) assert.Equal(t, "admin", result.RoleID) assert.True(t, result.AutonomousMode) assert.Equal(t, "fullrobot@test.com", result.RobotEmail) assert.Equal(t, "gpt-4", result.LanguageModel) assert.Equal(t, 100.0, result.CostLimit) }) t.Run("creates_robot_with_auth_scope", func(t *testing.T) { req := &api.CreateRobotRequest{ MemberID: "api_robot_create_003", TeamID: "api_team_003", DisplayName: "Robot with Auth", AuthScope: &api.AuthScope{ CreatedBy: "user_123", TeamID: "perm_team_001", TenantID: "tenant_001", }, } result, err := api.CreateRobot(ctx, req) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, "api_robot_create_003", result.MemberID) // InvitedBy should be set from AuthScope.CreatedBy assert.Equal(t, "user_123", result.InvitedBy) }) t.Run("returns_error_for_duplicate_member_id", func(t *testing.T) { req := &api.CreateRobotRequest{ MemberID: "api_robot_create_001", // Already created above TeamID: "api_team_001", DisplayName: "Duplicate Robot", } result, err := api.CreateRobot(ctx, req) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "already exists") }) } // TestUpdateRobot tests the UpdateRobot API function func TestUpdateRobot(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupAPITestRobots(t) defer cleanupAPITestRobots(t) ctx := types.NewContext(context.Background(), nil) // Create a robot to update createReq := &api.CreateRobotRequest{ MemberID: "api_robot_update_001", TeamID: "api_team_update", DisplayName: "Original Name", Bio: "Original bio", } _, err := api.CreateRobot(ctx, createReq) require.NoError(t, err) t.Run("returns_error_for_empty_member_id", func(t *testing.T) { req := &api.UpdateRobotRequest{} result, err := api.UpdateRobot(ctx, "", req) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("returns_error_for_non_existent_robot", func(t *testing.T) { newName := "New Name" req := &api.UpdateRobotRequest{ DisplayName: &newName, } result, err := api.UpdateRobot(ctx, "non_existent_robot", req) assert.Error(t, err) assert.Nil(t, result) }) t.Run("updates_display_name", func(t *testing.T) { newName := "Updated Name" req := &api.UpdateRobotRequest{ DisplayName: &newName, } result, err := api.UpdateRobot(ctx, "api_robot_update_001", req) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, "Updated Name", result.DisplayName) // Bio should be unchanged assert.Equal(t, "Original bio", result.Bio) }) t.Run("updates_multiple_fields", func(t *testing.T) { newBio := "New bio description" newPrompt := "Updated system prompt" autonomousMode := true req := &api.UpdateRobotRequest{ Bio: &newBio, SystemPrompt: &newPrompt, AutonomousMode: &autonomousMode, } result, err := api.UpdateRobot(ctx, "api_robot_update_001", req) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, "New bio description", result.Bio) assert.Equal(t, "Updated system prompt", result.SystemPrompt) assert.True(t, result.AutonomousMode) }) t.Run("updates_robot_status", func(t *testing.T) { newStatus := "working" req := &api.UpdateRobotRequest{ RobotStatus: &newStatus, } result, err := api.UpdateRobot(ctx, "api_robot_update_001", req) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, "working", result.RobotStatus) }) t.Run("updates_config", func(t *testing.T) { newConfig := map[string]interface{}{ "clock_mode": "off", "max_concurrent": 5, } req := &api.UpdateRobotRequest{ RobotConfig: newConfig, } result, err := api.UpdateRobot(ctx, "api_robot_update_001", req) require.NoError(t, err) require.NotNil(t, result) assert.NotNil(t, result.RobotConfig) }) } // TestRemoveRobot tests the RemoveRobot API function func TestRemoveRobot(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupAPITestRobots(t) defer cleanupAPITestRobots(t) ctx := types.NewContext(context.Background(), nil) t.Run("returns_error_for_empty_member_id", func(t *testing.T) { err := api.RemoveRobot(ctx, "") assert.Error(t, err) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("returns_error_for_non_existent_robot", func(t *testing.T) { err := api.RemoveRobot(ctx, "non_existent_robot") assert.Error(t, err) }) t.Run("removes_existing_robot", func(t *testing.T) { // Create a robot createReq := &api.CreateRobotRequest{ MemberID: "api_robot_remove_001", TeamID: "api_team_remove", DisplayName: "Robot to Remove", } _, err := api.CreateRobot(ctx, createReq) require.NoError(t, err) // Verify it exists robot, err := api.GetRobot(ctx, "api_robot_remove_001") require.NoError(t, err) require.NotNil(t, robot) // Remove it err = api.RemoveRobot(ctx, "api_robot_remove_001") require.NoError(t, err) // Verify it's gone robot, err = api.GetRobot(ctx, "api_robot_remove_001") assert.Error(t, err) // Should return error for non-existent }) } // TestGetRobotResponse tests the GetRobotResponse API function func TestGetRobotResponse(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupAPITestRobots(t) defer cleanupAPITestRobots(t) ctx := types.NewContext(context.Background(), nil) // Create a robot autonomousMode := true createReq := &api.CreateRobotRequest{ MemberID: "api_robot_response_001", TeamID: "api_team_response", DisplayName: "Response Test Robot", Bio: "Test bio for response", SystemPrompt: "Test prompt", AutonomousMode: &autonomousMode, RobotEmail: "response@test.com", CostLimit: 50.0, } _, err := api.CreateRobot(ctx, createReq) require.NoError(t, err) t.Run("returns_robot_response_format", func(t *testing.T) { result, err := api.GetRobotResponse(ctx, "api_robot_response_001") require.NoError(t, err) require.NotNil(t, result) // Verify all fields are present in response assert.Equal(t, "api_robot_response_001", result.MemberID) assert.Equal(t, "api_team_response", result.TeamID) assert.Equal(t, "Response Test Robot", result.DisplayName) assert.Equal(t, "Test bio for response", result.Bio) assert.Equal(t, "Test prompt", result.SystemPrompt) assert.True(t, result.AutonomousMode) assert.Equal(t, "response@test.com", result.RobotEmail) assert.Equal(t, 50.0, result.CostLimit) assert.Equal(t, "active", result.Status) assert.Equal(t, "idle", result.RobotStatus) }) t.Run("returns_error_for_non_existent", func(t *testing.T) { result, err := api.GetRobotResponse(ctx, "non_existent") assert.Error(t, err) assert.Nil(t, result) }) } // Note: cleanupAPITestRobots is defined in api_test.go (shared helper) ================================================ FILE: agent/robot/api/trigger.go ================================================ package api import ( "fmt" "github.com/yaoapp/yao/agent/robot/types" ) // ==================== Trigger API ==================== // These functions handle robot execution triggers // Trigger starts a robot execution with the specified trigger type and request // This is the main entry point for triggering robot execution func Trigger(ctx *types.Context, memberID string, req *TriggerRequest) (*TriggerResult, error) { if memberID == "" { return nil, fmt.Errorf("member_id is required") } if req == nil { return nil, fmt.Errorf("trigger request is required") } mgr, err := getManager() if err != nil { return nil, err } switch req.Type { case types.TriggerHuman: return triggerHuman(ctx, mgr, memberID, req) case types.TriggerEvent: return triggerEvent(ctx, mgr, memberID, req) case types.TriggerClock: return triggerManual(ctx, mgr, memberID, req) default: return nil, fmt.Errorf("invalid trigger type: %s", req.Type) } } // TriggerManual manually triggers a robot execution (for testing or debugging) // This bypasses normal trigger validation and directly submits to the pool func TriggerManual(ctx *types.Context, memberID string, triggerType types.TriggerType, data interface{}) (*TriggerResult, error) { if memberID == "" { return nil, fmt.Errorf("member_id is required") } mgr, err := getManager() if err != nil { return nil, err } execID, err := mgr.TriggerManual(ctx, memberID, triggerType, data) if err != nil { return &TriggerResult{ Accepted: false, Message: err.Error(), }, nil } return &TriggerResult{ Accepted: true, ExecutionID: execID, Message: fmt.Sprintf("Manual trigger (%s) submitted", triggerType), }, nil } // Intervene processes a human intervention request // Human intervention skips P0 (inspiration) and goes directly to P1 (goals) func Intervene(ctx *types.Context, memberID string, req *TriggerRequest) (*TriggerResult, error) { if memberID == "" { return nil, fmt.Errorf("member_id is required") } if req == nil { return nil, fmt.Errorf("intervention request is required") } mgr, err := getManager() if err != nil { return nil, err } return triggerHuman(ctx, mgr, memberID, req) } // HandleEvent processes an event trigger request // Event trigger skips P0 (inspiration) and goes directly to P1 (goals) func HandleEvent(ctx *types.Context, memberID string, req *TriggerRequest) (*TriggerResult, error) { if memberID == "" { return nil, fmt.Errorf("member_id is required") } if req == nil { return nil, fmt.Errorf("event request is required") } mgr, err := getManager() if err != nil { return nil, err } return triggerEvent(ctx, mgr, memberID, req) } // ==================== Internal Trigger Functions ==================== // triggerHuman handles human intervention trigger func triggerHuman(ctx *types.Context, mgr managerInterface, memberID string, req *TriggerRequest) (*TriggerResult, error) { // Build intervention request interveneReq := &types.InterveneRequest{ MemberID: memberID, TeamID: ctx.TeamID(), Action: req.Action, Messages: req.Messages, PlanTime: req.PlanAt, ExecutorMode: req.ExecutorMode, } // Call manager's Intervene result, err := mgr.Intervene(ctx, interveneReq) if err != nil { return &TriggerResult{ Accepted: false, Message: err.Error(), }, nil } return &TriggerResult{ Accepted: true, ExecutionID: result.ExecutionID, Message: result.Message, }, nil } // triggerEvent handles event trigger func triggerEvent(ctx *types.Context, mgr managerInterface, memberID string, req *TriggerRequest) (*TriggerResult, error) { // Build event request eventReq := &types.EventRequest{ MemberID: memberID, Source: string(req.Source), EventType: req.EventType, Data: req.Data, ExecutorMode: req.ExecutorMode, } // Call manager's HandleEvent result, err := mgr.HandleEvent(ctx, eventReq) if err != nil { return &TriggerResult{ Accepted: false, Message: err.Error(), }, nil } return &TriggerResult{ Accepted: true, ExecutionID: result.ExecutionID, Message: result.Message, }, nil } // triggerManual handles manual/clock trigger func triggerManual(ctx *types.Context, mgr managerInterface, memberID string, req *TriggerRequest) (*TriggerResult, error) { // For clock trigger, pass clock context if available var data interface{} if req.Data != nil { data = req.Data } execID, err := mgr.TriggerManual(ctx, memberID, req.Type, data) if err != nil { return &TriggerResult{ Accepted: false, Message: err.Error(), }, nil } return &TriggerResult{ Accepted: true, ExecutionID: execID, Message: fmt.Sprintf("Trigger (%s) submitted", req.Type), }, nil } // managerInterface defines the methods we need from manager // This allows for easier testing with mocks type managerInterface interface { TriggerManual(ctx *types.Context, memberID string, trigger types.TriggerType, data interface{}) (string, error) Intervene(ctx *types.Context, req *types.InterveneRequest) (*types.ExecutionResult, error) HandleEvent(ctx *types.Context, req *types.EventRequest) (*types.ExecutionResult, error) PauseExecution(ctx *types.Context, execID string) error ResumeExecution(ctx *types.Context, execID string) error StopExecution(ctx *types.Context, execID string) error } ================================================ FILE: agent/robot/api/trigger_test.go ================================================ package api_test import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/robot/api" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // TestTriggerValidation tests parameter validation for Trigger func TestTriggerValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) t.Run("returns error for empty member_id", func(t *testing.T) { result, err := api.Trigger(ctx, "", &api.TriggerRequest{ Type: types.TriggerHuman, }) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("returns error for nil request", func(t *testing.T) { result, err := api.Trigger(ctx, "test_member", nil) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "trigger request is required") }) t.Run("returns error when manager not started", func(t *testing.T) { result, err := api.Trigger(ctx, "test_member", &api.TriggerRequest{ Type: types.TriggerHuman, }) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "not started") }) } // TestTriggerManualValidation tests parameter validation for TriggerManual func TestTriggerManualValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) t.Run("returns error for empty member_id", func(t *testing.T) { result, err := api.TriggerManual(ctx, "", types.TriggerClock, nil) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("returns error when manager not started", func(t *testing.T) { result, err := api.TriggerManual(ctx, "test_member", types.TriggerClock, nil) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "not started") }) } // TestInterveneValidation tests parameter validation for Intervene func TestInterveneValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) t.Run("returns error for empty member_id", func(t *testing.T) { result, err := api.Intervene(ctx, "", &api.TriggerRequest{ Type: types.TriggerHuman, Action: types.ActionTaskAdd, }) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("returns error for nil request", func(t *testing.T) { result, err := api.Intervene(ctx, "test_member", nil) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "intervention request is required") }) } // TestHandleEventValidation tests parameter validation for HandleEvent func TestHandleEventValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) t.Run("returns error for empty member_id", func(t *testing.T) { result, err := api.HandleEvent(ctx, "", &api.TriggerRequest{ Type: types.TriggerEvent, Source: types.EventWebhook, EventType: "test.event", }) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("returns error for nil request", func(t *testing.T) { result, err := api.HandleEvent(ctx, "test_member", nil) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "event request is required") }) } // TestTriggerWithManagerStarted tests trigger APIs when manager is running func TestTriggerWithManagerStarted(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Start manager err := api.Start() require.NoError(t, err) defer api.Stop() ctx := types.NewContext(context.Background(), nil) t.Run("returns not accepted for non-existent robot", func(t *testing.T) { result, err := api.Trigger(ctx, "non_existent_robot_xyz", &api.TriggerRequest{ Type: types.TriggerHuman, Action: types.ActionTaskAdd, }) // Should not error, but return not accepted assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.Accepted) }) t.Run("returns error for invalid trigger type", func(t *testing.T) { result, err := api.Trigger(ctx, "test_member", &api.TriggerRequest{ Type: "invalid_type", }) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "invalid trigger type") }) } ================================================ FILE: agent/robot/api/types.go ================================================ package api import ( "time" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/types" ) // ListQuery - query options for List() type ListQuery struct { TeamID string `json:"team_id,omitempty"` Status types.RobotStatus `json:"status,omitempty"` Keywords string `json:"keywords,omitempty"` ClockMode types.ClockMode `json:"clock_mode,omitempty"` AutonomousMode *bool `json:"autonomous_mode,omitempty"` // nil=all, true=autonomous only, false=on-demand only Page int `json:"page,omitempty"` PageSize int `json:"pagesize,omitempty"` Order string `json:"order,omitempty"` } // ListResult - result of List() type ListResult struct { Data []*types.Robot `json:"data"` Total int `json:"total"` Page int `json:"page"` PageSize int `json:"pagesize"` } // RobotState - runtime state from Status() type RobotState struct { MemberID string `json:"member_id"` TeamID string `json:"team_id"` DisplayName string `json:"display_name"` Bio string `json:"bio,omitempty"` Status types.RobotStatus `json:"status"` Running int `json:"running"` MaxRunning int `json:"max_running"` LastRun *time.Time `json:"last_run,omitempty"` NextRun *time.Time `json:"next_run,omitempty"` RunningIDs []string `json:"running_ids,omitempty"` YaoCreatedBy string `json:"__yao_created_by,omitempty"` // Creator user_id for permission check YaoTeamID string `json:"__yao_team_id,omitempty"` // Team ID for permission check } // ==================== Trigger Types ==================== // TriggerRequest - request for Trigger() // Input uses []context.Message to support rich content (text, images, files, audio) type TriggerRequest struct { Type types.TriggerType `json:"type"` // human | event | clock // Human intervention fields (when Type = human) Action types.InterventionAction `json:"action,omitempty"` Messages []agentcontext.Message `json:"messages,omitempty"` // user's input (supports text, images, files) PlanAt *time.Time `json:"plan_at,omitempty"` InsertPosition InsertPosition `json:"insert_at,omitempty"` AtIndex int `json:"at_index,omitempty"` // Event fields (when Type = event) Source types.EventSource `json:"source,omitempty"` EventType string `json:"event_type,omitempty"` Data map[string]interface{} `json:"data,omitempty"` // Executor mode (optional, overrides robot config) ExecutorMode types.ExecutorMode `json:"executor_mode,omitempty"` // i18n support Locale string `json:"locale,omitempty"` // Locale for UI messages (e.g., "en", "zh") } // InsertPosition - where to insert task in queue type InsertPosition string const ( // InsertFirst inserts at beginning (highest priority) InsertFirst InsertPosition = "first" // InsertLast appends at end (default) InsertLast InsertPosition = "last" // InsertNext inserts after current task InsertNext InsertPosition = "next" // InsertAt inserts at specific index (use AtIndex) InsertAt InsertPosition = "at" ) // TriggerResult - result of Trigger() type TriggerResult struct { Accepted bool `json:"accepted"` Queued bool `json:"queued"` Execution *types.Execution `json:"execution,omitempty"` ExecutionID string `json:"execution_id,omitempty"` // Execution ID Message string `json:"message,omitempty"` } // ==================== Execution Types ==================== // ExecutionQuery - query options for GetExecutions() type ExecutionQuery struct { Status types.ExecStatus `json:"status,omitempty"` ExcludeStatuses []types.ExecStatus `json:"exclude_statuses,omitempty"` Trigger types.TriggerType `json:"trigger,omitempty"` Page int `json:"page,omitempty"` PageSize int `json:"pagesize,omitempty"` } // ExecutionResult - result of GetExecutions() type ExecutionResult struct { Data []*types.Execution `json:"data"` Total int `json:"total"` Page int `json:"page"` PageSize int `json:"pagesize"` } // ==================== CRUD Types ==================== // AuthScope contains Yao permission fields for data scoping // These fields are used by Yao's permission system (when model has permission: true) type AuthScope struct { CreatedBy string `json:"__yao_created_by,omitempty"` // Creator user_id UpdatedBy string `json:"__yao_updated_by,omitempty"` // Updater user_id TeamID string `json:"__yao_team_id,omitempty"` // Permission team scope TenantID string `json:"__yao_tenant_id,omitempty"` // Permission tenant scope } // CreateRobotRequest - request for CreateRobot() type CreateRobotRequest struct { // Identity (member_id is optional - auto-generated if not provided) MemberID string `json:"member_id,omitempty"` // Unique robot identifier (auto-generated if empty) TeamID string `json:"team_id"` // Team ID (required) // Profile DisplayName string `json:"display_name,omitempty"` // Display name Bio string `json:"bio,omitempty"` // Robot description Avatar string `json:"avatar,omitempty"` // Avatar URL // Identity & Role SystemPrompt string `json:"system_prompt,omitempty"` // System prompt RoleID string `json:"role_id,omitempty"` // Role within team ManagerID string `json:"manager_id,omitempty"` // Direct manager user_id // Status Status string `json:"status,omitempty"` // Member status: active | inactive | pending | suspended RobotStatus string `json:"robot_status,omitempty"` // Robot status: idle | working | paused | error | maintenance AutonomousMode *bool `json:"autonomous_mode,omitempty"` // Whether autonomous mode is enabled // Communication RobotEmail string `json:"robot_email,omitempty"` // Robot email address AuthorizedSenders interface{} `json:"authorized_senders,omitempty"` // Email whitelist (JSON array) EmailFilterRules interface{} `json:"email_filter_rules,omitempty"` // Email filter rules (JSON array) // Capabilities RobotConfig interface{} `json:"robot_config,omitempty"` // Robot config JSON Agents interface{} `json:"agents,omitempty"` // Accessible agents (JSON array) MCPServers interface{} `json:"mcp_servers,omitempty"` // MCP servers (JSON array) LanguageModel string `json:"language_model,omitempty"` // Language model name // Limits CostLimit float64 `json:"cost_limit,omitempty"` // Monthly cost limit USD // Auth scope (optional, used by OpenAPI layer via WithCreateScope) AuthScope *AuthScope `json:"auth_scope,omitempty"` } // UpdateRobotRequest - request for UpdateRobot() type UpdateRobotRequest struct { // Profile DisplayName *string `json:"display_name,omitempty"` // Display name Bio *string `json:"bio,omitempty"` // Robot description Avatar *string `json:"avatar,omitempty"` // Avatar URL // Identity & Role SystemPrompt *string `json:"system_prompt,omitempty"` // System prompt RoleID *string `json:"role_id,omitempty"` // Role within team ManagerID *string `json:"manager_id,omitempty"` // Direct manager user_id // Status Status *string `json:"status,omitempty"` // Member status RobotStatus *string `json:"robot_status,omitempty"` // Robot status AutonomousMode *bool `json:"autonomous_mode,omitempty"` // Autonomous mode // Communication RobotEmail *string `json:"robot_email,omitempty"` // Robot email address AuthorizedSenders interface{} `json:"authorized_senders,omitempty"` // Email whitelist EmailFilterRules interface{} `json:"email_filter_rules,omitempty"` // Email filter rules // Capabilities RobotConfig interface{} `json:"robot_config,omitempty"` // Robot config JSON Agents interface{} `json:"agents,omitempty"` // Accessible agents MCPServers interface{} `json:"mcp_servers,omitempty"` // MCP servers LanguageModel *string `json:"language_model,omitempty"` // Language model name // Limits CostLimit *float64 `json:"cost_limit,omitempty"` // Monthly cost limit USD // Auth scope (optional, used by OpenAPI layer via WithUpdateScope) AuthScope *AuthScope `json:"auth_scope,omitempty"` } // RobotResponse - response containing robot details for API type RobotResponse struct { // Basic ID int64 `json:"id,omitempty"` MemberID string `json:"member_id"` TeamID string `json:"team_id"` Status string `json:"status"` RobotStatus string `json:"robot_status"` AutonomousMode bool `json:"autonomous_mode"` // Profile DisplayName string `json:"display_name"` Bio string `json:"bio,omitempty"` Avatar string `json:"avatar,omitempty"` // Identity & Role SystemPrompt string `json:"system_prompt,omitempty"` RoleID string `json:"role_id,omitempty"` ManagerID string `json:"manager_id,omitempty"` // Communication RobotEmail string `json:"robot_email,omitempty"` AuthorizedSenders interface{} `json:"authorized_senders,omitempty"` EmailFilterRules interface{} `json:"email_filter_rules,omitempty"` // Capabilities RobotConfig interface{} `json:"robot_config,omitempty"` Agents interface{} `json:"agents,omitempty"` MCPServers interface{} `json:"mcp_servers,omitempty"` LanguageModel string `json:"language_model,omitempty"` // Limits CostLimit float64 `json:"cost_limit,omitempty"` // Ownership & Audit InvitedBy string `json:"invited_by,omitempty"` JoinedAt *time.Time `json:"joined_at,omitempty"` YaoCreatedBy string `json:"__yao_created_by,omitempty"` // Creator user_id for permission check YaoTeamID string `json:"__yao_team_id,omitempty"` // Team ID for permission check // Timestamps CreatedAt *time.Time `json:"created_at,omitempty"` UpdatedAt *time.Time `json:"updated_at,omitempty"` } // ==================== Helper Functions ==================== // applyDefaults applies default values to ListQuery func (q *ListQuery) applyDefaults() { if q.Page <= 0 { q.Page = 1 } if q.PageSize <= 0 { q.PageSize = 20 } if q.PageSize > 100 { q.PageSize = 100 } } // applyDefaults applies default values to ExecutionQuery func (q *ExecutionQuery) applyDefaults() { if q.Page <= 0 { q.Page = 1 } if q.PageSize <= 0 { q.PageSize = 20 } if q.PageSize > 100 { q.PageSize = 100 } } ================================================ FILE: agent/robot/cache/cache.go ================================================ package cache import ( "sync" "github.com/yaoapp/yao/agent/robot/types" ) // Cache implements types.Cache interface // Thread-safe in-memory cache for Robot instances type Cache struct { robots map[string]*types.Robot // memberID -> Robot byTeam map[string][]string // teamID -> memberIDs mu sync.RWMutex } // New creates a new cache instance func New() *Cache { return &Cache{ robots: make(map[string]*types.Robot), byTeam: make(map[string][]string), } } // Get returns a robot by member ID // Stub: returns nil (will be implemented in Phase 3) func (c *Cache) Get(memberID string) *types.Robot { c.mu.RLock() defer c.mu.RUnlock() return c.robots[memberID] } // List returns all robots for a team func (c *Cache) List(teamID string) []*types.Robot { c.mu.RLock() defer c.mu.RUnlock() memberIDs := c.byTeam[teamID] robots := make([]*types.Robot, 0, len(memberIDs)) for _, memberID := range memberIDs { if robot := c.robots[memberID]; robot != nil { robots = append(robots, robot) } } return robots } // Note: Refresh is implemented in refresh.go // Add adds or updates a robot in cache func (c *Cache) Add(robot *types.Robot) { if robot == nil { return } c.mu.Lock() defer c.mu.Unlock() c.robots[robot.MemberID] = robot // Update team index if _, exists := c.byTeam[robot.TeamID]; !exists { c.byTeam[robot.TeamID] = []string{} } // Check if member ID already in team list found := false for _, id := range c.byTeam[robot.TeamID] { if id == robot.MemberID { found = true break } } if !found { c.byTeam[robot.TeamID] = append(c.byTeam[robot.TeamID], robot.MemberID) } } // Remove removes a robot from cache func (c *Cache) Remove(memberID string) { c.mu.Lock() defer c.mu.Unlock() robot := c.robots[memberID] if robot == nil { return } delete(c.robots, memberID) // Remove from team index teamMembers := c.byTeam[robot.TeamID] for i, id := range teamMembers { if id == memberID { c.byTeam[robot.TeamID] = append(teamMembers[:i], teamMembers[i+1:]...) break } } } ================================================ FILE: agent/robot/cache/cache_test.go ================================================ package cache_test import ( "context" "encoding/json" "testing" "time" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" "github.com/yaoapp/yao/agent/robot/cache" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // TestCacheLoad tests loading all active robots from database func TestCacheLoad(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Clean up any existing test data first cleanupTestRobots(t) // Create test robots in database setupTestRobots(t) defer cleanupTestRobots(t) c := cache.New() ctx := types.NewContext(context.Background(), nil) // Load all robots err := c.Load(ctx) assert.NoError(t, err) // Count should be at least 2 (may have other robots in DB) count := c.Count() assert.GreaterOrEqual(t, count, 2, "Should load at least 2 active autonomous robots") // Verify first robot robot1 := c.Get("robot_test_sales_001") assert.NotNil(t, robot1, "Sales bot should be loaded") if robot1 == nil { t.Fatal("robot_test_sales_001 not found in cache") } assert.Equal(t, "robot_test_sales_001", robot1.MemberID) assert.Equal(t, "team_test_cache_001", robot1.TeamID) assert.Equal(t, "Test Sales Bot", robot1.DisplayName) assert.Equal(t, types.RobotIdle, robot1.Status) assert.True(t, robot1.AutonomousMode) assert.NotNil(t, robot1.Config, "Robot config should be parsed") assert.NotNil(t, robot1.Config.Identity, "Identity should be parsed") assert.Equal(t, "Sales Manager", robot1.Config.Identity.Role) assert.Equal(t, 3, robot1.Config.Quota.GetMax()) // Verify second robot robot2 := c.Get("robot_test_support_002") assert.NotNil(t, robot2, "Support bot should be loaded") assert.Equal(t, "robot_test_support_002", robot2.MemberID) assert.Equal(t, "Test Support Bot", robot2.DisplayName) assert.NotNil(t, robot2.Config) assert.Equal(t, "Customer Support", robot2.Config.Identity.Role) assert.Equal(t, 2, robot2.Config.Quota.GetMax()) // Verify inactive robot is not loaded robot3 := c.Get("robot_test_inactive_003") assert.Nil(t, robot3, "Inactive robot should not be loaded") } // TestCacheLoadByID tests loading a single robot by member ID func TestCacheLoadByID(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobots(t) defer cleanupTestRobots(t) c := cache.New() ctx := types.NewContext(context.Background(), nil) t.Run("load existing robot", func(t *testing.T) { robot, err := c.LoadByID(ctx, "robot_test_sales_001") assert.NoError(t, err) assert.NotNil(t, robot) assert.Equal(t, "robot_test_sales_001", robot.MemberID) assert.Equal(t, "Test Sales Bot", robot.DisplayName) assert.NotNil(t, robot.Config) }) t.Run("load non-existent robot", func(t *testing.T) { robot, err := c.LoadByID(ctx, "robot_nonexistent") assert.Error(t, err) assert.Equal(t, types.ErrRobotNotFound, err) assert.Nil(t, robot) }) t.Run("load inactive robot by ID", func(t *testing.T) { // LoadByID doesn't filter by status, so it should load robot, err := c.LoadByID(ctx, "robot_test_inactive_003") assert.NoError(t, err) assert.NotNil(t, robot) assert.Equal(t, "robot_test_inactive_003", robot.MemberID) }) } // TestCacheRefresh tests refreshing a single robot from database func TestCacheRefresh(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobots(t) defer cleanupTestRobots(t) c := cache.New() ctx := types.NewContext(context.Background(), nil) // Load initial data err := c.Load(ctx) assert.NoError(t, err) t.Run("refresh existing robot", func(t *testing.T) { err := c.Refresh(ctx, "robot_test_sales_001") assert.NoError(t, err) // Robot should still be in cache robot := c.Get("robot_test_sales_001") assert.NotNil(t, robot) }) t.Run("refresh removes non-existent robot", func(t *testing.T) { // Add a fake robot to cache c.Add(&types.Robot{MemberID: "robot_test_fake", TeamID: "team_test_cache_001"}) assert.NotNil(t, c.Get("robot_test_fake")) // Refresh should remove it err := c.Refresh(ctx, "robot_test_fake") assert.NoError(t, err) assert.Nil(t, c.Get("robot_test_fake"), "Non-existent robot should be removed") }) } // TestCacheListByTeam tests listing robots by team func TestCacheListByTeam(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobots(t) defer cleanupTestRobots(t) c := cache.New() ctx := types.NewContext(context.Background(), nil) // Load all robots err := c.Load(ctx) assert.NoError(t, err) // List robots by team robots := c.List("team_test_cache_001") assert.Len(t, robots, 2, "Should have 2 robots in team_test_cache_001") // List robots for non-existent team robots = c.List("team_nonexistent") assert.Len(t, robots, 0, "Non-existent team should have no robots") } // TestCacheGetByStatus tests getting robots by status func TestCacheGetByStatus(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobots(t) defer cleanupTestRobots(t) c := cache.New() ctx := types.NewContext(context.Background(), nil) // Load all robots err := c.Load(ctx) assert.NoError(t, err) // Get idle robots (may have others in DB) idle := c.GetIdle() assert.GreaterOrEqual(t, len(idle), 2, "Should have at least 2 idle robots") // Verify our test robots are not working testRobot1 := c.Get("robot_test_sales_001") testRobot2 := c.Get("robot_test_support_002") assert.Equal(t, types.RobotIdle, testRobot1.Status, "Test robot 1 should be idle") assert.Equal(t, types.RobotIdle, testRobot2.Status, "Test robot 2 should be idle") } // TestCacheAutoRefresh tests auto-refresh functionality and goroutine leak prevention func TestCacheAutoRefresh(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobots(t) defer cleanupTestRobots(t) // Verify test data is set up c := cache.New() ctx := types.NewContext(context.Background(), nil) err := c.Load(ctx) assert.NoError(t, err) assert.GreaterOrEqual(t, c.Count(), 1, "Should have at least one robot loaded") t.Run("start and stop auto-refresh", func(t *testing.T) { // Use a fresh cache for this test testCache := cache.New() testCtx := types.NewContext(context.Background(), nil) err := testCache.Load(testCtx) assert.NoError(t, err) // Start auto-refresh with short interval config := &cache.RefreshConfig{Interval: 100 * time.Millisecond} testCache.StartAutoRefresh(testCtx, config) // Wait a bit to let it run (should trigger at least 2 refreshes) time.Sleep(250 * time.Millisecond) // Stop auto-refresh testCache.StopAutoRefresh() // Verify it stopped by checking that no more refreshes happen countBefore := testCache.Count() time.Sleep(200 * time.Millisecond) countAfter := testCache.Count() // Count should be stable (no errors from stopped goroutine) assert.Equal(t, countBefore, countAfter, "Cache should be stable after stop") }) t.Run("multiple start calls should replace previous", func(t *testing.T) { // Use a fresh cache for this test testCache := cache.New() testCtx := types.NewContext(context.Background(), nil) err := testCache.Load(testCtx) assert.NoError(t, err) // Track refresh count using a counter refreshCount := 0 originalCount := testCache.Count() // Start multiple times without stopping config := &cache.RefreshConfig{Interval: 50 * time.Millisecond} testCache.StartAutoRefresh(testCtx, config) time.Sleep(30 * time.Millisecond) testCache.StartAutoRefresh(testCtx, config) // Should stop previous one time.Sleep(30 * time.Millisecond) testCache.StartAutoRefresh(testCtx, config) // Should stop previous one // Wait for some refreshes time.Sleep(150 * time.Millisecond) // Stop once should be enough testCache.StopAutoRefresh() // Verify cache still works correctly assert.GreaterOrEqual(t, testCache.Count(), 0, "Cache should still be functional") // Verify we can still access robots _ = refreshCount // suppress unused warning _ = originalCount // suppress unused warning }) t.Run("stop without start should not panic", func(t *testing.T) { // Use a fresh cache for this test testCache := cache.New() // Multiple stops should be safe assert.NotPanics(t, func() { testCache.StopAutoRefresh() testCache.StopAutoRefresh() testCache.StopAutoRefresh() }) }) t.Run("concurrent start and stop should be safe", func(t *testing.T) { // Use a fresh cache for this test testCache := cache.New() testCtx := types.NewContext(context.Background(), nil) err := testCache.Load(testCtx) assert.NoError(t, err) config := &cache.RefreshConfig{Interval: 50 * time.Millisecond} // Rapidly start and stop multiple times - should not panic or deadlock done := make(chan bool) go func() { for i := 0; i < 10; i++ { testCache.StartAutoRefresh(testCtx, config) time.Sleep(10 * time.Millisecond) testCache.StopAutoRefresh() time.Sleep(10 * time.Millisecond) } done <- true }() // Wait with timeout to detect deadlocks select { case <-done: // Success - no deadlock case <-time.After(5 * time.Second): t.Fatal("Rapid start/stop cycles caused deadlock") } // Final cleanup testCache.StopAutoRefresh() // Verify cache is still functional assert.GreaterOrEqual(t, testCache.Count(), 0, "Cache should still be functional after rapid cycles") }) } // setupTestRobots creates 3 test robot records in database func setupTestRobots(t *testing.T) { // Get the actual table name from model m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() // Robot 1: Sales Bot (active, autonomous) robotConfig1 := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Sales Manager", "duties": []string{"Manage leads", "Follow up customers"}, "rules": []string{"Be professional", "Reply within 24h"}, }, "quota": map[string]interface{}{ "max": 3, "queue": 15, "priority": 7, }, "clock": map[string]interface{}{ "mode": "times", "times": []string{"09:00", "14:00"}, "tz": "Asia/Shanghai", }, } config1JSON, _ := json.Marshal(robotConfig1) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_sales_001", "team_id": "team_test_cache_001", "member_type": "robot", "display_name": "Test Sales Bot", "system_prompt": "You are a professional sales manager assistant.", "status": "active", "role_id": "member", // required field "autonomous_mode": true, "robot_status": "idle", "robot_config": string(config1JSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_sales_001: %v", err) } // Robot 2: Support Bot (active, autonomous) robotConfig2 := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Customer Support", "duties": []string{"Answer questions", "Resolve issues"}, }, "quota": map[string]interface{}{ "max": 2, "queue": 10, "priority": 5, }, "clock": map[string]interface{}{ "mode": "interval", "every": "1h", }, } config2JSON, _ := json.Marshal(robotConfig2) err = qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_support_002", "team_id": "team_test_cache_001", "member_type": "robot", "display_name": "Test Support Bot", "system_prompt": "You are a helpful customer support assistant.", "status": "active", "role_id": "member", // required field "autonomous_mode": true, "robot_status": "idle", "robot_config": string(config2JSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_support_002: %v", err) } // Robot 3: Inactive robot (should not be loaded by Load()) err = qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_inactive_003", "team_id": "team_test_cache_001", "member_type": "robot", "display_name": "Test Inactive Bot", "status": "inactive", "role_id": "member", // required field "autonomous_mode": true, "robot_status": "paused", }, }) if err != nil { t.Fatalf("Failed to insert robot_test_inactive_003: %v", err) } } // cleanupTestRobots removes test robot records func cleanupTestRobots(t *testing.T) { qb := capsule.Query() // Use the member model to perform soft delete m := model.Select("__yao.member") // Delete test robots m.DeleteWhere(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "member_id", Value: "robot_test_sales_001"}, }, }) m.DeleteWhere(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "member_id", Value: "robot_test_support_002"}, }, }) m.DeleteWhere(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "member_id", Value: "robot_test_inactive_003"}, }, }) // Hard delete from database (cleanup for next test run) m2 := model.Select("__yao.member") tableName2 := m2.MetaData.Table.Name qb.Table(tableName2).Where("member_id", "robot_test_sales_001").Delete() qb.Table(tableName2).Where("member_id", "robot_test_support_002").Delete() qb.Table(tableName2).Where("member_id", "robot_test_inactive_003").Delete() } ================================================ FILE: agent/robot/cache/load.go ================================================ package cache import ( "fmt" "github.com/yaoapp/gou/model" "github.com/yaoapp/kun/maps" "github.com/yaoapp/yao/agent/robot/types" ) // memberModel is the model name for member table // Can be changed via SetMemberModel() during system initialization var memberModel = "__yao.member" // memberFields are the fields to select when loading robots var memberFields = []interface{}{ "id", "member_id", "team_id", "display_name", "bio", "system_prompt", "robot_status", "autonomous_mode", "robot_config", "robot_email", "agents", "mcp_servers", "manager_id", "language_model", } // SetMemberModel sets the member model name // Call this during system initialization to override the default func SetMemberModel(model string) { if model != "" { memberModel = model } } // Load loads all active robots from database with pagination // Query: member_type='robot' AND autonomous_mode=true AND status='active' func (c *Cache) Load(ctx *types.Context) error { m := model.Select(memberModel) // Clear existing cache first c.mu.Lock() c.robots = make(map[string]*types.Robot) c.byTeam = make(map[string][]string) c.mu.Unlock() // Paginate to handle large number of robots page := 1 pageSize := 100 // load 100 robots per page totalLoaded := 0 for { // Query with pagination result, err := m.Paginate(model.QueryParam{ Select: memberFields, Wheres: []model.QueryWhere{ {Column: "member_type", Value: "robot"}, {Column: "autonomous_mode", Value: true}, {Column: "status", Value: "active"}, }, }, page, pageSize) if err != nil { return fmt.Errorf("failed to load robots (page %d): %w", page, err) } // Extract records from pagination result data, ok := result.Get("data").([]maps.MapStr) if !ok || len(data) == 0 { break } // Parse and add each robot for _, record := range data { robot, err := types.NewRobotFromMap(map[string]interface{}(record)) if err != nil { // Log error but continue loading other robots continue } c.Add(robot) totalLoaded++ } // Check if there are more pages total, _ := result.Get("total").(int) if totalLoaded >= total { break } page++ } return nil } // LoadByID loads a single robot from database by member ID func (c *Cache) LoadByID(ctx *types.Context, memberID string) (*types.Robot, error) { m := model.Select(memberModel) records, err := m.Get(model.QueryParam{ Select: memberFields, Wheres: []model.QueryWhere{ {Column: "member_id", Value: memberID}, {Column: "member_type", Value: "robot"}, }, Limit: 1, }) if err != nil { return nil, fmt.Errorf("failed to load robot %s: %w", memberID, err) } if len(records) == 0 { return nil, types.ErrRobotNotFound } return types.NewRobotFromMap(map[string]interface{}(records[0])) } ================================================ FILE: agent/robot/cache/refresh.go ================================================ package cache import ( "sync" "time" "github.com/yaoapp/yao/agent/robot/types" ) // RefreshConfig holds refresh configuration type RefreshConfig struct { Interval time.Duration // full refresh interval (default: 1 hour) } // DefaultRefreshConfig returns default refresh configuration func DefaultRefreshConfig() *RefreshConfig { return &RefreshConfig{ Interval: time.Hour, } } // refreshState holds the refresh goroutine state type refreshState struct { ticker *time.Ticker done chan struct{} mu sync.Mutex } var refresher = &refreshState{} // Refresh refreshes a single robot's config from database func (c *Cache) Refresh(ctx *types.Context, memberID string) error { robot, err := c.LoadByID(ctx, memberID) if err != nil { // If robot not found or no longer autonomous, remove from cache if err == types.ErrRobotNotFound { c.Remove(memberID) return nil } return err } // Check if robot is still active and autonomous if !robot.AutonomousMode { c.Remove(memberID) return nil } // Update cache c.Add(robot) return nil } // StartAutoRefresh starts periodic full refresh func (c *Cache) StartAutoRefresh(ctx *types.Context, config *RefreshConfig) { if config == nil { config = DefaultRefreshConfig() } refresher.mu.Lock() defer refresher.mu.Unlock() // Stop existing refresher if any if refresher.done != nil { close(refresher.done) } refresher.ticker = time.NewTicker(config.Interval) refresher.done = make(chan struct{}) go func() { for { select { case <-refresher.done: refresher.ticker.Stop() return case <-refresher.ticker.C: // Perform full refresh _ = c.Load(ctx) } } }() } // StopAutoRefresh stops the periodic refresh func (c *Cache) StopAutoRefresh() { refresher.mu.Lock() defer refresher.mu.Unlock() if refresher.done != nil { close(refresher.done) refresher.done = nil } } // RefreshAll reloads all robots from database func (c *Cache) RefreshAll(ctx *types.Context) error { return c.Load(ctx) } // Count returns the number of cached robots func (c *Cache) Count() int { c.mu.RLock() defer c.mu.RUnlock() return len(c.robots) } // ListAll returns all cached robots (across all teams) func (c *Cache) ListAll() []*types.Robot { c.mu.RLock() defer c.mu.RUnlock() robots := make([]*types.Robot, 0, len(c.robots)) for _, robot := range c.robots { robots = append(robots, robot) } return robots } // GetByStatus returns robots with the specified status func (c *Cache) GetByStatus(status types.RobotStatus) []*types.Robot { c.mu.RLock() defer c.mu.RUnlock() var robots []*types.Robot for _, robot := range c.robots { if robot.Status == status { robots = append(robots, robot) } } return robots } // GetIdle returns all idle robots ready to execute func (c *Cache) GetIdle() []*types.Robot { return c.GetByStatus(types.RobotIdle) } // GetWorking returns all currently working robots func (c *Cache) GetWorking() []*types.Robot { return c.GetByStatus(types.RobotWorking) } ================================================ FILE: agent/robot/dedup/dedup.go ================================================ package dedup import ( "sync" "time" "github.com/yaoapp/yao/agent/robot/types" ) // Dedup implements types.Dedup interface // This is a stub implementation for Phase 2 type Dedup struct { marks map[string]time.Time // key -> expiry time mu sync.RWMutex } // New creates a new dedup instance func New() *Dedup { return &Dedup{ marks: make(map[string]time.Time), } } // Check checks if execution should be deduplicated // Stub: always returns proceed (will be implemented in Phase 3) func (d *Dedup) Check(ctx *types.Context, memberID string, trigger types.TriggerType) (types.DedupResult, error) { return types.DedupProceed, nil } // Mark marks an execution to prevent duplicates within window // Stub: does nothing (will be implemented in Phase 3) func (d *Dedup) Mark(memberID string, trigger types.TriggerType, window time.Duration) { // Stub: no-op } ================================================ FILE: agent/robot/events/delivery.go ================================================ package events import ( "bytes" "context" "crypto/hmac" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "io" "net/http" "strings" "time" "github.com/yaoapp/gou/process" "github.com/yaoapp/gou/text" agentcontext "github.com/yaoapp/yao/agent/context" robottypes "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/attachment" eventtypes "github.com/yaoapp/yao/event/types" "github.com/yaoapp/yao/messenger" messengerTypes "github.com/yaoapp/yao/messenger/types" ) // handleDelivery routes delivery content to configured channels (email, webhook, process). func (h *robotHandler) handleDelivery(ctx context.Context, ev *eventtypes.Event, resp chan<- eventtypes.Result) { var payload DeliveryPayload if err := ev.Should(&payload); err != nil { log.Error("delivery handler: invalid payload: %v", err) if ev.IsCall { resp <- eventtypes.Result{Err: err} } return } log.Info("delivery handler: execution=%s member=%s", payload.ExecutionID, payload.MemberID) content := payload.Content prefs := payload.Preferences if content == nil { log.Warn("delivery handler: nil content for execution=%s", payload.ExecutionID) if ev.IsCall { resp <- eventtypes.Result{Data: "no content"} } return } if prefs == nil { if ev.IsCall { resp <- eventtypes.Result{Data: "no preferences, skipped"} } return } deliveryCtx := &robottypes.DeliveryContext{ MemberID: payload.MemberID, ExecutionID: payload.ExecutionID, TeamID: payload.TeamID, } var results []robottypes.ChannelResult var lastErr error if prefs.Email != nil && prefs.Email.Enabled { for _, target := range prefs.Email.Targets { r := h.sendEmail(ctx, content, target, deliveryCtx) results = append(results, r) if !r.Success && lastErr == nil { lastErr = fmt.Errorf("email delivery failed: %s", r.Error) } } } if prefs.Webhook != nil && prefs.Webhook.Enabled { for _, target := range prefs.Webhook.Targets { r := h.postWebhook(ctx, content, target, deliveryCtx) results = append(results, r) if !r.Success && lastErr == nil { lastErr = fmt.Errorf("webhook delivery failed: %s", r.Error) } } } if prefs.Process != nil && prefs.Process.Enabled { for _, target := range prefs.Process.Targets { r := h.callProcess(ctx, content, target, deliveryCtx) results = append(results, r) if !r.Success && lastErr == nil { lastErr = fmt.Errorf("process delivery failed: %s", r.Error) } } } // Push delivery to integration channels only when the task originated from one if reply := getReplyFunc(); reply != nil && payload.ChatID != "" { channel, chatID := splitChannelChatID(payload.ChatID) if channel != "" && chatID != "" { msg := buildDeliveryMessage(content) if msg != nil { extra := map[string]any{ "member_id": payload.MemberID, "execution_id": payload.ExecutionID, } for k, v := range payload.Extra { extra[k] = v } metadata := &MessageMetadata{ Channel: channel, ChatID: chatID, Extra: extra, } if err := reply(ctx, msg, metadata); err != nil { log.Error("delivery handler: integration reply failed channel=%s execution=%s: %v", channel, payload.ExecutionID, err) } } } } if lastErr != nil { log.Error("delivery handler: partial failure execution=%s: %v", payload.ExecutionID, lastErr) } if ev.IsCall { resp <- eventtypes.Result{ Data: map[string]interface{}{ "execution_id": payload.ExecutionID, "results": results, }, Err: lastErr, } } } // buildDeliveryMessage converts DeliveryContent into a standard assistant Message. func buildDeliveryMessage(content *robottypes.DeliveryContent) *agentcontext.Message { if content == nil { return nil } var parts []interface{} text := content.Body if text == "" { text = content.Summary } if text != "" { parts = append(parts, map[string]interface{}{ "type": "text", "text": text, }) } for _, att := range content.Attachments { if att.File == "" { continue } part := map[string]interface{}{ "type": "file", "file": map[string]interface{}{ "url": att.File, "filename": att.Title, }, } parts = append(parts, part) } if len(parts) == 0 { return nil } var msgContent interface{} if len(parts) == 1 { if tp, ok := parts[0].(map[string]interface{}); ok && tp["type"] == "text" { msgContent = tp["text"] } else { msgContent = parts } } else { msgContent = parts } return &agentcontext.Message{ Role: agentcontext.RoleAssistant, Content: msgContent, } } // ============================================================================ // Email // ============================================================================ func (h *robotHandler) sendEmail( ctx context.Context, content *robottypes.DeliveryContent, target robottypes.EmailTarget, deliveryCtx *robottypes.DeliveryContext, ) robottypes.ChannelResult { now := time.Now() targetID := strings.Join(target.To, ",") if targetID == "" { targetID = "no-recipients" } result := robottypes.ChannelResult{ Type: robottypes.DeliveryEmail, Target: targetID, SentAt: &now, } svc := messenger.Instance if svc == nil { result.Error = "messenger service not available" return result } htmlBody, plainBody := buildEmailBody(target.Template, content) msg := &messengerTypes.Message{ To: target.To, Subject: buildEmailSubject(target.Subject, target.Template, content, deliveryCtx), Body: plainBody, HTML: htmlBody, Type: messengerTypes.MessageTypeEmail, } attachments := convertAttachments(ctx, content.Attachments) if len(attachments) > 0 { msg.Attachments = attachments } channel := robottypes.DefaultEmailChannel() if err := svc.Send(ctx, channel, msg); err != nil { result.Error = err.Error() return result } result.Success = true result.Recipients = target.To return result } // ============================================================================ // Webhook // ============================================================================ func (h *robotHandler) postWebhook( ctx context.Context, content *robottypes.DeliveryContent, target robottypes.WebhookTarget, deliveryCtx *robottypes.DeliveryContext, ) robottypes.ChannelResult { now := time.Now() result := robottypes.ChannelResult{ Type: robottypes.DeliveryWebhook, Target: target.URL, SentAt: &now, } payload := map[string]interface{}{ "event": "robot.delivery", "timestamp": now.Format(time.RFC3339), "execution_id": deliveryCtx.ExecutionID, "member_id": deliveryCtx.MemberID, "team_id": deliveryCtx.TeamID, "trigger_type": deliveryCtx.TriggerType, "content": map[string]interface{}{ "summary": content.Summary, "body": content.Body, }, } if len(content.Attachments) > 0 { info := make([]map[string]interface{}, 0, len(content.Attachments)) for _, att := range content.Attachments { info = append(info, map[string]interface{}{ "title": att.Title, "description": att.Description, "task_id": att.TaskID, "file": att.File, }) } payload["attachments"] = info } payloadBytes, err := json.Marshal(payload) if err != nil { result.Error = fmt.Sprintf("failed to marshal payload: %v", err) return result } method := target.Method if method == "" { method = "POST" } req, err := http.NewRequestWithContext(ctx, method, target.URL, bytes.NewReader(payloadBytes)) if err != nil { result.Error = fmt.Sprintf("failed to create request: %v", err) return result } req.Header.Set("Content-Type", "application/json") for key, value := range target.Headers { req.Header.Set(key, value) } if target.Secret != "" { signature := ComputeHMACSignature(payloadBytes, target.Secret) req.Header.Set("X-Yao-Signature", signature) req.Header.Set("X-Yao-Signature-Algorithm", "HMAC-SHA256") } httpResp, err := h.httpClient.Do(req) if err != nil { result.Error = fmt.Sprintf("request failed: %v", err) return result } defer httpResp.Body.Close() body, _ := io.ReadAll(httpResp.Body) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { result.Error = fmt.Sprintf("webhook returned status %d: %s", httpResp.StatusCode, string(body)) return result } result.Success = true result.Details = map[string]interface{}{ "status_code": httpResp.StatusCode, "response": string(body), } return result } // ============================================================================ // Process // ============================================================================ func (h *robotHandler) callProcess( ctx context.Context, content *robottypes.DeliveryContent, target robottypes.ProcessTarget, deliveryCtx *robottypes.DeliveryContext, ) robottypes.ChannelResult { now := time.Now() result := robottypes.ChannelResult{ Type: robottypes.DeliveryProcess, Target: target.Process, SentAt: &now, } args := make([]interface{}, 0, 1+len(target.Args)) args = append(args, map[string]interface{}{ "content": map[string]interface{}{ "summary": content.Summary, "body": content.Body, "attachments": content.Attachments, }, "context": map[string]interface{}{ "execution_id": deliveryCtx.ExecutionID, "member_id": deliveryCtx.MemberID, "team_id": deliveryCtx.TeamID, "trigger_type": deliveryCtx.TriggerType, }, }) args = append(args, target.Args...) proc, err := process.Of(target.Process, args...) if err != nil { result.Error = fmt.Sprintf("failed to create process: %v", err) return result } proc.Context = ctx if err = proc.Execute(); err != nil { result.Error = err.Error() return result } result.Success = true result.Details = toJSONSerializable(proc.Value) return result } // ============================================================================ // Helpers // ============================================================================ func toJSONSerializable(v interface{}) interface{} { if v == nil { return nil } if _, err := json.Marshal(v); err != nil { return fmt.Sprintf("%v", v) } return v } func buildEmailSubject(subject, template string, content *robottypes.DeliveryContent, ctx *robottypes.DeliveryContext) string { if subject != "" { return subject } if content.Summary != "" { return content.Summary } return fmt.Sprintf("Execution %s Complete", ctx.ExecutionID) } func buildEmailBody(template string, content *robottypes.DeliveryContent) (string, string) { markdown := content.Body if markdown == "" { markdown = content.Summary } html, err := text.MarkdownToHTML(markdown) if err != nil { return markdown, markdown } return html, markdown } func convertAttachments(ctx context.Context, attachments []robottypes.DeliveryAttachment) []messengerTypes.Attachment { if len(attachments) == 0 { return nil } result := make([]messengerTypes.Attachment, 0, len(attachments)) for _, att := range attachments { uploader, fileID, isWrapper := attachment.Parse(att.File) if !isWrapper { log.Warn("convertAttachments: skipping non-wrapper file value=%q title=%q", att.File, att.Title) continue } manager, ok := attachment.Managers[uploader] if !ok { log.Warn("convertAttachments: manager not found uploader=%q file=%q title=%q (available: %v)", uploader, att.File, att.Title, attachmentManagerKeys()) continue } info, err := manager.Info(ctx, fileID) if err != nil { log.Warn("convertAttachments: failed to get file info fileID=%q uploader=%q: %v", fileID, uploader, err) continue } content, err := manager.Read(ctx, fileID) if err != nil { log.Warn("convertAttachments: failed to read file fileID=%q uploader=%q: %v", fileID, uploader, err) continue } filename := info.Filename if att.Title != "" { ext := "" if idx := strings.LastIndex(info.Filename, "."); idx >= 0 { ext = info.Filename[idx:] } titleExt := "" if idx := strings.LastIndex(att.Title, "."); idx >= 0 { titleExt = att.Title[idx:] } if titleExt != "" { filename = att.Title } else { filename = att.Title + ext } } log.Info("convertAttachments: added attachment filename=%q contentType=%q size=%d", filename, info.ContentType, len(content)) result = append(result, messengerTypes.Attachment{ Filename: filename, ContentType: info.ContentType, Content: content, }) } return result } func attachmentManagerKeys() []string { keys := make([]string, 0, len(attachment.Managers)) for k := range attachment.Managers { keys = append(keys, k) } return keys } // ComputeHMACSignature computes HMAC-SHA256 signature for webhook payload. func ComputeHMACSignature(payload []byte, secret string) string { mac := hmac.New(sha256.New, []byte(secret)) mac.Write(payload) return hex.EncodeToString(mac.Sum(nil)) } // VerifyHMACSignature verifies the HMAC-SHA256 signature of a webhook payload. func VerifyHMACSignature(payload []byte, secret, signature string) bool { expected := ComputeHMACSignature(payload, secret) return hmac.Equal([]byte(expected), []byte(signature)) } // splitChannelChatID splits a composite "channel:chatID" string (e.g. "telegram:8134167376") // into its channel and chatID parts. If no colon is present, channel is empty. func splitChannelChatID(composite string) (channel, chatID string) { if idx := strings.Index(composite, ":"); idx >= 0 { return composite[:idx], composite[idx+1:] } return "", composite } ================================================ FILE: agent/robot/events/event_push_test.go ================================================ package events import ( "encoding/json" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" robottypes "github.com/yaoapp/yao/agent/robot/types" ) // EP1: ExecPayload with all execution statuses func TestExecPayloadAllStatuses(t *testing.T) { statuses := []string{ "running", "completed", "failed", "cancelled", "waiting", "confirming", } for _, s := range statuses { payload := ExecPayload{ ExecutionID: "exec-ep1", MemberID: "member-ep1", TeamID: "team-ep1", Status: s, } data, err := json.Marshal(payload) require.NoError(t, err) var parsed ExecPayload err = json.Unmarshal(data, &parsed) require.NoError(t, err) assert.Equal(t, s, parsed.Status, "Status %s should round-trip", s) } } // EP2: NeedInputPayload with empty question func TestNeedInputPayloadEmptyQuestion(t *testing.T) { payload := NeedInputPayload{ ExecutionID: "exec-ep2", MemberID: "member-ep2", TeamID: "team-ep2", TaskID: "task-ep2", Question: "", } data, err := json.Marshal(payload) require.NoError(t, err) var parsed map[string]interface{} err = json.Unmarshal(data, &parsed) require.NoError(t, err) // Empty string should be present but empty q, ok := parsed["question"] assert.True(t, ok) assert.Equal(t, "", q) } // EP3: TaskPayload serializes error correctly func TestTaskPayloadErrorSerialization(t *testing.T) { payload := TaskPayload{ ExecutionID: "exec-ep3", MemberID: "member-ep3", TeamID: "team-ep3", TaskID: "task-ep3", Error: "context deadline exceeded", } data, err := json.Marshal(payload) require.NoError(t, err) var parsed map[string]interface{} err = json.Unmarshal(data, &parsed) require.NoError(t, err) assert.Equal(t, "context deadline exceeded", parsed["error"]) } // EP4: DeliveryPayload with nested content func TestDeliveryPayloadNestedContent(t *testing.T) { payload := DeliveryPayload{ ExecutionID: "exec-ep4", MemberID: "member-ep4", TeamID: "team-ep4", Content: &robottypes.DeliveryContent{ Summary: "Daily Summary", Body: "Full body with sections: intro, body, conclusion", }, } data, err := json.Marshal(payload) require.NoError(t, err) var parsed DeliveryPayload err = json.Unmarshal(data, &parsed) require.NoError(t, err) require.NotNil(t, parsed.Content) assert.Equal(t, "Daily Summary", parsed.Content.Summary) assert.Contains(t, parsed.Content.Body, "sections") } // EP5: Event constants follow naming convention func TestEventConstantNamingConvention(t *testing.T) { allEvents := []string{ TaskNeedInput, TaskFailed, TaskCompleted, ExecWaiting, ExecResumed, ExecCompleted, ExecFailed, ExecCancelled, Delivery, } for _, e := range allEvents { assert.Contains(t, e, "robot.", "Event %q should start with 'robot.'", e) } } // EP6: ExecPayload omits empty ChatID func TestExecPayloadOmitsEmptyOptionalFields(t *testing.T) { payload := ExecPayload{ ExecutionID: "exec-ep6", MemberID: "member-ep6", TeamID: "team-ep6", Status: "completed", } data, err := json.Marshal(payload) require.NoError(t, err) var parsed map[string]interface{} err = json.Unmarshal(data, &parsed) require.NoError(t, err) _, hasChatID := parsed["chat_id"] if hasChatID { assert.Equal(t, "", parsed["chat_id"]) } } // EP7: All payloads share common fields (ExecutionID, MemberID, TeamID) func TestPayloadCommonFields(t *testing.T) { needInput := NeedInputPayload{ExecutionID: "e1", MemberID: "m1", TeamID: "t1"} task := TaskPayload{ExecutionID: "e2", MemberID: "m2", TeamID: "t2"} exec := ExecPayload{ExecutionID: "e3", MemberID: "m3", TeamID: "t3"} delivery := DeliveryPayload{ExecutionID: "e4", MemberID: "m4", TeamID: "t4"} assert.Equal(t, "e1", needInput.ExecutionID) assert.Equal(t, "m2", task.MemberID) assert.Equal(t, "t3", exec.TeamID) assert.Equal(t, "e4", delivery.ExecutionID) } ================================================ FILE: agent/robot/events/events.go ================================================ package events import ( "context" "strings" "sync" agentcontext "github.com/yaoapp/yao/agent/context" robottypes "github.com/yaoapp/yao/agent/robot/types" ) // TriggerFunc is the callback for triggering a robot execution. // Injected by the api package at startup to break the import cycle. // Returns (executionID, accepted, error). type TriggerFunc func(ctx *robottypes.Context, memberID string, triggerType robottypes.TriggerType, data interface{}) (string, bool, error) // ReplyFunc is the callback for replying to the originating channel. // Injected by the dispatcher at startup. The implementation routes the // reply to the correct adapter based on metadata.Channel. // msg is the standard assistant message (text, Content Parts, media, etc.). type ReplyFunc func(ctx context.Context, msg *agentcontext.Message, metadata *MessageMetadata) error var ( triggerFn TriggerFunc triggerFnMu sync.RWMutex replyFn ReplyFunc replyFnMu sync.RWMutex ) // RegisterTriggerFunc sets the function used by handleMessage to trigger // robot execution when a confirmed action is detected. func RegisterTriggerFunc(fn TriggerFunc) { triggerFnMu.Lock() defer triggerFnMu.Unlock() triggerFn = fn } func getTriggerFunc() TriggerFunc { triggerFnMu.RLock() defer triggerFnMu.RUnlock() return triggerFn } // RegisterReplyFunc sets the function used by handleMessage to reply // to the originating channel after processing. func RegisterReplyFunc(fn ReplyFunc) { replyFnMu.Lock() defer replyFnMu.Unlock() replyFn = fn } func getReplyFunc() ReplyFunc { replyFnMu.RLock() defer replyFnMu.RUnlock() return replyFn } // Robot event type constants for event.Push integration. // Events are fire-and-forget; handlers are registered via event.Register(). const ( TaskNeedInput = "robot.task.need_input" TaskFailed = "robot.task.failed" TaskCompleted = "robot.task.completed" ExecWaiting = "robot.exec.waiting" ExecResumed = "robot.exec.resumed" ExecCompleted = "robot.exec.completed" ExecFailed = "robot.exec.failed" ExecCancelled = "robot.exec.cancelled" Delivery = "robot.delivery" Message = "robot.message" ) // Robot configuration change events (used by integrations Receiver). const ( RobotConfigCreated = "robot.config.created" RobotConfigUpdated = "robot.config.updated" RobotConfigDeleted = "robot.config.deleted" ) // NeedInputPayload is the event payload for TaskNeedInput / ExecWaiting events. type NeedInputPayload struct { ExecutionID string `json:"execution_id"` MemberID string `json:"member_id"` TeamID string `json:"team_id"` TaskID string `json:"task_id"` Question string `json:"question"` ChatID string `json:"chat_id,omitempty"` } // ExecPayload is a generic execution event payload. type ExecPayload struct { ExecutionID string `json:"execution_id"` MemberID string `json:"member_id"` TeamID string `json:"team_id"` Status string `json:"status,omitempty"` Error string `json:"error,omitempty"` ChatID string `json:"chat_id,omitempty"` } // TaskPayload is the event payload for TaskFailed / TaskCompleted events. type TaskPayload struct { ExecutionID string `json:"execution_id"` MemberID string `json:"member_id"` TeamID string `json:"team_id"` TaskID string `json:"task_id"` Error string `json:"error,omitempty"` ChatID string `json:"chat_id,omitempty"` } // DeliveryPayload is the event payload for Delivery events. type DeliveryPayload struct { ExecutionID string `json:"execution_id"` MemberID string `json:"member_id"` TeamID string `json:"team_id"` ChatID string `json:"chat_id,omitempty"` Content *robottypes.DeliveryContent `json:"content,omitempty"` Preferences *robottypes.DeliveryPreferences `json:"preferences,omitempty"` Extra map[string]any `json:"extra,omitempty"` } // MessagePayload is the event payload for Message events (external channel messages). type MessagePayload struct { RobotID string `json:"robot_id"` Messages []agentcontext.Message `json:"messages"` Metadata *MessageMetadata `json:"metadata"` } // MessageMetadata carries channel-specific information for routing and deduplication. type MessageMetadata struct { Channel string `json:"channel"` MessageID string `json:"message_id,omitempty"` AppID string `json:"app_id,omitempty"` ChatID string `json:"chat_id,omitempty"` SenderID string `json:"sender_id,omitempty"` SenderName string `json:"sender_name,omitempty"` Locale string `json:"locale,omitempty"` ReplyTo string `json:"reply_to,omitempty"` Extra map[string]any `json:"extra,omitempty"` } // MessageResult is the result returned from handleMessage via event.Call. type MessageResult struct { Message *agentcontext.Message `json:"message,omitempty"` Action *ActionResult `json:"action,omitempty"` ExecutionID string `json:"execution_id,omitempty"` Metadata *MessageMetadata `json:"metadata,omitempty"` } // ActionResult describes a detected action from the Host Agent's Next hook. type ActionResult struct { Name string `json:"name"` Payload any `json:"payload,omitempty"` } // RobotConfigPayload is the event payload for robot.config.* events. type RobotConfigPayload struct { MemberID string `json:"member_id"` TeamID string `json:"team_id"` } // NormalizeLocale converts various language code formats (IETF BCP 47, etc.) // into the lowercase hyphenated form used by agentcontext (e.g. "zh-cn", "en-us"). // // Mapping rules: // // "zh-hans", "zh-cn" → "zh-cn" // "zh-hant", "zh-tw", "zh-hk" → "zh-tw" // "zh" → "zh-cn" // "en-us" → "en-us" // "en-gb" → "en-gb" // "en" → "en" // "" → "en" (default) // other → lowercased as-is func NormalizeLocale(raw string) string { code := strings.ToLower(strings.TrimSpace(raw)) if code == "" { return "en" } // Normalize underscore to hyphen (e.g. zh_CN → zh-cn) code = strings.ReplaceAll(code, "_", "-") switch code { case "zh-hans", "zh-cn": return "zh-cn" case "zh-hant", "zh-tw", "zh-hk": return "zh-tw" case "zh": return "zh-cn" default: return code } } ================================================ FILE: agent/robot/events/events_test.go ================================================ package events import ( "encoding/json" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" robottypes "github.com/yaoapp/yao/agent/robot/types" ) func TestEventConstants(t *testing.T) { expected := map[string]string{ "TaskNeedInput": "robot.task.need_input", "TaskFailed": "robot.task.failed", "TaskCompleted": "robot.task.completed", "ExecWaiting": "robot.exec.waiting", "ExecResumed": "robot.exec.resumed", "ExecCompleted": "robot.exec.completed", "ExecFailed": "robot.exec.failed", "ExecCancelled": "robot.exec.cancelled", "Delivery": "robot.delivery", } actual := map[string]string{ "TaskNeedInput": TaskNeedInput, "TaskFailed": TaskFailed, "TaskCompleted": TaskCompleted, "ExecWaiting": ExecWaiting, "ExecResumed": ExecResumed, "ExecCompleted": ExecCompleted, "ExecFailed": ExecFailed, "ExecCancelled": ExecCancelled, "Delivery": Delivery, } for name, exp := range expected { assert.Equal(t, exp, actual[name], "Event constant %s mismatch", name) } assert.Len(t, actual, 9, "Expected exactly 9 event constants") } func TestNeedInputPayloadMarshalling(t *testing.T) { payload := NeedInputPayload{ ExecutionID: "exec-123", MemberID: "member-1", TeamID: "team-1", TaskID: "task-5", Question: "What date range?", ChatID: "chat-abc", } data, err := json.Marshal(payload) require.NoError(t, err) var parsed NeedInputPayload err = json.Unmarshal(data, &parsed) require.NoError(t, err) assert.Equal(t, payload, parsed) } func TestTaskPayloadMarshalling(t *testing.T) { t.Run("with error", func(t *testing.T) { payload := TaskPayload{ ExecutionID: "exec-1", MemberID: "member-1", TeamID: "team-1", TaskID: "task-1", Error: "timeout", ChatID: "chat-1", } data, err := json.Marshal(payload) require.NoError(t, err) var parsed TaskPayload err = json.Unmarshal(data, &parsed) require.NoError(t, err) assert.Equal(t, payload, parsed) }) t.Run("without error", func(t *testing.T) { payload := TaskPayload{ ExecutionID: "exec-2", MemberID: "member-2", TeamID: "team-2", TaskID: "task-2", } data, err := json.Marshal(payload) require.NoError(t, err) var parsed TaskPayload err = json.Unmarshal(data, &parsed) require.NoError(t, err) assert.Equal(t, payload, parsed) assert.Empty(t, parsed.Error) }) } func TestExecPayloadMarshalling(t *testing.T) { payload := ExecPayload{ ExecutionID: "exec-100", MemberID: "member-10", TeamID: "team-10", Status: "completed", ChatID: "chat-100", } data, err := json.Marshal(payload) require.NoError(t, err) var parsed ExecPayload err = json.Unmarshal(data, &parsed) require.NoError(t, err) assert.Equal(t, payload, parsed) } func TestDeliveryPayloadMarshalling(t *testing.T) { payload := DeliveryPayload{ ExecutionID: "exec-d1", MemberID: "member-d1", TeamID: "team-d1", ChatID: "chat-d1", Content: &robottypes.DeliveryContent{ Summary: "done", Body: "full report", }, Preferences: &robottypes.DeliveryPreferences{ Email: &robottypes.EmailPreference{ Enabled: true, Targets: []robottypes.EmailTarget{{To: []string{"a@b.com"}}}, }, }, } data, err := json.Marshal(payload) require.NoError(t, err) var parsed DeliveryPayload err = json.Unmarshal(data, &parsed) require.NoError(t, err) assert.Equal(t, "exec-d1", parsed.ExecutionID) assert.Equal(t, "member-d1", parsed.MemberID) assert.NotNil(t, parsed.Content) assert.Equal(t, "done", parsed.Content.Summary) assert.NotNil(t, parsed.Preferences) assert.NotNil(t, parsed.Preferences.Email) } ================================================ FILE: agent/robot/events/handlers.go ================================================ package events import ( "context" "net/http" "time" "github.com/yaoapp/yao/event" eventtypes "github.com/yaoapp/yao/event/types" ) func init() { event.Register("robot", &robotHandler{ httpClient: &http.Client{Timeout: 30 * time.Second}, }) } // robotHandler processes all robot.* events. type robotHandler struct { httpClient *http.Client } // Handle dispatches robot events by type. func (h *robotHandler) Handle(ctx context.Context, ev *eventtypes.Event, resp chan<- eventtypes.Result) { switch ev.Type { case Delivery: h.handleDelivery(ctx, ev, resp) case Message: h.handleMessage(ctx, ev, resp) default: log.Debug("robot handler: unhandled event type=%s id=%s", ev.Type, ev.ID) } } // Shutdown gracefully shuts down the robot handler. func (h *robotHandler) Shutdown(ctx context.Context) error { return nil } ================================================ FILE: agent/robot/events/handlers_test.go ================================================ package events import ( "context" "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" robottypes "github.com/yaoapp/yao/agent/robot/types" eventtypes "github.com/yaoapp/yao/event/types" ) func newTestHandler() *robotHandler { return &robotHandler{ httpClient: http.DefaultClient, } } func TestRobotHandler_DeliveryWebhook(t *testing.T) { var received map[string]interface{} server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { decoder := json.NewDecoder(r.Body) _ = decoder.Decode(&received) w.WriteHeader(http.StatusOK) w.Write([]byte(`{"ok":true}`)) })) defer server.Close() handler := newTestHandler() ev := &eventtypes.Event{ Type: Delivery, ID: "test-ev-1", IsCall: true, Payload: DeliveryPayload{ ExecutionID: "exec-1", MemberID: "member-1", TeamID: "team-1", Content: &robottypes.DeliveryContent{ Summary: "test summary", Body: "test body", }, Preferences: &robottypes.DeliveryPreferences{ Webhook: &robottypes.WebhookPreference{ Enabled: true, Targets: []robottypes.WebhookTarget{ {URL: server.URL}, }, }, }, }, } resp := make(chan eventtypes.Result, 1) handler.Handle(context.Background(), ev, resp) result := <-resp require.NotNil(t, result.Data) assert.NoError(t, result.Err) data, ok := result.Data.(map[string]interface{}) require.True(t, ok) assert.Equal(t, "exec-1", data["execution_id"]) require.NotNil(t, received) assert.Equal(t, "robot.delivery", received["event"]) } func TestRobotHandler_DeliveryNoContent(t *testing.T) { handler := newTestHandler() ev := &eventtypes.Event{ Type: Delivery, ID: "test-ev-2", IsCall: true, Payload: DeliveryPayload{ ExecutionID: "exec-2", MemberID: "member-2", TeamID: "team-2", }, } resp := make(chan eventtypes.Result, 1) handler.Handle(context.Background(), ev, resp) result := <-resp assert.Equal(t, "no content", result.Data) } func TestRobotHandler_DeliveryNoPreferences(t *testing.T) { handler := newTestHandler() ev := &eventtypes.Event{ Type: Delivery, ID: "test-ev-3", IsCall: true, Payload: DeliveryPayload{ ExecutionID: "exec-3", MemberID: "member-3", TeamID: "team-3", Content: &robottypes.DeliveryContent{ Summary: "test", Body: "body", }, }, } resp := make(chan eventtypes.Result, 1) handler.Handle(context.Background(), ev, resp) result := <-resp assert.Equal(t, "no preferences, skipped", result.Data) } func TestRobotHandler_InvalidPayload(t *testing.T) { handler := newTestHandler() ev := &eventtypes.Event{ Type: Delivery, ID: "test-ev-4", IsCall: true, Payload: "invalid", } resp := make(chan eventtypes.Result, 1) handler.Handle(context.Background(), ev, resp) result := <-resp assert.Error(t, result.Err) } func TestRobotHandler_UnhandledEvent(t *testing.T) { handler := newTestHandler() ev := &eventtypes.Event{ Type: "robot.unknown", ID: "test-ev-5", } resp := make(chan eventtypes.Result, 1) handler.Handle(context.Background(), ev, resp) // Fire-and-forget, no response expected } func TestRobotHandler_Shutdown(t *testing.T) { handler := newTestHandler() err := handler.Shutdown(context.Background()) assert.NoError(t, err) } func TestVerifyHMACSignature(t *testing.T) { payload := []byte(`{"event":"robot.delivery"}`) secret := "test-secret" sig := ComputeHMACSignature(payload, secret) assert.True(t, VerifyHMACSignature(payload, secret, sig)) assert.False(t, VerifyHMACSignature(payload, "wrong-secret", sig)) } ================================================ FILE: agent/robot/events/integrations/dingtalk/dedup.go ================================================ package dingtalk import ( "sync" "time" ) const ( dedupTTL = 24 * time.Hour dedupCleanInterval = time.Hour ) type dedupStore struct { m sync.Map } func newDedupStore() *dedupStore { return &dedupStore{} } func (d *dedupStore) markSeen(key string) bool { now := time.Now().Unix() _, loaded := d.m.LoadOrStore(key, now) return !loaded } func (d *dedupStore) cleaner(stopCh <-chan struct{}) { ticker := time.NewTicker(dedupCleanInterval) defer ticker.Stop() for { select { case <-stopCh: return case <-ticker.C: cutoff := time.Now().Add(-dedupTTL).Unix() d.m.Range(func(key, value any) bool { if ts, ok := value.(int64); ok && ts < cutoff { d.m.Delete(key) } return true }) } } } ================================================ FILE: agent/robot/events/integrations/dingtalk/dingtalk.go ================================================ package dingtalk import ( "context" "sync" "github.com/yaoapp/yao/agent/robot/logger" robottypes "github.com/yaoapp/yao/agent/robot/types" dtapi "github.com/yaoapp/yao/integrations/dingtalk" ) var log = logger.New("dingtalk") // Adapter implements the integrations.Adapter interface for DingTalk. // // Architecture: // - One DingTalk Stream client per registered bot for real-time message reception // - One dedup cleaner goroutine removes expired keys every hour type Adapter struct { mu sync.RWMutex bots map[string]*botEntry // robotID -> *botEntry appIdx map[string]string // clientID -> robotID dedup *dedupStore stopCh chan struct{} } // botEntry holds the state for one robot's DingTalk integration. type botEntry struct { robotID string clientID string bot *dtapi.Bot cancelFn context.CancelFunc } // NewAdapter creates a new DingTalk adapter. func NewAdapter() *Adapter { a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } go a.dedup.cleaner(a.stopCh) return a } // Apply is called by the Dispatcher when a robot config is created or updated. func (a *Adapter) Apply(ctx context.Context, robot *robottypes.Robot) { dtConf := extractConfig(robot) log.Debug("Apply robot=%s dtConf=%v", robot.MemberID, dtConf != nil) if dtConf == nil || !dtConf.Enabled || dtConf.ClientID == "" || dtConf.ClientSecret == "" { a.removeBot(robot.MemberID) return } a.mu.Lock() defer a.mu.Unlock() if existing, ok := a.bots[robot.MemberID]; ok { if existing.clientID == dtConf.ClientID { return } a.removeBotLocked(robot.MemberID) } bot := dtapi.NewBot(dtConf.ClientID, dtConf.ClientSecret) streamCtx, streamCancel := context.WithCancel(context.Background()) entry := &botEntry{ robotID: robot.MemberID, clientID: dtConf.ClientID, bot: bot, cancelFn: streamCancel, } a.bots[robot.MemberID] = entry a.appIdx[dtConf.ClientID] = robot.MemberID go a.streamLoop(streamCtx, entry) log.Info("dingtalk adapter: registered robot=%s client=%s", robot.MemberID, dtConf.ClientID) } // Remove is called by the Dispatcher when a robot is deleted. func (a *Adapter) Remove(ctx context.Context, robotID string) { a.removeBot(robotID) } // Shutdown stops all stream connections and dedup cleaner. func (a *Adapter) Shutdown() { close(a.stopCh) a.mu.Lock() for _, entry := range a.bots { if entry.cancelFn != nil { entry.cancelFn() } } a.mu.Unlock() log.Info("dingtalk adapter: shutdown complete") } func (a *Adapter) removeBot(robotID string) { a.mu.Lock() defer a.mu.Unlock() a.removeBotLocked(robotID) } func (a *Adapter) removeBotLocked(robotID string) { entry, ok := a.bots[robotID] if !ok { return } if entry.cancelFn != nil { entry.cancelFn() } if entry.clientID != "" { delete(a.appIdx, entry.clientID) } delete(a.bots, robotID) log.Info("dingtalk adapter: unregistered robot=%s", robotID) } func (a *Adapter) resolveByClientID(clientID string) (*botEntry, bool) { a.mu.RLock() defer a.mu.RUnlock() robotID, ok := a.appIdx[clientID] if !ok { return nil, false } entry, ok := a.bots[robotID] return entry, ok } func extractConfig(robot *robottypes.Robot) *robottypes.DingTalkConfig { if robot.Config == nil || robot.Config.Integrations == nil { return nil } return robot.Config.Integrations.DingTalk } ================================================ FILE: agent/robot/events/integrations/dingtalk/e2e_test.go ================================================ package dingtalk import ( "context" "os" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" robottypes "github.com/yaoapp/yao/agent/robot/types" dtapi "github.com/yaoapp/yao/integrations/dingtalk" ) var ( dtClientID string dtClientSecret string ) func TestMain(m *testing.M) { dtClientID = os.Getenv("DINGTALK_TEST_CLIENT_ID") dtClientSecret = os.Getenv("DINGTALK_TEST_CLIENT_SECRET") os.Exit(m.Run()) } func skipIfNoCreds(t *testing.T) { t.Helper() if dtClientID == "" || dtClientSecret == "" { t.Skip("DINGTALK_TEST_CLIENT_ID or DINGTALK_TEST_CLIENT_SECRET not set") } } // TestE2E_Adapter_Apply verifies that Apply correctly registers a bot. func TestE2E_Adapter_Apply(t *testing.T) { skipIfNoCreds(t) a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) robot := &robottypes.Robot{ MemberID: "robot_e2e_dt_adapter", TeamID: "team_e2e_dt", Config: &robottypes.Config{ Integrations: &robottypes.Integrations{ DingTalk: &robottypes.DingTalkConfig{ Enabled: true, ClientID: dtClientID, ClientSecret: dtClientSecret, }, }, }, } a.Apply(context.Background(), robot) a.mu.RLock() entry, ok := a.bots["robot_e2e_dt_adapter"] a.mu.RUnlock() require.True(t, ok, "bot should be registered") assert.Equal(t, dtClientID, entry.clientID) assert.NotNil(t, entry.bot) t.Logf("OK Apply: dingtalk bot registered robot=%s client=%s", robot.MemberID, entry.clientID) } // TestE2E_Adapter_Apply_Update verifies re-Apply with same clientID is a no-op. func TestE2E_Adapter_Apply_Update(t *testing.T) { skipIfNoCreds(t) a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) robot := &robottypes.Robot{ MemberID: "robot_e2e_dt_update", TeamID: "team_e2e_dt", Config: &robottypes.Config{ Integrations: &robottypes.Integrations{ DingTalk: &robottypes.DingTalkConfig{ Enabled: true, ClientID: dtClientID, ClientSecret: dtClientSecret, }, }, }, } a.Apply(context.Background(), robot) a.mu.RLock() _, ok := a.bots["robot_e2e_dt_update"] a.mu.RUnlock() require.True(t, ok) a.Apply(context.Background(), robot) a.mu.RLock() assert.Len(t, a.bots, 1) a.mu.RUnlock() a.Remove(context.Background(), "robot_e2e_dt_update") a.mu.RLock() _, ok = a.bots["robot_e2e_dt_update"] a.mu.RUnlock() assert.False(t, ok, "bot should be removed") t.Log("OK Apply/Remove lifecycle verified") } // TestE2E_Adapter_Dedup verifies deduplication works. func TestE2E_Adapter_Dedup(t *testing.T) { a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) key := "dt:test-robot:msg-12345" assert.True(t, a.dedup.markSeen(key), "first time should return true") assert.False(t, a.dedup.markSeen(key), "second time should return false (dedup)") t.Log("OK dedup working correctly") } // TestE2E_Adapter_HandleMessages verifies message handling through the adapter. func TestE2E_Adapter_HandleMessages(t *testing.T) { skipIfNoCreds(t) a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) entry := &botEntry{ robotID: "robot_e2e_dt_handle", clientID: dtClientID, bot: dtapi.NewBot(dtClientID, dtClientSecret), } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() cms := []*dtapi.ConvertedMessage{ { MessageID: "test_msg_1", ConversationID: "test_conv_1", ConversationType: "1", SenderID: "test_sender_1", SenderNick: "Test User", Text: "Hello from E2E test", SessionWebhook: "https://oapi.dingtalk.com/robot/sendBySession/xxx", }, } a.handleMessages(ctx, entry, cms) assert.False(t, a.dedup.markSeen("dt:robot_e2e_dt_handle:test_msg_1"), "message should be marked as seen after handleMessages") t.Log("OK handleMessages processed 1 message") } // TestE2E_Adapter_ApplyDisabled verifies Apply removes bot when disabled. func TestE2E_Adapter_ApplyDisabled(t *testing.T) { a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) robot := &robottypes.Robot{ MemberID: "robot_e2e_dt_disabled", TeamID: "team_e2e_dt", Config: &robottypes.Config{ Integrations: &robottypes.Integrations{ DingTalk: &robottypes.DingTalkConfig{ Enabled: false, ClientID: "some_id", ClientSecret: "some_secret", }, }, }, } a.Apply(context.Background(), robot) a.mu.RLock() _, ok := a.bots["robot_e2e_dt_disabled"] a.mu.RUnlock() assert.False(t, ok, "disabled bot should not be registered") t.Log("OK disabled config not registered") } // TestE2E_Adapter_GetAccessToken verifies real DingTalk credentials work. func TestE2E_Adapter_GetAccessToken(t *testing.T) { skipIfNoCreds(t) b := dtapi.NewBot(dtClientID, dtClientSecret) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() token, err := b.GetAccessToken(ctx) require.NoError(t, err) assert.NotEmpty(t, token) t.Logf("OK DingTalk access token obtained, len=%d", len(token)) } ================================================ FILE: agent/robot/events/integrations/dingtalk/message.go ================================================ package dingtalk import ( "context" "fmt" "strings" agentcontext "github.com/yaoapp/yao/agent/context" events "github.com/yaoapp/yao/agent/robot/events" "github.com/yaoapp/yao/event" dtapi "github.com/yaoapp/yao/integrations/dingtalk" ) // handleMessages processes a batch of DingTalk messages. func (a *Adapter) handleMessages(ctx context.Context, entry *botEntry, cms []*dtapi.ConvertedMessage) { if len(cms) == 0 { return } var allParts []interface{} var lastCM *dtapi.ConvertedMessage for _, cm := range cms { if cm == nil { continue } dedupKey := fmt.Sprintf("dt:%s:%s", entry.robotID, cm.MessageID) if !a.dedup.markSeen(dedupKey) { continue } parts := buildContentParts(cm) if len(parts) == 0 { continue } allParts = append(allParts, parts...) lastCM = cm } if len(allParts) == 0 || lastCM == nil { return } content := mergeContentParts(allParts) msgPayload := events.MessagePayload{ RobotID: entry.robotID, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: content}, }, Metadata: &events.MessageMetadata{ Channel: "dingtalk", MessageID: lastCM.MessageID, AppID: entry.clientID, ChatID: lastCM.ConversationID, SenderID: lastCM.SenderID, SenderName: lastCM.SenderNick, Locale: "zh-cn", Extra: map[string]any{ "session_webhook": lastCM.SessionWebhook, "conversation_type": lastCM.ConversationType, "dt_message_id": lastCM.MessageID, }, }, } if _, err := event.Push(ctx, events.Message, msgPayload); err != nil { log.Error("dingtalk adapter: event.Push robot.message failed robot=%s: %v", entry.robotID, err) } } func buildContentParts(cm *dtapi.ConvertedMessage) []interface{} { var parts []interface{} if cm.HasText() { parts = append(parts, map[string]interface{}{ "type": "text", "text": cm.Text, }) } for _, mi := range cm.MediaItems { if mi.Wrapper == "" && mi.URL == "" { continue } url := mi.Wrapper if url == "" { url = mi.URL } parts = append(parts, map[string]interface{}{ "type": "file", "file_url": url, "mime_type": mi.MimeType, "file_name": mi.FileName, }) } return parts } func mergeContentParts(parts []interface{}) interface{} { allText := true for _, p := range parts { m, ok := p.(map[string]interface{}) if !ok || m["type"] != "text" { allText = false break } } if allText { var buf strings.Builder for i, p := range parts { if i > 0 { buf.WriteString("\n") } m := p.(map[string]interface{}) buf.WriteString(m["text"].(string)) } return buf.String() } return parts } ================================================ FILE: agent/robot/events/integrations/dingtalk/reply.go ================================================ package dingtalk import ( "context" "fmt" "strings" agentcontext "github.com/yaoapp/yao/agent/context" events "github.com/yaoapp/yao/agent/robot/events" dtapi "github.com/yaoapp/yao/integrations/dingtalk" ) // Reply sends the assistant message back to the originating DingTalk conversation. func (a *Adapter) Reply(ctx context.Context, msg *agentcontext.Message, metadata *events.MessageMetadata) error { if msg == nil || metadata == nil { return fmt.Errorf("nil message or metadata") } var sessionWebhook string if metadata.Extra != nil { if v, ok := metadata.Extra["session_webhook"]; ok { if s, ok := v.(string); ok { sessionWebhook = s } } } if sessionWebhook == "" { return fmt.Errorf("no session_webhook in metadata for dingtalk reply") } return sendContent(ctx, sessionWebhook, msg.Content) } func sendContent(ctx context.Context, sessionWebhook string, content interface{}) error { switch c := content.(type) { case string: if strings.TrimSpace(c) == "" { return nil } return dtapi.SendMarkdownMessage(ctx, sessionWebhook, "Reply", dtapi.FormatDingTalkMarkdown(c)) case []interface{}: return sendParts(ctx, sessionWebhook, c) default: parts, ok := toContentParts(content) if ok { return sendPartsTyped(ctx, sessionWebhook, parts) } return dtapi.SendTextMessage(ctx, sessionWebhook, fmt.Sprintf("%v", content)) } } func sendParts(ctx context.Context, sessionWebhook string, parts []interface{}) error { var textBuf strings.Builder for _, part := range parts { m, ok := part.(map[string]interface{}) if !ok { continue } partType, _ := m["type"].(string) switch partType { case "text": if text, ok := m["text"].(string); ok { textBuf.WriteString(text) } case "image_url": if err := flushText(ctx, sessionWebhook, &textBuf); err != nil { return err } if imgMap, ok := m["image_url"].(map[string]interface{}); ok { if url, ok := imgMap["url"].(string); ok { if strings.HasPrefix(url, "http") { textBuf.WriteString(fmt.Sprintf("\n![image](%s)\n", url)) } } } case "file": if err := flushText(ctx, sessionWebhook, &textBuf); err != nil { return err } fileURL, _ := m["file_url"].(string) fileName, _ := m["file_name"].(string) if fileURL == "" { if fileMap, ok := m["file"].(map[string]interface{}); ok { fileURL, _ = fileMap["url"].(string) if fn, ok := fileMap["filename"].(string); ok && fn != "" { fileName = fn } } } if fileURL != "" && strings.HasPrefix(fileURL, "http") { label := fileName if label == "" { label = "file" } textBuf.WriteString(fmt.Sprintf("\n[%s](%s)\n", label, fileURL)) } } } return flushText(ctx, sessionWebhook, &textBuf) } func sendPartsTyped(ctx context.Context, sessionWebhook string, parts []agentcontext.ContentPart) error { var textBuf strings.Builder for _, part := range parts { switch part.Type { case agentcontext.ContentText: textBuf.WriteString(part.Text) case agentcontext.ContentImageURL: if err := flushText(ctx, sessionWebhook, &textBuf); err != nil { return err } if part.ImageURL != nil && strings.HasPrefix(part.ImageURL.URL, "http") { textBuf.WriteString(fmt.Sprintf("\n![image](%s)\n", part.ImageURL.URL)) } case agentcontext.ContentFile: if err := flushText(ctx, sessionWebhook, &textBuf); err != nil { return err } if part.File != nil && part.File.URL != "" && strings.HasPrefix(part.File.URL, "http") { label := part.File.Filename if label == "" { label = "file" } textBuf.WriteString(fmt.Sprintf("\n[%s](%s)\n", label, part.File.URL)) } } } return flushText(ctx, sessionWebhook, &textBuf) } func flushText(ctx context.Context, sessionWebhook string, buf *strings.Builder) error { if buf.Len() == 0 { return nil } text := buf.String() buf.Reset() return dtapi.SendMarkdownMessage(ctx, sessionWebhook, "Reply", dtapi.FormatDingTalkMarkdown(text)) } func toContentParts(content interface{}) ([]agentcontext.ContentPart, bool) { parts, ok := content.([]agentcontext.ContentPart) return parts, ok } ================================================ FILE: agent/robot/events/integrations/dingtalk/stream.go ================================================ package dingtalk import ( "context" "strings" "time" dingstream "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" dingclient "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" dtapi "github.com/yaoapp/yao/integrations/dingtalk" ) const reconnectDelay = 5 * time.Second // streamLoop starts the DingTalk Stream client for a single bot. // It automatically reconnects on failure. func (a *Adapter) streamLoop(ctx context.Context, entry *botEntry) { log.Info("dingtalk streamLoop started robot=%s client=%s", entry.robotID, entry.clientID) for { select { case <-ctx.Done(): log.Info("dingtalk streamLoop stopped robot=%s", entry.robotID) return case <-a.stopCh: return default: } err := a.runStreamClient(ctx, entry) if err != nil { log.Error("dingtalk stream disconnected robot=%s: %v, reconnecting in %s", entry.robotID, err, reconnectDelay) } select { case <-ctx.Done(): return case <-a.stopCh: return case <-time.After(reconnectDelay): } } } func (a *Adapter) runStreamClient(ctx context.Context, entry *botEntry) error { cli := dingclient.NewStreamClient( dingclient.WithAppCredential(dingclient.NewAppCredentialConfig(entry.clientID, entry.bot.ClientSecret())), ) cli.RegisterChatBotCallbackRouter(func(c context.Context, data *dingstream.BotCallbackDataModel) ([]byte, error) { return a.onBotCallback(c, entry, data) }) errCh := make(chan error, 1) go func() { errCh <- cli.Start(ctx) }() select { case <-ctx.Done(): return ctx.Err() case <-a.stopCh: return nil case err := <-errCh: return err } } func (a *Adapter) onBotCallback(ctx context.Context, entry *botEntry, data *dingstream.BotCallbackDataModel) ([]byte, error) { if data == nil { return nil, nil } cm := &dtapi.ConvertedMessage{ MessageID: data.MsgId, ConversationID: data.ConversationId, ConversationType: data.ConversationType, SenderID: data.SenderId, SenderNick: data.SenderNick, SenderStaffID: data.SenderStaffId, ChatbotUserID: data.ChatbotUserId, IsInAtList: data.IsInAtList, SessionWebhook: data.SessionWebhook, } switch data.Msgtype { case "text": cm.Text = strings.TrimSpace(data.Text.Content) } if cm.HasMedia() { groups := []string{"dingtalk", entry.robotID} dtapi.ResolveMedia(ctx, cm, groups) } a.handleMessages(ctx, entry, []*dtapi.ConvertedMessage{cm}) return nil, nil } ================================================ FILE: agent/robot/events/integrations/discord/dedup.go ================================================ package discord import ( "sync" "time" ) const ( dedupTTL = 24 * time.Hour dedupCleanInterval = time.Hour ) type dedupStore struct { m sync.Map } func newDedupStore() *dedupStore { return &dedupStore{} } func (d *dedupStore) markSeen(key string) bool { now := time.Now().Unix() _, loaded := d.m.LoadOrStore(key, now) return !loaded } func (d *dedupStore) cleaner(stopCh <-chan struct{}) { ticker := time.NewTicker(dedupCleanInterval) defer ticker.Stop() for { select { case <-stopCh: return case <-ticker.C: cutoff := time.Now().Add(-dedupTTL).Unix() d.m.Range(func(key, value any) bool { if ts, ok := value.(int64); ok && ts < cutoff { d.m.Delete(key) } return true }) } } } ================================================ FILE: agent/robot/events/integrations/discord/discord.go ================================================ package discord import ( "context" "sync" "github.com/yaoapp/yao/agent/robot/logger" robottypes "github.com/yaoapp/yao/agent/robot/types" dcapi "github.com/yaoapp/yao/integrations/discord" ) var log = logger.New("discord") // Adapter implements the integrations.Adapter interface for Discord. // // Architecture: // - One WebSocket Gateway connection per registered bot via discordgo // - One dedup cleaner goroutine removes expired keys every hour type Adapter struct { mu sync.RWMutex bots map[string]*botEntry // robotID -> *botEntry appIdx map[string]string // appID -> robotID dedup *dedupStore stopCh chan struct{} } // botEntry holds the state for one robot's Discord integration. type botEntry struct { robotID string appID string bot *dcapi.Bot cancelFn context.CancelFunc } // NewAdapter creates a new Discord adapter. func NewAdapter() *Adapter { a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } go a.dedup.cleaner(a.stopCh) return a } // Apply is called by the Dispatcher when a robot config is created or updated. func (a *Adapter) Apply(ctx context.Context, robot *robottypes.Robot) { dcConf := extractConfig(robot) log.Debug("Apply robot=%s dcConf=%v", robot.MemberID, dcConf != nil) if dcConf == nil || !dcConf.Enabled || dcConf.BotToken == "" { a.removeBot(robot.MemberID) return } a.mu.Lock() defer a.mu.Unlock() if existing, ok := a.bots[robot.MemberID]; ok { if existing.bot.Token() == dcConf.BotToken { return } a.removeBotLocked(robot.MemberID) } bot, err := dcapi.NewBot(dcConf.BotToken, dcConf.AppID) if err != nil { log.Error("discord adapter: create bot failed robot=%s: %v", robot.MemberID, err) return } gwCtx, gwCancel := context.WithCancel(context.Background()) entry := &botEntry{ robotID: robot.MemberID, appID: dcConf.AppID, bot: bot, cancelFn: gwCancel, } a.bots[robot.MemberID] = entry if dcConf.AppID != "" { a.appIdx[dcConf.AppID] = robot.MemberID } go a.gatewayLoop(gwCtx, entry) log.Info("discord adapter: registered robot=%s app=%s", robot.MemberID, dcConf.AppID) } // Remove is called by the Dispatcher when a robot is deleted. func (a *Adapter) Remove(ctx context.Context, robotID string) { a.removeBot(robotID) } // Shutdown stops all gateway connections and dedup cleaner. func (a *Adapter) Shutdown() { close(a.stopCh) a.mu.Lock() for _, entry := range a.bots { if entry.cancelFn != nil { entry.cancelFn() } if entry.bot != nil && entry.bot.Session() != nil { entry.bot.Session().Close() } } a.mu.Unlock() log.Info("discord adapter: shutdown complete") } func (a *Adapter) removeBot(robotID string) { a.mu.Lock() defer a.mu.Unlock() a.removeBotLocked(robotID) } func (a *Adapter) removeBotLocked(robotID string) { entry, ok := a.bots[robotID] if !ok { return } if entry.cancelFn != nil { entry.cancelFn() } if entry.bot != nil && entry.bot.Session() != nil { entry.bot.Session().Close() } if entry.appID != "" { delete(a.appIdx, entry.appID) } delete(a.bots, robotID) log.Info("discord adapter: unregistered robot=%s", robotID) } func (a *Adapter) resolveByAppID(appID string) (*botEntry, bool) { a.mu.RLock() defer a.mu.RUnlock() robotID, ok := a.appIdx[appID] if !ok { return nil, false } entry, ok := a.bots[robotID] return entry, ok } func extractConfig(robot *robottypes.Robot) *robottypes.DiscordConfig { if robot.Config == nil || robot.Config.Integrations == nil { return nil } return robot.Config.Integrations.Discord } ================================================ FILE: agent/robot/events/integrations/discord/e2e_test.go ================================================ package discord import ( "context" "os" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" robottypes "github.com/yaoapp/yao/agent/robot/types" dcapi "github.com/yaoapp/yao/integrations/discord" ) var ( dcBotToken string dcAppID string ) func TestMain(m *testing.M) { dcBotToken = os.Getenv("DISCORD_TEST_BOT_TOKEN") dcAppID = os.Getenv("DISCORD_TEST_APP_ID") os.Exit(m.Run()) } func skipIfNoToken(t *testing.T) { t.Helper() if dcBotToken == "" { t.Skip("DISCORD_TEST_BOT_TOKEN not set") } } // TestE2E_Adapter_Apply verifies that Apply correctly registers a bot. func TestE2E_Adapter_Apply(t *testing.T) { skipIfNoToken(t) a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) robot := &robottypes.Robot{ MemberID: "robot_e2e_dc_adapter", TeamID: "team_e2e_dc", Config: &robottypes.Config{ Integrations: &robottypes.Integrations{ Discord: &robottypes.DiscordConfig{ Enabled: true, BotToken: dcBotToken, AppID: dcAppID, }, }, }, } a.Apply(context.Background(), robot) a.mu.RLock() entry, ok := a.bots["robot_e2e_dc_adapter"] a.mu.RUnlock() require.True(t, ok, "bot should be registered") assert.Equal(t, dcBotToken, entry.bot.Token()) assert.Equal(t, dcAppID, entry.appID) t.Logf("OK Apply: discord bot registered robot=%s app=%s", robot.MemberID, entry.appID) } // TestE2E_Adapter_Apply_Update verifies re-Apply with same token is a no-op. func TestE2E_Adapter_Apply_Update(t *testing.T) { skipIfNoToken(t) a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) robot := &robottypes.Robot{ MemberID: "robot_e2e_dc_update", TeamID: "team_e2e_dc", Config: &robottypes.Config{ Integrations: &robottypes.Integrations{ Discord: &robottypes.DiscordConfig{ Enabled: true, BotToken: dcBotToken, AppID: dcAppID, }, }, }, } a.Apply(context.Background(), robot) a.mu.RLock() _, ok := a.bots["robot_e2e_dc_update"] a.mu.RUnlock() require.True(t, ok) a.Apply(context.Background(), robot) a.mu.RLock() assert.Len(t, a.bots, 1) a.mu.RUnlock() a.Remove(context.Background(), "robot_e2e_dc_update") a.mu.RLock() _, ok = a.bots["robot_e2e_dc_update"] a.mu.RUnlock() assert.False(t, ok, "bot should be removed") t.Log("OK Apply/Remove lifecycle verified") } // TestE2E_Adapter_Dedup verifies deduplication works. func TestE2E_Adapter_Dedup(t *testing.T) { a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) key := "dc:test-robot:msg-12345" assert.True(t, a.dedup.markSeen(key), "first time should return true") assert.False(t, a.dedup.markSeen(key), "second time should return false (dedup)") t.Log("OK dedup working correctly") } // TestE2E_Adapter_HandleMessages verifies message handling. func TestE2E_Adapter_HandleMessages(t *testing.T) { skipIfNoToken(t) bot, err := dcapi.NewBot(dcBotToken, dcAppID) require.NoError(t, err) a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) entry := &botEntry{ robotID: "robot_e2e_dc_handle", appID: dcAppID, bot: bot, } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() cms := []*dcapi.ConvertedMessage{ { MessageID: "test_msg_1", ChannelID: "test_ch_1", AuthorID: "test_user_1", AuthorName: "TestUser", Text: "Hello from E2E test", }, } a.handleMessages(ctx, entry, cms) assert.False(t, a.dedup.markSeen("dc:robot_e2e_dc_handle:test_msg_1"), "message should be marked as seen after handleMessages") t.Log("OK handleMessages processed 1 message") } // TestE2E_Adapter_ApplyDisabled verifies Apply removes bot when disabled. func TestE2E_Adapter_ApplyDisabled(t *testing.T) { a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) robot := &robottypes.Robot{ MemberID: "robot_e2e_dc_disabled", TeamID: "team_e2e_dc", Config: &robottypes.Config{ Integrations: &robottypes.Integrations{ Discord: &robottypes.DiscordConfig{ Enabled: false, BotToken: "some_token", }, }, }, } a.Apply(context.Background(), robot) a.mu.RLock() _, ok := a.bots["robot_e2e_dc_disabled"] a.mu.RUnlock() assert.False(t, ok, "disabled bot should not be registered") t.Log("OK disabled config not registered") } // TestE2E_BotUser verifies real Discord credentials. func TestE2E_BotUser(t *testing.T) { skipIfNoToken(t) bot, err := dcapi.NewBot(dcBotToken, dcAppID) require.NoError(t, err) user, err := bot.BotUser() require.NoError(t, err) assert.NotEmpty(t, user.ID) assert.NotEmpty(t, user.Username) assert.True(t, user.Bot) t.Logf("OK Discord bot verified: id=%s username=%s", user.ID, user.Username) } ================================================ FILE: agent/robot/events/integrations/discord/gateway.go ================================================ package discord import ( "context" "time" "github.com/bwmarrin/discordgo" dcapi "github.com/yaoapp/yao/integrations/discord" ) const reconnectDelay = 5 * time.Second // gatewayLoop starts the Discord WebSocket Gateway for a single bot. // It automatically reconnects on failure. func (a *Adapter) gatewayLoop(ctx context.Context, entry *botEntry) { log.Info("discord gatewayLoop started robot=%s app=%s", entry.robotID, entry.appID) for { select { case <-ctx.Done(): log.Info("discord gatewayLoop stopped robot=%s", entry.robotID) return case <-a.stopCh: return default: } err := a.runGateway(ctx, entry) if err != nil { log.Error("discord gateway disconnected robot=%s: %v, reconnecting in %s", entry.robotID, err, reconnectDelay) } select { case <-ctx.Done(): return case <-a.stopCh: return case <-time.After(reconnectDelay): } } } func (a *Adapter) runGateway(ctx context.Context, entry *botEntry) error { session := entry.bot.Session() session.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) { a.onMessageCreate(ctx, entry, m) }) if err := session.Open(); err != nil { return err } // Block until context is cancelled or stop signal select { case <-ctx.Done(): case <-a.stopCh: } return session.Close() } func (a *Adapter) onMessageCreate(ctx context.Context, entry *botEntry, m *discordgo.MessageCreate) { if m == nil || m.Author == nil { return } // Ignore bot's own messages if m.Author.Bot { return } cm := dcapi.ConvertMessageCreate(m) if cm == nil { return } if cm.HasMedia() { groups := []string{"discord", entry.robotID} dcapi.ResolveMedia(ctx, cm, groups) } a.handleMessages(ctx, entry, []*dcapi.ConvertedMessage{cm}) } ================================================ FILE: agent/robot/events/integrations/discord/message.go ================================================ package discord import ( "context" "fmt" "strings" agentcontext "github.com/yaoapp/yao/agent/context" events "github.com/yaoapp/yao/agent/robot/events" "github.com/yaoapp/yao/event" dcapi "github.com/yaoapp/yao/integrations/discord" ) // handleMessages processes a batch of Discord messages. func (a *Adapter) handleMessages(ctx context.Context, entry *botEntry, cms []*dcapi.ConvertedMessage) { if len(cms) == 0 { return } var allParts []interface{} var lastCM *dcapi.ConvertedMessage for _, cm := range cms { if cm == nil { continue } // Skip bot commands (messages starting with /) if strings.HasPrefix(strings.TrimSpace(cm.Text), "/") && !cm.HasMedia() { continue } dedupKey := fmt.Sprintf("dc:%s:%s", entry.robotID, cm.MessageID) if !a.dedup.markSeen(dedupKey) { continue } parts := buildContentParts(cm) if len(parts) == 0 { continue } allParts = append(allParts, parts...) lastCM = cm } if len(allParts) == 0 || lastCM == nil { return } content := mergeContentParts(allParts) msgPayload := events.MessagePayload{ RobotID: entry.robotID, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: content}, }, Metadata: &events.MessageMetadata{ Channel: "discord", MessageID: lastCM.MessageID, AppID: entry.appID, ChatID: lastCM.ChannelID, SenderID: lastCM.AuthorID, SenderName: lastCM.AuthorName, Locale: events.NormalizeLocale(discordLocale(lastCM.Locale)), Extra: map[string]any{ "discord_message_id": lastCM.MessageID, "guild_id": lastCM.GuildID, "is_dm": lastCM.IsDM, }, }, } if _, err := event.Push(ctx, events.Message, msgPayload); err != nil { log.Error("discord adapter: event.Push robot.message failed robot=%s: %v", entry.robotID, err) } } func buildContentParts(cm *dcapi.ConvertedMessage) []interface{} { var parts []interface{} if cm.HasText() { parts = append(parts, map[string]interface{}{ "type": "text", "text": cm.Text, }) } for _, mi := range cm.MediaItems { url := mi.Wrapper if url == "" { url = mi.URL } if url == "" { continue } parts = append(parts, map[string]interface{}{ "type": "file", "file_url": url, "mime_type": mi.ContentType, "file_name": mi.FileName, }) } return parts } func mergeContentParts(parts []interface{}) interface{} { allText := true for _, p := range parts { m, ok := p.(map[string]interface{}) if !ok || m["type"] != "text" { allText = false break } } if allText { var buf strings.Builder for i, p := range parts { if i > 0 { buf.WriteString("\n") } m := p.(map[string]interface{}) buf.WriteString(m["text"].(string)) } return buf.String() } return parts } func discordLocale(locale string) string { if locale == "" { return "en" } return locale } ================================================ FILE: agent/robot/events/integrations/discord/reply.go ================================================ package discord import ( "context" "fmt" "strings" agentcontext "github.com/yaoapp/yao/agent/context" events "github.com/yaoapp/yao/agent/robot/events" dcapi "github.com/yaoapp/yao/integrations/discord" ) // Reply sends the assistant message back to the originating Discord channel. func (a *Adapter) Reply(ctx context.Context, msg *agentcontext.Message, metadata *events.MessageMetadata) error { if msg == nil || metadata == nil { return fmt.Errorf("nil message or metadata") } entry := a.resolveByChat(metadata) if entry == nil { return fmt.Errorf("no bot registered for discord metadata (appID=%s)", metadata.AppID) } var replyToID string if metadata.Extra != nil { if v, ok := metadata.Extra["discord_message_id"]; ok { if s, ok := v.(string); ok { replyToID = s } } } return a.sendContent(ctx, entry, metadata.ChatID, replyToID, msg.Content) } func (a *Adapter) sendContent(ctx context.Context, entry *botEntry, channelID, replyToID string, content interface{}) error { switch c := content.(type) { case string: if strings.TrimSpace(c) == "" { return nil } formatted := dcapi.FormatDiscordMarkdown(c) if replyToID != "" { _, err := entry.bot.SendMessageReply(channelID, formatted, replyToID) return err } _, err := entry.bot.SendMessage(channelID, formatted) return err case []interface{}: return a.sendParts(ctx, entry, channelID, replyToID, c) default: parts, ok := toContentParts(content) if ok { return a.sendPartsTyped(ctx, entry, channelID, replyToID, parts) } _, err := entry.bot.SendMessage(channelID, fmt.Sprintf("%v", content)) return err } } func (a *Adapter) sendParts(ctx context.Context, entry *botEntry, channelID, replyToID string, parts []interface{}) error { var textBuf strings.Builder for _, part := range parts { m, ok := part.(map[string]interface{}) if !ok { continue } partType, _ := m["type"].(string) switch partType { case "text": if text, ok := m["text"].(string); ok { textBuf.WriteString(text) } case "image_url": if err := a.flushText(entry, channelID, replyToID, &textBuf); err != nil { return err } if imgMap, ok := m["image_url"].(map[string]interface{}); ok { if url, ok := imgMap["url"].(string); ok { if err := sendFileOrWrapper(entry, channelID, url, ""); err != nil { log.Error("discord reply: send image: %v", err) } } } case "file": if err := a.flushText(entry, channelID, replyToID, &textBuf); err != nil { return err } fileURL, _ := m["file_url"].(string) if fileURL == "" { if fileMap, ok := m["file"].(map[string]interface{}); ok { fileURL, _ = fileMap["url"].(string) } } if fileURL != "" { if err := sendFileOrWrapper(entry, channelID, fileURL, ""); err != nil { log.Error("discord reply: send file: %v", err) } } } } return a.flushText(entry, channelID, replyToID, &textBuf) } func (a *Adapter) sendPartsTyped(ctx context.Context, entry *botEntry, channelID, replyToID string, parts []agentcontext.ContentPart) error { var textBuf strings.Builder for _, part := range parts { switch part.Type { case agentcontext.ContentText: textBuf.WriteString(part.Text) case agentcontext.ContentImageURL: if err := a.flushText(entry, channelID, replyToID, &textBuf); err != nil { return err } if part.ImageURL != nil { if err := sendFileOrWrapper(entry, channelID, part.ImageURL.URL, ""); err != nil { log.Error("discord reply: send image: %v", err) } } case agentcontext.ContentFile: if err := a.flushText(entry, channelID, replyToID, &textBuf); err != nil { return err } if part.File != nil { if err := sendFileOrWrapper(entry, channelID, part.File.URL, part.File.Filename); err != nil { log.Error("discord reply: send file: %v", err) } } } } return a.flushText(entry, channelID, replyToID, &textBuf) } func (a *Adapter) flushText(entry *botEntry, channelID, replyToID string, buf *strings.Builder) error { if buf.Len() == 0 { return nil } text := dcapi.FormatDiscordMarkdown(buf.String()) buf.Reset() if replyToID != "" { _, err := entry.bot.SendMessageReply(channelID, text, replyToID) return err } _, err := entry.bot.SendMessage(channelID, text) return err } func sendFileOrWrapper(entry *botEntry, channelID, url, caption string) error { if strings.Contains(url, "://") && !strings.HasPrefix(url, "http") { return entry.bot.SendMediaFromWrapper(channelID, url, caption) } if strings.HasPrefix(url, "http") { _, err := entry.bot.SendMessage(channelID, url) return err } return fmt.Errorf("unsupported file URL scheme: %s", url) } func toContentParts(content interface{}) ([]agentcontext.ContentPart, bool) { parts, ok := content.([]agentcontext.ContentPart) return parts, ok } func (a *Adapter) resolveByChat(metadata *events.MessageMetadata) *botEntry { if metadata.AppID != "" { if entry, ok := a.resolveByAppID(metadata.AppID); ok { return entry } } a.mu.RLock() defer a.mu.RUnlock() for _, entry := range a.bots { return entry } return nil } ================================================ FILE: agent/robot/events/integrations/dispatcher.go ================================================ package integrations import ( "context" "fmt" "github.com/yaoapp/gou/model" "github.com/yaoapp/kun/maps" agentcontext "github.com/yaoapp/yao/agent/context" robotcache "github.com/yaoapp/yao/agent/robot/cache" events "github.com/yaoapp/yao/agent/robot/events" "github.com/yaoapp/yao/agent/robot/logger" robottypes "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/event" eventtypes "github.com/yaoapp/yao/event/types" ) var log = logger.New("dispatcher") // Adapter is the interface each platform adapter implements. type Adapter interface { Apply(ctx context.Context, robot *robottypes.Robot) Remove(ctx context.Context, robotID string) Reply(ctx context.Context, msg *agentcontext.Message, metadata *events.MessageMetadata) error } // Dispatcher distributes Robot integration configs to platform adapters. type Dispatcher struct { robotCache *robotcache.Cache adapters map[string]Adapter // key matches Integrations field: "telegram", "discord", etc. stopCh chan struct{} subID string } // NewDispatcher creates a Dispatcher. // Each adapter has a fixed key matching the field name in robottypes.Integrations. func NewDispatcher(cache *robotcache.Cache, adapters map[string]Adapter) *Dispatcher { return &Dispatcher{ robotCache: cache, adapters: adapters, stopCh: make(chan struct{}), } } // Start loads all robots and subscribes to config change events. func (d *Dispatcher) Start(ctx context.Context) error { d.loadAll(ctx) events.RegisterReplyFunc(d.reply) ch := make(chan *eventtypes.Event, 64) d.subID = event.Subscribe("robot.config.*", ch) go d.watch(ctx, ch) log.Info("integration dispatcher: started with %d adapters", len(d.adapters)) return nil } // reply routes a reply to the correct adapter based on channel. // When channel is empty (e.g. delivery), broadcasts to all adapters. func (d *Dispatcher) reply(ctx context.Context, msg *agentcontext.Message, metadata *events.MessageMetadata) error { if metadata == nil { return fmt.Errorf("no metadata in reply") } if metadata.Channel != "" { adapter, ok := d.adapters[metadata.Channel] if !ok { return fmt.Errorf("no adapter for channel: %s", metadata.Channel) } return adapter.Reply(ctx, msg, metadata) } var lastErr error for name, adapter := range d.adapters { if err := adapter.Reply(ctx, msg, metadata); err != nil { log.Error("dispatcher reply: broadcast to %s failed: %v", name, err) lastErr = err } } return lastErr } // Stop unsubscribes from events. func (d *Dispatcher) Stop() { close(d.stopCh) if d.subID != "" { event.Unsubscribe(d.subID) } log.Info("integration dispatcher: stopped") } func (d *Dispatcher) loadAll(ctx context.Context) { robots := d.loadIntegrationRobots() for _, robot := range robots { d.robotCache.Add(robot) d.apply(ctx, robot) } log.Info("integration dispatcher: initial load complete, %d robots with integrations", len(robots)) } // loadIntegrationRobots queries all active robots that have a non-null // robot_config (which may contain integrations). This is independent of // autonomous_mode so non-autonomous robots with Telegram etc. are included. func (d *Dispatcher) loadIntegrationRobots() []*robottypes.Robot { m := model.Select("__yao.member") fields := []interface{}{ "id", "member_id", "team_id", "display_name", "bio", "system_prompt", "robot_status", "autonomous_mode", "robot_config", "robot_email", "agents", "mcp_servers", "manager_id", "language_model", } page := 1 pageSize := 100 var result []*robottypes.Robot for { res, err := m.Paginate(model.QueryParam{ Select: fields, Wheres: []model.QueryWhere{ {Column: "member_type", Value: "robot"}, {Column: "status", Value: "active"}, }, }, page, pageSize) if err != nil { log.Error("loadIntegrationRobots: query failed page=%d: %v", page, err) break } data, ok := res.Get("data").([]maps.MapStr) if !ok || len(data) == 0 { break } for _, record := range data { robot, err := robottypes.NewRobotFromMap(map[string]interface{}(record)) if err != nil { continue } if robot.Config != nil && robot.Config.Integrations != nil && len(parseIntegrations(robot.Config.Integrations)) > 0 { result = append(result, robot) } } total, _ := res.Get("total").(int) if page*pageSize >= total { break } page++ } return result } // apply parses which integrations the robot has configured, // and calls the matching adapter for each one. func (d *Dispatcher) apply(ctx context.Context, robot *robottypes.Robot) { if robot.Config == nil || robot.Config.Integrations == nil { return } for _, key := range parseIntegrations(robot.Config.Integrations) { if adapter, ok := d.adapters[key]; ok { adapter.Apply(ctx, robot) } } } func (d *Dispatcher) remove(ctx context.Context, robotID string) { for _, adapter := range d.adapters { adapter.Remove(ctx, robotID) } } // parseIntegrations returns the keys of integrations present in the config. func parseIntegrations(intg *robottypes.Integrations) []string { var keys []string if intg.Telegram != nil { keys = append(keys, "telegram") } if intg.Feishu != nil { keys = append(keys, "feishu") } if intg.DingTalk != nil { keys = append(keys, "dingtalk") } if intg.Discord != nil { keys = append(keys, "discord") } return keys } func (d *Dispatcher) watch(ctx context.Context, ch <-chan *eventtypes.Event) { for { select { case <-d.stopCh: return case <-ctx.Done(): return case ev, ok := <-ch: if !ok { return } d.dispatch(ctx, ev) } } } func (d *Dispatcher) dispatch(ctx context.Context, ev *eventtypes.Event) { var payload events.RobotConfigPayload if err := ev.Should(&payload); err != nil { log.Error("integration dispatcher: invalid config event: %v", err) return } switch ev.Type { case events.RobotConfigCreated, events.RobotConfigUpdated: robot := d.robotCache.Get(payload.MemberID) if robot == nil { rCtx := robottypes.NewContext(ctx, nil) loaded, err := d.robotCache.LoadByID(rCtx, payload.MemberID) if err != nil { log.Warn("integration dispatcher: failed to load robot from DB member=%s: %v", payload.MemberID, err) return } d.robotCache.Add(loaded) robot = loaded log.Info("integration dispatcher: loaded robot from DB member=%s", payload.MemberID) } d.apply(ctx, robot) case events.RobotConfigDeleted: d.remove(ctx, payload.MemberID) } } ================================================ FILE: agent/robot/events/integrations/dispatcher_test.go ================================================ package integrations import ( "context" "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" agentcontext "github.com/yaoapp/yao/agent/context" robotcache "github.com/yaoapp/yao/agent/robot/cache" events "github.com/yaoapp/yao/agent/robot/events" robottypes "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/event" eventtypes "github.com/yaoapp/yao/event/types" ) // mockAdapter records Apply/Remove calls for assertions. type mockAdapter struct { mu sync.Mutex applied []*robottypes.Robot removed []string } func (m *mockAdapter) Apply(ctx context.Context, robot *robottypes.Robot) { m.mu.Lock() defer m.mu.Unlock() m.applied = append(m.applied, robot) } func (m *mockAdapter) Remove(ctx context.Context, robotID string) { m.mu.Lock() defer m.mu.Unlock() m.removed = append(m.removed, robotID) } func (m *mockAdapter) Reply(ctx context.Context, msg *agentcontext.Message, metadata *events.MessageMetadata) error { return nil } func (m *mockAdapter) getApplied() []*robottypes.Robot { m.mu.Lock() defer m.mu.Unlock() cp := make([]*robottypes.Robot, len(m.applied)) copy(cp, m.applied) return cp } func (m *mockAdapter) getRemoved() []string { m.mu.Lock() defer m.mu.Unlock() cp := make([]string, len(m.removed)) copy(cp, m.removed) return cp } // noopHandler satisfies event.Handler so we can register "robot" prefix for Push/Call. type noopHandler struct{} func (h *noopHandler) Handle(ctx context.Context, ev *eventtypes.Event, resp chan<- eventtypes.Result) { if ev.IsCall { resp <- eventtypes.Result{} } } func (h *noopHandler) Shutdown(ctx context.Context) error { return nil } var eventOnce sync.Once func setupEventBus(t *testing.T) { t.Helper() eventOnce.Do(func() { event.Register("robot", &noopHandler{}) }) if err := event.Start(); err != nil && err != event.ErrAlreadyStart { t.Fatalf("event.Start: %v", err) } t.Cleanup(func() { _ = event.Stop(context.Background()) }) } func newRobot(memberID, teamID string, intg *robottypes.Integrations) *robottypes.Robot { return &robottypes.Robot{ MemberID: memberID, TeamID: teamID, AutonomousMode: true, Config: &robottypes.Config{ Integrations: intg, }, } } func TestLoadAll_OnlyTelegramConfigured(t *testing.T) { setupEventBus(t) cache := robotcache.New() tgRobot := newRobot("r-tg", "team1", &robottypes.Integrations{ Telegram: &robottypes.TelegramConfig{Enabled: true, BotToken: "tok"}, }) noIntgRobot := newRobot("r-plain", "team1", nil) cache.Add(tgRobot) cache.Add(noIntgRobot) tgAdapter := &mockAdapter{} d := NewDispatcher(cache, map[string]Adapter{"telegram": tgAdapter}) require.NoError(t, d.Start(context.Background())) defer d.Stop() applied := tgAdapter.getApplied() assert.Len(t, applied, 1) assert.Equal(t, "r-tg", applied[0].MemberID) } func TestLoadAll_NoIntegrations(t *testing.T) { setupEventBus(t) cache := robotcache.New() cache.Add(newRobot("r1", "team1", nil)) cache.Add(&robottypes.Robot{MemberID: "r2", TeamID: "team1"}) tgAdapter := &mockAdapter{} d := NewDispatcher(cache, map[string]Adapter{"telegram": tgAdapter}) require.NoError(t, d.Start(context.Background())) defer d.Stop() assert.Empty(t, tgAdapter.getApplied()) } func TestLoadAll_MultipleAdapters(t *testing.T) { setupEventBus(t) cache := robotcache.New() // Only Telegram configured, no Discord robot := newRobot("r-multi", "team1", &robottypes.Integrations{ Telegram: &robottypes.TelegramConfig{Enabled: true, BotToken: "tok"}, }) cache.Add(robot) tgAdapter := &mockAdapter{} discordAdapter := &mockAdapter{} d := NewDispatcher(cache, map[string]Adapter{ "telegram": tgAdapter, "discord": discordAdapter, }) require.NoError(t, d.Start(context.Background())) defer d.Stop() assert.Len(t, tgAdapter.getApplied(), 1) assert.Empty(t, discordAdapter.getApplied(), "discord adapter should not be called") } func TestConfigCreated_TriggersApply(t *testing.T) { setupEventBus(t) cache := robotcache.New() tgAdapter := &mockAdapter{} d := NewDispatcher(cache, map[string]Adapter{"telegram": tgAdapter}) require.NoError(t, d.Start(context.Background())) defer d.Stop() assert.Empty(t, tgAdapter.getApplied()) // Simulate: robot created with Telegram config, added to cache, event pushed robot := newRobot("r-new", "team1", &robottypes.Integrations{ Telegram: &robottypes.TelegramConfig{Enabled: true, BotToken: "new-tok"}, }) cache.Add(robot) event.Push(context.Background(), events.RobotConfigCreated, events.RobotConfigPayload{ MemberID: "r-new", TeamID: "team1", }) assert.Eventually(t, func() bool { return len(tgAdapter.getApplied()) == 1 }, 2*time.Second, 50*time.Millisecond) assert.Equal(t, "r-new", tgAdapter.getApplied()[0].MemberID) } func TestConfigUpdated_TriggersApply(t *testing.T) { setupEventBus(t) cache := robotcache.New() robot := newRobot("r-upd", "team1", &robottypes.Integrations{ Telegram: &robottypes.TelegramConfig{Enabled: true, BotToken: "old-tok"}, }) cache.Add(robot) tgAdapter := &mockAdapter{} d := NewDispatcher(cache, map[string]Adapter{"telegram": tgAdapter}) require.NoError(t, d.Start(context.Background())) defer d.Stop() // Initial load assert.Len(t, tgAdapter.getApplied(), 1) // Update config in cache robot.Config.Integrations.Telegram.BotToken = "new-tok" event.Push(context.Background(), events.RobotConfigUpdated, events.RobotConfigPayload{ MemberID: "r-upd", TeamID: "team1", }) assert.Eventually(t, func() bool { return len(tgAdapter.getApplied()) == 2 }, 2*time.Second, 50*time.Millisecond) assert.Equal(t, "new-tok", tgAdapter.getApplied()[1].Config.Integrations.Telegram.BotToken) } func TestConfigDeleted_TriggersRemove(t *testing.T) { setupEventBus(t) cache := robotcache.New() robot := newRobot("r-del", "team1", &robottypes.Integrations{ Telegram: &robottypes.TelegramConfig{Enabled: true, BotToken: "tok"}, }) cache.Add(robot) tgAdapter := &mockAdapter{} d := NewDispatcher(cache, map[string]Adapter{"telegram": tgAdapter}) require.NoError(t, d.Start(context.Background())) defer d.Stop() assert.Len(t, tgAdapter.getApplied(), 1) event.Push(context.Background(), events.RobotConfigDeleted, events.RobotConfigPayload{ MemberID: "r-del", TeamID: "team1", }) assert.Eventually(t, func() bool { return len(tgAdapter.getRemoved()) == 1 }, 2*time.Second, 50*time.Millisecond) assert.Equal(t, "r-del", tgAdapter.getRemoved()[0]) } func TestConfigCreated_RobotNotInCache(t *testing.T) { setupEventBus(t) cache := robotcache.New() tgAdapter := &mockAdapter{} d := NewDispatcher(cache, map[string]Adapter{"telegram": tgAdapter}) require.NoError(t, d.Start(context.Background())) defer d.Stop() // Push event but don't add robot to cache event.Push(context.Background(), events.RobotConfigCreated, events.RobotConfigPayload{ MemberID: "r-ghost", TeamID: "team1", }) time.Sleep(200 * time.Millisecond) assert.Empty(t, tgAdapter.getApplied()) } func TestParseIntegrations(t *testing.T) { tests := []struct { name string intg *robottypes.Integrations expected []string }{ {"nil", nil, nil}, {"empty", &robottypes.Integrations{}, nil}, {"telegram only", &robottypes.Integrations{ Telegram: &robottypes.TelegramConfig{Enabled: true}, }, []string{"telegram"}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.intg == nil { return } result := parseIntegrations(tt.intg) assert.Equal(t, tt.expected, result) }) } } ================================================ FILE: agent/robot/events/integrations/feishu/dedup.go ================================================ package feishu import ( "sync" "time" ) const ( dedupTTL = 24 * time.Hour dedupCleanInterval = time.Hour ) type dedupStore struct { m sync.Map } func newDedupStore() *dedupStore { return &dedupStore{} } func (d *dedupStore) markSeen(key string) bool { now := time.Now().Unix() _, loaded := d.m.LoadOrStore(key, now) return !loaded } func (d *dedupStore) cleaner(stopCh <-chan struct{}) { ticker := time.NewTicker(dedupCleanInterval) defer ticker.Stop() for { select { case <-stopCh: return case <-ticker.C: cutoff := time.Now().Add(-dedupTTL).Unix() d.m.Range(func(key, value any) bool { if ts, ok := value.(int64); ok && ts < cutoff { d.m.Delete(key) } return true }) } } } ================================================ FILE: agent/robot/events/integrations/feishu/e2e_test.go ================================================ package feishu import ( "context" "os" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" robottypes "github.com/yaoapp/yao/agent/robot/types" fsapi "github.com/yaoapp/yao/integrations/feishu" ) var ( fsAppID string fsAppSecret string ) func TestMain(m *testing.M) { fsAppID = os.Getenv("FEISHU_TEST_APP_ID") fsAppSecret = os.Getenv("FEISHU_TEST_APP_SECRET") os.Exit(m.Run()) } func skipIfNoCreds(t *testing.T) { t.Helper() if fsAppID == "" || fsAppSecret == "" { t.Skip("FEISHU_TEST_APP_ID or FEISHU_TEST_APP_SECRET not set") } } // TestE2E_Adapter_Apply verifies that Apply correctly registers a bot. func TestE2E_Adapter_Apply(t *testing.T) { skipIfNoCreds(t) a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) robot := &robottypes.Robot{ MemberID: "robot_e2e_feishu_adapter", TeamID: "team_e2e_fs", Config: &robottypes.Config{ Integrations: &robottypes.Integrations{ Feishu: &robottypes.FeishuConfig{ Enabled: true, AppID: fsAppID, AppSecret: fsAppSecret, }, }, }, } a.Apply(context.Background(), robot) a.mu.RLock() entry, ok := a.bots["robot_e2e_feishu_adapter"] a.mu.RUnlock() require.True(t, ok, "bot should be registered") assert.Equal(t, fsAppID, entry.appID) assert.NotNil(t, entry.bot) t.Logf("OK Apply: feishu bot registered robot=%s app=%s", robot.MemberID, entry.appID) } // TestE2E_Adapter_Apply_Update verifies re-Apply with same appID is a no-op. func TestE2E_Adapter_Apply_Update(t *testing.T) { skipIfNoCreds(t) a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) robot := &robottypes.Robot{ MemberID: "robot_e2e_feishu_update", TeamID: "team_e2e_fs", Config: &robottypes.Config{ Integrations: &robottypes.Integrations{ Feishu: &robottypes.FeishuConfig{ Enabled: true, AppID: fsAppID, AppSecret: fsAppSecret, }, }, }, } a.Apply(context.Background(), robot) a.mu.RLock() _, ok := a.bots["robot_e2e_feishu_update"] a.mu.RUnlock() require.True(t, ok) // Apply again — should be no-op a.Apply(context.Background(), robot) a.mu.RLock() assert.Len(t, a.bots, 1) a.mu.RUnlock() // Remove a.Remove(context.Background(), "robot_e2e_feishu_update") a.mu.RLock() _, ok = a.bots["robot_e2e_feishu_update"] a.mu.RUnlock() assert.False(t, ok, "bot should be removed") t.Log("OK Apply/Remove lifecycle verified") } // TestE2E_Adapter_Dedup verifies deduplication works. func TestE2E_Adapter_Dedup(t *testing.T) { a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) key := "fs:test-robot:msg-12345" assert.True(t, a.dedup.markSeen(key), "first time should return true") assert.False(t, a.dedup.markSeen(key), "second time should return false (dedup)") t.Log("OK dedup working correctly") } // TestE2E_Adapter_HandleMessages verifies message handling through the adapter. func TestE2E_Adapter_HandleMessages(t *testing.T) { skipIfNoCreds(t) a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) entry := &botEntry{ robotID: "robot_e2e_feishu_handle", appID: fsAppID, bot: fsapi.NewBot(fsAppID, fsAppSecret), } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() cms := []*fsapi.ConvertedMessage{ { MessageID: "test_msg_1", ChatID: "test_chat_1", ChatType: "p2p", SenderID: "test_sender_1", Text: "Hello from E2E test", }, } // This should not panic even without event bus running a.handleMessages(ctx, entry, cms) // Verify dedup: should be marked as seen assert.False(t, a.dedup.markSeen("fs:robot_e2e_feishu_handle:test_msg_1"), "message should be marked as seen after handleMessages") t.Log("OK handleMessages processed 1 message") } // TestE2E_Adapter_ApplyDisabled verifies Apply removes bot when disabled. func TestE2E_Adapter_ApplyDisabled(t *testing.T) { a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) robot := &robottypes.Robot{ MemberID: "robot_e2e_feishu_disabled", TeamID: "team_e2e_fs", Config: &robottypes.Config{ Integrations: &robottypes.Integrations{ Feishu: &robottypes.FeishuConfig{ Enabled: false, AppID: "some_app", AppSecret: "some_secret", }, }, }, } a.Apply(context.Background(), robot) a.mu.RLock() _, ok := a.bots["robot_e2e_feishu_disabled"] a.mu.RUnlock() assert.False(t, ok, "disabled bot should not be registered") t.Log("OK disabled config not registered") } ================================================ FILE: agent/robot/events/integrations/feishu/feishu.go ================================================ package feishu import ( "context" "sync" "github.com/yaoapp/yao/agent/robot/logger" robottypes "github.com/yaoapp/yao/agent/robot/types" fsapi "github.com/yaoapp/yao/integrations/feishu" ) var log = logger.New("feishu") // Adapter implements the integrations.Adapter interface for Feishu (Lark). // // Architecture: // - One event subscription per registered bot via Feishu SDK's long-poll/callback mechanism // - One dedup cleaner goroutine removes expired keys every hour type Adapter struct { mu sync.RWMutex bots map[string]*botEntry // robotID -> *botEntry appIdx map[string]string // appID -> robotID dedup *dedupStore stopCh chan struct{} } // botEntry holds the state for one robot's Feishu integration. type botEntry struct { robotID string appID string bot *fsapi.Bot cancelFn context.CancelFunc // cancels the event subscription goroutine } // NewAdapter creates a new Feishu adapter. func NewAdapter() *Adapter { a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } go a.dedup.cleaner(a.stopCh) return a } // Apply is called by the Dispatcher when a robot config is created or updated. func (a *Adapter) Apply(ctx context.Context, robot *robottypes.Robot) { fsConf := extractConfig(robot) log.Debug("Apply robot=%s fsConf=%v", robot.MemberID, fsConf != nil) if fsConf == nil || !fsConf.Enabled || fsConf.AppID == "" || fsConf.AppSecret == "" { a.removeBot(robot.MemberID) return } a.mu.Lock() defer a.mu.Unlock() if existing, ok := a.bots[robot.MemberID]; ok { if existing.appID == fsConf.AppID { return } a.removeBotLocked(robot.MemberID) } bot := fsapi.NewBot(fsConf.AppID, fsConf.AppSecret) streamCtx, streamCancel := context.WithCancel(context.Background()) entry := &botEntry{ robotID: robot.MemberID, appID: fsConf.AppID, bot: bot, cancelFn: streamCancel, } a.bots[robot.MemberID] = entry a.appIdx[fsConf.AppID] = robot.MemberID go a.eventLoop(streamCtx, entry) log.Info("feishu adapter: registered robot=%s app=%s", robot.MemberID, fsConf.AppID) } // Remove is called by the Dispatcher when a robot is deleted. func (a *Adapter) Remove(ctx context.Context, robotID string) { a.removeBot(robotID) } // Shutdown stops all event subscriptions and dedup cleaner. func (a *Adapter) Shutdown() { close(a.stopCh) a.mu.Lock() for _, entry := range a.bots { if entry.cancelFn != nil { entry.cancelFn() } } a.mu.Unlock() log.Info("feishu adapter: shutdown complete") } func (a *Adapter) removeBot(robotID string) { a.mu.Lock() defer a.mu.Unlock() a.removeBotLocked(robotID) } func (a *Adapter) removeBotLocked(robotID string) { entry, ok := a.bots[robotID] if !ok { return } if entry.cancelFn != nil { entry.cancelFn() } if entry.appID != "" { delete(a.appIdx, entry.appID) } delete(a.bots, robotID) log.Info("feishu adapter: unregistered robot=%s", robotID) } func (a *Adapter) resolveByAppID(appID string) (*botEntry, bool) { a.mu.RLock() defer a.mu.RUnlock() robotID, ok := a.appIdx[appID] if !ok { return nil, false } entry, ok := a.bots[robotID] return entry, ok } func extractConfig(robot *robottypes.Robot) *robottypes.FeishuConfig { if robot.Config == nil || robot.Config.Integrations == nil { return nil } return robot.Config.Integrations.Feishu } ================================================ FILE: agent/robot/events/integrations/feishu/message.go ================================================ package feishu import ( "context" "fmt" "strings" agentcontext "github.com/yaoapp/yao/agent/context" events "github.com/yaoapp/yao/agent/robot/events" "github.com/yaoapp/yao/event" fsapi "github.com/yaoapp/yao/integrations/feishu" ) // handleMessages processes a batch of Feishu messages for one chat. func (a *Adapter) handleMessages(ctx context.Context, entry *botEntry, cms []*fsapi.ConvertedMessage) { if len(cms) == 0 { return } var allParts []interface{} var lastCM *fsapi.ConvertedMessage for _, cm := range cms { if cm == nil { continue } dedupKey := fmt.Sprintf("fs:%s:%s", entry.robotID, cm.MessageID) if !a.dedup.markSeen(dedupKey) { continue } parts := buildContentParts(cm) if len(parts) == 0 { continue } allParts = append(allParts, parts...) lastCM = cm } if len(allParts) == 0 || lastCM == nil { return } content := mergeContentParts(allParts) msgPayload := events.MessagePayload{ RobotID: entry.robotID, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: content}, }, Metadata: &events.MessageMetadata{ Channel: "feishu", MessageID: lastCM.MessageID, AppID: entry.appID, ChatID: lastCM.ChatID, SenderID: lastCM.SenderID, SenderName: lastCM.SenderName, Locale: events.NormalizeLocale(lastCM.LanguageCode), Extra: map[string]any{ "feishu_message_id": lastCM.MessageID, }, }, } if _, err := event.Push(ctx, events.Message, msgPayload); err != nil { log.Error("feishu adapter: event.Push robot.message failed robot=%s: %v", entry.robotID, err) } } func buildContentParts(cm *fsapi.ConvertedMessage) []interface{} { var parts []interface{} if cm.HasText() { parts = append(parts, map[string]interface{}{ "type": "text", "text": cm.Text, }) } for _, mi := range cm.MediaItems { if mi.Wrapper == "" { continue } parts = append(parts, map[string]interface{}{ "type": "file", "file_url": mi.Wrapper, "mime_type": mi.MimeType, "file_name": mi.FileName, }) } return parts } func mergeContentParts(parts []interface{}) interface{} { allText := true for _, p := range parts { m, ok := p.(map[string]interface{}) if !ok || m["type"] != "text" { allText = false break } } if allText { var buf strings.Builder for i, p := range parts { if i > 0 { buf.WriteString("\n") } m := p.(map[string]interface{}) buf.WriteString(m["text"].(string)) } return buf.String() } return parts } ================================================ FILE: agent/robot/events/integrations/feishu/reply.go ================================================ package feishu import ( "context" "fmt" "strings" agentcontext "github.com/yaoapp/yao/agent/context" events "github.com/yaoapp/yao/agent/robot/events" fsapi "github.com/yaoapp/yao/integrations/feishu" ) // Reply sends the assistant message back to the originating Feishu chat. func (a *Adapter) Reply(ctx context.Context, msg *agentcontext.Message, metadata *events.MessageMetadata) error { if msg == nil || metadata == nil { return fmt.Errorf("nil message or metadata") } entry := a.resolveByChat(metadata) if entry == nil { return fmt.Errorf("no bot registered for feishu metadata (appID=%s)", metadata.AppID) } var replyToMsgID string if metadata.Extra != nil { if v, ok := metadata.Extra["feishu_message_id"]; ok { if s, ok := v.(string); ok { replyToMsgID = s } } } return a.sendContent(ctx, entry, metadata.ChatID, replyToMsgID, msg.Content) } func (a *Adapter) sendContent(ctx context.Context, entry *botEntry, chatID, replyToMsgID string, content interface{}) error { switch c := content.(type) { case string: if strings.TrimSpace(c) == "" { return nil } return a.sendMarkdown(ctx, entry, chatID, replyToMsgID, c) case []interface{}: return a.sendParts(ctx, entry, chatID, replyToMsgID, c) default: parts, ok := toContentParts(content) if ok { return a.sendPartsTyped(ctx, entry, chatID, replyToMsgID, parts) } return a.sendMarkdown(ctx, entry, chatID, replyToMsgID, fmt.Sprintf("%v", content)) } } // sendMarkdown converts standard Markdown to Feishu lark_md and sends as an interactive card. func (a *Adapter) sendMarkdown(ctx context.Context, entry *botEntry, chatID, replyToMsgID, text string) error { formatted := fsapi.FormatFeishuMarkdown(text) if replyToMsgID != "" { _, err := entry.bot.ReplyCardMessage(ctx, replyToMsgID, formatted) return err } _, err := entry.bot.SendCardMessage(ctx, chatID, formatted) return err } func (a *Adapter) sendParts(ctx context.Context, entry *botEntry, chatID, replyToMsgID string, parts []interface{}) error { var textBuf strings.Builder for _, part := range parts { m, ok := part.(map[string]interface{}) if !ok { continue } partType, _ := m["type"].(string) switch partType { case "text": if text, ok := m["text"].(string); ok { textBuf.WriteString(text) } case "image_url": if err := a.flushText(ctx, entry, chatID, replyToMsgID, &textBuf); err != nil { return err } if imgMap, ok := m["image_url"].(map[string]interface{}); ok { if url, ok := imgMap["url"].(string); ok { if err := sendImageOrWrapper(ctx, entry, chatID, url, ""); err != nil { log.Error("feishu reply: send image: %v", err) } } } case "file": if err := a.flushText(ctx, entry, chatID, replyToMsgID, &textBuf); err != nil { return err } fileURL, _ := m["file_url"].(string) if fileURL == "" { if fileMap, ok := m["file"].(map[string]interface{}); ok { fileURL, _ = fileMap["url"].(string) } } if fileURL != "" { if err := sendFileOrWrapper(ctx, entry, chatID, fileURL, ""); err != nil { log.Error("feishu reply: send file: %v", err) } } } } return a.flushText(ctx, entry, chatID, replyToMsgID, &textBuf) } func (a *Adapter) sendPartsTyped(ctx context.Context, entry *botEntry, chatID, replyToMsgID string, parts []agentcontext.ContentPart) error { var textBuf strings.Builder for _, part := range parts { switch part.Type { case agentcontext.ContentText: textBuf.WriteString(part.Text) case agentcontext.ContentImageURL: if err := a.flushText(ctx, entry, chatID, replyToMsgID, &textBuf); err != nil { return err } if part.ImageURL != nil { if err := sendImageOrWrapper(ctx, entry, chatID, part.ImageURL.URL, ""); err != nil { log.Error("feishu reply: send image: %v", err) } } case agentcontext.ContentFile: if err := a.flushText(ctx, entry, chatID, replyToMsgID, &textBuf); err != nil { return err } if part.File != nil { if err := sendFileOrWrapper(ctx, entry, chatID, part.File.URL, part.File.Filename); err != nil { log.Error("feishu reply: send file: %v", err) } } } } return a.flushText(ctx, entry, chatID, replyToMsgID, &textBuf) } func (a *Adapter) flushText(ctx context.Context, entry *botEntry, chatID, replyToMsgID string, buf *strings.Builder) error { if buf.Len() == 0 { return nil } text := buf.String() buf.Reset() return a.sendMarkdown(ctx, entry, chatID, replyToMsgID, text) } func sendImageOrWrapper(ctx context.Context, entry *botEntry, chatID, url, caption string) error { if isWrapper(url) { return entry.bot.SendImageFromWrapper(ctx, chatID, url, caption) } if strings.HasPrefix(url, "http") { text := url if caption != "" { text = caption + "\n" + url } _, err := entry.bot.SendTextMessage(ctx, chatID, text) return err } return fmt.Errorf("unsupported image URL scheme: %s", url) } func sendFileOrWrapper(ctx context.Context, entry *botEntry, chatID, url, caption string) error { if isWrapper(url) { return entry.bot.SendFileFromWrapper(ctx, chatID, url, caption) } if strings.HasPrefix(url, "http") { text := url if caption != "" { text = caption + "\n" + url } _, err := entry.bot.SendTextMessage(ctx, chatID, text) return err } return fmt.Errorf("unsupported file URL scheme: %s", url) } func isWrapper(url string) bool { return strings.Contains(url, "://") && !strings.HasPrefix(url, "http") } func toContentParts(content interface{}) ([]agentcontext.ContentPart, bool) { parts, ok := content.([]agentcontext.ContentPart) return parts, ok } func (a *Adapter) resolveByChat(metadata *events.MessageMetadata) *botEntry { if metadata.AppID != "" { if entry, ok := a.resolveByAppID(metadata.AppID); ok { return entry } } a.mu.RLock() defer a.mu.RUnlock() for _, entry := range a.bots { return entry } return nil } ================================================ FILE: agent/robot/events/integrations/feishu/stream.go ================================================ package feishu import ( "context" "time" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher" larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" larkws "github.com/larksuite/oapi-sdk-go/v3/ws" fsapi "github.com/yaoapp/yao/integrations/feishu" ) const reconnectDelay = 5 * time.Second // eventLoop starts the Feishu WebSocket event subscription for a single bot. // It automatically reconnects on failure. func (a *Adapter) eventLoop(ctx context.Context, entry *botEntry) { log.Info("feishu eventLoop started robot=%s app=%s", entry.robotID, entry.appID) for { select { case <-ctx.Done(): log.Info("feishu eventLoop stopped robot=%s", entry.robotID) return case <-a.stopCh: return default: } err := a.runWSClient(ctx, entry) if err != nil { log.Error("feishu ws disconnected robot=%s: %v, reconnecting in %s", entry.robotID, err, reconnectDelay) } select { case <-ctx.Done(): return case <-a.stopCh: return case <-time.After(reconnectDelay): } } } func (a *Adapter) runWSClient(ctx context.Context, entry *botEntry) error { eventHandler := dispatcher.NewEventDispatcher("", "") eventHandler.OnP2MessageReceiveV1(func(ctx context.Context, event *larkim.P2MessageReceiveV1) error { return a.onMessageReceive(ctx, entry, event) }) cli := larkws.NewClient(entry.bot.AppID(), entry.bot.AppSecret(), larkws.WithEventHandler(eventHandler), larkws.WithLogLevel(larkcore.LogLevelWarn), ) errCh := make(chan error, 1) go func() { errCh <- cli.Start(ctx) }() select { case <-ctx.Done(): return ctx.Err() case <-a.stopCh: return nil case err := <-errCh: return err } } func (a *Adapter) onMessageReceive(ctx context.Context, entry *botEntry, event *larkim.P2MessageReceiveV1) error { if event == nil || event.Event == nil || event.Event.Message == nil { return nil } msg := event.Event.Message sender := event.Event.Sender msgType := derefStr(msg.MessageType) content := derefStr(msg.Content) messageID := derefStr(msg.MessageId) chatID := derefStr(msg.ChatId) chatType := derefStr(msg.ChatType) text, media := fsapi.ParseMessageContent(msgType, content) cm := &fsapi.ConvertedMessage{ MessageID: messageID, ChatID: chatID, ChatType: chatType, Text: text, MediaItems: media, EventID: event.EventV2Base.Header.EventID, LanguageCode: "zh", } if sender != nil && sender.SenderId != nil { cm.SenderID = derefStr(sender.SenderId.OpenId) } if cm.HasMedia() { groups := []string{"feishu", entry.robotID} entry.bot.ResolveMedia(ctx, cm, groups) } a.handleMessages(ctx, entry, []*fsapi.ConvertedMessage{cm}) return nil } func derefStr(s *string) string { if s == nil { return "" } return *s } ================================================ FILE: agent/robot/events/integrations/telegram/dedup.go ================================================ package telegram import ( "sync" "time" ) const ( dedupTTL = 24 * time.Hour dedupCleanInterval = time.Hour ) // dedupStore is a lightweight in-memory deduplication store with TTL. // Used for message-level dedup (same update_id won't be processed twice). type dedupStore struct { m sync.Map // key -> int64 (unix timestamp) } func newDedupStore() *dedupStore { return &dedupStore{} } // markSeen returns true if this is the first time the key is seen. func (d *dedupStore) markSeen(key string) bool { now := time.Now().Unix() _, loaded := d.m.LoadOrStore(key, now) return !loaded } // cleaner periodically removes expired entries. Runs until stopCh is closed. func (d *dedupStore) cleaner(stopCh <-chan struct{}) { ticker := time.NewTicker(dedupCleanInterval) defer ticker.Stop() for { select { case <-stopCh: return case <-ticker.C: cutoff := time.Now().Add(-dedupTTL).Unix() d.m.Range(func(key, value any) bool { if ts, ok := value.(int64); ok && ts < cutoff { d.m.Delete(key) } return true }) } } } ================================================ FILE: agent/robot/events/integrations/telegram/e2e_test.go ================================================ package telegram import ( "context" "encoding/json" "fmt" "os" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" robottypes "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/event" tgapi "github.com/yaoapp/yao/integrations/telegram" ) var ( tgBotToken string tgHost string ) func TestMain(m *testing.M) { tgBotToken = os.Getenv("TELEGRAM_TEST_BOT_TOKEN") tgHost = os.Getenv("TELEGRAM_TEST_HOST") os.Exit(m.Run()) } func skipIfNoToken(t *testing.T) { t.Helper() if tgBotToken == "" { t.Skip("TELEGRAM_TEST_BOT_TOKEN not set") } } func newTestBot() *tgapi.Bot { var opts []tgapi.BotOption if tgHost != "" { opts = append(opts, tgapi.WithAPIBase(tgHost)) } return tgapi.NewBot(tgBotToken, "", opts...) } // confirmPendingUpdates checks if there are pending updates from previous seeds. func confirmPendingUpdates(t *testing.T) []*tgapi.ConvertedMessage { t.Helper() b := newTestBot() ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() msgs, err := b.GetUpdates(ctx, 0, 5, nil) require.NoError(t, err) return msgs } // TestE2E_Adapter_Apply verifies that Apply correctly registers a bot. func TestE2E_Adapter_Apply(t *testing.T) { skipIfNoToken(t) a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) robot := &robottypes.Robot{ MemberID: "robot_e2e_tg_adapter", TeamID: "team_e2e_tg", Config: &robottypes.Config{ Integrations: &robottypes.Integrations{ Telegram: &robottypes.TelegramConfig{ Enabled: true, BotToken: tgBotToken, Host: tgHost, AppID: "e2e-test-app", }, }, }, } a.Apply(context.Background(), robot) a.mu.RLock() entry, ok := a.bots["robot_e2e_tg_adapter"] a.mu.RUnlock() require.True(t, ok, "bot should be registered") assert.Equal(t, tgBotToken, entry.bot.Token()) assert.Equal(t, "e2e-test-app", entry.appID) // Verify GetMe works through the registered bot ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() me, err := entry.bot.GetMe(ctx) require.NoError(t, err) assert.True(t, me.IsBot) t.Logf("OK Apply: bot registered id=%d username=%s", me.ID, me.Username) // Verify ResolveBot resolved := a.ResolveBot("e2e-test-app") require.NotNil(t, resolved) assert.Equal(t, tgBotToken, resolved.Token()) } // TestE2E_Adapter_Apply_Update verifies that Apply with a different token replaces the bot. func TestE2E_Adapter_Apply_Update(t *testing.T) { skipIfNoToken(t) a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) robot := &robottypes.Robot{ MemberID: "robot_e2e_tg_update", TeamID: "team_e2e_tg", Config: &robottypes.Config{ Integrations: &robottypes.Integrations{ Telegram: &robottypes.TelegramConfig{ Enabled: true, BotToken: tgBotToken, Host: tgHost, }, }, }, } a.Apply(context.Background(), robot) a.mu.RLock() _, ok := a.bots["robot_e2e_tg_update"] a.mu.RUnlock() require.True(t, ok) // Apply again with same token — should be a no-op a.Apply(context.Background(), robot) a.mu.RLock() assert.Len(t, a.bots, 1) a.mu.RUnlock() // Remove a.Remove(context.Background(), "robot_e2e_tg_update") a.mu.RLock() _, ok = a.bots["robot_e2e_tg_update"] a.mu.RUnlock() assert.False(t, ok, "bot should be removed") t.Log("OK Apply/Remove lifecycle verified") } // TestE2E_Adapter_PollAll verifies that pollAll fetches updates from Telegram // and processes them through handleMessages. func TestE2E_Adapter_PollAll(t *testing.T) { skipIfNoToken(t) testutils.Prepare(t) defer testutils.Clean(t) pending := confirmPendingUpdates(t) if len(pending) == 0 { t.Skip("no pending updates; run integrations/telegram seed first") } t.Logf("found %d pending updates", len(pending)) // Create adapter WITHOUT auto-starting pollLoop a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) memberID := "robot_e2e_tg_poll" setupTestRobot(t, memberID) defer cleanupTestRobots(t) var opts []tgapi.BotOption if tgHost != "" { opts = append(opts, tgapi.WithAPIBase(tgHost)) } a.bots[memberID] = &botEntry{ robotID: memberID, appID: "e2e-poll-app", bot: tgapi.NewBot(tgBotToken, "", opts...), } // Start event bus so event.Push works if err := event.Start(); err != nil && err != event.ErrAlreadyStart { t.Fatalf("event.Start: %v", err) } defer func() { _ = event.Stop(context.Background()) }() // Manually trigger one poll cycle a.pollAll() // Verify offset advanced (meaning updates were processed) a.mu.RLock() entry := a.bots[memberID] a.mu.RUnlock() assert.Greater(t, entry.offset, int64(0), "offset should have advanced after processing updates") t.Logf("OK pollAll: offset advanced to %d", entry.offset) } // TestE2E_Adapter_Dedup verifies that duplicate messages are not processed twice. func TestE2E_Adapter_Dedup(t *testing.T) { skipIfNoToken(t) a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) key := "tg:test-robot:12345" assert.True(t, a.dedup.markSeen(key), "first time should return true") assert.False(t, a.dedup.markSeen(key), "second time should return false (dedup)") t.Log("OK dedup working correctly") } // TestE2E_Adapter_HandleMessages_Integration verifies the full flow: // GetUpdates → ConvertedMessage → handleMessages → event.Push func TestE2E_Adapter_HandleMessages_Integration(t *testing.T) { skipIfNoToken(t) testutils.Prepare(t) defer testutils.Clean(t) b := newTestBot() ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() msgs, err := b.GetUpdates(ctx, 0, 5, nil) require.NoError(t, err) if len(msgs) == 0 { t.Skip("no pending updates; run integrations/telegram seed first") } memberID := "robot_e2e_tg_handle" setupTestRobot(t, memberID) defer cleanupTestRobots(t) if err := event.Start(); err != nil && err != event.ErrAlreadyStart { t.Fatalf("event.Start: %v", err) } defer func() { _ = event.Stop(context.Background()) }() a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } defer close(a.stopCh) entry := &botEntry{ robotID: memberID, appID: "e2e-handle-app", bot: b, } // Group all messages by chatID (like pollAll does) and process each group grouped := groupByChatID(msgs) for chatID, chatMsgs := range grouped { t.Logf("processing chat=%d messages=%d", chatID, len(chatMsgs)) for _, cm := range chatMsgs { t.Logf(" update_id=%d msg_id=%d text=%q media=%d", cm.UpdateID, cm.MessageID, truncate(cm.Text, 40), len(cm.MediaItems)) } a.handleMessages(ctx, entry, chatMsgs) } // Verify dedup: all updates should be marked as seen cm := msgs[0] assert.False(t, a.dedup.markSeen(fmt.Sprintf("tg:%s:%d", memberID, cm.UpdateID)), "update should be marked as seen after handleMessages") // Second call with same messages should be fully deduped (no-op) a.handleMessages(ctx, entry, msgs) t.Logf("OK handleMessages processed %d updates across %d chats", len(msgs), len(grouped)) } // ==================== Helpers ==================== func setupTestRobot(t *testing.T, memberID string) { t.Helper() m := model.Select("__yao.member") if m == nil { t.Skip("__yao.member model not loaded") } qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Telegram E2E Test Robot", "duties": []string{"Process Telegram messages"}, }, "integrations": map[string]interface{}{ "telegram": map[string]interface{}{ "enabled": true, "bot_token": tgBotToken, "host": tgHost, "app_id": "e2e-tg-app-" + memberID, }, }, "resources": map[string]interface{}{ "phases": map[string]interface{}{ "host": "robot.host", }, }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(m.MetaData.Table.Name).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": "team_e2e_tg", "member_type": "robot", "display_name": "E2E TG Adapter Test " + memberID, "system_prompt": "You are a test robot for Telegram adapter E2E testing.", "status": "active", "role_id": "member", "autonomous_mode": false, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("setup robot %s: %v", memberID, err) } } func cleanupTestRobots(t *testing.T) { t.Helper() m := model.Select("__yao.member") if m == nil { return } qb := capsule.Query() _, _ = qb.Table(m.MetaData.Table.Name).Where("member_id", "like", "robot_e2e_tg%").Delete() } func truncate(s string, n int) string { if len(s) <= n { return s } return s[:n] + "..." } ================================================ FILE: agent/robot/events/integrations/telegram/message.go ================================================ package telegram import ( "context" "fmt" "strconv" "strings" agentcontext "github.com/yaoapp/yao/agent/context" events "github.com/yaoapp/yao/agent/robot/events" "github.com/yaoapp/yao/event" tgapi "github.com/yaoapp/yao/integrations/telegram" ) // handleMessages builds a single event payload from a batch of ConvertedMessages // belonging to the same chat. Consecutive user messages are merged into one to // keep the messages array clean for the LLM. func (a *Adapter) handleMessages(ctx context.Context, entry *botEntry, cms []*tgapi.ConvertedMessage) { if len(cms) == 0 { return } var allParts []interface{} var lastCM *tgapi.ConvertedMessage for _, cm := range cms { if cm == nil { continue } if isBotCommand(cm) { continue } dedupKey := fmt.Sprintf("tg:%s:%d", entry.robotID, cm.UpdateID) if !a.dedup.markSeen(dedupKey) { continue } parts := buildContentParts(cm) if len(parts) == 0 { continue } allParts = append(allParts, parts...) lastCM = cm } if len(allParts) == 0 || lastCM == nil { return } content := mergeContentParts(allParts) msgPayload := events.MessagePayload{ RobotID: entry.robotID, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: content}, }, Metadata: &events.MessageMetadata{ Channel: "telegram", MessageID: strconv.FormatInt(lastCM.UpdateID, 10), AppID: entry.appID, ChatID: strconv.FormatInt(lastCM.ChatID, 10), SenderID: strconv.FormatInt(lastCM.SenderID, 10), SenderName: lastCM.SenderName, Locale: events.NormalizeLocale(lastCM.LanguageCode), Extra: map[string]any{ "tg_message_id": lastCM.MessageID, }, }, } if _, err := event.Push(ctx, events.Message, msgPayload); err != nil { log.Error("telegram adapter: event.Push robot.message failed robot=%s: %v", entry.robotID, err) } } // buildContentParts extracts content parts from a single ConvertedMessage. func buildContentParts(cm *tgapi.ConvertedMessage) []interface{} { var parts []interface{} if cm.HasText() { parts = append(parts, map[string]interface{}{ "type": "text", "text": cm.Text, }) } for _, mi := range cm.MediaItems { if mi.Wrapper == "" { continue } parts = append(parts, map[string]interface{}{ "type": "file", "file_url": mi.Wrapper, "mime_type": mi.MimeType, "file_name": mi.FileName, }) } return parts } // mergeContentParts merges collected parts into a single content value. // If all parts are text-only, they are joined with newlines into a plain string. // Otherwise the full parts array is returned. func mergeContentParts(parts []interface{}) interface{} { allText := true for _, p := range parts { m, ok := p.(map[string]interface{}) if !ok || m["type"] != "text" { allText = false break } } if allText { var buf strings.Builder for i, p := range parts { if i > 0 { buf.WriteString("\n") } m := p.(map[string]interface{}) buf.WriteString(m["text"].(string)) } return buf.String() } return parts } // isBotCommand returns true if the message is a Telegram bot command (text starting with "/"). func isBotCommand(cm *tgapi.ConvertedMessage) bool { return !cm.HasMedia() && strings.HasPrefix(strings.TrimSpace(cm.Text), "/") } ================================================ FILE: agent/robot/events/integrations/telegram/polling.go ================================================ package telegram import ( "context" "time" tgapi "github.com/yaoapp/yao/integrations/telegram" ) const ( pollInterval = 60 * time.Second pollTimeout = 30 // seconds, Telegram long-polling timeout per request ) // pollLoop runs a single goroutine that iterates all registered bots // every pollInterval, calling getUpdates for each one sequentially. func (a *Adapter) pollLoop() { log.Info("pollLoop started, interval=%s", pollInterval) a.pollAll() ticker := time.NewTicker(pollInterval) defer ticker.Stop() for { select { case <-a.stopCh: log.Info("pollLoop stopped") return case <-ticker.C: a.pollAll() } } } func (a *Adapter) pollAll() { entries := a.snapshot() log.Debug("pollAll bots=%d", len(entries)) if len(entries) == 0 { return } ctx, cancel := context.WithTimeout(context.Background(), pollInterval) defer cancel() for _, entry := range entries { select { case <-a.stopCh: return default: } log.Debug("polling robot=%s offset=%d", entry.robotID, entry.offset) groups := []string{"telegram", entry.robotID} msgs, err := entry.bot.GetUpdates(ctx, entry.offset, pollTimeout, groups) if err != nil { log.Error("getUpdates failed robot=%s: %v", entry.robotID, err) continue } if len(msgs) == 0 { continue } // Advance offset for all received messages for _, cm := range msgs { if cm.UpdateID >= entry.offset { entry.offset = cm.UpdateID + 1 } } // Group by chatID, preserving order grouped := groupByChatID(msgs) log.Info("robot=%s got %d updates in %d chats", entry.robotID, len(msgs), len(grouped)) for chatID, chatMsgs := range grouped { log.Debug("robot=%s chat=%d messages=%d", entry.robotID, chatID, len(chatMsgs)) a.handleMessages(ctx, entry, chatMsgs) } } } // groupByChatID groups messages by chat ID, preserving chronological order. func groupByChatID(msgs []*tgapi.ConvertedMessage) map[int64][]*tgapi.ConvertedMessage { grouped := make(map[int64][]*tgapi.ConvertedMessage) for _, cm := range msgs { grouped[cm.ChatID] = append(grouped[cm.ChatID], cm) } return grouped } ================================================ FILE: agent/robot/events/integrations/telegram/reply.go ================================================ package telegram import ( "context" "fmt" "strconv" "strings" agentcontext "github.com/yaoapp/yao/agent/context" events "github.com/yaoapp/yao/agent/robot/events" tgapi "github.com/yaoapp/yao/integrations/telegram" ) // Reply sends the assistant message back to the originating Telegram chat. // Content may be a plain string or []ContentPart (text, image_url, file, etc.). // Each adapter is responsible for interpreting the standard message format. func (a *Adapter) Reply(ctx context.Context, msg *agentcontext.Message, metadata *events.MessageMetadata) error { if msg == nil || metadata == nil { return fmt.Errorf("nil message or metadata") } chatID, err := strconv.ParseInt(metadata.ChatID, 10, 64) if err != nil { return fmt.Errorf("invalid chat_id %q: %w", metadata.ChatID, err) } var replyTo int64 if metadata.Extra != nil { if v, ok := metadata.Extra["tg_message_id"]; ok { switch id := v.(type) { case int64: replyTo = id case float64: replyTo = int64(id) } } } entry := a.resolveByChat(metadata) if entry == nil { return fmt.Errorf("no bot registered for channel metadata (appID=%s)", metadata.AppID) } return a.sendContent(ctx, entry.bot, chatID, replyTo, msg.Content) } // sendContent dispatches based on the Content type. func (a *Adapter) sendContent(ctx context.Context, bot *tgapi.Bot, chatID, replyTo int64, content interface{}) error { switch c := content.(type) { case string: if strings.TrimSpace(c) == "" { return nil } return bot.SendMessage(ctx, chatID, c, replyTo) case []interface{}: return a.sendParts(ctx, bot, chatID, replyTo, c) default: parts, ok := toContentParts(content) if ok { return a.sendPartsTyped(ctx, bot, chatID, replyTo, parts) } return bot.SendMessage(ctx, chatID, fmt.Sprintf("%v", content), replyTo) } } // sendParts handles []interface{} content parts (common from JSON unmarshalling). func (a *Adapter) sendParts(ctx context.Context, bot *tgapi.Bot, chatID, replyTo int64, parts []interface{}) error { var textBuf strings.Builder for _, part := range parts { m, ok := part.(map[string]interface{}) if !ok { continue } partType, _ := m["type"].(string) switch partType { case "text": if text, ok := m["text"].(string); ok { textBuf.WriteString(text) } case "image_url": if err := a.flushText(ctx, bot, chatID, replyTo, &textBuf); err != nil { return err } if imgMap, ok := m["image_url"].(map[string]interface{}); ok { if url, ok := imgMap["url"].(string); ok { if err := sendFileOrWrapper(ctx, bot, chatID, replyTo, url, ""); err != nil { log.Error("telegram reply: send image: %v", err) } } } case "file": if err := a.flushText(ctx, bot, chatID, replyTo, &textBuf); err != nil { return err } if fileMap, ok := m["file"].(map[string]interface{}); ok { url, _ := fileMap["url"].(string) filename, _ := fileMap["filename"].(string) if url != "" { if err := sendFileOrWrapper(ctx, bot, chatID, replyTo, url, filename); err != nil { log.Error("telegram reply: send file: %v", err) } } } } } return a.flushText(ctx, bot, chatID, replyTo, &textBuf) } // sendPartsTyped handles typed []agentcontext.ContentPart slices. func (a *Adapter) sendPartsTyped(ctx context.Context, bot *tgapi.Bot, chatID, replyTo int64, parts []agentcontext.ContentPart) error { var textBuf strings.Builder for _, part := range parts { switch part.Type { case agentcontext.ContentText: textBuf.WriteString(part.Text) case agentcontext.ContentImageURL: if err := a.flushText(ctx, bot, chatID, replyTo, &textBuf); err != nil { return err } if part.ImageURL != nil { if err := sendFileOrWrapper(ctx, bot, chatID, replyTo, part.ImageURL.URL, ""); err != nil { log.Error("telegram reply: send image: %v", err) } } case agentcontext.ContentFile: if err := a.flushText(ctx, bot, chatID, replyTo, &textBuf); err != nil { return err } if part.File != nil { if err := sendFileOrWrapper(ctx, bot, chatID, replyTo, part.File.URL, part.File.Filename); err != nil { log.Error("telegram reply: send file: %v", err) } } } } return a.flushText(ctx, bot, chatID, replyTo, &textBuf) } func (a *Adapter) flushText(ctx context.Context, bot *tgapi.Bot, chatID, replyTo int64, buf *strings.Builder) error { if buf.Len() == 0 { return nil } err := bot.SendMessage(ctx, chatID, buf.String(), replyTo) buf.Reset() return err } // sendFileOrWrapper sends a file from a wrapper (__yao.attachment://xxx) or URL. func sendFileOrWrapper(ctx context.Context, bot *tgapi.Bot, chatID, replyTo int64, url, caption string) error { if strings.Contains(url, "://") && !strings.HasPrefix(url, "http") { return bot.SendMedia(ctx, chatID, url, caption, replyTo) } if strings.HasPrefix(url, "http") { mediaType := tgapi.DetectMediaType("") return bot.SendMediaByURL(ctx, chatID, mediaType, url, caption, replyTo) } return fmt.Errorf("unsupported file URL scheme: %s", url) } // toContentParts tries to type-assert content to []agentcontext.ContentPart. func toContentParts(content interface{}) ([]agentcontext.ContentPart, bool) { parts, ok := content.([]agentcontext.ContentPart) return parts, ok } // resolveByChat finds the bot entry matching the metadata. func (a *Adapter) resolveByChat(metadata *events.MessageMetadata) *botEntry { if metadata.AppID != "" { if entry, ok := a.resolveByAppID(metadata.AppID); ok { return entry } } a.mu.RLock() defer a.mu.RUnlock() for _, entry := range a.bots { return entry } return nil } ================================================ FILE: agent/robot/events/integrations/telegram/telegram.go ================================================ package telegram import ( "context" "sync" "github.com/yaoapp/yao/agent/robot/logger" robottypes "github.com/yaoapp/yao/agent/robot/types" tgapi "github.com/yaoapp/yao/integrations/telegram" ) var log = logger.New("telegram") // Adapter implements the integrations.Adapter interface for Telegram Bot API. // // Architecture: // - One polling goroutine (ticker) iterates all registered bots every 60s // - One webhook goroutine listens to integration.webhook.telegram events // - One dedup cleaner goroutine removes expired keys every hour type Adapter struct { mu sync.RWMutex bots map[string]*botEntry // robotID -> *botEntry appIdx map[string]string // appID -> robotID (webhook routing) dedup *dedupStore webhSub string stopCh chan struct{} } // botEntry holds the state for one robot's Telegram integration. type botEntry struct { robotID string appID string bot *tgapi.Bot // bound to this robot's token offset int64 // polling offset } // NewAdapter creates a new Telegram adapter. func NewAdapter() *Adapter { a := &Adapter{ bots: make(map[string]*botEntry), appIdx: make(map[string]string), dedup: newDedupStore(), stopCh: make(chan struct{}), } go a.dedup.cleaner(a.stopCh) go a.pollLoop() return a } // Apply is called by the Dispatcher when a robot config is created or updated. func (a *Adapter) Apply(ctx context.Context, robot *robottypes.Robot) { tgConf := extractConfig(robot) log.Debug("Apply robot=%s tgConf=%v", robot.MemberID, tgConf != nil) if tgConf != nil { log.Debug("Apply robot=%s enabled=%v token_len=%d host=%q", robot.MemberID, tgConf.Enabled, len(tgConf.BotToken), tgConf.Host) } if tgConf == nil || !tgConf.Enabled || tgConf.BotToken == "" { a.removeBot(robot.MemberID) return } a.mu.Lock() defer a.mu.Unlock() if existing, ok := a.bots[robot.MemberID]; ok { if existing.bot.Token() == tgConf.BotToken { return } a.removeBotLocked(robot.MemberID) } var opts []tgapi.BotOption if tgConf.Host != "" { opts = append(opts, tgapi.WithAPIBase(tgConf.Host)) } entry := &botEntry{ robotID: robot.MemberID, appID: tgConf.AppID, bot: tgapi.NewBot(tgConf.BotToken, tgConf.WebhookSecret, opts...), } a.bots[robot.MemberID] = entry if tgConf.AppID != "" { a.appIdx[tgConf.AppID] = robot.MemberID } log.Info("telegram adapter: registered robot=%s", robot.MemberID) } // Remove is called by the Dispatcher when a robot is deleted. func (a *Adapter) Remove(ctx context.Context, robotID string) { a.removeBot(robotID) } // Shutdown stops the polling loop, webhook subscription, and dedup cleaner. func (a *Adapter) Shutdown() { close(a.stopCh) a.StopWebhookSubscription() log.Info("telegram adapter: shutdown complete") } // ResolveBot returns the tgapi.Bot for a given appID, used by the webhook // verification layer. Returns nil if not found. func (a *Adapter) ResolveBot(appID string) *tgapi.Bot { entry, ok := a.resolveByAppID(appID) if !ok { return nil } return entry.bot } // --- Bot registry --- func (a *Adapter) removeBot(robotID string) { a.mu.Lock() defer a.mu.Unlock() a.removeBotLocked(robotID) } func (a *Adapter) removeBotLocked(robotID string) { entry, ok := a.bots[robotID] if !ok { return } if entry.appID != "" { delete(a.appIdx, entry.appID) } delete(a.bots, robotID) log.Info("telegram adapter: unregistered robot=%s", robotID) } // snapshot returns a copy of all bot entries for safe iteration outside the lock. func (a *Adapter) snapshot() []*botEntry { a.mu.RLock() defer a.mu.RUnlock() list := make([]*botEntry, 0, len(a.bots)) for _, entry := range a.bots { list = append(list, entry) } return list } func (a *Adapter) resolveByAppID(appID string) (*botEntry, bool) { a.mu.RLock() defer a.mu.RUnlock() robotID, ok := a.appIdx[appID] if !ok { return nil, false } entry, ok := a.bots[robotID] return entry, ok } func extractConfig(robot *robottypes.Robot) *robottypes.TelegramConfig { if robot.Config == nil || robot.Config.Integrations == nil { return nil } return robot.Config.Integrations.Telegram } ================================================ FILE: agent/robot/events/integrations/telegram/webhook.go ================================================ package telegram import ( "context" "encoding/json" "github.com/go-telegram/bot/models" "github.com/yaoapp/yao/event" eventtypes "github.com/yaoapp/yao/event/types" tgapi "github.com/yaoapp/yao/integrations/telegram" webhooktypes "github.com/yaoapp/yao/openapi/integrations" ) // StartWebhookSubscription subscribes to integration.webhook.telegram events. // Call once after event.Start(). func (a *Adapter) StartWebhookSubscription() { ch := make(chan *eventtypes.Event, 128) a.webhSub = event.Subscribe("integration.webhook.telegram", ch) go a.handleWebhooks(ch) log.Info("telegram adapter: webhook subscription started") } // StopWebhookSubscription unsubscribes from webhook events. func (a *Adapter) StopWebhookSubscription() { if a.webhSub != "" { event.Unsubscribe(a.webhSub) a.webhSub = "" } } func (a *Adapter) handleWebhooks(ch <-chan *eventtypes.Event) { for ev := range ch { var payload webhooktypes.WebhookPayload if err := ev.Should(&payload); err != nil { log.Error("telegram adapter: invalid webhook event: %v", err) continue } entry, ok := a.resolveByAppID(payload.AppID) if !ok { log.Warn("telegram adapter: unknown app_id=%s", payload.AppID) continue } headerSecret := payload.Headers["X-Telegram-Bot-Api-Secret-Token"] if !entry.bot.VerifyWebhook(headerSecret) { log.Warn("telegram adapter: webhook secret mismatch app_id=%s", payload.AppID) continue } var update models.Update if err := json.Unmarshal(payload.Body, &update); err != nil { log.Error("telegram adapter: webhook unmarshal failed: %v", err) continue } cm := tgapi.ConvertUpdate(&update) if cm != nil && cm.HasMedia() { groups := []string{"telegram", entry.robotID} ctx := context.Background() entry.bot.ResolveMedia(ctx, cm, groups) } a.handleMessages(context.Background(), entry, []*tgapi.ConvertedMessage{cm}) } } ================================================ FILE: agent/robot/events/log.go ================================================ package events import "github.com/yaoapp/yao/agent/robot/logger" var log = logger.New("events") ================================================ FILE: agent/robot/events/message.go ================================================ package events import ( "context" "fmt" "strings" agent "github.com/yaoapp/yao/agent" "github.com/yaoapp/yao/agent/assistant" agentcontext "github.com/yaoapp/yao/agent/context" robotstore "github.com/yaoapp/yao/agent/robot/store" robottypes "github.com/yaoapp/yao/agent/robot/types" eventtypes "github.com/yaoapp/yao/event/types" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) // handleMessage processes messages from external integrations (Telegram, etc.). // It calls the Host Agent with the provided messages and returns a MessageResult. // Action detection is done via the Host Agent's Next hook return value. func (h *robotHandler) handleMessage(ctx context.Context, ev *eventtypes.Event, resp chan<- eventtypes.Result) { var payload MessagePayload if err := ev.Should(&payload); err != nil { log.Error("message handler: invalid payload: %v", err) if ev.IsCall { resp <- eventtypes.Result{Err: err} } return } log.Info("message handler: robot=%s channel=%s msg_id=%s", payload.RobotID, payload.Metadata.Channel, payload.Metadata.MessageID) result, err := callHostAgent(ctx, &payload) if err != nil { log.Error("message handler: host agent call failed robot=%s: %v", payload.RobotID, err) if ev.IsCall { resp <- eventtypes.Result{Err: err} } return } if reply := getReplyFunc(); reply != nil && result.Message != nil { if err := reply(ctx, result.Message, payload.Metadata); err != nil { log.Error("message handler: reply failed robot=%s channel=%s: %v", payload.RobotID, payload.Metadata.Channel, err) } } if ev.IsCall { resp <- eventtypes.Result{Data: result} } } // callHostAgent resolves the Host Agent for the robot and calls it with messages. // This avoids importing executor/standard to prevent import cycles; instead it // calls the assistant directly via assistant.Get + ast.Stream (same as AgentCaller.Call). func callHostAgent(ctx context.Context, payload *MessagePayload) (*MessageResult, error) { hostID, record, err := resolveHostAssistantID(ctx, payload.RobotID) if err != nil { return nil, fmt.Errorf("failed to resolve host agent: %w", err) } ast, err := assistant.Get(hostID) if err != nil { return nil, fmt.Errorf("assistant not found: %s: %w", hostID, err) } opts := &agentcontext.Options{ Skip: &agentcontext.Skip{ Search: false, }, } authorized := &oauthtypes.AuthorizedInfo{ UserID: payload.Metadata.SenderID, } chatID := fmt.Sprintf("%s:%s", payload.Metadata.Channel, payload.Metadata.ChatID) agentCtx := agentcontext.New(ctx, authorized, chatID) agentCtx.AssistantID = hostID agentCtx.Referer = "integration" agentCtx.Locale = payload.Metadata.Locale agentCtx.Metadata = map[string]interface{}{ "robot_id": payload.RobotID, "channel": payload.Metadata.Channel, } if dsl := agent.GetAgent(); dsl != nil { if cache, err := dsl.GetCacheStore(); err == nil { agentCtx.Cache = cache } } defer agentCtx.Release() response, err := ast.Stream(agentCtx, payload.Messages, opts) if err != nil { return nil, fmt.Errorf("host agent call failed: %w", err) } result := &MessageResult{ Metadata: payload.Metadata, } if response.Completion != nil { result.Message = &agentcontext.Message{ Role: agentcontext.RoleAssistant, Content: response.Completion.Content, } } // Detect action from Next hook return value log.Debug("response.Next type=%T value=%+v", response.Next, response.Next) if action := detectAction(response.Next); action != nil { log.Info("action detected: name=%s payload=%+v", action.Name, action.Payload) result.Action = action if action.Name == "robot.execute" { if execID := executeAction(ctx, payload, record, action); execID != "" { result.ExecutionID = execID result.Message = &agentcontext.Message{ Role: agentcontext.RoleAssistant, Content: taskDeployedMessage(execID, payload.Metadata.Locale), } } } } else { log.Debug("no action detected from response.Next") } return result, nil } // executeAction triggers robot execution when the Host Agent returns a // confirmed action. Uses the injected TriggerFunc to call robotapi.TriggerManual // without creating a circular import. func executeAction(ctx context.Context, payload *MessagePayload, record *robotstore.RobotRecord, action *ActionResult) string { trigger := getTriggerFunc() if trigger == nil { log.Warn("message handler: trigger func not registered, cannot execute action for robot=%s", payload.RobotID) return "" } data, _ := action.Payload.(map[string]interface{}) goals, _ := data["goals"].(string) if goals == "" { log.Warn("message handler: confirmed action has no goals, robot=%s", payload.RobotID) return "" } triggerData := &robottypes.TriggerInput{ Data: map[string]interface{}{ "goals": goals, "channel": payload.Metadata.Channel, "chat_id": payload.Metadata.ChatID, "extra": payload.Metadata.Extra, }, } authorized := &oauthtypes.AuthorizedInfo{ UserID: record.MemberID, TeamID: record.TeamID, } rCtx := robottypes.NewContext(ctx, authorized) execID, accepted, err := trigger(rCtx, payload.RobotID, robottypes.TriggerHuman, triggerData) if err != nil { log.Error("message handler: execute action failed robot=%s: %v", payload.RobotID, err) return "" } if !accepted { log.Warn("message handler: execute action not accepted robot=%s", payload.RobotID) return "" } log.Info("message handler: execution triggered robot=%s exec_id=%s", payload.RobotID, execID) return execID } // detectAction checks the Next hook return value for a confirmed action. // The Host Agent returns { data: { confirmed: true, robot_id: "...", goals: "..." } } // when it detects a confirm_task tool call. func detectAction(next interface{}) *ActionResult { if next == nil { return nil } m, ok := next.(map[string]interface{}) if !ok { return nil } // Next hook may return { data: { confirmed, ... } } or flat { confirmed, ... } data, _ := m["data"].(map[string]interface{}) if data == nil { data = m } confirmed, _ := data["confirmed"].(bool) if !confirmed { return nil } return &ActionResult{ Name: "robot.execute", Payload: data, } } // resolveHostAssistantID resolves the host assistant ID from a robot member ID. // Mirrors the logic in openapi/agent/robot/completions.go. func resolveHostAssistantID(ctx context.Context, memberID string) (string, *robotstore.RobotRecord, error) { store := robotstore.NewRobotStore() record, err := store.Get(ctx, memberID) if err != nil { return "", nil, fmt.Errorf("failed to get robot: %w", err) } if record == nil { return "", nil, fmt.Errorf("robot not found: %s", memberID) } config, err := robottypes.ParseConfig(record.RobotConfig) if err != nil { return "", nil, fmt.Errorf("failed to parse robot config: %w", err) } var hostID string if config != nil && config.Resources != nil { hostID = config.Resources.GetPhaseAgent(robottypes.PhaseHost) } else { hostID = "__yao." + string(robottypes.PhaseHost) } return hostID, record, nil } func taskDeployedMessage(execID string, locale string) string { if strings.HasPrefix(locale, "zh") { return fmt.Sprintf("任务已部署(执行编号: %s),完成后会将结果发送给你。", execID) } return fmt.Sprintf("Task deployed (execution: %s). You will receive results once completed.", execID) } ================================================ FILE: agent/robot/executor/README.md ================================================ # Robot Executor Robot Executor provides pluggable execution strategies for robot phase execution. ## Architecture ``` executor/ ├── types/ │ ├── types.go # Interface definitions and common types │ └── helpers.go # Shared helper functions ├── standard/ │ ├── executor.go # Real Agent execution (production) │ └── phases.go # Phase implementations ├── dryrun/ │ └── executor.go # Simulated execution (testing/demo) ├── sandbox/ │ └── executor.go # Container-isolated execution (NOT IMPLEMENTED) └── executor.go # Factory functions and unified entry ``` ## Execution Modes ### Standard Mode (Production) Real Agent calls with full phase execution: ```go exec := executor.New() // or exec := executor.NewWithConfig(executor.Config{ OnPhaseStart: func(phase types.Phase) { ... }, OnPhaseEnd: func(phase types.Phase) { ... }, }) ``` ### DryRun Mode (Testing/Demo) Simulates execution without real Agent calls: ```go // Simple dry-run exec := executor.NewDryRun() // With delay simulation exec := executor.NewDryRunWithDelay(100 * time.Millisecond) // With full configuration exec := executor.NewDryRunWithConfig(executor.DryRunConfig{ Delay: 100 * time.Millisecond, OnStart: func() { ... }, OnEnd: func() { ... }, Config: executor.Config{ OnPhaseStart: func(phase types.Phase) { ... }, }, }) ``` ### Sandbox Mode (NOT IMPLEMENTED) > **⚠️ Not Implemented:** Sandbox mode requires container-level isolation (Docker/gVisor/Firecracker) for true security isolation. This is a future feature that depends on infrastructure support. **Intended Design:** Sandbox mode is designed for executing untrusted robot configurations in a fully isolated environment: - **Container Isolation:** Each execution runs in a separate container - **Resource Limits:** CPU, memory, disk, network quotas enforced by container runtime - **Network Isolation:** Restricted network access via container networking - **File System Isolation:** Read-only root filesystem, limited writable paths - **Process Isolation:** Separate PID namespace, no access to host processes **Future Implementation:** ```go // Future API (not yet implemented) exec := executor.NewSandbox(executor.SandboxConfig{ Image: "yao-executor:latest", MaxDuration: 30 * time.Minute, MaxMemory: 512 * 1024 * 1024, // 512MB MaxCPU: 1.0, // 1 CPU core NetworkPolicy: "restricted", // restricted | none | full AllowedAgents: []string{"agent1", "agent2"}, }) ``` **Current Placeholder:** The current `sandbox/executor.go` is a placeholder that behaves like DryRun mode. It does NOT provide real security isolation. ## Mode Selection Select mode dynamically: ```go // By mode constant exec := executor.NewWithMode(executor.ModeDryRun) // From settings setting := &executor.Setting{ Mode: executor.ModeStandard, } exec := executor.NewWithSetting(setting) ``` ## Interface All executors implement the `Executor` interface: ```go type Executor interface { Execute(ctx *Context, robot *Robot, trigger TriggerType, data interface{}) (*Execution, error) ExecCount() int CurrentCount() int Reset() } ``` ## Use Cases | Mode | Use Case | Status | | -------- | --------------------------------------------------- | ------------------ | | Standard | Production environment with real Agent calls | ✅ Implemented | | DryRun | Unit tests, integration tests, demos, previews | ✅ Implemented | | Sandbox | Untrusted code execution, multi-tenant environments | ⬜ Not Implemented | ## Testing Tests use DryRun mode by default: ```go func TestSomething(t *testing.T) { exec := executor.NewDryRunWithDelay(50 * time.Millisecond) // ... test with simulated execution } ``` ## Manager Integration Inject executor into Manager: ```go exec := executor.NewDryRun() config := &manager.Config{ Executor: exec, } m := manager.NewWithConfig(config) ``` ================================================ FILE: agent/robot/executor/dryrun/executor.go ================================================ package dryrun import ( "fmt" "sync/atomic" "time" "github.com/yaoapp/yao/agent/robot/executor/types" robottypes "github.com/yaoapp/yao/agent/robot/types" ) // Executor implements a dry-run executor that simulates execution // without making real Agent calls. Useful for: // - Testing scheduling and concurrency logic // - Demo and preview modes // - Debugging execution flow // - Performance testing type Executor struct { config types.DryRunConfig execCount atomic.Int32 currentCount atomic.Int32 } // New creates a new dry-run executor with default settings func New() *Executor { return &Executor{} } // NewWithDelay creates a dry-run executor with specified delay func NewWithDelay(delay time.Duration) *Executor { return &Executor{ config: types.DryRunConfig{ Delay: delay, }, } } // NewWithConfig creates a dry-run executor with full configuration func NewWithConfig(config types.DryRunConfig) *Executor { return &Executor{ config: config, } } // Execute simulates robot execution without real Agent calls (auto-generates ID) func (e *Executor) Execute(ctx *robottypes.Context, robot *robottypes.Robot, trigger robottypes.TriggerType, data interface{}) (*robottypes.Execution, error) { return e.ExecuteWithControl(ctx, robot, trigger, data, "", nil) } // ExecuteWithID simulates robot execution with a pre-generated execution ID (no control) func (e *Executor) ExecuteWithID(ctx *robottypes.Context, robot *robottypes.Robot, trigger robottypes.TriggerType, data interface{}, execID string) (*robottypes.Execution, error) { return e.ExecuteWithControl(ctx, robot, trigger, data, execID, nil) } // ExecuteWithControl simulates robot execution with execution control func (e *Executor) ExecuteWithControl(ctx *robottypes.Context, robot *robottypes.Robot, trigger robottypes.TriggerType, data interface{}, execID string, control robottypes.ExecutionControl) (*robottypes.Execution, error) { if robot == nil { return nil, fmt.Errorf("robot cannot be nil") } // Determine starting phase startPhaseIndex := 0 if trigger == robottypes.TriggerHuman || trigger == robottypes.TriggerEvent { startPhaseIndex = 1 // Skip P0 } // Use provided execID or generate new one if execID == "" { execID = fmt.Sprintf("dryrun_%d", time.Now().UnixNano()) } // Create execution record exec := &robottypes.Execution{ ID: execID, MemberID: robot.MemberID, TeamID: robot.TeamID, TriggerType: trigger, StartTime: time.Now(), Status: robottypes.ExecPending, Phase: robottypes.AllPhases[startPhaseIndex], Input: types.BuildTriggerInput(trigger, data), } // Set robot reference exec.SetRobot(robot) // Acquire slot if !robot.TryAcquireSlot(exec) { return nil, robottypes.ErrQuotaExceeded } defer robot.RemoveExecution(exec.ID) // Track counts e.execCount.Add(1) e.currentCount.Add(1) defer e.currentCount.Add(-1) // Start callback if e.config.OnStart != nil { e.config.OnStart() } if e.config.OnEnd != nil { defer e.config.OnEnd() } // Update status exec.Status = robottypes.ExecRunning // Simulate execution delay (once for entire execution, not per-phase) if e.config.Delay > 0 { time.Sleep(e.config.Delay) } // Check for simulated failure if dataStr, ok := data.(string); ok && dataStr == "simulate_failure" { exec.Status = robottypes.ExecFailed exec.Error = "simulated failure" return exec, nil } // Execute phases with mock data phases := robottypes.AllPhases[startPhaseIndex:] for _, phase := range phases { // Check if cancelled select { case <-ctx.Context.Done(): exec.Status = robottypes.ExecCancelled exec.Error = "execution cancelled" return exec, nil default: } // Wait if paused if control != nil { if err := control.WaitIfPaused(); err != nil { exec.Status = robottypes.ExecCancelled exec.Error = "execution cancelled while paused" return exec, nil } } exec.Phase = phase // Phase start callback if e.config.OnPhaseStart != nil { e.config.OnPhaseStart(phase) } // Generate mock output e.mockPhaseOutput(exec, phase) // Phase end callback if e.config.OnPhaseEnd != nil { e.config.OnPhaseEnd(phase) } } // Mark completed exec.Status = robottypes.ExecCompleted now := time.Now() exec.EndTime = &now return exec, nil } // mockPhaseOutput generates mock output for each phase func (e *Executor) mockPhaseOutput(exec *robottypes.Execution, phase robottypes.Phase) { switch phase { case robottypes.PhaseInspiration: exec.Inspiration = &robottypes.InspirationReport{ Clock: robottypes.NewClockContext(time.Now(), ""), Content: "## Dry-Run Inspiration\n\nThis is a simulated inspiration report for testing.", } case robottypes.PhaseGoals: exec.Goals = &robottypes.Goals{ Content: "## Dry-Run Goals\n\n1. [High] Simulated goal for testing", } case robottypes.PhaseTasks: exec.Tasks = []robottypes.Task{ { ID: "dryrun-task-1", GoalRef: "Goal 1", Source: robottypes.TaskSourceAuto, ExecutorType: robottypes.ExecutorAssistant, ExecutorID: "mock-agent", Status: robottypes.TaskPending, }, } case robottypes.PhaseRun: exec.Results = []robottypes.TaskResult{ { TaskID: "dryrun-task-1", Success: true, Output: map[string]interface{}{"mode": "dryrun", "result": "simulated"}, Duration: 100, Validation: &robottypes.ValidationResult{ Passed: true, Score: 1.0, }, }, } case robottypes.PhaseDelivery: exec.Delivery = &robottypes.DeliveryResult{ RequestID: "dryrun-" + exec.ID, Content: &robottypes.DeliveryContent{ Summary: "Dry-run delivery completed", Body: "# Dry-run Delivery\n\nThis is a simulated delivery result.", }, Success: true, } case robottypes.PhaseLearning: exec.Learning = []robottypes.LearningEntry{ { Type: robottypes.LearnExecution, Content: "Dry-run execution completed successfully", }, } } } // ExecCount returns total execution count func (e *Executor) ExecCount() int { return int(e.execCount.Load()) } // CurrentCount returns currently running execution count func (e *Executor) CurrentCount() int { return int(e.currentCount.Load()) } // Reset resets the executor counters func (e *Executor) Reset() { e.execCount.Store(0) e.currentCount.Store(0) } // Resume is not supported in dry-run mode func (e *Executor) Resume(ctx *robottypes.Context, execID string, reply string) error { return fmt.Errorf("resume is not supported in dry-run executor") } // Verify Executor implements types.Executor var _ types.Executor = (*Executor)(nil) ================================================ FILE: agent/robot/executor/executor.go ================================================ // Package executor provides robot execution strategies // // Architecture: // // executor/ // ├── types/ // │ ├── types.go # Interface definitions and common types // │ └── helpers.go # Shared helper functions // ├── standard/ // │ ├── executor.go # Real Agent execution (production) // │ ├── agent.go # AgentCaller for LLM calls // │ ├── input.go # InputFormatter for prompts // │ ├── inspiration.go # P0: Inspiration phase // │ ├── goals.go # P1: Goals phase // │ ├── tasks.go # P2: Tasks phase // │ ├── run.go # P3: Run phase // │ ├── delivery.go # P4: Delivery phase // │ └── learning.go # P5: Learning phase // ├── dryrun/ // │ └── executor.go # Simulated execution (testing/demo) // ├── sandbox/ // │ └── executor.go # Container-isolated execution (NOT IMPLEMENTED) // └── executor.go # Factory functions (this file) // // Usage: // // // Production - real Agent calls // exec := executor.New() // // // Testing - simulated execution // exec := executor.NewDryRun() // // // Sandbox - NOT IMPLEMENTED (requires container infrastructure) // // exec := executor.NewSandbox() // placeholder only // // // With mode selection // exec := executor.NewWithMode(executor.ModeDryRun) package executor import ( "time" "github.com/yaoapp/yao/agent/robot/executor/dryrun" "github.com/yaoapp/yao/agent/robot/executor/sandbox" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/executor/types" robottypes "github.com/yaoapp/yao/agent/robot/types" ) // Re-export types for convenience type ( Executor = types.Executor Config = types.Config DryRunConfig = types.DryRunConfig SandboxConfig = types.SandboxConfig Mode = types.Mode Setting = types.Setting ) // Re-export mode constants const ( ModeStandard = types.ModeStandard ModeDryRun = types.ModeDryRun ModeSandbox = types.ModeSandbox ) // ==================== Factory Functions ==================== // New creates a new standard executor (production mode) // Uses real Agent calls for phase execution func New() Executor { return standard.New() } // NewWithConfig creates a standard executor with configuration func NewWithConfig(config Config) Executor { return standard.NewWithConfig(config) } // NewDryRun creates a dry-run executor (testing/demo mode) // Simulates execution without real Agent calls func NewDryRun() Executor { return dryrun.New() } // NewDryRunWithDelay creates a dry-run executor with specified delay func NewDryRunWithDelay(delay time.Duration) *DryRunExecutor { return dryrun.NewWithDelay(delay) } // NewDryRunWithConfig creates a dry-run executor with full configuration func NewDryRunWithConfig(config DryRunConfig) *DryRunExecutor { return dryrun.NewWithConfig(config) } // NewDryRunWithCallbacks creates a dry-run executor with start/end callbacks func NewDryRunWithCallbacks(delay time.Duration, onStart, onEnd func()) *DryRunExecutor { return dryrun.NewWithConfig(DryRunConfig{ Delay: delay, OnStart: onStart, OnEnd: onEnd, }) } // NewSandbox creates a sandbox executor placeholder // // ⚠️ NOT IMPLEMENTED: True sandbox requires container-level isolation // (Docker/gVisor/Firecracker). Current implementation behaves like DryRun. func NewSandbox() Executor { return sandbox.New() } // NewSandboxWithConfig creates a sandbox executor placeholder with configuration // // ⚠️ NOT IMPLEMENTED: Config options are placeholders. Current implementation // behaves like DryRun and does NOT provide real security isolation. func NewSandboxWithConfig(config SandboxConfig) Executor { return sandbox.NewWithConfig(config) } // NewWithMode creates an executor based on the specified mode func NewWithMode(mode Mode) Executor { switch mode { case ModeDryRun: return NewDryRun() case ModeSandbox: return NewSandbox() default: return New() } } // NewWithSetting creates an executor based on configuration settings func NewWithSetting(setting *Setting) Executor { if setting == nil { return New() } switch setting.Mode { case ModeDryRun: return NewDryRun() case ModeSandbox: return NewSandboxWithConfig(SandboxConfig{ MaxDuration: setting.MaxDuration, MaxMemory: setting.MaxMemory, AllowedAgents: setting.AllowedAgents, NetworkAccess: setting.NetworkAccess, FileAccess: setting.FileAccess, }) default: return New() } } // ==================== Concrete Types ==================== // Export concrete executor types for direct access when needed // DryRunExecutor is the concrete dry-run executor type type DryRunExecutor = dryrun.Executor // StandardExecutor is the concrete standard executor type type StandardExecutor = standard.Executor // SandboxExecutor is the concrete sandbox executor type type SandboxExecutor = sandbox.Executor // ==================== Interface Verification ==================== // Verify all executors implement the Executor interface var ( _ Executor = (*standard.Executor)(nil) _ Executor = (*dryrun.Executor)(nil) _ Executor = (*sandbox.Executor)(nil) ) // Verify standard executor implements PhaseExecutor var _ types.PhaseExecutor = (*standard.Executor)(nil) // ==================== Helper Types ==================== // DefaultSetting returns default executor settings func DefaultSetting() *Setting { return types.DefaultSetting() } // PhaseExecutor is the interface for phase execution type PhaseExecutor = types.PhaseExecutor // ==================== Context Helpers ==================== // These are re-exported from robot types for convenience type ( Context = robottypes.Context Robot = robottypes.Robot Execution = robottypes.Execution Phase = robottypes.Phase ) ================================================ FILE: agent/robot/executor/executor_test.go ================================================ package executor import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/robot/types" ) // Smoke tests to verify basic flow works // Real integration tests are in manager_test.go func TestExecutorSmoke(t *testing.T) { exec := NewDryRunWithDelay(0) robot := &types.Robot{ MemberID: "test-smoke", TeamID: "team-1", Config: &types.Config{Quota: &types.Quota{Max: 1}}, } ctx := types.NewContext(context.Background(), nil) result, err := exec.Execute(ctx, robot, types.TriggerClock, nil) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, types.ExecCompleted, result.Status) assert.Equal(t, types.TriggerClock, result.TriggerType) // Clock trigger executes all phases (P0-P5) assert.NotNil(t, result.Inspiration, "P0 should be executed for clock trigger") assert.NotNil(t, result.Goals, "P1 should be executed") assert.NotEmpty(t, result.Tasks, "P2 should generate tasks") assert.NotEmpty(t, result.Results, "P3 should generate results") assert.NotNil(t, result.Delivery, "P4 should be executed") assert.NotEmpty(t, result.Learning, "P5 should be executed") } func TestExecutorHumanTriggerSkipsP0(t *testing.T) { exec := NewDryRunWithDelay(0) robot := &types.Robot{ MemberID: "test-human", TeamID: "team-1", Config: &types.Config{Quota: &types.Quota{Max: 1}}, } ctx := types.NewContext(context.Background(), nil) result, err := exec.Execute(ctx, robot, types.TriggerHuman, nil) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, types.ExecCompleted, result.Status) // Human trigger skips P0 (Inspiration) assert.Nil(t, result.Inspiration, "P0 should be skipped for human trigger") assert.NotNil(t, result.Goals, "P1 should be executed") } func TestExecutorEventTriggerSkipsP0(t *testing.T) { exec := NewDryRunWithDelay(0) robot := &types.Robot{ MemberID: "test-event", TeamID: "team-1", Config: &types.Config{Quota: &types.Quota{Max: 1}}, } ctx := types.NewContext(context.Background(), nil) result, err := exec.Execute(ctx, robot, types.TriggerEvent, nil) assert.NoError(t, err) assert.Nil(t, result.Inspiration, "P0 should be skipped for event trigger") assert.NotNil(t, result.Goals) } func TestExecutorNilRobot(t *testing.T) { exec := NewDryRunWithDelay(0) ctx := types.NewContext(context.Background(), nil) result, err := exec.Execute(ctx, nil, types.TriggerClock, nil) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "robot cannot be nil") } func TestExecutorSimulatedFailure(t *testing.T) { exec := NewDryRunWithDelay(0) robot := &types.Robot{ MemberID: "test-fail", TeamID: "team-1", Config: &types.Config{Quota: &types.Quota{Max: 1}}, } ctx := types.NewContext(context.Background(), nil) // Pass "simulate_failure" to trigger simulated failure result, err := exec.Execute(ctx, robot, types.TriggerClock, "simulate_failure") assert.NoError(t, err) // Execute returns nil error, failure is in result assert.NotNil(t, result) assert.Equal(t, types.ExecFailed, result.Status) assert.Equal(t, "simulated failure", result.Error) } func TestExecutorCounters(t *testing.T) { exec := NewDryRunWithDelay(0) robot := &types.Robot{ MemberID: "test-counter", TeamID: "team-1", Config: &types.Config{Quota: &types.Quota{Max: 10}}, } ctx := types.NewContext(context.Background(), nil) assert.Equal(t, 0, exec.ExecCount()) assert.Equal(t, 0, exec.CurrentCount()) _, _ = exec.Execute(ctx, robot, types.TriggerClock, nil) assert.Equal(t, 1, exec.ExecCount()) assert.Equal(t, 0, exec.CurrentCount()) // Completed, so 0 _, _ = exec.Execute(ctx, robot, types.TriggerClock, nil) assert.Equal(t, 2, exec.ExecCount()) exec.Reset() assert.Equal(t, 0, exec.ExecCount()) } ================================================ FILE: agent/robot/executor/sandbox/executor.go ================================================ package sandbox import ( "context" "fmt" "sync/atomic" "time" "github.com/yaoapp/yao/agent/robot/executor/types" robottypes "github.com/yaoapp/yao/agent/robot/types" ) // Executor implements a sandboxed executor placeholder. // // ⚠️ NOT IMPLEMENTED: True sandbox mode requires container-level isolation // (Docker/gVisor/Firecracker) for security. This placeholder currently // behaves like DryRun mode and does NOT provide real security isolation. // // Future Implementation: // - Container isolation: Each execution in separate container // - Resource limits: CPU, memory, disk enforced by container runtime // - Network isolation: Restricted network via container networking // - File system isolation: Read-only root, limited writable paths // - Process isolation: Separate PID namespace // // Current behavior: Simulates execution with mock data (same as DryRun) type Executor struct { config types.SandboxConfig execCount atomic.Int32 currentCount atomic.Int32 } // New creates a new sandbox executor with default settings func New() *Executor { return &Executor{ config: types.SandboxConfig{ MaxDuration: 30 * time.Minute, NetworkAccess: true, FileAccess: false, }, } } // NewWithConfig creates a sandbox executor with custom configuration func NewWithConfig(config types.SandboxConfig) *Executor { return &Executor{ config: config, } } // Execute runs robot execution within sandbox constraints (auto-generates ID) func (e *Executor) Execute(ctx *robottypes.Context, robot *robottypes.Robot, trigger robottypes.TriggerType, data interface{}) (*robottypes.Execution, error) { return e.ExecuteWithControl(ctx, robot, trigger, data, "", nil) } // ExecuteWithID runs robot execution within sandbox constraints with a pre-generated execution ID (no control) func (e *Executor) ExecuteWithID(ctx *robottypes.Context, robot *robottypes.Robot, trigger robottypes.TriggerType, data interface{}, execID string) (*robottypes.Execution, error) { return e.ExecuteWithControl(ctx, robot, trigger, data, execID, nil) } // ExecuteWithControl runs robot execution within sandbox constraints with execution control func (e *Executor) ExecuteWithControl(ctx *robottypes.Context, robot *robottypes.Robot, trigger robottypes.TriggerType, data interface{}, execID string, control robottypes.ExecutionControl) (*robottypes.Execution, error) { if robot == nil { return nil, fmt.Errorf("robot cannot be nil") } // Create timeout context execCtx, cancel := context.WithTimeout(ctx.Context, e.config.MaxDuration) defer cancel() // Create new context with timeout sandboxCtx := robottypes.NewContext(execCtx, ctx.Auth) // Determine starting phase startPhaseIndex := 0 if trigger == robottypes.TriggerHuman || trigger == robottypes.TriggerEvent { startPhaseIndex = 1 } // Use provided execID or generate new one if execID == "" { execID = fmt.Sprintf("sandbox_%d", time.Now().UnixNano()) } // Create execution record exec := &robottypes.Execution{ ID: execID, MemberID: robot.MemberID, TeamID: robot.TeamID, TriggerType: trigger, StartTime: time.Now(), Status: robottypes.ExecPending, Phase: robottypes.AllPhases[startPhaseIndex], Input: types.BuildTriggerInput(trigger, data), } // Set robot reference exec.SetRobot(robot) // Acquire slot if !robot.TryAcquireSlot(exec) { return nil, robottypes.ErrQuotaExceeded } defer robot.RemoveExecution(exec.ID) // Track counts e.execCount.Add(1) e.currentCount.Add(1) defer e.currentCount.Add(-1) // Update status exec.Status = robottypes.ExecRunning // Execute phases with sandbox constraints phases := robottypes.AllPhases[startPhaseIndex:] for _, phase := range phases { // Check timeout or cancellation select { case <-execCtx.Done(): exec.Status = robottypes.ExecFailed exec.Error = "execution timeout exceeded" return exec, nil default: } // Wait if paused if control != nil { if err := control.WaitIfPaused(); err != nil { exec.Status = robottypes.ExecCancelled exec.Error = "execution cancelled while paused" return exec, nil } } exec.Phase = phase if e.config.OnPhaseStart != nil { e.config.OnPhaseStart(phase) } // Execute phase with sandbox constraints if err := e.runSandboxedPhase(sandboxCtx, exec, phase, data); err != nil { exec.Status = robottypes.ExecFailed exec.Error = err.Error() return exec, nil } if e.config.OnPhaseEnd != nil { e.config.OnPhaseEnd(phase) } } // Mark completed exec.Status = robottypes.ExecCompleted now := time.Now() exec.EndTime = &now return exec, nil } // runSandboxedPhase executes a phase with sandbox constraints func (e *Executor) runSandboxedPhase(ctx *robottypes.Context, exec *robottypes.Execution, phase robottypes.Phase, data interface{}) error { // Validate agent is allowed (if whitelist is set) if len(e.config.AllowedAgents) > 0 { robot := exec.GetRobot() if robot != nil && robot.Config != nil && robot.Config.Resources != nil { agentID := robot.Config.Resources.GetPhaseAgent(phase) if !e.isAgentAllowed(agentID) { return fmt.Errorf("agent %s is not allowed in sandbox", agentID) } } } // For now, generate mock output (real implementation would call agents with restrictions) e.mockPhaseOutput(exec, phase) return nil } // isAgentAllowed checks if an agent is in the whitelist func (e *Executor) isAgentAllowed(agentID string) bool { for _, allowed := range e.config.AllowedAgents { if allowed == agentID || allowed == "*" { return true } } return false } // mockPhaseOutput generates mock output for each phase func (e *Executor) mockPhaseOutput(exec *robottypes.Execution, phase robottypes.Phase) { switch phase { case robottypes.PhaseInspiration: exec.Inspiration = &robottypes.InspirationReport{ Clock: robottypes.NewClockContext(time.Now(), ""), Content: "## Sandbox Inspiration\n\nExecuted in isolated sandbox environment.", } case robottypes.PhaseGoals: exec.Goals = &robottypes.Goals{ Content: "## Sandbox Goals\n\n1. [High] Sandboxed goal execution", } case robottypes.PhaseTasks: exec.Tasks = []robottypes.Task{ { ID: "sandbox-task-1", GoalRef: "Goal 1", Source: robottypes.TaskSourceAuto, ExecutorType: robottypes.ExecutorAssistant, ExecutorID: "sandbox-agent", Status: robottypes.TaskPending, }, } case robottypes.PhaseRun: exec.Results = []robottypes.TaskResult{ { TaskID: "sandbox-task-1", Success: true, Output: map[string]interface{}{"mode": "sandbox", "isolated": true}, Duration: 50, Validation: &robottypes.ValidationResult{ Passed: true, Score: 1.0, }, }, } case robottypes.PhaseDelivery: exec.Delivery = &robottypes.DeliveryResult{ RequestID: "sandbox-" + exec.ID, Content: &robottypes.DeliveryContent{ Summary: "Sandbox delivery completed", Body: "# Sandbox Delivery\n\nThis is a simulated sandbox delivery result.", }, Success: true, } case robottypes.PhaseLearning: exec.Learning = []robottypes.LearningEntry{ { Type: robottypes.LearnExecution, Content: "Sandbox execution completed within constraints", }, } } } // ExecCount returns total execution count func (e *Executor) ExecCount() int { return int(e.execCount.Load()) } // CurrentCount returns currently running execution count func (e *Executor) CurrentCount() int { return int(e.currentCount.Load()) } // Reset resets the executor counters func (e *Executor) Reset() { e.execCount.Store(0) e.currentCount.Store(0) } // Resume is not supported in sandbox mode func (e *Executor) Resume(ctx *robottypes.Context, execID string, reply string) error { return fmt.Errorf("resume is not supported in sandbox executor") } // Verify Executor implements types.Executor var _ types.Executor = (*Executor)(nil) ================================================ FILE: agent/robot/executor/standard/agent.go ================================================ package standard import ( "fmt" "github.com/yaoapp/gou/text" "github.com/yaoapp/yao/agent/assistant" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" robottypes "github.com/yaoapp/yao/agent/robot/types" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) // StreamCallback receives text chunks during streaming agent calls. // Return 0 to continue, non-zero to stop. type StreamCallback func(chunk *StreamChunk) int // StreamChunk represents a single chunk in a streaming response. type StreamChunk struct { Type string // "text", "thinking", "event" Content string Delta bool } // AgentCaller provides unified interface for calling AI assistants // It wraps the Yao Assistant framework and handles: // - Getting assistant by ID // - Single call with messages (streaming) // - Multi-turn conversation with session state // - Parsing responses (text, JSON, Next hook data) type AgentCaller struct { // SkipOutput skips sending output to client (for internal calls) SkipOutput bool // SkipHistory skips saving to chat history (default: true for robot) // Set to false to enable multi-turn conversation with history SkipHistory bool // SkipSearch skips auto search SkipSearch bool // ChatID is used for multi-turn conversations to maintain session state // If empty, each call is independent (no history) ChatID string // Connector overrides the assistant's default LLM connector (from Robot.LanguageModel). // When non-empty, passed as opts.Connector to ast.Stream so the agent uses the Robot's model. Connector string // log is an optional structured logger; when set, Call emits agent-call logs. log *execLogger } // NewAgentCaller creates a new AgentCaller with default settings (single-call mode) func NewAgentCaller() *AgentCaller { return &AgentCaller{ SkipOutput: true, // Robot executions don't send to UI SkipHistory: true, // Robot executions don't save to chat history SkipSearch: true, // Robot executions don't trigger auto search } } // NewConversationCaller creates an AgentCaller for multi-turn conversations // chatID is used to maintain session state across calls // This is useful for: // - P2 (Tasks): Iterative task refinement with user feedback // - P3 (Run): Multi-step task execution with intermediate results func NewConversationCaller(chatID string) *AgentCaller { return &AgentCaller{ SkipOutput: true, SkipHistory: false, // Enable history for multi-turn SkipSearch: true, ChatID: chatID, } } // CallResult holds the result of an agent call type CallResult struct { // Content is the raw text content from LLM completion Content string // Next is the data returned from Next hook (if any) // This is typically a structured response from the assistant Next interface{} // Response is the full response object (for advanced use) Response *agentcontext.Response } // IsEmpty returns true if the result has no content func (r *CallResult) IsEmpty() bool { return r.Content == "" && r.Next == nil } // GetText returns the text content, preferring Content over Next func (r *CallResult) GetText() string { if r.Content != "" { return r.Content } if s, ok := r.Next.(string); ok { return s } if m, ok := r.Next.(map[string]interface{}); ok { if content, ok := m["content"].(string); ok { return content } if data, ok := m["data"].(map[string]interface{}); ok { if content, ok := data["content"].(string); ok { return content } } } return "" } // GetJSON attempts to parse the result as JSON // It tries in order: // 1. Next hook data (already structured) // 2. Content parsed using gou/text.ExtractJSON (fault-tolerant) // Returns the parsed data and any error func (r *CallResult) GetJSON() (map[string]interface{}, error) { if r.Next != nil { if m, ok := r.Next.(map[string]interface{}); ok { if data, ok := m["data"].(map[string]interface{}); ok { return data, nil } return m, nil } } if r.Content != "" { data := text.ExtractJSON(r.Content) if data != nil { if m, ok := data.(map[string]interface{}); ok { return m, nil } } return nil, fmt.Errorf("content is not a JSON object") } return nil, fmt.Errorf("no content to parse") } // GetJSONArray attempts to parse the result as JSON array // Similar to GetJSON but for array responses func (r *CallResult) GetJSONArray() ([]interface{}, error) { // Try Next hook data first if r.Next != nil { if arr, ok := r.Next.([]interface{}); ok { return arr, nil } if m, ok := r.Next.(map[string]interface{}); ok { // Check for "data" wrapper if data, ok := m["data"].([]interface{}); ok { return data, nil } } } // Try parsing Content using gou/text (handles markdown blocks, JSON, YAML) if r.Content != "" { data := text.ExtractJSON(r.Content) if data != nil { if arr, ok := data.([]interface{}); ok { return arr, nil } } return nil, fmt.Errorf("content is not a JSON array") } return nil, fmt.Errorf("no content to parse") } // Call calls an assistant with messages and returns the result // This is the main entry point for agent calls func (c *AgentCaller) Call(ctx *robottypes.Context, assistantID string, messages []agentcontext.Message) (*CallResult, error) { // Get assistant ast, err := assistant.Get(assistantID) if err != nil { return nil, fmt.Errorf("assistant not found: %s: %w", assistantID, err) } // Build options opts := &agentcontext.Options{ Skip: &agentcontext.Skip{ Output: c.SkipOutput, History: c.SkipHistory, Search: c.SkipSearch, }, Connector: c.Connector, } // Convert robot context to agent context agentCtx := c.buildAgentContext(ctx) defer agentCtx.Release() // IMPORTANT: Release agent context to prevent resource leaks // Call assistant with streaming response, err := ast.Stream(agentCtx, messages, opts) if err != nil { return nil, fmt.Errorf("assistant call failed: %w", err) } // Build result result := &CallResult{ Response: response, } // Extract Next hook data if response.Next != nil { result.Next = response.Next } // Extract Content from Completion if response.Completion != nil { if content, ok := response.Completion.Content.(string); ok { result.Content = content } } if c.log != nil { c.log.logAgentCall(assistantID, result) } return result, nil } // CallWithMessages is a convenience method that builds messages from a single user input func (c *AgentCaller) CallWithMessages(ctx *robottypes.Context, assistantID string, userContent string) (*CallResult, error) { messages := []agentcontext.Message{ { Role: agentcontext.RoleUser, Content: userContent, }, } return c.Call(ctx, assistantID, messages) } // CallWithSystemAndUser calls with both system and user messages func (c *AgentCaller) CallWithSystemAndUser(ctx *robottypes.Context, assistantID string, systemContent, userContent string) (*CallResult, error) { messages := []agentcontext.Message{ { Role: agentcontext.RoleSystem, Content: systemContent, }, { Role: agentcontext.RoleUser, Content: userContent, }, } return c.Call(ctx, assistantID, messages) } // CallStream calls an assistant with messages and streams text chunks via callback. // The callback receives each text delta in real-time while the response is being generated. // After streaming completes, the full CallResult is returned. func (c *AgentCaller) CallStream(ctx *robottypes.Context, assistantID string, messages []agentcontext.Message, streamFn StreamCallback) (*CallResult, error) { ast, err := assistant.Get(assistantID) if err != nil { return nil, fmt.Errorf("assistant not found: %s: %w", assistantID, err) } opts := &agentcontext.Options{ Skip: &agentcontext.Skip{ Output: c.SkipOutput, History: c.SkipHistory, Search: c.SkipSearch, }, Connector: c.Connector, } // Hook OnMessage to intercept streaming chunks and forward to callback if streamFn != nil { opts.OnMessage = func(msg *message.Message) int { if msg == nil { return 0 } switch msg.Type { case message.TypeText: if msg.Delta { content, _ := msg.Props["content"].(string) if content != "" { return streamFn(&StreamChunk{Type: "text", Content: content, Delta: true}) } } case message.TypeThinking: if msg.Delta { content, _ := msg.Props["content"].(string) if content != "" { return streamFn(&StreamChunk{Type: "thinking", Content: content, Delta: true}) } } } return 0 } } agentCtx := c.buildAgentContext(ctx) defer agentCtx.Release() response, err := ast.Stream(agentCtx, messages, opts) if err != nil { return nil, fmt.Errorf("assistant call failed: %w", err) } result := &CallResult{Response: response} if response.Next != nil { result.Next = response.Next } if response.Completion != nil { if content, ok := response.Completion.Content.(string); ok { result.Content = content } } if c.log != nil { c.log.logAgentCall(assistantID, result) } return result, nil } // CallWithMessagesStream is a convenience method that streams a single user input. func (c *AgentCaller) CallWithMessagesStream(ctx *robottypes.Context, assistantID string, userContent string, streamFn StreamCallback) (*CallResult, error) { messages := []agentcontext.Message{ { Role: agentcontext.RoleUser, Content: userContent, }, } return c.CallStream(ctx, assistantID, messages, streamFn) } // CallStreamRaw calls an assistant with streaming, passing raw message.Message objects // to the callback without any type filtering or degradation. This preserves all CUI // message protocol fields (chunk_id, message_id, block_id, delta_path, etc.) // for direct SSE passthrough to the frontend. func (c *AgentCaller) CallStreamRaw(ctx *robottypes.Context, assistantID string, messages []agentcontext.Message, onMessage agentcontext.OnMessageFunc) (*CallResult, error) { ast, err := assistant.Get(assistantID) if err != nil { return nil, fmt.Errorf("assistant not found: %s: %w", assistantID, err) } opts := &agentcontext.Options{ Skip: &agentcontext.Skip{ Output: c.SkipOutput, History: c.SkipHistory, Search: c.SkipSearch, }, Connector: c.Connector, } if onMessage != nil { opts.OnMessage = onMessage } agentCtx := c.buildAgentContext(ctx) defer agentCtx.Release() response, err := ast.Stream(agentCtx, messages, opts) if err != nil { return nil, fmt.Errorf("assistant call failed: %w", err) } result := &CallResult{Response: response} if response.Next != nil { result.Next = response.Next } if response.Completion != nil { if content, ok := response.Completion.Content.(string); ok { result.Content = content } } if c.log != nil { c.log.logAgentCall(assistantID, result) } return result, nil } // CallWithMessagesStreamRaw is a convenience method that streams raw messages for a single user input. func (c *AgentCaller) CallWithMessagesStreamRaw(ctx *robottypes.Context, assistantID string, userContent string, onMessage agentcontext.OnMessageFunc) (*CallResult, error) { messages := []agentcontext.Message{ { Role: agentcontext.RoleUser, Content: userContent, }, } return c.CallStreamRaw(ctx, assistantID, messages, onMessage) } // buildAgentContext converts robot context to agent context func (c *AgentCaller) buildAgentContext(ctx *robottypes.Context) *agentcontext.Context { // Build authorized info for agent context var authorized *oauthtypes.AuthorizedInfo if ctx.Auth != nil { authorized = &oauthtypes.AuthorizedInfo{ UserID: ctx.Auth.UserID, TeamID: ctx.Auth.TeamID, } } // Create a new agent context // Use ChatID for multi-turn conversations, empty for single calls agentCtx := agentcontext.New(ctx.Context, authorized, c.ChatID) // Set locale if available if ctx.Locale != "" { agentCtx.Locale = ctx.Locale } // Use noop logger to suppress LLM debug output for robot executions // Robot executions run in background and don't need console output if agentCtx.Logger != nil { agentCtx.Logger.Close() } agentCtx.Logger = agentcontext.Noop() return agentCtx } // ExtractCodeBlock extracts the first code block from content using gou/text // Returns the CodeBlock with type, content, and parsed data (for JSON/YAML) func ExtractCodeBlock(content string) *text.CodeBlock { return text.ExtractFirst(content) } // ExtractAllCodeBlocks extracts all code blocks from content using gou/text func ExtractAllCodeBlocks(content string) []text.CodeBlock { return text.Extract(content) } // ============================================================================ // Conversation - Multi-turn dialogue support // ============================================================================ // Conversation manages a multi-turn dialogue with an assistant // Useful for: // - P2 (Tasks): Iterative task planning with clarification // - P3 (Run): Multi-step execution with intermediate validation // - Complex reasoning that requires back-and-forth type Conversation struct { caller *AgentCaller assistantID string messages []agentcontext.Message maxTurns int } // TurnResult holds the result of a single conversation turn type TurnResult struct { Turn int // Turn number (1-based) Input string // User input for this turn Result *CallResult // Agent response Messages []agentcontext.Message // Full message history after this turn } // NewConversation creates a new multi-turn conversation // assistantID: the assistant to converse with // chatID: session ID for maintaining state (use exec.ID for robot executions) // maxTurns: maximum number of turns (0 = unlimited) func NewConversation(assistantID, chatID string, maxTurns int) *Conversation { return &Conversation{ caller: NewConversationCaller(chatID), assistantID: assistantID, messages: make([]agentcontext.Message, 0), maxTurns: maxTurns, } } // WithCaller sets a custom AgentCaller for the conversation // Useful for customizing SkipSearch or other options func (c *Conversation) WithCaller(caller *AgentCaller) *Conversation { c.caller = caller return c } // WithSystemPrompt adds a system prompt at the beginning of the conversation func (c *Conversation) WithSystemPrompt(systemPrompt string) *Conversation { if systemPrompt != "" { c.messages = append([]agentcontext.Message{{ Role: agentcontext.RoleSystem, Content: systemPrompt, }}, c.messages...) } return c } // WithHistory initializes the conversation with existing message history // Note: Message structs are copied, but Content (interface{}) is a shallow copy func (c *Conversation) WithHistory(messages []agentcontext.Message) *Conversation { c.messages = append(c.messages, messages...) return c } // Turn executes a single turn in the conversation // userInput: the user's message for this turn // Returns the turn result with agent response func (c *Conversation) Turn(ctx *robottypes.Context, userInput string) (*TurnResult, error) { // Check max turns turnNum := c.TurnCount() + 1 if c.maxTurns > 0 && turnNum > c.maxTurns { return nil, fmt.Errorf("max turns (%d) exceeded", c.maxTurns) } // Build messages with user input (don't modify history yet) userMsg := agentcontext.Message{ Role: agentcontext.RoleUser, Content: userInput, } // Create a new slice to avoid modifying c.messages if capacity allows append in-place messagesWithInput := make([]agentcontext.Message, len(c.messages)+1) copy(messagesWithInput, c.messages) messagesWithInput[len(c.messages)] = userMsg // Call assistant with full history result, err := c.caller.Call(ctx, c.assistantID, messagesWithInput) if err != nil { return nil, fmt.Errorf("turn %d failed: %w", turnNum, err) } // Only update history after successful call c.messages = append(c.messages, userMsg) // Add assistant response to history if result.Content != "" { c.messages = append(c.messages, agentcontext.Message{ Role: agentcontext.RoleAssistant, Content: result.Content, }) } // Return a copy of messages to prevent external modification messagesCopy := make([]agentcontext.Message, len(c.messages)) copy(messagesCopy, c.messages) return &TurnResult{ Turn: turnNum, Input: userInput, Result: result, Messages: messagesCopy, }, nil } // TurnCount returns the number of user turns so far func (c *Conversation) TurnCount() int { count := 0 for _, msg := range c.messages { if msg.Role == agentcontext.RoleUser { count++ } } return count } // Messages returns a copy of the current message history func (c *Conversation) Messages() []agentcontext.Message { messagesCopy := make([]agentcontext.Message, len(c.messages)) copy(messagesCopy, c.messages) return messagesCopy } // LastResponse returns a copy of the last assistant response, or nil if none func (c *Conversation) LastResponse() *agentcontext.Message { for i := len(c.messages) - 1; i >= 0; i-- { if c.messages[i].Role == agentcontext.RoleAssistant { // Return a copy to prevent external modification msg := c.messages[i] return &msg } } return nil } // Reset clears the conversation history (keeps system prompt if any) func (c *Conversation) Reset() { // Keep system prompt if present var systemPrompt *agentcontext.Message if len(c.messages) > 0 && c.messages[0].Role == agentcontext.RoleSystem { systemPrompt = &c.messages[0] } c.messages = make([]agentcontext.Message, 0) if systemPrompt != nil { c.messages = append(c.messages, *systemPrompt) } } // RunUntil runs the conversation until a condition is met // checkFn: called after each turn, returns (done, error) // Returns all turn results func (c *Conversation) RunUntil( ctx *robottypes.Context, inputFn func(turn int, lastResult *CallResult) (string, error), checkFn func(turn int, result *CallResult) (done bool, err error), ) ([]*TurnResult, error) { var results []*TurnResult for { turnNum := c.TurnCount() + 1 // Check max turns if c.maxTurns > 0 && turnNum > c.maxTurns { return results, fmt.Errorf("max turns (%d) exceeded without completion", c.maxTurns) } // Get input for this turn var lastResult *CallResult if len(results) > 0 { lastResult = results[len(results)-1].Result } input, err := inputFn(turnNum, lastResult) if err != nil { return results, fmt.Errorf("input generation failed at turn %d: %w", turnNum, err) } // Execute turn turnResult, err := c.Turn(ctx, input) if err != nil { return results, err } results = append(results, turnResult) // Check completion condition done, err := checkFn(turnNum, turnResult.Result) if err != nil { return results, fmt.Errorf("check failed at turn %d: %w", turnNum, err) } if done { return results, nil } } } ================================================ FILE: agent/robot/executor/standard/agent_stream_test.go ================================================ package standard_test import ( "context" "strings" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) func TestAgentCallerCallStream(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test (requires LLM)") } testutils.Prepare(t) defer testutils.Clean(t) caller := standard.NewAgentCaller() ctx := types.NewContext(context.Background(), testAuth()) t.Run("streams text chunks and returns result", func(t *testing.T) { var mu sync.Mutex var chunks []string streamFn := func(chunk *standard.StreamChunk) int { mu.Lock() defer mu.Unlock() if chunk.Type == "text" && chunk.Delta { chunks = append(chunks, chunk.Content) } return 0 } result, err := caller.CallWithMessagesStream(ctx, "tests.robot-single", "Hello, test message", streamFn) require.NoError(t, err) require.NotNil(t, result) assert.False(t, result.IsEmpty()) mu.Lock() combined := strings.Join(chunks, "") chunkCount := len(chunks) mu.Unlock() t.Logf("Received %d text chunks, total length: %d", chunkCount, len(combined)) assert.Greater(t, chunkCount, 0, "should have received at least one text chunk") assert.NotEmpty(t, combined, "combined chunks should not be empty") }) t.Run("nil callback works like non-stream call", func(t *testing.T) { result, err := caller.CallStream(ctx, "tests.robot-single", []agentcontext.Message{{Role: "user", Content: "Hello"}}, nil, ) require.NoError(t, err) require.NotNil(t, result) assert.False(t, result.IsEmpty()) }) t.Run("stream returns parseable JSON", func(t *testing.T) { var mu sync.Mutex var chunks []string streamFn := func(chunk *standard.StreamChunk) int { mu.Lock() defer mu.Unlock() if chunk.Type == "text" && chunk.Delta { chunks = append(chunks, chunk.Content) } return 0 } result, err := caller.CallWithMessagesStream(ctx, "tests.robot-single", "Generate inspiration report", streamFn) require.NoError(t, err) require.NotNil(t, result) data, err := result.GetJSON() require.NoError(t, err) assert.NotNil(t, data) assert.Contains(t, data, "type") mu.Lock() chunkCount := len(chunks) mu.Unlock() t.Logf("Received %d chunks for JSON response", chunkCount) }) t.Run("assistant not found returns error", func(t *testing.T) { result, err := caller.CallWithMessagesStream(ctx, "non.existent", "hello", nil) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "assistant not found") }) } ================================================ FILE: agent/robot/executor/standard/agent_test.go ================================================ package standard_test import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) // testAuth returns a test auth info for agent calls func testAuth() *oauthtypes.AuthorizedInfo { return &oauthtypes.AuthorizedInfo{ UserID: "test-user-1", TeamID: "test-team-1", } } // ============================================================================ // AgentCaller Tests - Single Call Mode // ============================================================================ func TestAgentCallerSingleCall(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) caller := standard.NewAgentCaller() ctx := types.NewContext(context.Background(), testAuth()) // Test basic call - verify assistant responds and returns parseable JSON // Note: LLM outputs are non-deterministic, so we test structure not exact values t.Run("basic call returns response", func(t *testing.T) { result, err := caller.CallWithMessages(ctx, "tests.robot-single", "Hello, test message") require.NoError(t, err) require.NotNil(t, result) assert.False(t, result.IsEmpty(), "result should not be empty") // Should be able to get text content text := result.GetText() assert.NotEmpty(t, text, "should have text content") }) t.Run("call returns parseable JSON", func(t *testing.T) { result, err := caller.CallWithMessages(ctx, "tests.robot-single", "Generate inspiration report") require.NoError(t, err) require.NotNil(t, result) // Should return parseable JSON (content may vary) data, err := result.GetJSON() require.NoError(t, err) assert.NotNil(t, data) // Verify it has "type" field (all test responses should have this) assert.Contains(t, data, "type", "response should have type field") }) } func TestAgentCallerNextHookData(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) caller := standard.NewAgentCaller() ctx := types.NewContext(context.Background(), testAuth()) t.Run("next_hook inspiration returns structured data", func(t *testing.T) { result, err := caller.CallWithMessages(ctx, "tests.robot-single", "next_hook inspiration test") require.NoError(t, err) require.NotNil(t, result) // Next hook should return structured data data, err := result.GetJSON() require.NoError(t, err) assert.Equal(t, "inspiration", data["type"]) assert.Equal(t, "next_hook", data["source"]) }) t.Run("next_hook goals returns structured data", func(t *testing.T) { result, err := caller.CallWithMessages(ctx, "tests.robot-single", "next_hook goals test") require.NoError(t, err) require.NotNil(t, result) data, err := result.GetJSON() require.NoError(t, err) assert.Equal(t, "goals", data["type"]) assert.Equal(t, "next_hook", data["source"]) }) t.Run("next_hook tasks returns structured data", func(t *testing.T) { result, err := caller.CallWithMessages(ctx, "tests.robot-single", "next_hook tasks test") require.NoError(t, err) require.NotNil(t, result) data, err := result.GetJSON() require.NoError(t, err) assert.Equal(t, "tasks", data["type"]) assert.Equal(t, "next_hook", data["source"]) }) } func TestAgentCallerJSONArray(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) caller := standard.NewAgentCaller() ctx := types.NewContext(context.Background(), testAuth()) t.Run("array_test returns JSON array", func(t *testing.T) { result, err := caller.CallWithMessages(ctx, "tests.robot-single", "array_test") require.NoError(t, err) require.NotNil(t, result) arr, err := result.GetJSONArray() require.NoError(t, err) assert.Len(t, arr, 3) // Verify first item structure item1, ok := arr[0].(map[string]interface{}) require.True(t, ok, "first item should be a map") assert.Equal(t, float64(1), item1["id"]) assert.Equal(t, "Item 1", item1["name"]) }) } func TestAgentCallerEmptyResponse(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) caller := standard.NewAgentCaller() ctx := types.NewContext(context.Background(), testAuth()) t.Run("empty_test falls back to completion content", func(t *testing.T) { result, err := caller.CallWithMessages(ctx, "tests.robot-single", "empty_test") require.NoError(t, err) require.NotNil(t, result) // When Next hook returns null, should use Completion content assert.False(t, result.IsEmpty()) }) } func TestAgentCallerAssistantNotFound(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) caller := standard.NewAgentCaller() ctx := types.NewContext(context.Background(), testAuth()) t.Run("non-existent assistant returns error", func(t *testing.T) { result, err := caller.CallWithMessages(ctx, "non.existent.assistant", "hello") assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "assistant not found") }) } func TestAgentCallerWithSystemAndUser(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) caller := standard.NewAgentCaller() ctx := types.NewContext(context.Background(), testAuth()) t.Run("call with system and user messages", func(t *testing.T) { result, err := caller.CallWithSystemAndUser( ctx, "tests.robot-single", "You are a helpful assistant.", "Generate inspiration report", ) require.NoError(t, err) require.NotNil(t, result) assert.False(t, result.IsEmpty()) }) } // ============================================================================ // Conversation Tests - Multi-Turn Mode // ============================================================================ func TestConversationMultiTurn(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("multi-turn conversation maintains state", func(t *testing.T) { conv := standard.NewConversation("tests.robot-conversation", "test-conv-1", 10) // Turn 1: Start planning turn1, err := conv.Turn(ctx, "Plan tasks for sending weekly report") require.NoError(t, err) require.NotNil(t, turn1) assert.Equal(t, 1, turn1.Turn) data1, err := turn1.Result.GetJSON() require.NoError(t, err) // Verify basic structure - turn number and completed flag assert.Contains(t, data1, "turn") assert.Contains(t, data1, "status") assert.Contains(t, data1, "completed") // Turn 2: Continue conversation turn2, err := conv.Turn(ctx, "Send to managers, include sales data") require.NoError(t, err) require.NotNil(t, turn2) assert.Equal(t, 2, turn2.Turn) data2, err := turn2.Result.GetJSON() require.NoError(t, err) assert.Contains(t, data2, "turn") assert.Contains(t, data2, "status") // Turn 3: Complete with confirm/skip turn3, err := conv.Turn(ctx, "skip") // Use skip for deterministic completion require.NoError(t, err) require.NotNil(t, turn3) assert.Equal(t, 3, turn3.Turn) data3, err := turn3.Result.GetJSON() require.NoError(t, err) assert.Equal(t, "completed", data3["status"]) assert.Equal(t, true, data3["completed"]) }) } func TestConversationTurnCount(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("turn count increments correctly", func(t *testing.T) { conv := standard.NewConversation("tests.robot-conversation", "test-conv-2", 10) assert.Equal(t, 0, conv.TurnCount()) _, err := conv.Turn(ctx, "First message") require.NoError(t, err) assert.Equal(t, 1, conv.TurnCount()) _, err = conv.Turn(ctx, "Second message") require.NoError(t, err) assert.Equal(t, 2, conv.TurnCount()) }) } func TestConversationMaxTurns(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("exceeding max turns returns error", func(t *testing.T) { conv := standard.NewConversation("tests.robot-conversation", "test-conv-3", 2) _, err := conv.Turn(ctx, "First") require.NoError(t, err) _, err = conv.Turn(ctx, "Second") require.NoError(t, err) // Third turn should fail _, err = conv.Turn(ctx, "Third") assert.Error(t, err) assert.Contains(t, err.Error(), "max turns (2) exceeded") }) } func TestConversationMessages(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("messages history is maintained", func(t *testing.T) { conv := standard.NewConversation("tests.robot-conversation", "test-conv-4", 5) // Initially empty assert.Empty(t, conv.Messages()) // After first turn _, err := conv.Turn(ctx, "Hello") require.NoError(t, err) msgs := conv.Messages() assert.GreaterOrEqual(t, len(msgs), 1) // At least user message }) } func TestConversationLastResponse(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("last response returns assistant message", func(t *testing.T) { conv := standard.NewConversation("tests.robot-conversation", "test-conv-5", 5) // No response yet assert.Nil(t, conv.LastResponse()) // After turn _, err := conv.Turn(ctx, "Start planning") require.NoError(t, err) lastResp := conv.LastResponse() assert.NotNil(t, lastResp) assert.Equal(t, "assistant", string(lastResp.Role)) }) } func TestConversationReset(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("reset clears conversation history", func(t *testing.T) { conv := standard.NewConversation("tests.robot-conversation", "test-conv-6", 5) _, err := conv.Turn(ctx, "First message") require.NoError(t, err) assert.Equal(t, 1, conv.TurnCount()) conv.Reset() assert.Equal(t, 0, conv.TurnCount()) assert.Empty(t, conv.Messages()) }) } func TestConversationWithSystemPrompt(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("system prompt is preserved after reset", func(t *testing.T) { conv := standard.NewConversation("tests.robot-conversation", "test-conv-7", 5). WithSystemPrompt("You are a task planner.") msgs := conv.Messages() require.Len(t, msgs, 1) assert.Equal(t, "system", string(msgs[0].Role)) _, err := conv.Turn(ctx, "Hello") require.NoError(t, err) conv.Reset() // System prompt should be preserved msgs = conv.Messages() require.Len(t, msgs, 1) assert.Equal(t, "system", string(msgs[0].Role)) }) } func TestConversationSpecialCommands(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("skip command jumps to completed", func(t *testing.T) { conv := standard.NewConversation("tests.robot-conversation", "test-conv-8", 5) turn, err := conv.Turn(ctx, "skip") require.NoError(t, err) data, err := turn.Result.GetJSON() require.NoError(t, err) assert.Equal(t, "completed", data["status"]) assert.Equal(t, true, data["completed"]) }) t.Run("abort command ends conversation", func(t *testing.T) { conv := standard.NewConversation("tests.robot-conversation", "test-conv-9", 5) turn, err := conv.Turn(ctx, "abort") require.NoError(t, err) data, err := turn.Result.GetJSON() require.NoError(t, err) assert.Equal(t, "aborted", data["status"]) assert.Equal(t, true, data["completed"]) }) t.Run("reset command resets conversation state", func(t *testing.T) { conv := standard.NewConversation("tests.robot-conversation", "test-conv-10", 5) // First do a turn _, err := conv.Turn(ctx, "Start planning") require.NoError(t, err) // Then reset via command turn, err := conv.Turn(ctx, "reset") require.NoError(t, err) data, err := turn.Result.GetJSON() require.NoError(t, err) assert.Equal(t, "reset", data["status"]) }) } func TestConversationRunUntil(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("run until completion", func(t *testing.T) { conv := standard.NewConversation("tests.robot-conversation", "test-conv-11", 10) inputs := []string{ "Plan weekly report tasks", "Send to team leads, include metrics", "confirm", } inputIdx := 0 results, err := conv.RunUntil( ctx, func(turn int, lastResult *standard.CallResult) (string, error) { if inputIdx < len(inputs) { input := inputs[inputIdx] inputIdx++ return input, nil } return "confirm", nil }, func(turn int, result *standard.CallResult) (bool, error) { data, err := result.GetJSON() if err != nil { return false, nil } completed, ok := data["completed"].(bool) return ok && completed, nil }, ) require.NoError(t, err) require.Len(t, results, 3, "should complete in 3 turns") // Final result should be completed finalData, err := results[len(results)-1].Result.GetJSON() require.NoError(t, err) assert.Equal(t, true, finalData["completed"]) }) } // ============================================================================ // CallResult Tests // ============================================================================ func TestCallResultGetText(t *testing.T) { t.Run("returns content when available", func(t *testing.T) { result := &standard.CallResult{Content: "Hello World"} assert.Equal(t, "Hello World", result.GetText()) }) t.Run("returns empty for empty result", func(t *testing.T) { result := &standard.CallResult{} assert.Equal(t, "", result.GetText()) }) } func TestCallResultIsEmpty(t *testing.T) { t.Run("empty when no content and no next", func(t *testing.T) { result := &standard.CallResult{} assert.True(t, result.IsEmpty()) }) t.Run("not empty when has content", func(t *testing.T) { result := &standard.CallResult{Content: "test"} assert.False(t, result.IsEmpty()) }) t.Run("not empty when has next", func(t *testing.T) { result := &standard.CallResult{Next: map[string]interface{}{"key": "value"}} assert.False(t, result.IsEmpty()) }) } // ============================================================================ // ExtractCodeBlock Tests // ============================================================================ func TestExtractCodeBlock(t *testing.T) { t.Run("extracts JSON code block", func(t *testing.T) { content := "Here is the result:\n```json\n{\"key\": \"value\"}\n```" block := standard.ExtractCodeBlock(content) require.NotNil(t, block) assert.Equal(t, "json", block.Type) assert.Contains(t, block.Content, "key") }) t.Run("returns nil for no code block", func(t *testing.T) { content := "Just plain text" block := standard.ExtractCodeBlock(content) // gou/text returns text type for plain text require.NotNil(t, block) assert.Equal(t, "text", block.Type) }) } func TestExtractAllCodeBlocks(t *testing.T) { t.Run("extracts multiple code blocks", func(t *testing.T) { content := "```json\n{}\n```\n\n```python\nprint('hello')\n```" blocks := standard.ExtractAllCodeBlocks(content) assert.Len(t, blocks, 2) }) } ================================================ FILE: agent/robot/executor/standard/delivery.go ================================================ package standard import ( "encoding/json" "fmt" "strings" "time" "github.com/yaoapp/gou/model" kunlog "github.com/yaoapp/kun/log" robotevents "github.com/yaoapp/yao/agent/robot/events" robottypes "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/event" ) // RunDelivery executes P4: Delivery phase // // Process: // 1. Call Delivery Agent with full execution context // 2. Agent generates DeliveryContent (summary, body, attachments) // 3. Push delivery event for asynchronous routing via handlers func (e *Executor) RunDelivery(ctx *robottypes.Context, exec *robottypes.Execution, _ interface{}) error { robot := exec.GetRobot() if robot == nil { return fmt.Errorf("robot not found in execution") } locale := getEffectiveLocale(robot, exec.Input) e.updateUIFields(ctx, exec, "", getLocalizedMessage(locale, "generating_delivery")) agentID := "__yao.delivery" if robot.Config != nil && robot.Config.Resources != nil { agentID = robot.Config.Resources.GetPhaseAgent(robottypes.PhaseDelivery) } formatter := NewInputFormatter() userContent := formatter.FormatDeliveryInput(exec, robot) if userContent == "" { return fmt.Errorf("no content available for delivery generation") } caller := NewAgentCaller() caller.Connector = robot.LanguageModel result, err := caller.CallWithMessages(ctx, agentID, userContent) if err != nil { return fmt.Errorf("delivery agent (%s) call failed: %w", agentID, err) } data, err := result.GetJSON() if err != nil { content := result.GetText() if content == "" { return fmt.Errorf("delivery agent returned empty response") } exec.Delivery = &robottypes.DeliveryResult{ RequestID: generateRequestID(exec.ID), Content: &robottypes.DeliveryContent{ Summary: truncateSummary(content, 200), Body: content, }, Success: true, } return e.pushDeliveryEvent(ctx, exec, robot) } content := parseDeliveryContent(data) if content == nil { return fmt.Errorf("delivery agent (%s) returned invalid content", agentID) } exec.Delivery = &robottypes.DeliveryResult{ RequestID: generateRequestID(exec.ID), Content: content, Success: true, } return e.pushDeliveryEvent(ctx, exec, robot) } // pushDeliveryEvent pushes a delivery event to the event bus. // Registered handlers (see events/handlers.go) route to email/webhook/process channels. func (e *Executor) pushDeliveryEvent(ctx *robottypes.Context, exec *robottypes.Execution, robot *robottypes.Robot) error { prefs := buildDeliveryPreferences(robot) chatID := exec.ChatID var extra map[string]any if exec.Input != nil && exec.Input.Data != nil { if sourceChatID, ok := exec.Input.Data["chat_id"].(string); ok && sourceChatID != "" { if channel, ok := exec.Input.Data["channel"].(string); ok && channel != "" { chatID = channel + ":" + sourceChatID } } if e, ok := exec.Input.Data["extra"].(map[string]any); ok { extra = e } } _, err := event.Push(ctx.Context, robotevents.Delivery, robotevents.DeliveryPayload{ ExecutionID: exec.ID, MemberID: exec.MemberID, TeamID: exec.TeamID, ChatID: chatID, Content: exec.Delivery.Content, Preferences: prefs, Extra: extra, }) if err != nil { kunlog.Error("delivery event push failed: execution=%s error=%v", exec.ID, err) } return nil } // parseDeliveryContent parses the Delivery Agent response into DeliveryContent func parseDeliveryContent(data map[string]interface{}) *robottypes.DeliveryContent { if data == nil { return nil } contentData, ok := data["content"].(map[string]interface{}) if !ok { contentData = data } content := &robottypes.DeliveryContent{} if summary, ok := contentData["summary"].(string); ok { content.Summary = summary } if body, ok := contentData["body"].(string); ok { content.Body = body } if attachments, ok := contentData["attachments"].([]interface{}); ok { for _, att := range attachments { if attMap, ok := att.(map[string]interface{}); ok { attachment := parseDeliveryAttachment(attMap) if attachment != nil { content.Attachments = append(content.Attachments, *attachment) } } } } if content.Summary == "" && content.Body == "" { return nil } return content } func parseDeliveryAttachment(data map[string]interface{}) *robottypes.DeliveryAttachment { if data == nil { return nil } att := &robottypes.DeliveryAttachment{} if title, ok := data["title"].(string); ok { att.Title = title } if desc, ok := data["description"].(string); ok { att.Description = desc } if taskID, ok := data["task_id"].(string); ok { att.TaskID = taskID } if file, ok := data["file"].(string); ok { att.File = file } if att.Title == "" || att.File == "" { return nil } return att } func generateRequestID(execID string) string { return fmt.Sprintf("dlv-%s-%d", execID, time.Now().UnixNano()%1000000) } func getTaskDescription(task robottypes.Task) string { if len(task.Messages) == 0 { return task.GoalRef } for _, msg := range task.Messages { if content, ok := msg.Content.(string); ok && content != "" { if len(content) > 100 { return content[:97] + "..." } return content } } if task.GoalRef != "" { return task.GoalRef } return "Task " + task.ID } func truncateSummary(text string, maxLen int) string { if len(text) <= maxLen { return text } truncated := text[:maxLen] if idx := strings.LastIndex(truncated, " "); idx > maxLen/2 { return truncated[:idx] + "..." } return truncated + "..." } func buildDeliveryPreferences(robot *robottypes.Robot) *robottypes.DeliveryPreferences { if robot == nil { return nil } prefs := &robottypes.DeliveryPreferences{} managerEmail := robot.ManagerEmail if managerEmail == "" && robot.ManagerID != "" { managerEmail = getManagerEmail(robot.ManagerID) if managerEmail != "" { robot.ManagerEmail = managerEmail } } var emailTargets []robottypes.EmailTarget if managerEmail != "" { emailTargets = append(emailTargets, robottypes.EmailTarget{ To: []string{managerEmail}, }) } if robot.Config != nil && robot.Config.Delivery != nil && robot.Config.Delivery.Email != nil { for _, target := range robot.Config.Delivery.Email.Targets { if len(target.To) > 0 { emailTargets = append(emailTargets, target) } } } if len(emailTargets) > 0 { prefs.Email = &robottypes.EmailPreference{ Enabled: true, Targets: emailTargets, } } if robot.Config != nil && robot.Config.Delivery != nil && robot.Config.Delivery.Webhook != nil { if robot.Config.Delivery.Webhook.Enabled && len(robot.Config.Delivery.Webhook.Targets) > 0 { prefs.Webhook = robot.Config.Delivery.Webhook } } if robot.Config != nil && robot.Config.Delivery != nil && robot.Config.Delivery.Process != nil { if robot.Config.Delivery.Process.Enabled && len(robot.Config.Delivery.Process.Targets) > 0 { prefs.Process = robot.Config.Delivery.Process } } return prefs } func getManagerEmail(managerID string) string { if managerID == "" { return "" } m := model.Select("__yao.member") if m == nil { return "" } records, err := m.Get(model.QueryParam{ Select: []interface{}{"email"}, Wheres: []model.QueryWhere{ {Column: "member_id", Value: managerID}, }, Limit: 1, }) if err != nil || len(records) == 0 { return "" } if email, ok := records[0]["email"].(string); ok { return email } return "" } // FormatDeliveryInput formats the full execution context for the Delivery Agent func (f *InputFormatter) FormatDeliveryInput(exec *robottypes.Execution, robot *robottypes.Robot) string { if exec == nil { return "" } var sb strings.Builder if robot != nil && robot.Config != nil && robot.Config.Identity != nil { sb.WriteString("## Robot Identity\n\n") sb.WriteString(fmt.Sprintf("- **Role**: %s\n", robot.Config.Identity.Role)) if len(robot.Config.Identity.Duties) > 0 { sb.WriteString("- **Duties**: ") sb.WriteString(strings.Join(robot.Config.Identity.Duties, ", ")) sb.WriteString("\n") } sb.WriteString("\n") } sb.WriteString("## Execution Context\n\n") sb.WriteString(fmt.Sprintf("- **Trigger**: %s\n", exec.TriggerType)) sb.WriteString(fmt.Sprintf("- **Status**: %s\n", exec.Status)) sb.WriteString(fmt.Sprintf("- **Start Time**: %s\n", exec.StartTime.Format("2006-01-02 15:04:05"))) if exec.EndTime != nil { duration := exec.EndTime.Sub(exec.StartTime) sb.WriteString(fmt.Sprintf("- **Duration**: %s\n", duration.String())) } sb.WriteString("\n") if exec.Inspiration != nil && exec.Inspiration.Content != "" { sb.WriteString("## Inspiration (P0)\n\n") sb.WriteString(exec.Inspiration.Content) sb.WriteString("\n\n") } if exec.Goals != nil && exec.Goals.Content != "" { sb.WriteString("## Goals (P1)\n\n") sb.WriteString(exec.Goals.Content) sb.WriteString("\n\n") } if len(exec.Tasks) > 0 { sb.WriteString("## Tasks (P2)\n\n") for i, task := range exec.Tasks { taskDesc := getTaskDescription(task) sb.WriteString(fmt.Sprintf("%d. **%s** - %s\n", i+1, task.ID, taskDesc)) sb.WriteString(fmt.Sprintf(" - Executor: %s (%s)\n", task.ExecutorID, task.ExecutorType)) sb.WriteString(fmt.Sprintf(" - Status: %s\n", task.Status)) if task.ExpectedOutput != "" { sb.WriteString(fmt.Sprintf(" - Expected: %s\n", task.ExpectedOutput)) } } sb.WriteString("\n") } if len(exec.Results) > 0 { sb.WriteString("## Results (P3)\n\n") successCount := 0 failCount := 0 for _, result := range exec.Results { if result.Success { successCount++ sb.WriteString(fmt.Sprintf("### ✓ Task: %s\n\n", result.TaskID)) } else { failCount++ sb.WriteString(fmt.Sprintf("### ✗ Task: %s\n\n", result.TaskID)) } sb.WriteString(fmt.Sprintf("- **Duration**: %dms\n", result.Duration)) if result.Validation != nil { if result.Validation.Passed { sb.WriteString(fmt.Sprintf("- **Validation**: ✓ Passed (score: %.2f)\n", result.Validation.Score)) } else { sb.WriteString("- **Validation**: ✗ Failed\n") if len(result.Validation.Issues) > 0 { for _, issue := range result.Validation.Issues { sb.WriteString(fmt.Sprintf(" - %s\n", issue)) } } } } if result.Output != nil { sb.WriteString("\n**Output**:\n") if output, err := json.MarshalIndent(result.Output, "", " "); err == nil { sb.WriteString("```json\n") sb.WriteString(string(output)) sb.WriteString("\n```\n") } else { sb.WriteString(fmt.Sprintf("%v\n", result.Output)) } } if result.Error != "" { sb.WriteString(fmt.Sprintf("\n**Error**: %s\n", result.Error)) } sb.WriteString("\n") } sb.WriteString(fmt.Sprintf("### Summary\n\n- **Total Tasks**: %d\n- **Succeeded**: %d\n- **Failed**: %d\n\n", len(exec.Results), successCount, failCount)) } return sb.String() } ================================================ FILE: agent/robot/executor/standard/delivery_test.go ================================================ package standard_test import ( "context" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ============================================================================ // P4 Delivery Phase Tests // ============================================================================ func TestRunDeliveryBasic(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("generates delivery content from execution results", func(t *testing.T) { robot := createDeliveryTestRobot(t, "robot.delivery") exec := createDeliveryTestExecution(robot) exec.Inspiration = &types.InspirationReport{ Content: "Morning analysis suggests focus on Q4 review.", } exec.Goals = &types.Goals{ Content: "## Goals\n1. Review Q4 data\n2. Generate summary report", } exec.Tasks = []types.Task{ {ID: "task-001", ExecutorID: "experts.data-analyst", ExecutorType: types.ExecutorAssistant, Status: types.TaskCompleted}, {ID: "task-002", ExecutorID: "experts.summarizer", ExecutorType: types.ExecutorAssistant, Status: types.TaskCompleted}, } exec.Results = []types.TaskResult{ {TaskID: "task-001", Success: true, Duration: 1500, Output: map[string]interface{}{"total_sales": 1500000}}, {TaskID: "task-002", Success: true, Duration: 800, Output: "Q4 sales exceeded expectations by 15%."}, } e := standard.New() err := e.RunDelivery(ctx, exec, nil) require.NoError(t, err) require.NotNil(t, exec.Delivery) require.NotNil(t, exec.Delivery.Content) assert.NotEmpty(t, exec.Delivery.Content.Summary) assert.NotEmpty(t, exec.Delivery.Content.Body) assert.True(t, exec.Delivery.Success) }) t.Run("handles partial failure in results", func(t *testing.T) { robot := createDeliveryTestRobot(t, "robot.delivery") exec := createDeliveryTestExecution(robot) exec.Goals = &types.Goals{ Content: "## Goals\n1. Analyze data\n2. Generate report", } exec.Tasks = []types.Task{ {ID: "task-001", ExecutorID: "experts.data-analyst", ExecutorType: types.ExecutorAssistant, Status: types.TaskCompleted}, {ID: "task-002", ExecutorID: "experts.summarizer", ExecutorType: types.ExecutorAssistant, Status: types.TaskFailed}, } exec.Results = []types.TaskResult{ {TaskID: "task-001", Success: true, Duration: 1500, Output: map[string]interface{}{"data": "analyzed"}}, {TaskID: "task-002", Success: false, Duration: 500, Error: "Summarization failed: timeout"}, } e := standard.New() err := e.RunDelivery(ctx, exec, nil) require.NoError(t, err) require.NotNil(t, exec.Delivery) require.NotNil(t, exec.Delivery.Content) body := strings.ToLower(exec.Delivery.Content.Body) hasFailureInfo := strings.Contains(body, "fail") || strings.Contains(body, "error") || strings.Contains(body, "partial") || strings.Contains(body, "✗") assert.True(t, hasFailureInfo || exec.Delivery.Content.Summary != "", "should mention failure or have valid summary") }) } func TestRunDeliveryErrorHandling(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("returns error when robot is nil", func(t *testing.T) { exec := &types.Execution{ ID: "test-exec-1", TriggerType: types.TriggerClock, } e := standard.New() err := e.RunDelivery(ctx, exec, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "robot not found") }) t.Run("returns error when agent not found", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot-1", TeamID: "test-team-1", Config: &types.Config{ Identity: &types.Identity{Role: "Test"}, Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseDelivery: "non.existent.agent", }, }, }, } exec := createDeliveryTestExecution(robot) exec.Results = []types.TaskResult{ {TaskID: "task-001", Success: true, Duration: 100}, } e := standard.New() err := e.RunDelivery(ctx, exec, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "call failed") }) } // ============================================================================ // Email Channel Config Tests // ============================================================================ func TestDefaultEmailChannel(t *testing.T) { t.Run("returns default email channel", func(t *testing.T) { assert.Equal(t, "default", types.DefaultEmailChannel()) }) t.Run("can set custom email channel", func(t *testing.T) { original := types.DefaultEmailChannel() defer types.SetDefaultEmailChannel(original) types.SetDefaultEmailChannel("custom-email") assert.Equal(t, "custom-email", types.DefaultEmailChannel()) }) t.Run("ignores empty channel", func(t *testing.T) { original := types.DefaultEmailChannel() defer types.SetDefaultEmailChannel(original) types.SetDefaultEmailChannel("") assert.Equal(t, original, types.DefaultEmailChannel()) }) } func TestRobotEmailInDelivery(t *testing.T) { t.Run("robot email field is loaded from map", func(t *testing.T) { data := map[string]interface{}{ "member_id": "robot-001", "team_id": "team-001", "robot_email": "robot@example.com", } robot, err := types.NewRobotFromMap(data) require.NoError(t, err) assert.Equal(t, "robot@example.com", robot.RobotEmail) }) t.Run("robot email can be empty", func(t *testing.T) { data := map[string]interface{}{ "member_id": "robot-001", "team_id": "team-001", } robot, err := types.NewRobotFromMap(data) require.NoError(t, err) assert.Empty(t, robot.RobotEmail) }) } // ============================================================================ // FormatDeliveryInput Tests // ============================================================================ func TestFormatDeliveryInput(t *testing.T) { formatter := standard.NewInputFormatter() t.Run("formats complete execution context", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot", Config: &types.Config{ Identity: &types.Identity{ Role: "Sales Analyst", Duties: []string{"Analyze data", "Generate reports"}, }, }, } startTime := time.Now().Add(-5 * time.Minute) endTime := time.Now() exec := &types.Execution{ ID: "exec-123", TriggerType: types.TriggerClock, Status: types.ExecCompleted, StartTime: startTime, EndTime: &endTime, Inspiration: &types.InspirationReport{ Content: "Morning analysis suggests focus on Q4.", }, Goals: &types.Goals{ Content: "## Goals\n1. Review Q4 data", }, Tasks: []types.Task{ {ID: "task-001", ExecutorID: "data-analyst", ExecutorType: types.ExecutorAssistant, Status: types.TaskCompleted, ExpectedOutput: "JSON with sales data"}, }, Results: []types.TaskResult{ {TaskID: "task-001", Success: true, Duration: 1500, Output: map[string]interface{}{"sales": 1000000}}, }, } result := formatter.FormatDeliveryInput(exec, robot) assert.Contains(t, result, "## Robot Identity") assert.Contains(t, result, "Sales Analyst") assert.Contains(t, result, "## Execution Context") assert.Contains(t, result, "clock") assert.Contains(t, result, "## Inspiration (P0)") assert.Contains(t, result, "Morning analysis") assert.Contains(t, result, "## Goals (P1)") assert.Contains(t, result, "Review Q4 data") assert.Contains(t, result, "## Tasks (P2)") assert.Contains(t, result, "task-001") assert.Contains(t, result, "## Results (P3)") assert.Contains(t, result, "✓ Task: task-001") }) t.Run("handles empty execution", func(t *testing.T) { exec := &types.Execution{ ID: "exec-empty", TriggerType: types.TriggerHuman, Status: types.ExecPending, StartTime: time.Now(), } result := formatter.FormatDeliveryInput(exec, nil) assert.Contains(t, result, "## Execution Context") assert.Contains(t, result, "human") }) } // ============================================================================ // Helper Functions // ============================================================================ func createDeliveryTestRobot(t *testing.T, agentID string) *types.Robot { t.Helper() return &types.Robot{ MemberID: "test-robot-1", TeamID: "test-team-1", DisplayName: "Test Robot", Config: &types.Config{ Identity: &types.Identity{ Role: "Test Assistant", Duties: []string{"Testing", "Validation"}, }, Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseDelivery: agentID, }, }, }, } } func createDeliveryTestExecution(robot *types.Robot) *types.Execution { exec := &types.Execution{ ID: "test-exec-delivery-1", MemberID: robot.MemberID, TeamID: robot.TeamID, TriggerType: types.TriggerClock, StartTime: time.Now(), Status: types.ExecRunning, Phase: types.PhaseDelivery, } exec.SetRobot(robot) return exec } ================================================ FILE: agent/robot/executor/standard/executor.go ================================================ package standard import ( "fmt" "strings" "sync/atomic" "time" kunlog "github.com/yaoapp/kun/log" agentcontext "github.com/yaoapp/yao/agent/context" robotevents "github.com/yaoapp/yao/agent/robot/events" "github.com/yaoapp/yao/agent/robot/executor/types" "github.com/yaoapp/yao/agent/robot/store" robottypes "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/robot/utils" "github.com/yaoapp/yao/event" ) // Executor implements the standard executor with real Agent calls // This is the production executor that: // - Persists execution history to database // - Calls real Agents via Assistant.Stream() // - Logs phase transitions and errors using kun/log type Executor struct { config types.Config store *store.ExecutionStore robotStore *store.RobotStore execCount atomic.Int32 currentCount atomic.Int32 onStart func() onEnd func() } // New creates a new standard executor func New() *Executor { return &Executor{ store: store.NewExecutionStore(), robotStore: store.NewRobotStore(), } } // NewWithConfig creates a new standard executor with configuration func NewWithConfig(config types.Config) *Executor { return &Executor{ config: config, store: store.NewExecutionStore(), robotStore: store.NewRobotStore(), } } // Execute runs a robot through all applicable phases with real Agent calls (auto-generates ID) func (e *Executor) Execute(ctx *robottypes.Context, robot *robottypes.Robot, trigger robottypes.TriggerType, data interface{}) (*robottypes.Execution, error) { return e.ExecuteWithControl(ctx, robot, trigger, data, "", nil) } // ExecuteWithID runs a robot through all applicable phases with a pre-generated execution ID (no control) func (e *Executor) ExecuteWithID(ctx *robottypes.Context, robot *robottypes.Robot, trigger robottypes.TriggerType, data interface{}, execID string) (*robottypes.Execution, error) { return e.ExecuteWithControl(ctx, robot, trigger, data, execID, nil) } // ExecuteWithControl runs a robot through all applicable phases with execution control // control: optional, allows pause/resume functionality during execution func (e *Executor) ExecuteWithControl(ctx *robottypes.Context, robot *robottypes.Robot, trigger robottypes.TriggerType, data interface{}, execID string, control robottypes.ExecutionControl) (*robottypes.Execution, error) { if robot == nil { return nil, fmt.Errorf("robot cannot be nil") } // Determine starting phase based on trigger type startPhaseIndex := 0 if trigger == robottypes.TriggerHuman || trigger == robottypes.TriggerEvent { startPhaseIndex = 1 // Skip P0 (Inspiration) } // Use provided execID or generate new one if execID == "" { execID = utils.NewID() } // Create execution (Job system removed, using ExecutionStore only) input := types.BuildTriggerInput(trigger, data) exec := &robottypes.Execution{ ID: execID, MemberID: robot.MemberID, TeamID: robot.TeamID, TriggerType: trigger, StartTime: time.Now(), Status: robottypes.ExecPending, Phase: robottypes.AllPhases[startPhaseIndex], Input: input, ChatID: fmt.Sprintf("robot_%s_%s", robot.MemberID, execID), } // Load pre-existing Goals/Tasks from store when resuming a confirmed execution. // RunGoals and RunTasks have skip logic when these are already populated. if execID != "" && !e.config.SkipPersistence && e.store != nil { if existing, err := e.store.Get(ctx.Context, execID); err == nil && existing != nil { exec.Goals = existing.Goals exec.Tasks = existing.Tasks if existing.Input != nil { exec.Input = existing.Input } } } // If goals are pre-confirmed (passed via Input.Data["goals"]), inject them directly. // RunGoals will skip LLM call when exec.Goals is already populated (§18.2). if exec.Goals == nil && input != nil && input.Data != nil { if goalsStr, ok := input.Data["goals"].(string); ok && goalsStr != "" { exec.Goals = &robottypes.Goals{Content: goalsStr} } } // Initialize UI display fields (with i18n support) exec.Name, exec.CurrentTaskName = e.initUIFields(trigger, input, robot) // Set robot reference for phase methods exec.SetRobot(robot) // Persist execution record to database // Robot is identified by member_id (globally unique in __yao.member table) if !e.config.SkipPersistence && e.store != nil { record := store.FromExecution(exec) if err := e.store.Save(ctx.Context, record); err != nil { // Log warning but don't fail execution kunlog.With(kunlog.F{ "execution_id": exec.ID, "member_id": exec.MemberID, "error": err, }).Warn("Failed to persist execution record: %v", err) } // If goals were pre-injected, persist them and update the execution title if exec.Goals != nil && exec.Goals.Content != "" { if err := e.store.UpdatePhase(ctx.Context, exec.ID, robottypes.PhaseGoals, exec.Goals); err != nil { kunlog.With(kunlog.F{ "execution_id": exec.ID, "member_id": exec.MemberID, "error": err, }).Warn("Failed to persist pre-confirmed goals: %v", err) } if goalName := extractGoalName(exec.Goals); goalName != "" { e.updateUIFields(ctx, exec, goalName, "") } } } // Acquire execution slot if !robot.TryAcquireSlot(exec) { kunlog.With(kunlog.F{ "execution_id": exec.ID, "member_id": exec.MemberID, }).Warn("Execution quota exceeded") return nil, robottypes.ErrQuotaExceeded } // Defer: remove execution from robot's tracking (unless suspended) and update robot status defer func() { // Suspended executions stay in tracking — they are still "alive" if exec.Status == robottypes.ExecWaiting { return } robot.RemoveExecution(exec.ID) // Update robot status to idle if no more running executions if robot.RunningCount() == 0 && !e.config.SkipPersistence && e.robotStore != nil { if err := e.robotStore.UpdateStatus(ctx.Context, robot.MemberID, robottypes.RobotIdle); err != nil { kunlog.With(kunlog.F{ "member_id": robot.MemberID, "error": err, }).Warn("Failed to update robot status to idle: %v", err) } } }() // Track execution count e.execCount.Add(1) e.currentCount.Add(1) defer e.currentCount.Add(-1) // Callbacks if e.onStart != nil { e.onStart() } if e.onEnd != nil { defer e.onEnd() } // Update status to running exec.Status = robottypes.ExecRunning kunlog.With(kunlog.F{ "execution_id": exec.ID, "member_id": exec.MemberID, "trigger_type": string(exec.TriggerType), }).Info("Execution started") // Persist running status if !e.config.SkipPersistence && e.store != nil { if err := e.store.UpdateStatus(ctx.Context, exec.ID, robottypes.ExecRunning, ""); err != nil { kunlog.With(kunlog.F{ "execution_id": exec.ID, "error": err, }).Warn("Failed to persist running status: %v", err) } } // Update robot status to working (when execution starts) if !e.config.SkipPersistence && e.robotStore != nil { if err := e.robotStore.UpdateStatus(ctx.Context, robot.MemberID, robottypes.RobotWorking); err != nil { kunlog.With(kunlog.F{ "member_id": robot.MemberID, "error": err, }).Warn("Failed to update robot status to working: %v", err) } } // Check for simulated failure (for testing) if dataStr, ok := data.(string); ok && dataStr == "simulate_failure" { exec.Status = robottypes.ExecFailed exec.Error = "simulated failure" kunlog.With(kunlog.F{ "execution_id": exec.ID, "member_id": exec.MemberID, }).Warn("Simulated failure triggered") // Persist failed status if !e.config.SkipPersistence && e.store != nil { _ = e.store.UpdateStatus(ctx.Context, exec.ID, robottypes.ExecFailed, "simulated failure") } return exec, nil } // Determine locale for UI messages locale := getEffectiveLocale(robot, exec.Input) // Execute phases (PhaseHost is not part of the normal pipeline — it is only for Interact) phases := robottypes.AllPhases[startPhaseIndex:] for _, phase := range phases { if phase == robottypes.PhaseHost { continue } if err := e.runPhase(ctx, exec, phase, data, control); err != nil { // Check if execution was suspended (needs human input) if err == robottypes.ErrExecutionSuspended { kunlog.With(kunlog.F{ "execution_id": exec.ID, "member_id": exec.MemberID, "phase": string(phase), }).Info("Execution suspended during phase %s", phase) return exec, robottypes.ErrExecutionSuspended } // Check if execution was cancelled if err == robottypes.ErrExecutionCancelled { exec.Status = robottypes.ExecCancelled exec.Error = "execution cancelled by user" now := time.Now() exec.EndTime = &now // Update UI field for cancellation with i18n e.updateUIFields(ctx, exec, "", getLocalizedMessage(locale, "cancelled")) kunlog.With(kunlog.F{ "execution_id": exec.ID, "member_id": exec.MemberID, "phase": string(phase), }).Info("Execution cancelled by user") // Persist cancelled status if !e.config.SkipPersistence && e.store != nil { _ = e.store.UpdateStatus(ctx.Context, exec.ID, robottypes.ExecCancelled, "execution cancelled by user") } return exec, nil } // Normal failure case exec.Status = robottypes.ExecFailed exec.Error = err.Error() // Update UI field for failure with i18n failedPrefix := getLocalizedMessage(locale, "failed_prefix") phaseName := getLocalizedMessage(locale, "phase_"+string(phase)) failureMsg := failedPrefix + phaseName e.updateUIFields(ctx, exec, "", failureMsg) kunlog.With(kunlog.F{ "execution_id": exec.ID, "member_id": exec.MemberID, "phase": string(phase), "error": err.Error(), }).Error("Phase execution failed: %v", err) // Persist failed status if !e.config.SkipPersistence && e.store != nil { _ = e.store.UpdateStatus(ctx.Context, exec.ID, robottypes.ExecFailed, err.Error()) } return exec, nil } } // Mark completed exec.Status = robottypes.ExecCompleted now := time.Now() exec.EndTime = &now // Update UI field for completion with i18n e.updateUIFields(ctx, exec, "", getLocalizedMessage(locale, "completed")) duration := now.Sub(exec.StartTime) kunlog.With(kunlog.F{ "execution_id": exec.ID, "member_id": exec.MemberID, "duration_ms": duration.Milliseconds(), }).Info("Execution completed successfully") // Persist completed status if !e.config.SkipPersistence && e.store != nil { if err := e.store.UpdateStatus(ctx.Context, exec.ID, robottypes.ExecCompleted, ""); err != nil { kunlog.With(kunlog.F{ "execution_id": exec.ID, "error": err, }).Warn("Failed to persist completed status: %v", err) } } event.Push(ctx.Context, robotevents.ExecCompleted, robotevents.ExecPayload{ ExecutionID: exec.ID, MemberID: exec.MemberID, TeamID: exec.TeamID, Status: string(robottypes.ExecCompleted), ChatID: exec.ChatID, }) return exec, nil } // runPhase executes a single phase func (e *Executor) runPhase(ctx *robottypes.Context, exec *robottypes.Execution, phase robottypes.Phase, data interface{}, control robottypes.ExecutionControl) error { // Check if context is cancelled before starting this phase select { case <-ctx.Context.Done(): return robottypes.ErrExecutionCancelled default: } // Wait if execution is paused (blocks until resumed or cancelled) if control != nil { if err := control.WaitIfPaused(); err != nil { return err // Returns ErrExecutionCancelled if cancelled while paused } } exec.Phase = phase kunlog.With(kunlog.F{ "execution_id": exec.ID, "member_id": exec.MemberID, "phase": string(phase), }).Info("Phase started: %s", phase) // Persist phase change immediately (so frontend sees current phase) if !e.config.SkipPersistence && e.store != nil { if err := e.store.UpdatePhase(ctx.Context, exec.ID, phase, nil); err != nil { kunlog.With(kunlog.F{ "execution_id": exec.ID, "phase": string(phase), "error": err, }).Warn("Failed to persist phase start: %v", err) } } if e.config.OnPhaseStart != nil { e.config.OnPhaseStart(phase) } phaseStart := time.Now() // Execute phase-specific logic var err error switch phase { case robottypes.PhaseInspiration: err = e.RunInspiration(ctx, exec, data) case robottypes.PhaseGoals: err = e.RunGoals(ctx, exec, data) case robottypes.PhaseTasks: err = e.RunTasks(ctx, exec, data) case robottypes.PhaseRun: err = e.RunExecution(ctx, exec, data) case robottypes.PhaseDelivery: err = e.RunDelivery(ctx, exec, data) case robottypes.PhaseLearning: err = e.RunLearning(ctx, exec, data) } if err != nil { if err == robottypes.ErrExecutionSuspended { kunlog.With(kunlog.F{ "execution_id": exec.ID, "member_id": exec.MemberID, "phase": string(phase), }).Info("Phase suspended: %s (waiting for human input)", phase) return err } kunlog.With(kunlog.F{ "execution_id": exec.ID, "member_id": exec.MemberID, "phase": string(phase), "error": err.Error(), }).Error("Phase failed: %s - %v", phase, err) return err } // Persist phase output to database if !e.config.SkipPersistence && e.store != nil { phaseData := e.getPhaseData(exec, phase) if phaseData != nil { if err := e.store.UpdatePhase(ctx.Context, exec.ID, phase, phaseData); err != nil { // Log warning but don't fail execution kunlog.With(kunlog.F{ "execution_id": exec.ID, "phase": string(phase), "error": err, }).Warn("Failed to persist phase %s data: %v", phase, err) } } } if e.config.OnPhaseEnd != nil { e.config.OnPhaseEnd(phase) } phaseDuration := time.Since(phaseStart).Milliseconds() kunlog.With(kunlog.F{ "execution_id": exec.ID, "member_id": exec.MemberID, "phase": string(phase), "duration_ms": phaseDuration, }).Info("Phase completed: %s (took %dms)", phase, phaseDuration) return nil } // getPhaseData extracts the output data for a specific phase from execution func (e *Executor) getPhaseData(exec *robottypes.Execution, phase robottypes.Phase) interface{} { switch phase { case robottypes.PhaseInspiration: return exec.Inspiration case robottypes.PhaseGoals: return exec.Goals case robottypes.PhaseTasks: return exec.Tasks case robottypes.PhaseRun: return exec.Results case robottypes.PhaseDelivery: return exec.Delivery case robottypes.PhaseLearning: return exec.Learning default: return nil } } // ExecCount returns total execution count func (e *Executor) ExecCount() int { return int(e.execCount.Load()) } // CurrentCount returns currently running execution count func (e *Executor) CurrentCount() int { return int(e.currentCount.Load()) } // Reset resets the executor counters func (e *Executor) Reset() { e.execCount.Store(0) e.currentCount.Store(0) } // DefaultStreamDelay is the simulated delay for Agent Stream calls // This will be removed when real Agent calls are implemented const DefaultStreamDelay = 50 * time.Millisecond // simulateStreamDelay simulates the delay of an Agent Stream call func (e *Executor) simulateStreamDelay() { time.Sleep(DefaultStreamDelay) } // initUIFields initializes UI display fields based on trigger type with i18n support // Returns (name, currentTaskName) func (e *Executor) initUIFields(trigger robottypes.TriggerType, input *robottypes.TriggerInput, robot *robottypes.Robot) (string, string) { // Determine locale for UI messages locale := getEffectiveLocale(robot, input) // Get localized default messages name := getLocalizedMessage(locale, "preparing") currentTaskName := getLocalizedMessage(locale, "starting") switch trigger { case robottypes.TriggerHuman: // For human trigger, extract name from first message if input != nil && len(input.Messages) > 0 { if content, ok := input.Messages[0].GetContentAsString(); ok && content != "" { // Use first 100 chars of message as name name = content if len(name) > 100 { name = name[:100] + "..." } } } case robottypes.TriggerClock: name = getLocalizedMessage(locale, "scheduled_execution") case robottypes.TriggerEvent: if input != nil && input.EventType != "" { name = getLocalizedMessage(locale, "event_prefix") + input.EventType } else { name = getLocalizedMessage(locale, "event_triggered") } } return name, currentTaskName } // getEffectiveLocale determines the locale for UI display // Priority: input.Locale > robot.Config.DefaultLocale > "en" func getEffectiveLocale(robot *robottypes.Robot, input *robottypes.TriggerInput) string { // 1. Human trigger with explicit locale if input != nil && input.Locale != "" { return input.Locale } // 2. Robot configured default if robot != nil && robot.Config != nil { return robot.Config.GetDefaultLocale() } // 3. System default return "en" } // i18n message maps for UI display fields // Use simple locale codes (en, zh) as keys var uiMessages = map[string]map[string]string{ "en": { "preparing": "Preparing...", "starting": "Starting...", "scheduled_execution": "Scheduled execution", "event_prefix": "Event: ", "event_triggered": "Event triggered", "analyzing_context": "Analyzing context...", "planning_goals": "Planning goals...", "breaking_down_tasks": "Breaking down tasks...", "generating_delivery": "Generating delivery content...", "sending_delivery": "Sending delivery...", "learning_from_exec": "Learning from execution...", "completed": "Completed", "cancelled": "Cancelled", "failed_prefix": "Failed at ", "task_prefix": "Task", // Phase names for failure messages "phase_inspiration": "inspiration", "phase_goals": "goals", "phase_tasks": "tasks", "phase_run": "execution", "phase_delivery": "delivery", "phase_learning": "learning", }, "zh": { "preparing": "准备中...", "starting": "启动中...", "scheduled_execution": "定时执行", "event_prefix": "事件: ", "event_triggered": "事件触发", "analyzing_context": "分析上下文...", "planning_goals": "规划目标...", "breaking_down_tasks": "分解任务...", "generating_delivery": "生成交付内容...", "sending_delivery": "正在发送...", "learning_from_exec": "学习执行经验...", "completed": "已完成", "cancelled": "已取消", "failed_prefix": "失败于", "task_prefix": "任务", // Phase names for failure messages "phase_inspiration": "灵感阶段", "phase_goals": "目标阶段", "phase_tasks": "任务阶段", "phase_run": "执行阶段", "phase_delivery": "交付阶段", "phase_learning": "学习阶段", }, } // getLocalizedMessage returns a localized message for the given key func getLocalizedMessage(locale string, key string) string { if messages, ok := uiMessages[locale]; ok { if msg, ok := messages[key]; ok { return msg } } // Fallback to English if messages, ok := uiMessages["en"]; ok { if msg, ok := messages[key]; ok { return msg } } return key // Return key as fallback } // updateUIFields updates UI display fields and persists to database func (e *Executor) updateUIFields(ctx *robottypes.Context, exec *robottypes.Execution, name string, currentTaskName string) { // Update in-memory execution if name != "" { exec.Name = name } if currentTaskName != "" { exec.CurrentTaskName = currentTaskName } // Persist to database if !e.config.SkipPersistence && e.store != nil { if err := e.store.UpdateUIFields(ctx.Context, exec.ID, name, currentTaskName); err != nil { kunlog.With(kunlog.F{ "execution_id": exec.ID, "error": err, }).Warn("Failed to update UI fields: %v", err) } } } // updateTasksState persists the current tasks array with status to database // This should be called after each task status change for real-time UI updates func (e *Executor) updateTasksState(ctx *robottypes.Context, exec *robottypes.Execution) { if e.config.SkipPersistence || e.store == nil { return } // Convert Current to store.CurrentState var current *store.CurrentState if exec.Current != nil { current = &store.CurrentState{ TaskIndex: exec.Current.TaskIndex, Progress: exec.Current.Progress, } } if err := e.store.UpdateTasks(ctx.Context, exec.ID, exec.Tasks, current); err != nil { kunlog.With(kunlog.F{ "execution_id": exec.ID, "error": err, }).Warn("Failed to update tasks state: %v", err) } } // extractGoalName extracts the execution name from goals output func extractGoalName(goals *robottypes.Goals) string { if goals == nil || goals.Content == "" { return "" } // Extract first non-empty, non-markdown-header line as the goal name content := goals.Content lines := strings.Split(content, "\n") for _, line := range lines { line = strings.TrimSpace(line) if line == "" { continue } // Skip markdown headers (# ## ### etc.) if strings.HasPrefix(line, "#") { continue } // Skip markdown horizontal rules (--- or ***) if strings.HasPrefix(line, "---") || strings.HasPrefix(line, "***") { continue } // Found a content line - strip markdown formatting line = stripMarkdownFormatting(line) // Limit length if len(line) > 150 { line = line[:150] + "..." } return line } // Fallback: if all lines are headers, use first header without # prefix for _, line := range lines { line = strings.TrimSpace(line) if line == "" { continue } // Strip leading # symbols line = strings.TrimLeft(line, "#") line = strings.TrimSpace(line) line = stripMarkdownFormatting(line) if line != "" { if len(line) > 150 { line = line[:150] + "..." } return line } } return "" } // stripMarkdownFormatting removes common markdown formatting from text func stripMarkdownFormatting(s string) string { // Remove bold/italic markers s = strings.ReplaceAll(s, "**", "") s = strings.ReplaceAll(s, "__", "") s = strings.ReplaceAll(s, "*", "") s = strings.ReplaceAll(s, "_", "") // Remove inline code s = strings.ReplaceAll(s, "`", "") // Remove link syntax [text](url) -> text // Simple approach: just remove brackets and parentheses content for { start := strings.Index(s, "[") if start == -1 { break } end := strings.Index(s[start:], "]") if end == -1 { break } linkEnd := start + end // Check if followed by (url) if linkEnd+1 < len(s) && s[linkEnd+1] == '(' { parenEnd := strings.Index(s[linkEnd+1:], ")") if parenEnd != -1 { // Extract just the link text linkText := s[start+1 : linkEnd] s = s[:start] + linkText + s[linkEnd+1+parenEnd+1:] continue } } // Just remove brackets s = s[:start] + s[start+1:linkEnd] + s[linkEnd+1:] } return strings.TrimSpace(s) } // Suspend transitions the execution to waiting status, persists state, and returns // ErrExecutionSuspended so the caller stops further phase processing. func (e *Executor) Suspend(ctx *robottypes.Context, exec *robottypes.Execution, taskIndex int, question string) error { now := time.Now() taskID := "" if taskIndex >= 0 && taskIndex < len(exec.Tasks) { taskID = exec.Tasks[taskIndex].ID exec.Tasks[taskIndex].Status = robottypes.TaskWaitingInput } exec.Status = robottypes.ExecWaiting exec.WaitingTaskID = taskID exec.WaitingQuestion = question exec.WaitingSince = &now exec.ResumeContext = &robottypes.ResumeContext{ TaskIndex: taskIndex, PreviousResults: exec.Results, } if !e.config.SkipPersistence && e.store != nil { // Persist task state (waiting_input on the specific task) e.updateTasksState(ctx, exec) // Persist P3 results so UI can show completed tasks while waiting (§16.26) if err := e.store.UpdatePhase(ctx.Context, exec.ID, robottypes.PhaseRun, exec.Results); err != nil { kunlog.With(kunlog.F{ "execution_id": exec.ID, "error": err, }).Warn("Failed to persist partial results on suspend: %v", err) } // Persist suspend state atomically if err := e.store.UpdateSuspendState(ctx.Context, exec.ID, taskID, question, exec.ResumeContext); err != nil { kunlog.With(kunlog.F{ "execution_id": exec.ID, "task_id": taskID, "error": err, }).Warn("Failed to persist suspend state: %v", err) } } kunlog.With(kunlog.F{ "execution_id": exec.ID, "member_id": exec.MemberID, "task_id": taskID, "question": question, }).Info("Execution suspended, waiting for human input") // Fire event (best-effort, errors are ignored) event.Push(ctx.Context, robotevents.ExecWaiting, robotevents.NeedInputPayload{ ExecutionID: exec.ID, MemberID: exec.MemberID, TeamID: exec.TeamID, TaskID: taskID, Question: question, ChatID: exec.ChatID, }) return robottypes.ErrExecutionSuspended } // Resume resumes a suspended execution with human-provided input. // Loads execution from DB, restores state, injects reply, and continues from the suspended task. func (e *Executor) Resume(ctx *robottypes.Context, execID string, reply string) error { if ctx == nil { return fmt.Errorf("context is required for resume") } if execID == "" { return fmt.Errorf("execID cannot be empty") } if e.store == nil { return fmt.Errorf("store is required for resume") } // Load execution record from DB record, err := e.store.Get(ctx.Context, execID) if err != nil { return fmt.Errorf("failed to load execution: %w", err) } if record == nil { return fmt.Errorf("execution not found: %s", execID) } if record.Status != robottypes.ExecWaiting { return fmt.Errorf("execution %s is not in waiting status (current: %s)", execID, record.Status) } // Restore runtime execution from record exec := record.ToExecution() // Load robot from store if e.robotStore == nil { return fmt.Errorf("robot store is required for resume") } robotRecord, err := e.robotStore.Get(ctx.Context, exec.MemberID) if err != nil { return fmt.Errorf("failed to load robot: %w", err) } if robotRecord == nil { return fmt.Errorf("robot not found: %s", exec.MemberID) } robot, err := robotRecord.ToRobot() if err != nil { return fmt.Errorf("failed to convert robot record: %w", err) } exec.SetRobot(robot) // Re-add execution to robot's in-memory tracking (skips quota check per §16.30) robot.AddExecution(exec) // Maintain executor concurrency count (§16.21) e.currentCount.Add(1) defer e.currentCount.Add(-1) // Defer cleanup: mirror ExecuteWithControl's defer logic (§16.21) defer func() { if exec.Status == robottypes.ExecWaiting { return // re-suspended, keep tracking } robot.RemoveExecution(exec.ID) if robot.RunningCount() == 0 && !e.config.SkipPersistence && e.robotStore != nil { if err := e.robotStore.UpdateStatus(ctx.Context, robot.MemberID, robottypes.RobotIdle); err != nil { kunlog.With(kunlog.F{ "member_id": robot.MemberID, "error": err, }).Warn("Failed to update robot status to idle after resume: %v", err) } } }() // Handle __skip__: mark waiting task as skipped and advance to next task if reply == "__skip__" && exec.ResumeContext != nil { ti := exec.ResumeContext.TaskIndex if ti >= 0 && ti < len(exec.Tasks) { task := &exec.Tasks[ti] task.Status = robottypes.TaskSkipped exec.ResumeContext.PreviousResults = append(exec.ResumeContext.PreviousResults, robottypes.TaskResult{ TaskID: task.ID, Success: false, Output: "skipped", Duration: 0, }) exec.ResumeContext.TaskIndex = ti + 1 if !e.config.SkipPersistence && e.store != nil { e.updateTasksState(ctx, exec) } } reply = "" // Don't inject __skip__ as a message } // Inject reply into the waiting task's messages so the re-executed task gets context if exec.ResumeContext != nil { ti := exec.ResumeContext.TaskIndex if ti >= 0 && ti < len(exec.Tasks) && reply != "" { exec.Tasks[ti].Messages = append(exec.Tasks[ti].Messages, agentcontext.Message{ Role: agentcontext.RoleUser, Content: fmt.Sprintf("[Human reply] %s", reply), }) } } // Clear waiting fields and transition back to running exec.Status = robottypes.ExecRunning exec.WaitingTaskID = "" exec.WaitingQuestion = "" exec.WaitingSince = nil if !e.config.SkipPersistence && e.store != nil { if err := e.store.UpdateResumeState(ctx.Context, exec.ID); err != nil { kunlog.With(kunlog.F{ "execution_id": exec.ID, "error": err, }).Warn("Failed to persist resume state: %v", err) } } kunlog.With(kunlog.F{ "execution_id": exec.ID, "member_id": exec.MemberID, "reply_len": len(reply), }).Info("Execution resumed") event.Push(ctx.Context, robotevents.ExecResumed, robotevents.ExecPayload{ ExecutionID: exec.ID, MemberID: exec.MemberID, TeamID: exec.TeamID, ChatID: exec.ChatID, }) // Continue P3 (Run) from where it was suspended if err := e.RunExecution(ctx, exec, nil); err != nil { if err == robottypes.ErrExecutionSuspended { return err } exec.Status = robottypes.ExecFailed exec.Error = err.Error() if !e.config.SkipPersistence && e.store != nil { _ = e.store.UpdateStatus(ctx.Context, exec.ID, robottypes.ExecFailed, err.Error()) } return err } // Clear resume context after successful P3 completion exec.ResumeContext = nil // Continue with P4 (Delivery) and P5 (Learning) locale := getEffectiveLocale(robot, exec.Input) for _, phase := range []robottypes.Phase{robottypes.PhaseDelivery, robottypes.PhaseLearning} { if err := e.runPhase(ctx, exec, phase, nil, nil); err != nil { if err == robottypes.ErrExecutionSuspended { return err } exec.Status = robottypes.ExecFailed exec.Error = err.Error() failedPrefix := getLocalizedMessage(locale, "failed_prefix") phaseName := getLocalizedMessage(locale, "phase_"+string(phase)) e.updateUIFields(ctx, exec, "", failedPrefix+phaseName) if !e.config.SkipPersistence && e.store != nil { _ = e.store.UpdateStatus(ctx.Context, exec.ID, robottypes.ExecFailed, err.Error()) } return fmt.Errorf("resume phase %s failed: %w", phase, err) } } // Mark completed exec.Status = robottypes.ExecCompleted now := time.Now() exec.EndTime = &now e.updateUIFields(ctx, exec, "", getLocalizedMessage(locale, "completed")) if !e.config.SkipPersistence && e.store != nil { _ = e.store.UpdateStatus(ctx.Context, exec.ID, robottypes.ExecCompleted, "") } return nil } // Verify Executor implements types.Executor var _ types.Executor = (*Executor)(nil) ================================================ FILE: agent/robot/executor/standard/executor_test.go ================================================ package standard_test import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/executor/types" "github.com/yaoapp/yao/agent/robot/store" robottypes "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" ) // ============================================================================ // Executor Persistence Integration Tests // ============================================================================ func TestExecutorPersistence(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) t.Run("persists_execution_record_on_start", func(t *testing.T) { ctx := robottypes.NewContext(context.Background(), &oauthTypes.AuthorizedInfo{ UserID: "user_persist_001", TeamID: "team_persist_001", }) robot := createPersistenceTestRobot("member_persist_001", "team_persist_001") // Create executor with persistence enabled e := standard.NewWithConfig(types.Config{ SkipPersistence: false, }) // Execute with simulated failure to ensure we get a result exec, err := e.Execute(ctx, robot, robottypes.TriggerHuman, "simulate_failure") require.NoError(t, err) require.NotNil(t, exec) // Verify execution record was persisted s := store.NewExecutionStore() record, err := s.Get(context.Background(), exec.ID) require.NoError(t, err) require.NotNil(t, record) assert.Equal(t, exec.ID, record.ExecutionID) assert.Equal(t, "member_persist_001", record.MemberID) assert.Equal(t, "team_persist_001", record.TeamID) assert.Equal(t, robottypes.TriggerHuman, record.TriggerType) assert.Equal(t, robottypes.ExecFailed, record.Status) assert.Equal(t, "simulated failure", record.Error) // Cleanup _ = s.Delete(context.Background(), exec.ID) }) t.Run("persists_failed_status_with_error", func(t *testing.T) { ctx := robottypes.NewContext(context.Background(), &oauthTypes.AuthorizedInfo{ UserID: "user_persist_002", TeamID: "team_persist_002", }) robot := createPersistenceTestRobot("member_persist_002", "team_persist_002") e := standard.NewWithConfig(types.Config{ SkipPersistence: false, }) // Execute with simulated failure exec, err := e.Execute(ctx, robot, robottypes.TriggerHuman, "simulate_failure") require.NoError(t, err) require.NotNil(t, exec) // Verify the record has failed status with error message s := store.NewExecutionStore() record, err := s.Get(context.Background(), exec.ID) require.NoError(t, err) require.NotNil(t, record) assert.Equal(t, robottypes.ExecFailed, record.Status) assert.Equal(t, "simulated failure", record.Error) assert.NotNil(t, record.StartTime) // Cleanup _ = s.Delete(context.Background(), exec.ID) }) t.Run("skips_persistence_when_disabled", func(t *testing.T) { ctx := robottypes.NewContext(context.Background(), &oauthTypes.AuthorizedInfo{ UserID: "user_persist_003", TeamID: "team_persist_003", }) robot := createPersistenceTestRobot("member_persist_003", "team_persist_003") // Create executor with persistence disabled e := standard.NewWithConfig(types.Config{ SkipPersistence: true, }) exec, err := e.Execute(ctx, robot, robottypes.TriggerHuman, "simulate_failure") require.NoError(t, err) require.NotNil(t, exec) // Verify no record was created s := store.NewExecutionStore() record, err := s.Get(context.Background(), exec.ID) require.NoError(t, err) assert.Nil(t, record) // Should not exist }) } // ============================================================================ // Goals Injection Tests (Host Agent confirmed goals) // ============================================================================ // TestExecutorGoalsInjection verifies that when TriggerHuman is used with // pre-confirmed goals (from Host Agent via /v1/agent/robots/:id/execute), // the goals are injected directly into exec.Goals before RunGoals runs, // and are persisted (title updated) so the task list shows the correct title. func TestExecutorGoalsInjection(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) t.Run("goals_injected_from_trigger_input_data", func(t *testing.T) { ctx := robottypes.NewContext(context.Background(), &oauthTypes.AuthorizedInfo{ UserID: "user_goals_inject_001", TeamID: "team_goals_inject_001", }) robot := createPersistenceTestRobot("member_goals_inject_001", "team_goals_inject_001") e := standard.NewWithConfig(types.Config{ SkipPersistence: false, }) // Simulate Host Agent confirmed goals passed via TriggerInput.Data triggerInput := &robottypes.TriggerInput{ Data: map[string]interface{}{ "goals": "Create a mecha image with sci-fi style", "chat_id": "robot_member_goals_inject_001_1234567890", }, } exec, err := e.Execute(ctx, robot, robottypes.TriggerHuman, triggerInput) require.NoError(t, err) require.NotNil(t, exec) // Goals should be injected from TriggerInput.Data require.NotNil(t, exec.Goals, "Goals should be injected from TriggerInput.Data") assert.Equal(t, "Create a mecha image with sci-fi style", exec.Goals.Content, "Goals content should match the pre-confirmed goals") // Verify goals were persisted to the store s := store.NewExecutionStore() record, err := s.Get(context.Background(), exec.ID) require.NoError(t, err) require.NotNil(t, record) // Execution name should reflect the goals (not "Preparing...") assert.NotEmpty(t, exec.Name, "Execution name should be set from goals") assert.NotEqual(t, "Preparing...", exec.Name, "Name should not be the default placeholder") // Cleanup _ = s.Delete(context.Background(), exec.ID) t.Logf("✓ Goals injected from TriggerInput.Data: goals=%q, name=%q", exec.Goals.Content, exec.Name) }) t.Run("empty_goals_falls_through_to_goals_agent", func(t *testing.T) { // When TriggerInput.Data["goals"] is an empty string, the executor // does NOT inject pre-confirmed goals and falls through to RunGoals. // RunGoals will call the Goals Agent (LLM), which may succeed or fail // depending on the environment. We only verify that the executor returns // without a panic and that no pre-confirmed goals were force-injected. // // This test requires a running AI environment; skip in short mode. if testing.Short() { t.Skip("Skipping: requires LLM for RunGoals fallback") } ctx := robottypes.NewContext(context.Background(), &oauthTypes.AuthorizedInfo{ UserID: "user_goals_empty_002", TeamID: "team_goals_empty_002", }) robot := createPersistenceTestRobot("member_goals_empty_002", "team_goals_empty_002") e := standard.NewWithConfig(types.Config{ SkipPersistence: true, }) triggerInput := &robottypes.TriggerInput{ Data: map[string]interface{}{ "goals": "", // empty — should not be injected as pre-confirmed }, } exec, err := e.Execute(ctx, robot, robottypes.TriggerHuman, triggerInput) require.NoError(t, err) require.NotNil(t, exec) // If Goals was set, it came from the Goals Agent, NOT from the empty string injection. // Either nil (agent skipped) or non-nil (agent ran) is acceptable. if exec.Goals != nil { assert.NotEmpty(t, exec.Goals.Content, "If Goals Agent ran, content should be non-empty") } t.Logf("✓ Empty goals falls through to Goals Agent (goals=%v)", exec.Goals != nil) }) t.Run("no_trigger_input_uses_normal_flow", func(t *testing.T) { ctx := robottypes.NewContext(context.Background(), &oauthTypes.AuthorizedInfo{ UserID: "user_goals_normal_001", TeamID: "team_goals_normal_001", }) robot := createPersistenceTestRobot("member_goals_normal_001", "team_goals_normal_001") e := standard.NewWithConfig(types.Config{ SkipPersistence: true, }) // No TriggerInput — simulate plain string fallback (old API usage) exec, err := e.Execute(ctx, robot, robottypes.TriggerHuman, "simulate_failure") require.NoError(t, err) require.NotNil(t, exec) // Goals nil is expected — RunGoals would normally call the LLM assert.Nil(t, exec.Goals, "Without pre-confirmed goals, Goals should remain nil") t.Logf("✓ Normal flow (no pre-confirmed goals) proceeds without injection") }) } // ============================================================================ // Helper Functions // ============================================================================ func createPersistenceTestRobot(memberID, teamID string) *robottypes.Robot { return &robottypes.Robot{ MemberID: memberID, TeamID: teamID, DisplayName: "Persistence Test Robot", Status: robottypes.RobotIdle, AutonomousMode: true, Config: &robottypes.Config{ Identity: &robottypes.Identity{ Role: "Test Robot", Duties: []string{"Testing persistence"}, }, Quota: &robottypes.Quota{ Max: 5, Queue: 10, }, Triggers: &robottypes.Triggers{ Intervene: &robottypes.TriggerSwitch{Enabled: true}, }, Resources: &robottypes.Resources{ Phases: map[robottypes.Phase]string{ robottypes.PhaseInspiration: "robot.inspiration", robottypes.PhaseGoals: "robot.goals", robottypes.PhaseTasks: "robot.tasks", robottypes.PhaseRun: "robot.validation", "validation": "robot.validation", }, Agents: []string{"experts.text-writer", "experts.data-analyst"}, }, }, } } ================================================ FILE: agent/robot/executor/standard/goals.go ================================================ package standard import ( "fmt" "strings" robottypes "github.com/yaoapp/yao/agent/robot/types" ) // RunGoals executes P1: Goals phase // Calls the Goals Agent to plan daily objectives // // Input: // - InspirationReport (from P0) for clock trigger // - TriggerInput for human/event trigger // // Output: // - Goals with markdown content and delivery info func (e *Executor) RunGoals(ctx *robottypes.Context, exec *robottypes.Execution, _ interface{}) error { // §18.2: confirming phase may have already populated Goals — skip regeneration if exec.Goals != nil && exec.Goals.Content != "" { return nil } // Get robot for identity and resources robot := exec.GetRobot() if robot == nil { return fmt.Errorf("robot not found in execution") } // Update UI field with i18n locale := getEffectiveLocale(robot, exec.Input) e.updateUIFields(ctx, exec, "", getLocalizedMessage(locale, "planning_goals")) // Get agent ID for goals phase agentID := "__yao.goals" // default if robot.Config != nil && robot.Config.Resources != nil { agentID = robot.Config.Resources.GetPhaseAgent(robottypes.PhaseGoals) } // Build prompt based on trigger type formatter := NewInputFormatter() var userContent string switch exec.TriggerType { case robottypes.TriggerClock: // For clock trigger: use InspirationReport from P0 if exec.Inspiration != nil { userContent = formatter.FormatInspirationReport(exec.Inspiration) } else { // Fallback: if no inspiration report, create minimal context userContent = formatter.FormatClockContext( robottypes.NewClockContext(exec.StartTime, ""), robot, ) } case robottypes.TriggerHuman, robottypes.TriggerEvent: // For human/event trigger: use TriggerInput directly if exec.Input != nil { userContent = formatter.FormatTriggerInput(exec.Input) } } // Add robot identity context if not already included // For clock trigger with inspiration report, identity is not in the report // For human/event trigger, identity provides context if robot.Config != nil && robot.Config.Identity != nil { if !strings.Contains(userContent, "## Robot Identity") { userContent = formatter.FormatRobotIdentity(robot) + "\n\n" + userContent } } // Add available resources - critical for generating achievable goals // Without knowing what tools are available, goals might be unachievable resourcesContent := formatter.FormatAvailableResources(robot) if resourcesContent != "" { userContent += "\n\n" + resourcesContent } if userContent == "" { return fmt.Errorf("no input available for goals generation") } // Call agent caller := NewAgentCaller() caller.Connector = robot.LanguageModel result, err := caller.CallWithMessages(ctx, agentID, userContent) if err != nil { return fmt.Errorf("goals agent (%s) call failed: %w", agentID, err) } // Parse response as JSON // Goals Agent returns: { "content": "...", "delivery": {...} } data, err := result.GetJSON() if err != nil { // Fallback: if not JSON, use raw text as content content := result.GetText() if content == "" { return fmt.Errorf("goals agent returned empty response") } exec.Goals = &robottypes.Goals{ Content: content, } return nil } // Build Goals from JSON exec.Goals = &robottypes.Goals{} // Extract content (markdown) if content, ok := data["content"].(string); ok { exec.Goals.Content = content } // Extract delivery if delivery, ok := data["delivery"].(map[string]interface{}); ok { exec.Goals.Delivery = ParseDelivery(delivery) } // Validate: content is required if exec.Goals.Content == "" { return fmt.Errorf("goals agent (%s) returned empty content", agentID) } // Update Name from goals content (extract first line as execution title) if goalName := extractGoalName(exec.Goals); goalName != "" { e.updateUIFields(ctx, exec, goalName, "") } return nil } // ParseDelivery converts map to DeliveryTarget struct // Returns nil if data is nil or type is invalid/missing func ParseDelivery(data map[string]interface{}) *robottypes.DeliveryTarget { if data == nil { return nil } // Type is required - if missing or invalid, return nil t, ok := data["type"].(string) if !ok || t == "" { return nil } deliveryType := robottypes.DeliveryType(t) if !IsValidDeliveryType(deliveryType) { // Invalid type - return nil to indicate parsing failure return nil } target := &robottypes.DeliveryTarget{ Type: deliveryType, } // Parse recipients if recipients, ok := data["recipients"].([]interface{}); ok { for _, r := range recipients { if s, ok := r.(string); ok { target.Recipients = append(target.Recipients, s) } } } // Parse format if format, ok := data["format"].(string); ok { target.Format = format } // Parse template if template, ok := data["template"].(string); ok { target.Template = template } // Parse options if options, ok := data["options"].(map[string]interface{}); ok { target.Options = options } return target } // IsValidDeliveryType checks if the delivery type is valid func IsValidDeliveryType(t robottypes.DeliveryType) bool { switch t { case robottypes.DeliveryEmail, robottypes.DeliveryWebhook, robottypes.DeliveryProcess, robottypes.DeliveryNotify: return true default: return false } } ================================================ FILE: agent/robot/executor/standard/goals_test.go ================================================ package standard_test import ( "context" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ============================================================================ // P1 Goals Phase Tests // ============================================================================ func TestRunGoalsBasic(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("generates goals from inspiration report (clock trigger)", func(t *testing.T) { // Create robot with goals agent configured robot := createGoalsTestRobot(t, "robot.goals") // Create execution with inspiration report (from P0) exec := createGoalsTestExecution(robot, types.TriggerClock) exec.Inspiration = &types.InspirationReport{ Clock: types.NewClockContext(time.Now(), ""), Content: "## Summary\nToday is Monday morning. Focus on weekly planning.\n\n## Highlights\n- New sales leads arrived\n- Weekly report due Friday", } // Run goals phase e := standard.New() err := e.RunGoals(ctx, exec, nil) require.NoError(t, err) require.NotNil(t, exec.Goals) assert.NotEmpty(t, exec.Goals.Content) }) t.Run("includes priority markers in output", func(t *testing.T) { robot := createGoalsTestRobot(t, "robot.goals") exec := createGoalsTestExecution(robot, types.TriggerClock) exec.Inspiration = &types.InspirationReport{ Clock: types.NewClockContext(time.Now(), ""), Content: "## Summary\nUrgent: Customer complaint needs attention.\n\n## Highlights\n- Critical bug reported\n- Regular maintenance scheduled", } e := standard.New() err := e.RunGoals(ctx, exec, nil) require.NoError(t, err) content := exec.Goals.Content // Verify expected structure in markdown output // Note: LLM output is non-deterministic, so we check for likely patterns hasGoals := strings.Contains(content, "Goal") || strings.Contains(content, "##") || strings.Contains(content, "High") || strings.Contains(content, "Normal") || strings.Contains(content, "1.") assert.True(t, hasGoals, "should contain goals structure, got: %s", content) }) } func TestRunGoalsHumanTrigger(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("generates goals from human intervention", func(t *testing.T) { robot := createGoalsTestRobot(t, "robot.goals") exec := createGoalsTestExecution(robot, types.TriggerHuman) // Set human intervention input exec.Input = &types.TriggerInput{ Action: "task.add", UserID: "user-123", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Please analyze the Q4 sales data and prepare a summary report for the management meeting tomorrow."}, }, } e := standard.New() err := e.RunGoals(ctx, exec, nil) require.NoError(t, err) require.NotNil(t, exec.Goals) assert.NotEmpty(t, exec.Goals.Content) // Goals should be related to the user request content := strings.ToLower(exec.Goals.Content) hasRelevantContent := strings.Contains(content, "sales") || strings.Contains(content, "report") || strings.Contains(content, "analysis") || strings.Contains(content, "data") || strings.Contains(content, "q4") assert.True(t, hasRelevantContent, "goals should relate to user request, got: %s", exec.Goals.Content) }) t.Run("includes robot identity for human trigger", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot-1", TeamID: "test-team-1", DisplayName: "Sales Analyst", Config: &types.Config{ Identity: &types.Identity{ Role: "Sales Analyst", Duties: []string{"Analyze sales data", "Generate reports"}, Rules: []string{"Focus on actionable insights"}, }, Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseGoals: "robot.goals", }, }, }, } exec := createGoalsTestExecution(robot, types.TriggerHuman) exec.Input = &types.TriggerInput{ Action: "instruct", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "What should I focus on today?"}, }, } e := standard.New() err := e.RunGoals(ctx, exec, nil) require.NoError(t, err) assert.NotEmpty(t, exec.Goals.Content) }) } func TestRunGoalsEventTrigger(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("generates goals from event trigger", func(t *testing.T) { robot := createGoalsTestRobot(t, "robot.goals") exec := createGoalsTestExecution(robot, types.TriggerEvent) // Set event input exec.Input = &types.TriggerInput{ Source: "webhook", EventType: "lead.created", Data: map[string]interface{}{ "lead_id": "lead-456", "company": "BigCorp Inc", "contact_name": "John Smith", "email": "john@bigcorp.com", "interest": "Enterprise plan", }, } e := standard.New() err := e.RunGoals(ctx, exec, nil) require.NoError(t, err) require.NotNil(t, exec.Goals) assert.NotEmpty(t, exec.Goals.Content) // Goals should be related to the event content := strings.ToLower(exec.Goals.Content) hasRelevantContent := strings.Contains(content, "lead") || strings.Contains(content, "bigcorp") || strings.Contains(content, "contact") || strings.Contains(content, "follow") || strings.Contains(content, "qualify") assert.True(t, hasRelevantContent, "goals should relate to event, got: %s", exec.Goals.Content) }) } func TestRunGoalsErrorHandling(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("returns error when robot is nil", func(t *testing.T) { exec := &types.Execution{ ID: "test-exec-1", TriggerType: types.TriggerClock, } // Don't set robot e := standard.New() err := e.RunGoals(ctx, exec, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "robot not found") }) t.Run("returns error when agent not found", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot-1", TeamID: "test-team-1", Config: &types.Config{ Identity: &types.Identity{Role: "Test"}, Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseGoals: "non.existent.agent", }, }, }, } exec := createGoalsTestExecution(robot, types.TriggerClock) exec.Inspiration = &types.InspirationReport{ Content: "Test content", } e := standard.New() err := e.RunGoals(ctx, exec, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "call failed") }) t.Run("returns error when no input available and no identity", func(t *testing.T) { // Robot without identity - should fail when no input is provided robot := &types.Robot{ MemberID: "test-robot-1", TeamID: "test-team-1", Config: &types.Config{ // No Identity - so no fallback content Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseGoals: "robot.goals", }, }, }, } exec := createGoalsTestExecution(robot, types.TriggerHuman) exec.Input = nil // No input e := standard.New() err := e.RunGoals(ctx, exec, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "no input available") }) } func TestRunGoalsFallbackBehavior(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("falls back to clock context when no inspiration report", func(t *testing.T) { robot := createGoalsTestRobot(t, "robot.goals") exec := createGoalsTestExecution(robot, types.TriggerClock) exec.Inspiration = nil // No inspiration report e := standard.New() err := e.RunGoals(ctx, exec, nil) // Should still work with fallback clock context require.NoError(t, err) require.NotNil(t, exec.Goals) assert.NotEmpty(t, exec.Goals.Content) }) } // ============================================================================ // Delivery Parsing Tests // ============================================================================ func TestParseDeliveryFromGoalsResponse(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("parses delivery when agent returns it", func(t *testing.T) { robot := createGoalsTestRobot(t, "robot.goals") exec := createGoalsTestExecution(robot, types.TriggerHuman) // Request that explicitly asks for email delivery exec.Input = &types.TriggerInput{ Action: "task.add", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Prepare a sales report and send it to team@example.com via email"}, }, } e := standard.New() err := e.RunGoals(ctx, exec, nil) require.NoError(t, err) require.NotNil(t, exec.Goals) assert.NotEmpty(t, exec.Goals.Content) // Delivery may or may not be present depending on LLM response // If present, verify structure if exec.Goals.Delivery != nil { // Type should be valid if present if exec.Goals.Delivery.Type != "" { validTypes := []types.DeliveryType{ types.DeliveryEmail, types.DeliveryWebhook, types.DeliveryProcess, types.DeliveryNotify, } found := false for _, vt := range validTypes { if exec.Goals.Delivery.Type == vt { found = true break } } // Note: LLM might return non-standard types, we accept them but log t.Logf("Delivery type: %s (valid: %v)", exec.Goals.Delivery.Type, found) } } }) } func TestDeliveryTypeValidation(t *testing.T) { t.Run("valid delivery types", func(t *testing.T) { validTypes := []types.DeliveryType{ types.DeliveryEmail, types.DeliveryWebhook, types.DeliveryProcess, types.DeliveryNotify, } for _, dt := range validTypes { assert.True(t, standard.IsValidDeliveryType(dt), "should be valid: %s", dt) } }) t.Run("invalid delivery types", func(t *testing.T) { invalidTypes := []types.DeliveryType{ "invalid", "sms", "", } for _, dt := range invalidTypes { assert.False(t, standard.IsValidDeliveryType(dt), "should be invalid: %s", dt) } }) } func TestParseDelivery(t *testing.T) { t.Run("parses valid delivery with all fields", func(t *testing.T) { data := map[string]interface{}{ "type": "email", "recipients": []interface{}{"user@example.com", "team@example.com"}, "format": "markdown", "template": "weekly-report", "options": map[string]interface{}{ "subject": "Weekly Report", }, } result := standard.ParseDelivery(data) require.NotNil(t, result) assert.Equal(t, types.DeliveryEmail, result.Type) assert.Equal(t, []string{"user@example.com", "team@example.com"}, result.Recipients) assert.Equal(t, "markdown", result.Format) assert.Equal(t, "weekly-report", result.Template) assert.Equal(t, "Weekly Report", result.Options["subject"]) }) t.Run("returns nil for nil data", func(t *testing.T) { result := standard.ParseDelivery(nil) assert.Nil(t, result) }) t.Run("returns nil for missing type", func(t *testing.T) { data := map[string]interface{}{ "recipients": []interface{}{"user@example.com"}, } result := standard.ParseDelivery(data) assert.Nil(t, result) }) t.Run("returns nil for empty type", func(t *testing.T) { data := map[string]interface{}{ "type": "", "recipients": []interface{}{"user@example.com"}, } result := standard.ParseDelivery(data) assert.Nil(t, result) }) t.Run("returns nil for invalid type", func(t *testing.T) { data := map[string]interface{}{ "type": "sms", "recipients": []interface{}{"user@example.com"}, } result := standard.ParseDelivery(data) assert.Nil(t, result) }) t.Run("handles missing optional fields", func(t *testing.T) { data := map[string]interface{}{ "type": "webhook", } result := standard.ParseDelivery(data) require.NotNil(t, result) assert.Equal(t, types.DeliveryWebhook, result.Type) assert.Empty(t, result.Recipients) assert.Empty(t, result.Format) assert.Empty(t, result.Template) assert.Nil(t, result.Options) }) t.Run("handles non-string recipients gracefully", func(t *testing.T) { data := map[string]interface{}{ "type": "email", "recipients": []interface{}{"valid@example.com", 123, nil, "another@example.com"}, } result := standard.ParseDelivery(data) require.NotNil(t, result) // Only string recipients should be included assert.Equal(t, []string{"valid@example.com", "another@example.com"}, result.Recipients) }) t.Run("parses all valid delivery types", func(t *testing.T) { validTypes := []string{"email", "webhook", "process", "notify"} for _, dt := range validTypes { data := map[string]interface{}{ "type": dt, } result := standard.ParseDelivery(data) require.NotNil(t, result, "should parse type: %s", dt) assert.Equal(t, types.DeliveryType(dt), result.Type) } }) } // ============================================================================ // InputFormatter Tests for P1 // ============================================================================ func TestInputFormatterFormatRobotIdentity(t *testing.T) { formatter := standard.NewInputFormatter() t.Run("formats robot identity correctly", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot", Config: &types.Config{ Identity: &types.Identity{ Role: "Sales Analyst", Duties: []string{"Analyze sales data", "Generate reports"}, Rules: []string{"Be accurate", "Be concise"}, }, }, } content := formatter.FormatRobotIdentity(robot) assert.Contains(t, content, "## Robot Identity") assert.Contains(t, content, "Sales Analyst") assert.Contains(t, content, "Analyze sales data") assert.Contains(t, content, "Generate reports") assert.Contains(t, content, "Be accurate") assert.Contains(t, content, "Be concise") }) t.Run("returns empty for nil robot", func(t *testing.T) { content := formatter.FormatRobotIdentity(nil) assert.Empty(t, content) }) t.Run("returns empty for robot without config", func(t *testing.T) { robot := &types.Robot{MemberID: "test"} content := formatter.FormatRobotIdentity(robot) assert.Empty(t, content) }) t.Run("returns empty for robot without identity", func(t *testing.T) { robot := &types.Robot{ MemberID: "test", Config: &types.Config{}, } content := formatter.FormatRobotIdentity(robot) assert.Empty(t, content) }) t.Run("handles identity with only role", func(t *testing.T) { robot := &types.Robot{ MemberID: "test", Config: &types.Config{ Identity: &types.Identity{ Role: "Simple Bot", }, }, } content := formatter.FormatRobotIdentity(robot) assert.Contains(t, content, "## Robot Identity") assert.Contains(t, content, "Simple Bot") assert.NotContains(t, content, "Duties") assert.NotContains(t, content, "Rules") }) } // ============================================================================ // Helper Functions // ============================================================================ // createGoalsTestRobot creates a test robot with specified goals agent // Includes available expert agents so the Goals Agent knows what resources are available // // Note: The agent IDs listed in Resources.Agents must exist in yao-dev-app/assistants/experts/ // Current available experts: data-analyst, summarizer, text-writer, web-reader func createGoalsTestRobot(t *testing.T, agentID string) *types.Robot { t.Helper() return &types.Robot{ MemberID: "test-robot-1", TeamID: "test-team-1", DisplayName: "Test Robot", Config: &types.Config{ Identity: &types.Identity{ Role: "Test Assistant", Duties: []string{"Testing", "Validation", "Data Analysis", "Report Generation"}, }, Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseGoals: agentID, }, // Available expert agents that can be delegated to // These IDs correspond to assistants in yao-dev-app/assistants/experts/ Agents: []string{ "experts.data-analyst", // Data analysis and insights "experts.summarizer", // Content summarization "experts.text-writer", // Report and document generation "experts.web-reader", // Web content extraction }, }, // Knowledge base collections (if any) KB: &types.KB{ Collections: []string{"test-knowledge"}, }, }, } } // createGoalsTestExecution creates a test execution for goals phase func createGoalsTestExecution(robot *types.Robot, trigger types.TriggerType) *types.Execution { exec := &types.Execution{ ID: "test-exec-goals-1", MemberID: robot.MemberID, TeamID: robot.TeamID, TriggerType: trigger, StartTime: time.Now(), Status: types.ExecRunning, Phase: types.PhaseGoals, } exec.SetRobot(robot) return exec } ================================================ FILE: agent/robot/executor/standard/host.go ================================================ package standard import ( "encoding/json" "fmt" kunlog "github.com/yaoapp/kun/log" robottypes "github.com/yaoapp/yao/agent/robot/types" ) // CallHostAgent calls the Host Agent with structured input and parses structured output. // The Host Agent mediates all human-robot interactions through three scenarios: // - "assign": new task assignment with multi-round confirmation // - "guide": guidance during execution // - "clarify": answering questions from waiting tasks func (e *Executor) CallHostAgent(ctx *robottypes.Context, robot *robottypes.Robot, input *robottypes.HostInput, chatID string) (*robottypes.HostOutput, error) { if robot == nil { return nil, fmt.Errorf("robot cannot be nil") } agentID := "" if robot.Config != nil && robot.Config.Resources != nil { agentID = robot.Config.Resources.GetPhaseAgent(robottypes.PhaseHost) } if agentID == "" { return nil, fmt.Errorf("no Host Agent configured for robot %s", robot.MemberID) } inputJSON, err := json.Marshal(input) if err != nil { return nil, fmt.Errorf("failed to marshal host input: %w", err) } kunlog.Info("calling Host Agent %s for scenario=%s chatID=%s", agentID, input.Scenario, chatID) caller := NewConversationCaller(chatID) result, err := caller.CallWithMessages(ctx, agentID, string(inputJSON)) if err != nil { return nil, fmt.Errorf("host agent (%s) call failed: %w", agentID, err) } data, err := result.GetJSON() if err != nil { text := result.GetText() kunlog.Warn("Host Agent returned non-JSON response, treating as confirm: %s", text) return &robottypes.HostOutput{ Reply: text, Action: robottypes.HostActionConfirm, }, nil } output := &robottypes.HostOutput{} raw, _ := json.Marshal(data) if err := json.Unmarshal(raw, output); err != nil { return &robottypes.HostOutput{ Reply: result.GetText(), Action: robottypes.HostActionConfirm, }, nil } return output, nil } ================================================ FILE: agent/robot/executor/standard/host_test.go ================================================ package standard_test import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/robot/executor/standard" robottypes "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) func hostTestAuth() *oauthtypes.AuthorizedInfo { return &oauthtypes.AuthorizedInfo{ UserID: "test-user-host", TeamID: "test-team-host", } } // H1: nil robot func TestCallHostAgent_NilRobot(t *testing.T) { e := standard.New() ctx := robottypes.NewContext(context.Background(), nil) _, err := e.CallHostAgent(ctx, nil, &robottypes.HostInput{Scenario: "assign"}, "chat-1") require.Error(t, err) assert.Contains(t, err.Error(), "robot cannot be nil") } // H2: no Host Agent configured func TestCallHostAgent_NoHostAgent(t *testing.T) { e := standard.New() ctx := robottypes.NewContext(context.Background(), nil) t.Run("nil config", func(t *testing.T) { robot := &robottypes.Robot{MemberID: "member-h2a"} _, err := e.CallHostAgent(ctx, robot, &robottypes.HostInput{Scenario: "assign"}, "chat-1") require.Error(t, err) assert.Contains(t, err.Error(), "no Host Agent configured") }) t.Run("nil resources", func(t *testing.T) { robot := &robottypes.Robot{ MemberID: "member-h2b", Config: &robottypes.Config{}, } _, err := e.CallHostAgent(ctx, robot, &robottypes.HostInput{Scenario: "assign"}, "chat-1") require.Error(t, err) assert.Contains(t, err.Error(), "no Host Agent configured") }) } // H3: valid JSON response from Host Agent func TestCallHostAgent_ValidJSONResponse(t *testing.T) { if testing.Short() { t.Skip("Requires assistant framework and LLM") } testutils.Prepare(t) defer testutils.Clean(t) e := standard.New() ctx := robottypes.NewContext(context.Background(), hostTestAuth()) robot := &robottypes.Robot{ MemberID: "member-h3", Config: &robottypes.Config{ Resources: &robottypes.Resources{ Phases: map[robottypes.Phase]string{ robottypes.PhaseHost: "tests.host-json", }, }, }, } input := &robottypes.HostInput{ Scenario: "assign", Context: &robottypes.HostContext{ RobotStatus: &robottypes.RobotStatusSnapshot{ActiveCount: 0, MaxQuota: 10}, }, } output, err := e.CallHostAgent(ctx, robot, input, "chat-h3") require.NoError(t, err, "CallHostAgent should not error for valid JSON host agent") require.NotNil(t, output, "output should not be nil") assert.NotEmpty(t, output.Reply, "reply should not be empty") assert.Equal(t, robottypes.HostActionConfirm, output.Action, "action should be 'confirm' for the JSON host agent") assert.False(t, output.WaitForMore, "wait_for_more should be false") } // H4: plain text response (non-JSON fallback) func TestCallHostAgent_PlaintextFallback(t *testing.T) { if testing.Short() { t.Skip("Requires assistant framework and LLM") } testutils.Prepare(t) defer testutils.Clean(t) e := standard.New() ctx := robottypes.NewContext(context.Background(), hostTestAuth()) robot := &robottypes.Robot{ MemberID: "member-h4", Config: &robottypes.Config{ Resources: &robottypes.Resources{ Phases: map[robottypes.Phase]string{ robottypes.PhaseHost: "tests.host-plaintext", }, }, }, } input := &robottypes.HostInput{ Scenario: "assign", Context: &robottypes.HostContext{ RobotStatus: &robottypes.RobotStatusSnapshot{ActiveCount: 0, MaxQuota: 10}, }, } output, err := e.CallHostAgent(ctx, robot, input, "chat-h4") require.NoError(t, err, "non-JSON response should fallback gracefully, not error") require.NotNil(t, output, "output should not be nil") assert.NotEmpty(t, output.Reply, "reply should contain the plaintext response") assert.Equal(t, robottypes.HostActionConfirm, output.Action, "action should fallback to 'confirm' for non-JSON response") } // H5: JSON with wrong structure (no action/reply fields) func TestCallHostAgent_BadJSONStructureFallback(t *testing.T) { if testing.Short() { t.Skip("Requires assistant framework and LLM") } testutils.Prepare(t) defer testutils.Clean(t) e := standard.New() ctx := robottypes.NewContext(context.Background(), hostTestAuth()) robot := &robottypes.Robot{ MemberID: "member-h5", Config: &robottypes.Config{ Resources: &robottypes.Resources{ Phases: map[robottypes.Phase]string{ robottypes.PhaseHost: "tests.host-badjson", }, }, }, } input := &robottypes.HostInput{ Scenario: "assign", Context: &robottypes.HostContext{ RobotStatus: &robottypes.RobotStatusSnapshot{ActiveCount: 0, MaxQuota: 10}, }, } output, err := e.CallHostAgent(ctx, robot, input, "chat-h5") require.NoError(t, err, "bad JSON structure should not error") require.NotNil(t, output, "output should not be nil") // The JSON is valid but has no action/reply fields. // json.Unmarshal won't error — Action will be zero value (""). // Verify the output is returned (either with empty action or fallback to confirm). if output.Action == "" { assert.Empty(t, output.Action, "action should be empty when JSON has no action field") } else { assert.Equal(t, robottypes.HostActionConfirm, output.Action, "action should be 'confirm' if fallback is triggered") } } // H6: assistant not found func TestCallHostAgent_AssistantNotFound(t *testing.T) { if testing.Short() { t.Skip("Requires assistant framework initialization") } testutils.Prepare(t) defer testutils.Clean(t) e := standard.New() ctx := robottypes.NewContext(context.Background(), hostTestAuth()) robot := &robottypes.Robot{ MemberID: "member-h6", Config: &robottypes.Config{ Resources: &robottypes.Resources{ Phases: map[robottypes.Phase]string{ robottypes.PhaseHost: "nonexistent-assistant", }, }, }, } input := &robottypes.HostInput{Scenario: "assign"} _, err := e.CallHostAgent(ctx, robot, input, "chat-h6") require.Error(t, err) assert.Contains(t, err.Error(), "host agent") } // H7: input marshalling verification (pure unit test, no LLM needed) func TestCallHostAgent_InputMarshalling(t *testing.T) { input := &robottypes.HostInput{ Scenario: "clarify", Context: &robottypes.HostContext{ RobotStatus: &robottypes.RobotStatusSnapshot{ ActiveCount: 2, MaxQuota: 5, }, AgentReply: "What format?", }, } assert.NotEmpty(t, input.Scenario) assert.NotNil(t, input.Context) assert.Equal(t, 2, input.Context.RobotStatus.ActiveCount) } ================================================ FILE: agent/robot/executor/standard/input.go ================================================ package standard import ( "context" "encoding/json" "fmt" "strings" "github.com/yaoapp/gou/mcp" "github.com/yaoapp/yao/agent/assistant" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" robottypes "github.com/yaoapp/yao/agent/robot/types" ) // InputFormatter provides methods to format input data for assistant prompts // Each phase has specific input requirements: // - P0 (Inspiration): ClockContext + Robot identity + Available resources // - P1 (Goals): InspirationReport/TriggerInput + Robot identity + Available resources // - P2 (Tasks): Goals + Available resources // - P3 (Run): Tasks // - P4 (Delivery): Task results // - P5 (Learning): Execution summary type InputFormatter struct{} // NewInputFormatter creates a new InputFormatter func NewInputFormatter() *InputFormatter { return &InputFormatter{} } // FormatClockContext formats ClockContext as user message content // Used by P0 (Inspiration) phase func (f *InputFormatter) FormatClockContext(clock *robottypes.ClockContext, robot *robottypes.Robot) string { if clock == nil { return "" } var sb strings.Builder // Time context section sb.WriteString("## Current Time Context\n\n") sb.WriteString(fmt.Sprintf("- **Now**: %s\n", clock.Now.Format("2006-01-02 15:04:05"))) sb.WriteString(fmt.Sprintf("- **Day**: %s\n", clock.DayOfWeek)) sb.WriteString(fmt.Sprintf("- **Date**: %d/%d/%d\n", clock.Year, clock.Month, clock.DayOfMonth)) sb.WriteString(fmt.Sprintf("- **Week**: %d of year\n", clock.WeekOfYear)) sb.WriteString(fmt.Sprintf("- **Hour**: %d\n", clock.Hour)) sb.WriteString(fmt.Sprintf("- **Timezone**: %s\n", clock.TZ)) // Time markers - show all markers with check/cross for context awareness sb.WriteString("\n### Time Markers\n") sb.WriteString(fmt.Sprintf("- %s Weekend\n", boolMark(clock.IsWeekend))) sb.WriteString(fmt.Sprintf("- %s Month Start (1st-3rd)\n", boolMark(clock.IsMonthStart))) sb.WriteString(fmt.Sprintf("- %s Month End (last 3 days)\n", boolMark(clock.IsMonthEnd))) sb.WriteString(fmt.Sprintf("- %s Quarter End\n", boolMark(clock.IsQuarterEnd))) sb.WriteString(fmt.Sprintf("- %s Year End\n", boolMark(clock.IsYearEnd))) // Robot identity section // Priority: Config.Identity > Robot fields (DisplayName, Bio, SystemPrompt) if robot != nil { if robot.Config != nil && robot.Config.Identity != nil { // Use structured identity from config sb.WriteString("\n## Robot Identity\n\n") sb.WriteString(fmt.Sprintf("- **Role**: %s\n", robot.Config.Identity.Role)) if len(robot.Config.Identity.Duties) > 0 { sb.WriteString("- **Duties**:\n") for _, duty := range robot.Config.Identity.Duties { sb.WriteString(fmt.Sprintf(" - %s\n", duty)) } } if len(robot.Config.Identity.Rules) > 0 { sb.WriteString("- **Rules**:\n") for _, rule := range robot.Config.Identity.Rules { sb.WriteString(fmt.Sprintf(" - %s\n", rule)) } } } else if robot.DisplayName != "" || robot.Bio != "" || robot.SystemPrompt != "" { // Fallback: build identity from Robot fields (from __yao.member table) sb.WriteString("\n## Robot Identity\n\n") if robot.DisplayName != "" { sb.WriteString(fmt.Sprintf("- **Role**: %s\n", robot.DisplayName)) } if robot.Bio != "" { sb.WriteString(fmt.Sprintf("- **Description**: %s\n", robot.Bio)) } if robot.SystemPrompt != "" { sb.WriteString(fmt.Sprintf("- **Instructions**:\n%s\n", robot.SystemPrompt)) } } } return sb.String() } // boolMark returns ✓ for true and ✗ for false func boolMark(v bool) string { if v { return "✓" } return "✗" } // FormatRobotIdentity formats robot identity as user message content // Used to provide context about the robot's role and duties func (f *InputFormatter) FormatRobotIdentity(robot *robottypes.Robot) string { if robot == nil || robot.Config == nil || robot.Config.Identity == nil { return "" } var sb strings.Builder identity := robot.Config.Identity sb.WriteString("## Robot Identity\n\n") sb.WriteString(fmt.Sprintf("- **Role**: %s\n", identity.Role)) if len(identity.Duties) > 0 { sb.WriteString("- **Duties**:\n") for _, duty := range identity.Duties { sb.WriteString(fmt.Sprintf(" - %s\n", duty)) } } if len(identity.Rules) > 0 { sb.WriteString("- **Rules**:\n") for _, rule := range identity.Rules { sb.WriteString(fmt.Sprintf(" - %s\n", rule)) } } return sb.String() } // FormatAvailableResources formats available resources (agents, MCP tools, KB, DB) as user message content // Used by P0 (Inspiration) and P1 (Goals) to inform the agent what tools are available // This is critical for generating achievable goals - without knowing available tools, // the agent might generate goals that cannot be accomplished func (f *InputFormatter) FormatAvailableResources(robot *robottypes.Robot) string { locale := "en" // default locale if robot != nil && robot.Config != nil { locale = robot.Config.GetDefaultLocale() } return f.FormatAvailableResourcesWithLocale(robot, locale) } // FormatAvailableResourcesWithLocale formats available resources with specific locale for i18n support func (f *InputFormatter) FormatAvailableResourcesWithLocale(robot *robottypes.Robot, locale string) string { if robot == nil || robot.Config == nil { return "" } var sb strings.Builder hasContent := false // Available Agents - with detailed information (name, description) if robot.Config.Resources != nil && len(robot.Config.Resources.Agents) > 0 { if !hasContent { sb.WriteString("## Available Resources\n\n") hasContent = true } sb.WriteString("### Agents\n") sb.WriteString("These are the AI assistants you can delegate tasks to:\n\n") for _, agentID := range robot.Config.Resources.Agents { // Try to get agent details ast, err := assistant.Get(agentID) if err != nil { // Fallback to just ID if agent not found sb.WriteString(fmt.Sprintf("- **%s**\n", agentID)) continue } // Get localized name and description name := i18n.Translate(agentID, locale, ast.Name).(string) description := "" if ast.Description != "" { description = i18n.Translate(agentID, locale, ast.Description).(string) } // Format agent info sb.WriteString(fmt.Sprintf("- **%s** (`%s`)\n", name, agentID)) if description != "" { sb.WriteString(fmt.Sprintf(" - %s\n", description)) } if ast.Capabilities != "" { capabilities := i18n.Translate(agentID, locale, ast.Capabilities).(string) sb.WriteString(fmt.Sprintf(" - **Capabilities**: %s\n", capabilities)) } } sb.WriteString("\n") } // Available MCP Tools - with detailed tool information if robot.Config.Resources != nil && len(robot.Config.Resources.MCP) > 0 { if !hasContent { sb.WriteString("## Available Resources\n\n") hasContent = true } sb.WriteString("### MCP Tools\n") sb.WriteString("These are the external tools and services you can use:\n\n") for _, mcpConfig := range robot.Config.Resources.MCP { // Try to get MCP client and list tools client, err := mcp.Select(mcpConfig.ID) if err != nil { // Fallback to basic info if client not found if len(mcpConfig.Tools) > 0 { sb.WriteString(fmt.Sprintf("- **%s**: %s\n", mcpConfig.ID, strings.Join(mcpConfig.Tools, ", "))) } else { sb.WriteString(fmt.Sprintf("- **%s**: all tools available\n", mcpConfig.ID)) } continue } // Get client info for name and description clientInfo := client.Info() clientName := mcpConfig.ID clientDesc := "" if clientInfo != nil { if clientInfo.Name != "" { clientName = clientInfo.Name } if clientInfo.Description != "" { clientDesc = clientInfo.Description } } // Write MCP header if clientDesc != "" { sb.WriteString(fmt.Sprintf("#### %s (`%s`)\n", clientName, mcpConfig.ID)) sb.WriteString(fmt.Sprintf("%s\n\n", clientDesc)) } else { sb.WriteString(fmt.Sprintf("#### %s (`%s`)\n\n", clientName, mcpConfig.ID)) } // Try to list tools with context ctx := context.Background() toolsResp, err := client.ListTools(ctx, "") if err != nil || toolsResp == nil { // Fallback to configured tools if len(mcpConfig.Tools) > 0 { sb.WriteString("Available tools: ") sb.WriteString(strings.Join(mcpConfig.Tools, ", ")) sb.WriteString("\n\n") } else { sb.WriteString("All tools available\n\n") } continue } // Filter tools if specific tools are configured toolsToShow := toolsResp.Tools if len(mcpConfig.Tools) > 0 { // Create a map for quick lookup allowedTools := make(map[string]bool) for _, t := range mcpConfig.Tools { allowedTools[t] = true } // Filter tools var filteredTools []struct { Name string Description string } for _, tool := range toolsResp.Tools { if allowedTools[tool.Name] { filteredTools = append(filteredTools, struct { Name string Description string }{tool.Name, tool.Description}) } } // Write filtered tools if len(filteredTools) > 0 { sb.WriteString("| Tool | Description |\n") sb.WriteString("|------|-------------|\n") for _, tool := range filteredTools { desc := tool.Description if len(desc) > 100 { desc = desc[:97] + "..." } // Escape pipe characters in description desc = strings.ReplaceAll(desc, "|", "\\|") sb.WriteString(fmt.Sprintf("| `%s` | %s |\n", tool.Name, desc)) } } else { sb.WriteString("Configured tools: ") sb.WriteString(strings.Join(mcpConfig.Tools, ", ")) } } else if len(toolsToShow) > 0 { // Show all available tools sb.WriteString("| Tool | Description |\n") sb.WriteString("|------|-------------|\n") for _, tool := range toolsToShow { desc := tool.Description if len(desc) > 100 { desc = desc[:97] + "..." } // Escape pipe characters in description desc = strings.ReplaceAll(desc, "|", "\\|") sb.WriteString(fmt.Sprintf("| `%s` | %s |\n", tool.Name, desc)) } } else { sb.WriteString("No tools available\n") } sb.WriteString("\n") } } // Available Knowledge Base if robot.Config.KB != nil && len(robot.Config.KB.Collections) > 0 { if !hasContent { sb.WriteString("## Available Resources\n\n") hasContent = true } sb.WriteString("### Knowledge Base\n") sb.WriteString("You have access to these knowledge collections:\n") for _, collection := range robot.Config.KB.Collections { sb.WriteString(fmt.Sprintf("- %s\n", collection)) } sb.WriteString("\n") } // Available Database Models if robot.Config.DB != nil && len(robot.Config.DB.Models) > 0 { if !hasContent { sb.WriteString("## Available Resources\n\n") hasContent = true } sb.WriteString("### Database\n") sb.WriteString("You can query these database models:\n") for _, model := range robot.Config.DB.Models { sb.WriteString(fmt.Sprintf("- %s\n", model)) } sb.WriteString("\n") } if !hasContent { return "" } sb.WriteString("**Important**: Only plan goals and tasks that can be accomplished with the above resources.\n") return sb.String() } // FormatInspirationReport formats InspirationReport as user message content // Used by P1 (Goals) phase when trigger is Clock func (f *InputFormatter) FormatInspirationReport(report *robottypes.InspirationReport) string { if report == nil { return "" } var sb strings.Builder // Clock context summary (if available) if report.Clock != nil { sb.WriteString("## Time Context\n\n") sb.WriteString(fmt.Sprintf("- **Time**: %s %s\n", report.Clock.DayOfWeek, report.Clock.Now.Format("15:04"))) sb.WriteString(fmt.Sprintf("- **Date**: %d/%d/%d\n", report.Clock.Year, report.Clock.Month, report.Clock.DayOfMonth)) // Add relevant time markers var markers []string if report.Clock.IsWeekend { markers = append(markers, "Weekend") } if report.Clock.IsMonthStart { markers = append(markers, "Month Start") } if report.Clock.IsMonthEnd { markers = append(markers, "Month End") } if report.Clock.IsQuarterEnd { markers = append(markers, "Quarter End") } if len(markers) > 0 { sb.WriteString(fmt.Sprintf("- **Markers**: %s\n", strings.Join(markers, ", "))) } sb.WriteString("\n") } // Inspiration content if report.Content != "" { sb.WriteString("## Inspiration Report\n\n") sb.WriteString(report.Content) sb.WriteString("\n") } return sb.String() } // FormatTriggerInput formats TriggerInput as user message content // Used by P1 (Goals) phase when trigger is Human or Event func (f *InputFormatter) FormatTriggerInput(input *robottypes.TriggerInput) string { if input == nil { return "" } var sb strings.Builder // Human intervention if input.Action != "" { sb.WriteString("## Human Intervention\n\n") sb.WriteString(fmt.Sprintf("- **Action**: %s\n", input.Action)) if input.UserID != "" { sb.WriteString(fmt.Sprintf("- **User**: %s\n", input.UserID)) } // Messages if len(input.Messages) > 0 { sb.WriteString("\n### User Input\n\n") for _, msg := range input.Messages { if content, ok := msg.Content.(string); ok { sb.WriteString(content) sb.WriteString("\n") } } } return sb.String() } // Event trigger if input.Source != "" { sb.WriteString("## Event Trigger\n\n") sb.WriteString(fmt.Sprintf("- **Source**: %s\n", input.Source)) sb.WriteString(fmt.Sprintf("- **Event Type**: %s\n", input.EventType)) // Event data if input.Data != nil { sb.WriteString("\n### Event Data\n\n") sb.WriteString("```json\n") if data, err := json.MarshalIndent(input.Data, "", " "); err == nil { sb.WriteString(string(data)) } sb.WriteString("\n```\n") } return sb.String() } return "" } // FormatGoals formats Goals as user message content // Used by P2 (Tasks) phase func (f *InputFormatter) FormatGoals(goals *robottypes.Goals, robot *robottypes.Robot) string { if goals == nil { return "" } var sb strings.Builder // Goals content sb.WriteString("## Goals\n\n") sb.WriteString(goals.Content) sb.WriteString("\n") // Delivery target (from P1) - important for task planning // Tasks should be designed to produce output suitable for the delivery method if goals.Delivery != nil { sb.WriteString("\n## Delivery Target\n\n") sb.WriteString(fmt.Sprintf("- **Type**: %s\n", goals.Delivery.Type)) if len(goals.Delivery.Recipients) > 0 { sb.WriteString(fmt.Sprintf("- **Recipients**: %s\n", strings.Join(goals.Delivery.Recipients, ", "))) } if goals.Delivery.Format != "" { sb.WriteString(fmt.Sprintf("- **Format**: %s\n", goals.Delivery.Format)) } if goals.Delivery.Template != "" { sb.WriteString(fmt.Sprintf("- **Template**: %s\n", goals.Delivery.Template)) } sb.WriteString("\n**Note**: Design tasks to produce output suitable for this delivery method.\n") } // Available resources - reuse FormatAvailableResources for consistency resourcesContent := f.FormatAvailableResources(robot) if resourcesContent != "" { sb.WriteString("\n") sb.WriteString(resourcesContent) } return sb.String() } // FormatTasks formats Tasks as user message content // Used by P3 (Run) phase func (f *InputFormatter) FormatTasks(tasks []robottypes.Task) string { if len(tasks) == 0 { return "No tasks to execute." } var sb strings.Builder sb.WriteString("## Tasks to Execute\n\n") for i, task := range tasks { sb.WriteString(fmt.Sprintf("### Task %d: %s\n\n", i+1, task.ID)) sb.WriteString(fmt.Sprintf("- **Goal Reference**: %s\n", task.GoalRef)) sb.WriteString(fmt.Sprintf("- **Source**: %s\n", task.Source)) sb.WriteString(fmt.Sprintf("- **Executor**: %s (%s)\n", task.ExecutorID, task.ExecutorType)) // Task content if len(task.Messages) > 0 { sb.WriteString("\n**Instructions**:\n") for _, msg := range task.Messages { if content, ok := msg.Content.(string); ok { sb.WriteString(content) sb.WriteString("\n") } } } // Arguments if len(task.Args) > 0 { sb.WriteString("\n**Arguments**:\n") if args, err := json.MarshalIndent(task.Args, "", " "); err == nil { sb.WriteString("```json\n") sb.WriteString(string(args)) sb.WriteString("\n```\n") } } sb.WriteString("\n") } return sb.String() } // FormatTaskResults formats TaskResults as user message content // Used by P4 (Delivery) and P5 (Learning) phases func (f *InputFormatter) FormatTaskResults(results []robottypes.TaskResult) string { if len(results) == 0 { return "No task results." } var sb strings.Builder sb.WriteString("## Task Results\n\n") successCount := 0 failCount := 0 validatedPassedCount := 0 validatedTotalCount := 0 for _, result := range results { if result.Success { successCount++ } else { failCount++ } if result.Validation != nil { validatedTotalCount++ if result.Validation.Passed { validatedPassedCount++ } } sb.WriteString(fmt.Sprintf("### Task: %s\n\n", result.TaskID)) if result.Success { sb.WriteString("- **Status**: ✓ Success\n") } else { sb.WriteString("- **Status**: ✗ Failed\n") } sb.WriteString(fmt.Sprintf("- **Duration**: %dms\n", result.Duration)) // Validation result (P3) if result.Validation != nil { if result.Validation.Passed { sb.WriteString(fmt.Sprintf("- **Validation**: ✓ Passed (score: %.2f)\n", result.Validation.Score)) } else { sb.WriteString("- **Validation**: ✗ Failed\n") if len(result.Validation.Issues) > 0 { sb.WriteString(" - Issues:\n") for _, issue := range result.Validation.Issues { sb.WriteString(fmt.Sprintf(" - %s\n", issue)) } } } } // Output if result.Output != nil { sb.WriteString("\n**Output**:\n") if output, err := json.MarshalIndent(result.Output, "", " "); err == nil { sb.WriteString("```json\n") sb.WriteString(string(output)) sb.WriteString("\n```\n") } else { sb.WriteString(fmt.Sprintf("%v\n", result.Output)) } } // Error if result.Error != "" { sb.WriteString(fmt.Sprintf("\n**Error**: %s\n", result.Error)) } sb.WriteString("\n") } // Summary sb.WriteString(fmt.Sprintf("## Summary\n\n- Total: %d tasks\n- Success: %d\n- Failed: %d\n- Validated: %d/%d\n", len(results), successCount, failCount, validatedPassedCount, validatedTotalCount)) return sb.String() } // FormatExecutionSummary formats the entire execution for P5 (Learning) phase func (f *InputFormatter) FormatExecutionSummary(exec *robottypes.Execution) string { if exec == nil { return "" } var sb strings.Builder // Execution metadata sb.WriteString("## Execution Summary\n\n") sb.WriteString(fmt.Sprintf("- **ID**: %s\n", exec.ID)) sb.WriteString(fmt.Sprintf("- **Trigger**: %s\n", exec.TriggerType)) sb.WriteString(fmt.Sprintf("- **Status**: %s\n", exec.Status)) sb.WriteString(fmt.Sprintf("- **Start Time**: %s\n", exec.StartTime.Format("2006-01-02 15:04:05"))) if exec.EndTime != nil { sb.WriteString(fmt.Sprintf("- **End Time**: %s\n", exec.EndTime.Format("2006-01-02 15:04:05"))) duration := exec.EndTime.Sub(exec.StartTime) sb.WriteString(fmt.Sprintf("- **Duration**: %s\n", duration.String())) } if exec.Error != "" { sb.WriteString(fmt.Sprintf("- **Error**: %s\n", exec.Error)) } sb.WriteString("\n") // Inspiration (P0) if exec.Inspiration != nil && exec.Inspiration.Content != "" { sb.WriteString("## Inspiration (P0)\n\n") sb.WriteString(exec.Inspiration.Content) sb.WriteString("\n\n") } // Goals (P1) if exec.Goals != nil && exec.Goals.Content != "" { sb.WriteString("## Goals (P1)\n\n") sb.WriteString(exec.Goals.Content) sb.WriteString("\n\n") } // Tasks (P2) if len(exec.Tasks) > 0 { sb.WriteString("## Tasks (P2)\n\n") for i, task := range exec.Tasks { sb.WriteString(fmt.Sprintf("%d. [%s] %s (executor: %s)\n", i+1, task.Status, task.ID, task.ExecutorID)) } sb.WriteString("\n") } // Results (P3) if len(exec.Results) > 0 { sb.WriteString("## Results (P3)\n\n") for _, result := range exec.Results { status := "✓" if !result.Success { status = "✗" } sb.WriteString(fmt.Sprintf("- %s %s (%dms)\n", status, result.TaskID, result.Duration)) } sb.WriteString("\n") } // Delivery (P4) if exec.Delivery != nil { sb.WriteString("## Delivery (P4)\n\n") if exec.Delivery.Content != nil { sb.WriteString(fmt.Sprintf("- **Summary**: %s\n", exec.Delivery.Content.Summary)) } if exec.Delivery.Success { sb.WriteString("- **Status**: ✓ Success\n") } else { sb.WriteString(fmt.Sprintf("- **Status**: ✗ Failed (%s)\n", exec.Delivery.Error)) } if len(exec.Delivery.Results) > 0 { sb.WriteString(fmt.Sprintf("- **Channels**: %d\n", len(exec.Delivery.Results))) } sb.WriteString("\n") } return sb.String() } // BuildMessages is a convenience method to build messages array from content func (f *InputFormatter) BuildMessages(userContent string) []agentcontext.Message { return []agentcontext.Message{ { Role: agentcontext.RoleUser, Content: userContent, }, } } // BuildMessagesWithSystem builds messages array with system and user content func (f *InputFormatter) BuildMessagesWithSystem(systemContent, userContent string) []agentcontext.Message { return []agentcontext.Message{ { Role: agentcontext.RoleSystem, Content: systemContent, }, { Role: agentcontext.RoleUser, Content: userContent, }, } } ================================================ FILE: agent/robot/executor/standard/input_integration_test.go ================================================ package standard_test import ( "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ============================================================================ // InputFormatter Integration Tests with Real Data // These tests use the yao-dev-app environment with real assistants and MCPs // ============================================================================ func TestFormatAvailableResourcesIntegration(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) formatter := standard.NewInputFormatter() t.Run("formats_real_agents_with_details", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot-agents", Config: &types.Config{ Identity: &types.Identity{ Role: "Test Robot", Duties: []string{"Testing agent formatting"}, }, DefaultLocale: "en", Resources: &types.Resources{ // Use real agents from yao-dev-app/assistants Agents: []string{ "experts.data-analyst", "experts.text-writer", "experts.summarizer", }, }, }, } result := formatter.FormatAvailableResources(robot) // Verify structure assert.Contains(t, result, "## Available Resources") assert.Contains(t, result, "### Agents") assert.Contains(t, result, "These are the AI assistants you can delegate tasks to:") // Verify real agent details are included // experts.data-analyst should show name and description assert.Contains(t, result, "experts.data-analyst") assert.Contains(t, result, "Data Analyst Expert") // Name from package.yao // experts.text-writer assert.Contains(t, result, "experts.text-writer") // experts.summarizer assert.Contains(t, result, "experts.summarizer") // Verify important note is present assert.Contains(t, result, "Only plan goals and tasks that can be accomplished") t.Logf("Formatted agents result:\n%s", result) }) t.Run("formats_real_mcp_with_tool_details", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot-mcp", Config: &types.Config{ Identity: &types.Identity{ Role: "Test Robot", Duties: []string{"Testing MCP formatting"}, }, DefaultLocale: "en", Resources: &types.Resources{ // Use real MCPs from yao-dev-app/mcps MCP: []types.MCPConfig{ {ID: "echo", Tools: []string{"ping", "status"}}, // Specific tools {ID: "echo"}, // All tools }, }, }, } result := formatter.FormatAvailableResources(robot) // Verify structure assert.Contains(t, result, "## Available Resources") assert.Contains(t, result, "### MCP Tools") assert.Contains(t, result, "These are the external tools and services you can use:") // Verify MCP details assert.Contains(t, result, "echo") // Verify important note is present assert.Contains(t, result, "Only plan goals and tasks that can be accomplished") t.Logf("Formatted MCP result:\n%s", result) }) t.Run("formats_combined_resources", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot-combined", Config: &types.Config{ Identity: &types.Identity{ Role: "Sales Analyst Robot", Duties: []string{"Analyze sales data", "Generate reports"}, Rules: []string{"Be accurate", "Be concise"}, }, DefaultLocale: "en", Resources: &types.Resources{ Agents: []string{ "experts.data-analyst", "experts.summarizer", }, MCP: []types.MCPConfig{ {ID: "echo", Tools: []string{"ping", "echo"}}, }, }, KB: &types.KB{ Collections: []string{"sales-policies", "product-catalog"}, }, DB: &types.DB{ Models: []string{"sales", "customers", "orders"}, }, }, } result := formatter.FormatAvailableResources(robot) // Verify all sections are present assert.Contains(t, result, "## Available Resources") assert.Contains(t, result, "### Agents") assert.Contains(t, result, "### MCP Tools") assert.Contains(t, result, "### Knowledge Base") assert.Contains(t, result, "### Database") // Verify agents assert.Contains(t, result, "experts.data-analyst") assert.Contains(t, result, "experts.summarizer") // Verify MCP assert.Contains(t, result, "echo") // Verify KB assert.Contains(t, result, "sales-policies") assert.Contains(t, result, "product-catalog") // Verify DB assert.Contains(t, result, "sales") assert.Contains(t, result, "customers") assert.Contains(t, result, "orders") t.Logf("Formatted combined resources result:\n%s", result) }) t.Run("handles_locale_zh", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot-zh", Config: &types.Config{ Identity: &types.Identity{ Role: "测试机器人", Duties: []string{"测试国际化"}, }, DefaultLocale: "zh-cn", Resources: &types.Resources{ // Use agents that have zh-cn locales Agents: []string{ "hello", // This agent has locales/zh-cn.yml "mohe", // This agent also has locales/zh-cn.yml }, }, }, } result := formatter.FormatAvailableResourcesWithLocale(robot, "zh-cn") // Verify structure assert.Contains(t, result, "## Available Resources") assert.Contains(t, result, "### Agents") // Verify agents are listed assert.Contains(t, result, "hello") assert.Contains(t, result, "mohe") t.Logf("Formatted zh-cn result:\n%s", result) }) t.Run("gracefully_handles_missing_agents", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot-missing", Config: &types.Config{ Identity: &types.Identity{ Role: "Test Robot", }, Resources: &types.Resources{ Agents: []string{ "non-existent-agent", "experts.data-analyst", // This one exists "another-missing-agent", }, }, }, } result := formatter.FormatAvailableResources(robot) // Should not panic, should include fallback for missing agents assert.Contains(t, result, "## Available Resources") assert.Contains(t, result, "### Agents") // Missing agents should still be listed with just ID assert.Contains(t, result, "non-existent-agent") assert.Contains(t, result, "another-missing-agent") // Existing agent should have full details assert.Contains(t, result, "experts.data-analyst") assert.Contains(t, result, "Data Analyst Expert") t.Logf("Formatted with missing agents:\n%s", result) }) t.Run("gracefully_handles_missing_mcp", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot-missing-mcp", Config: &types.Config{ Identity: &types.Identity{ Role: "Test Robot", }, Resources: &types.Resources{ MCP: []types.MCPConfig{ {ID: "non-existent-mcp", Tools: []string{"tool1", "tool2"}}, {ID: "echo"}, // This one exists }, }, }, } result := formatter.FormatAvailableResources(robot) // Should not panic, should include fallback for missing MCP assert.Contains(t, result, "## Available Resources") assert.Contains(t, result, "### MCP Tools") // Missing MCP should still be listed with fallback assert.Contains(t, result, "non-existent-mcp") assert.Contains(t, result, "tool1, tool2") // Existing MCP should have details assert.Contains(t, result, "echo") t.Logf("Formatted with missing MCP:\n%s", result) }) } func TestFormatAvailableResourcesTableFormat(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) formatter := standard.NewInputFormatter() t.Run("mcp_tools_in_table_format", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot-table", Config: &types.Config{ Identity: &types.Identity{ Role: "Test Robot", }, Resources: &types.Resources{ MCP: []types.MCPConfig{ {ID: "echo"}, // All tools - should show table }, }, }, } result := formatter.FormatAvailableResources(robot) // Check if table format is used when tools are available // Table headers: | Tool | Description | if strings.Contains(result, "| Tool | Description |") { assert.Contains(t, result, "|------|-------------|") t.Logf("MCP tools displayed in table format:\n%s", result) } else { // Fallback format t.Logf("MCP tools displayed in fallback format:\n%s", result) } }) } func TestFormatClockContextWithRobotIntegration(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) formatter := standard.NewInputFormatter() t.Run("full_context_for_inspiration", func(t *testing.T) { // Create a realistic robot configuration robot := &types.Robot{ MemberID: "sales-robot-001", TeamID: "team-001", DisplayName: "Sales Analyst Robot", AutonomousMode: true, Config: &types.Config{ Identity: &types.Identity{ Role: "Sales Analyst", Duties: []string{"Monitor sales performance", "Generate daily reports", "Alert on anomalies"}, Rules: []string{"Only use approved data sources", "Maintain confidentiality"}, }, DefaultLocale: "en", Resources: &types.Resources{ Agents: []string{ "experts.data-analyst", "experts.summarizer", }, MCP: []types.MCPConfig{ {ID: "echo", Tools: []string{"ping"}}, }, }, }, } // Create clock context clock := types.NewClockContext(time.Now(), "UTC") // Format clock context (includes robot identity) clockContent := formatter.FormatClockContext(clock, robot) // Format available resources resourcesContent := formatter.FormatAvailableResources(robot) // Combine for full context (as done in inspiration.go) fullContext := clockContent + "\n\n" + resourcesContent // Verify full context contains all necessary information require.NotEmpty(t, fullContext) // Time context assert.Contains(t, fullContext, "## Current Time Context") assert.Contains(t, fullContext, "### Time Markers") // Robot identity assert.Contains(t, fullContext, "## Robot Identity") assert.Contains(t, fullContext, "Sales Analyst") assert.Contains(t, fullContext, "Monitor sales performance") assert.Contains(t, fullContext, "Only use approved data sources") // Available resources assert.Contains(t, fullContext, "## Available Resources") assert.Contains(t, fullContext, "### Agents") assert.Contains(t, fullContext, "experts.data-analyst") assert.Contains(t, fullContext, "### MCP Tools") t.Logf("Full context for inspiration:\n%s", fullContext) }) } ================================================ FILE: agent/robot/executor/standard/input_test.go ================================================ package standard_test import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/types" ) // ============================================================================ // InputFormatter Tests // ============================================================================ func TestInputFormatterFormatClockContext(t *testing.T) { formatter := standard.NewInputFormatter() t.Run("formats clock context with all fields", func(t *testing.T) { now := time.Date(2024, 1, 15, 9, 30, 0, 0, time.UTC) clock := types.NewClockContext(now, "UTC") result := formatter.FormatClockContext(clock, nil) assert.Contains(t, result, "## Current Time Context") assert.Contains(t, result, "2024-01-15 09:30:00") assert.Contains(t, result, "Monday") assert.Contains(t, result, "UTC") assert.Contains(t, result, "**Hour**: 9") assert.Contains(t, result, "### Time Markers") }) t.Run("shows all time markers with check/cross", func(t *testing.T) { // Regular weekday, not month start/end now := time.Date(2024, 1, 15, 14, 0, 0, 0, time.UTC) clock := types.NewClockContext(now, "UTC") result := formatter.FormatClockContext(clock, nil) // Should show all markers, even when false assert.Contains(t, result, "✗ Weekend") assert.Contains(t, result, "✗ Month Start") assert.Contains(t, result, "✗ Month End") assert.Contains(t, result, "✗ Quarter End") assert.Contains(t, result, "✗ Year End") }) t.Run("includes robot identity when provided", func(t *testing.T) { now := time.Now() clock := types.NewClockContext(now, "UTC") robot := &types.Robot{ MemberID: "test-robot", Config: &types.Config{ Identity: &types.Identity{ Role: "Sales Analyst", Duties: []string{"Analyze sales data", "Generate reports"}, Rules: []string{"Be accurate", "Be concise"}, }, }, } result := formatter.FormatClockContext(clock, robot) assert.Contains(t, result, "## Robot Identity") assert.Contains(t, result, "Sales Analyst") assert.Contains(t, result, "Analyze sales data") assert.Contains(t, result, "Be accurate") }) t.Run("returns empty for nil clock", func(t *testing.T) { result := formatter.FormatClockContext(nil, nil) assert.Empty(t, result) }) t.Run("uses DisplayName/Bio/SystemPrompt when Identity is nil", func(t *testing.T) { now := time.Now() clock := types.NewClockContext(now, "UTC") robot := &types.Robot{ MemberID: "test-robot", DisplayName: "SEO Specialist", Bio: "Focuses on content optimization", SystemPrompt: "You are an SEO assistant.\n\n## Core Duties\n- Analyze keywords", Config: &types.Config{}, // Identity is nil } result := formatter.FormatClockContext(clock, robot) assert.Contains(t, result, "## Robot Identity") assert.Contains(t, result, "SEO Specialist") assert.Contains(t, result, "Focuses on content optimization") assert.Contains(t, result, "SEO assistant") assert.Contains(t, result, "Analyze keywords") }) t.Run("marks weekend correctly", func(t *testing.T) { // Saturday saturday := time.Date(2024, 1, 13, 10, 0, 0, 0, time.UTC) clock := types.NewClockContext(saturday, "UTC") result := formatter.FormatClockContext(clock, nil) assert.Contains(t, result, "✓ Weekend") }) t.Run("marks month start correctly", func(t *testing.T) { // 2nd of month monthStart := time.Date(2024, 1, 2, 10, 0, 0, 0, time.UTC) clock := types.NewClockContext(monthStart, "UTC") result := formatter.FormatClockContext(clock, nil) assert.Contains(t, result, "✓ Month Start") }) t.Run("marks month end correctly", func(t *testing.T) { // 30th of January (last 3 days) monthEnd := time.Date(2024, 1, 30, 10, 0, 0, 0, time.UTC) clock := types.NewClockContext(monthEnd, "UTC") result := formatter.FormatClockContext(clock, nil) assert.Contains(t, result, "✓ Month End") }) } func TestInputFormatterFormatInspirationReport(t *testing.T) { formatter := standard.NewInputFormatter() t.Run("formats inspiration report with clock", func(t *testing.T) { now := time.Date(2024, 1, 15, 9, 30, 0, 0, time.UTC) clock := types.NewClockContext(now, "UTC") report := &types.InspirationReport{ Clock: clock, Content: "Today is a good day to analyze sales data.", } result := formatter.FormatInspirationReport(report) assert.Contains(t, result, "## Time Context") assert.Contains(t, result, "Monday") assert.Contains(t, result, "## Inspiration Report") assert.Contains(t, result, "analyze sales data") }) t.Run("formats inspiration report without clock", func(t *testing.T) { report := &types.InspirationReport{ Content: "Focus on quarterly review.", } result := formatter.FormatInspirationReport(report) assert.NotContains(t, result, "## Time Context") assert.Contains(t, result, "## Inspiration Report") assert.Contains(t, result, "quarterly review") }) t.Run("returns empty for nil report", func(t *testing.T) { result := formatter.FormatInspirationReport(nil) assert.Empty(t, result) }) } func TestInputFormatterFormatAvailableResources(t *testing.T) { formatter := standard.NewInputFormatter() t.Run("formats all resource types", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot", Config: &types.Config{ Resources: &types.Resources{ Agents: []string{"data-analyst", "chart-gen", "report-writer"}, MCP: []types.MCPConfig{ {ID: "database", Tools: []string{"query", "insert"}}, {ID: "email", Tools: []string{}}, // all tools }, }, KB: &types.KB{ Collections: []string{"sales-policies", "products"}, }, DB: &types.DB{ Models: []string{"sales", "customers", "orders"}, }, }, } result := formatter.FormatAvailableResources(robot) // Check structure assert.Contains(t, result, "## Available Resources") // Check agents assert.Contains(t, result, "### Agents") assert.Contains(t, result, "data-analyst") assert.Contains(t, result, "chart-gen") assert.Contains(t, result, "report-writer") // Check MCP tools assert.Contains(t, result, "### MCP Tools") assert.Contains(t, result, "database") assert.Contains(t, result, "query, insert") assert.Contains(t, result, "email") assert.Contains(t, result, "all tools available") // Check KB assert.Contains(t, result, "### Knowledge Base") assert.Contains(t, result, "sales-policies") assert.Contains(t, result, "products") // Check DB assert.Contains(t, result, "### Database") assert.Contains(t, result, "sales") assert.Contains(t, result, "customers") assert.Contains(t, result, "orders") // Check important note assert.Contains(t, result, "Only plan goals and tasks that can be accomplished") }) t.Run("returns empty for nil robot", func(t *testing.T) { result := formatter.FormatAvailableResources(nil) assert.Empty(t, result) }) t.Run("returns empty for robot without config", func(t *testing.T) { robot := &types.Robot{MemberID: "test"} result := formatter.FormatAvailableResources(robot) assert.Empty(t, result) }) t.Run("returns empty for robot without resources", func(t *testing.T) { robot := &types.Robot{ MemberID: "test", Config: &types.Config{}, } result := formatter.FormatAvailableResources(robot) assert.Empty(t, result) }) t.Run("handles partial resources", func(t *testing.T) { robot := &types.Robot{ MemberID: "test", Config: &types.Config{ Resources: &types.Resources{ Agents: []string{"single-agent"}, }, }, } result := formatter.FormatAvailableResources(robot) assert.Contains(t, result, "## Available Resources") assert.Contains(t, result, "### Agents") assert.Contains(t, result, "single-agent") assert.NotContains(t, result, "### MCP Tools") assert.NotContains(t, result, "### Knowledge Base") assert.NotContains(t, result, "### Database") }) } func TestInputFormatterFormatTriggerInput(t *testing.T) { formatter := standard.NewInputFormatter() t.Run("formats human intervention", func(t *testing.T) { input := &types.TriggerInput{ Action: "task.add", UserID: "user-123", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Please add a task to review Q4 sales"}, }, } result := formatter.FormatTriggerInput(input) assert.Contains(t, result, "## Human Intervention") assert.Contains(t, result, "task.add") assert.Contains(t, result, "user-123") assert.Contains(t, result, "### User Input") assert.Contains(t, result, "review Q4 sales") }) t.Run("formats event trigger", func(t *testing.T) { input := &types.TriggerInput{ Source: "webhook", EventType: "order.created", Data: map[string]interface{}{ "order_id": "12345", "amount": 99.99, }, } result := formatter.FormatTriggerInput(input) assert.Contains(t, result, "## Event Trigger") assert.Contains(t, result, "webhook") assert.Contains(t, result, "order.created") assert.Contains(t, result, "### Event Data") assert.Contains(t, result, "order_id") assert.Contains(t, result, "12345") }) t.Run("returns empty for nil input", func(t *testing.T) { result := formatter.FormatTriggerInput(nil) assert.Empty(t, result) }) t.Run("returns empty for empty input", func(t *testing.T) { input := &types.TriggerInput{} result := formatter.FormatTriggerInput(input) assert.Empty(t, result) }) } func TestInputFormatterFormatGoals(t *testing.T) { formatter := standard.NewInputFormatter() t.Run("formats goals with resources", func(t *testing.T) { goals := &types.Goals{ Content: "1. Analyze sales data\n2. Generate report\n3. Send to stakeholders", } robot := &types.Robot{ MemberID: "test-robot", Config: &types.Config{ Resources: &types.Resources{ Agents: []string{"data-analyzer", "report-generator"}, MCP: []types.MCPConfig{ {ID: "database", Tools: []string{"query", "insert"}}, {ID: "email"}, }, }, }, } result := formatter.FormatGoals(goals, robot) assert.Contains(t, result, "## Goals") assert.Contains(t, result, "Analyze sales data") assert.Contains(t, result, "## Available Resources") assert.Contains(t, result, "### Agents") assert.Contains(t, result, "data-analyzer") assert.Contains(t, result, "### MCP Tools") assert.Contains(t, result, "database") assert.Contains(t, result, "query, insert") assert.Contains(t, result, "email") assert.Contains(t, result, "all tools available") }) t.Run("formats goals without robot", func(t *testing.T) { goals := &types.Goals{ Content: "Complete the task.", } result := formatter.FormatGoals(goals, nil) assert.Contains(t, result, "## Goals") assert.Contains(t, result, "Complete the task") assert.NotContains(t, result, "## Available Resources") }) t.Run("returns empty for nil goals", func(t *testing.T) { result := formatter.FormatGoals(nil, nil) assert.Empty(t, result) }) } func TestInputFormatterFormatTasks(t *testing.T) { formatter := standard.NewInputFormatter() t.Run("formats multiple tasks", func(t *testing.T) { tasks := []types.Task{ { ID: "task-1", GoalRef: "goal-1", Source: types.TaskSourceAuto, ExecutorType: types.ExecutorMCP, ExecutorID: "database.query", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Query sales data for Q4"}, }, Args: []any{"sales", "Q4"}, }, { ID: "task-2", GoalRef: "goal-1", Source: types.TaskSourceAuto, ExecutorType: types.ExecutorAssistant, ExecutorID: "report-generator", }, } result := formatter.FormatTasks(tasks) assert.Contains(t, result, "## Tasks to Execute") assert.Contains(t, result, "### Task 1: task-1") assert.Contains(t, result, "goal-1") assert.Contains(t, result, "database.query") assert.Contains(t, result, "**Instructions**") assert.Contains(t, result, "Query sales data") assert.Contains(t, result, "**Arguments**") assert.Contains(t, result, "### Task 2: task-2") assert.Contains(t, result, "report-generator") }) t.Run("returns message for empty tasks", func(t *testing.T) { result := formatter.FormatTasks(nil) assert.Equal(t, "No tasks to execute.", result) result = formatter.FormatTasks([]types.Task{}) assert.Equal(t, "No tasks to execute.", result) }) } func TestInputFormatterFormatTaskResults(t *testing.T) { formatter := standard.NewInputFormatter() t.Run("formats task results with summary", func(t *testing.T) { results := []types.TaskResult{ { TaskID: "task-1", Success: true, Duration: 150, Validation: &types.ValidationResult{ Passed: true, Score: 0.95, }, Output: map[string]interface{}{"rows": 100}, }, { TaskID: "task-2", Success: false, Duration: 50, Validation: &types.ValidationResult{ Passed: false, Issues: []string{"Connection timeout"}, }, Error: "Connection timeout", }, } result := formatter.FormatTaskResults(results) assert.Contains(t, result, "## Task Results") assert.Contains(t, result, "### Task: task-1") assert.Contains(t, result, "✓ Success") assert.Contains(t, result, "150ms") assert.Contains(t, result, "**Validation**: ✓ Passed") assert.Contains(t, result, "score: 0.95") assert.Contains(t, result, "**Output**") assert.Contains(t, result, "### Task: task-2") assert.Contains(t, result, "✗ Failed") assert.Contains(t, result, "**Validation**: ✗ Failed") assert.Contains(t, result, "Connection timeout") assert.Contains(t, result, "## Summary") assert.Contains(t, result, "Total: 2 tasks") assert.Contains(t, result, "Success: 1") assert.Contains(t, result, "Failed: 1") assert.Contains(t, result, "Validated: 1/2") }) t.Run("returns message for empty results", func(t *testing.T) { result := formatter.FormatTaskResults(nil) assert.Equal(t, "No task results.", result) }) } func TestInputFormatterFormatExecutionSummary(t *testing.T) { formatter := standard.NewInputFormatter() t.Run("formats complete execution summary", func(t *testing.T) { startTime := time.Date(2024, 1, 15, 9, 0, 0, 0, time.UTC) endTime := time.Date(2024, 1, 15, 9, 5, 0, 0, time.UTC) exec := &types.Execution{ ID: "exec-123", TriggerType: types.TriggerClock, Status: types.ExecCompleted, StartTime: startTime, EndTime: &endTime, Inspiration: &types.InspirationReport{ Content: "Morning analysis suggests high activity.", }, Goals: &types.Goals{ Content: "1. Review data\n2. Generate report", }, Tasks: []types.Task{ {ID: "t1", Status: types.TaskCompleted, ExecutorID: "db.query"}, {ID: "t2", Status: types.TaskCompleted, ExecutorID: "report.gen"}, }, Results: []types.TaskResult{ {TaskID: "t1", Success: true, Duration: 100}, {TaskID: "t2", Success: true, Duration: 200}, }, Delivery: &types.DeliveryResult{ RequestID: "test-delivery-001", Content: &types.DeliveryContent{ Summary: "Test delivery completed", Body: "# Test Delivery\n\nTest delivery body.", }, Success: true, }, } result := formatter.FormatExecutionSummary(exec) assert.Contains(t, result, "## Execution Summary") assert.Contains(t, result, "exec-123") assert.Contains(t, result, "clock") assert.Contains(t, result, "completed") assert.Contains(t, result, "**Duration**:") assert.Contains(t, result, "## Inspiration (P0)") assert.Contains(t, result, "Morning analysis") assert.Contains(t, result, "## Goals (P1)") assert.Contains(t, result, "Review data") assert.Contains(t, result, "## Tasks (P2)") assert.Contains(t, result, "db.query") assert.Contains(t, result, "## Results (P3)") assert.Contains(t, result, "✓ t1") assert.Contains(t, result, "## Delivery (P4)") assert.Contains(t, result, "Test delivery completed") }) t.Run("formats execution with error", func(t *testing.T) { startTime := time.Now() exec := &types.Execution{ ID: "exec-456", TriggerType: types.TriggerHuman, Status: types.ExecFailed, StartTime: startTime, Error: "Task execution failed", } result := formatter.FormatExecutionSummary(exec) assert.Contains(t, result, "exec-456") assert.Contains(t, result, "failed") assert.Contains(t, result, "**Error**: Task execution failed") }) t.Run("returns empty for nil execution", func(t *testing.T) { result := formatter.FormatExecutionSummary(nil) assert.Empty(t, result) }) } func TestInputFormatterBuildMessages(t *testing.T) { formatter := standard.NewInputFormatter() t.Run("builds user message", func(t *testing.T) { msgs := formatter.BuildMessages("Hello, world!") require.Len(t, msgs, 1) assert.Equal(t, agentcontext.RoleUser, msgs[0].Role) assert.Equal(t, "Hello, world!", msgs[0].Content) }) } func TestInputFormatterBuildMessagesWithSystem(t *testing.T) { formatter := standard.NewInputFormatter() t.Run("builds system and user messages", func(t *testing.T) { msgs := formatter.BuildMessagesWithSystem( "You are a helpful assistant.", "What is the weather?", ) require.Len(t, msgs, 2) assert.Equal(t, agentcontext.RoleSystem, msgs[0].Role) assert.Equal(t, "You are a helpful assistant.", msgs[0].Content) assert.Equal(t, agentcontext.RoleUser, msgs[1].Role) assert.Equal(t, "What is the weather?", msgs[1].Content) }) } ================================================ FILE: agent/robot/executor/standard/inspiration.go ================================================ package standard import ( "fmt" "time" robottypes "github.com/yaoapp/yao/agent/robot/types" ) // RunInspiration executes P0: Inspiration phase // Calls the Inspiration Agent to generate daily briefing // // Input: // - ClockContext from trigger input or current time // - Robot identity and resources // // Output: // - InspirationReport with markdown content func (e *Executor) RunInspiration(ctx *robottypes.Context, exec *robottypes.Execution, _ interface{}) error { // Get robot for identity and resources robot := exec.GetRobot() if robot == nil { return fmt.Errorf("robot not found in execution") } // Update UI field with i18n locale := getEffectiveLocale(robot, exec.Input) e.updateUIFields(ctx, exec, "", getLocalizedMessage(locale, "analyzing_context")) // Build clock context from trigger input or current time var clock *robottypes.ClockContext if exec.Input != nil && exec.Input.Clock != nil { clock = exec.Input.Clock } else { clock = robottypes.NewClockContext(time.Now(), "") } // Get agent ID for inspiration phase agentID := "__yao.inspiration" // default if robot.Config != nil && robot.Config.Resources != nil { agentID = robot.Config.Resources.GetPhaseAgent(robottypes.PhaseInspiration) } // Build prompt using InputFormatter formatter := NewInputFormatter() userContent := formatter.FormatClockContext(clock, robot) // Add available resources - critical for generating achievable insights resourcesContent := formatter.FormatAvailableResources(robot) if resourcesContent != "" { userContent += "\n\n" + resourcesContent } // Call agent caller := NewAgentCaller() caller.Connector = robot.LanguageModel result, err := caller.CallWithMessages(ctx, agentID, userContent) if err != nil { return fmt.Errorf("inspiration agent (%s) call failed: %w", agentID, err) } // Parse response - get markdown content content := result.GetText() if content == "" { return fmt.Errorf("inspiration agent (%s) returned empty response", agentID) } // Build InspirationReport exec.Inspiration = &robottypes.InspirationReport{ Clock: clock, Content: content, } return nil } ================================================ FILE: agent/robot/executor/standard/inspiration_test.go ================================================ package standard_test import ( "context" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ============================================================================ // P0 Inspiration Phase Tests // ============================================================================ func TestRunInspirationBasic(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("generates inspiration report with clock context", func(t *testing.T) { // Create robot with inspiration agent configured robot := createTestRobot(t, "robot.inspiration") // Create executor and execution exec := createTestExecution(robot, types.TriggerClock) // Run inspiration phase e := standard.New() err := e.RunInspiration(ctx, exec, nil) require.NoError(t, err) require.NotNil(t, exec.Inspiration) assert.NotEmpty(t, exec.Inspiration.Content) assert.NotNil(t, exec.Inspiration.Clock) }) t.Run("includes expected markdown sections", func(t *testing.T) { robot := createTestRobot(t, "robot.inspiration") exec := createTestExecution(robot, types.TriggerClock) e := standard.New() err := e.RunInspiration(ctx, exec, nil) require.NoError(t, err) content := exec.Inspiration.Content // Verify expected sections in markdown output // Note: LLM output is non-deterministic, so we check for likely sections hasSection := strings.Contains(content, "##") || strings.Contains(content, "Summary") || strings.Contains(content, "Highlight") || strings.Contains(content, "Recommend") assert.True(t, hasSection, "should contain markdown sections, got: %s", content) }) } func TestRunInspirationClockContext(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("uses clock from trigger input", func(t *testing.T) { robot := createTestRobot(t, "robot.inspiration") exec := createTestExecution(robot, types.TriggerClock) // Set specific clock context specificTime := time.Date(2024, 12, 31, 17, 0, 0, 0, time.UTC) exec.Input = &types.TriggerInput{ Clock: types.NewClockContext(specificTime, "UTC"), } e := standard.New() err := e.RunInspiration(ctx, exec, nil) require.NoError(t, err) require.NotNil(t, exec.Inspiration.Clock) // Clock should match input assert.Equal(t, specificTime.Year(), exec.Inspiration.Clock.Year) assert.Equal(t, int(specificTime.Month()), exec.Inspiration.Clock.Month) assert.Equal(t, specificTime.Day(), exec.Inspiration.Clock.DayOfMonth) }) t.Run("creates clock context when not provided", func(t *testing.T) { robot := createTestRobot(t, "robot.inspiration") exec := createTestExecution(robot, types.TriggerClock) exec.Input = nil // No input e := standard.New() err := e.RunInspiration(ctx, exec, nil) require.NoError(t, err) require.NotNil(t, exec.Inspiration.Clock) // Clock should be current time (approximately) now := time.Now() assert.Equal(t, now.Year(), exec.Inspiration.Clock.Year) }) } func TestRunInspirationRobotIdentity(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("robot identity influences output", func(t *testing.T) { // Create robot with specific identity robot := &types.Robot{ MemberID: "test-robot-1", TeamID: "test-team-1", DisplayName: "Sales Assistant", Config: &types.Config{ Identity: &types.Identity{ Role: "Sales Assistant", Duties: []string{"Track sales metrics", "Prepare weekly reports"}, Rules: []string{"Focus on actionable insights"}, }, Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseInspiration: "robot.inspiration", }, }, }, } exec := createTestExecution(robot, types.TriggerClock) e := standard.New() err := e.RunInspiration(ctx, exec, nil) require.NoError(t, err) assert.NotEmpty(t, exec.Inspiration.Content) // The content should be influenced by robot identity // (exact content varies due to LLM non-determinism) }) } func TestRunInspirationErrorHandling(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("returns error when robot is nil", func(t *testing.T) { exec := &types.Execution{ ID: "test-exec-1", TriggerType: types.TriggerClock, } // Don't set robot e := standard.New() err := e.RunInspiration(ctx, exec, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "robot not found") }) t.Run("returns error when agent not found", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot-1", TeamID: "test-team-1", Config: &types.Config{ Identity: &types.Identity{Role: "Test"}, Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseInspiration: "non.existent.agent", }, }, }, } exec := createTestExecution(robot, types.TriggerClock) e := standard.New() err := e.RunInspiration(ctx, exec, nil) // Real AgentCaller returns error for non-existent agent assert.Error(t, err) assert.Contains(t, err.Error(), "call failed") }) } func TestRunInspirationWithDefaultAgent(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("uses default agent when not configured", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot-1", TeamID: "test-team-1", Config: &types.Config{ Identity: &types.Identity{Role: "Test Robot"}, // No Resources configured - should use default __yao.inspiration }, } exec := createTestExecution(robot, types.TriggerClock) e := standard.New() err := e.RunInspiration(ctx, exec, nil) // This will fail if __yao.inspiration doesn't exist // In test environment, we expect it to fail with "agent not found" // In production, it would use the default agent if err != nil { assert.Contains(t, err.Error(), "call failed") } }) } // ============================================================================ // InputFormatter Tests for P0 // ============================================================================ func TestInputFormatterClockContext(t *testing.T) { t.Run("formats clock context correctly", func(t *testing.T) { formatter := standard.NewInputFormatter() // Create a specific clock context clock := types.NewClockContext( time.Date(2024, 12, 31, 17, 30, 0, 0, time.UTC), "UTC", ) robot := &types.Robot{ Config: &types.Config{ Identity: &types.Identity{ Role: "Sales Assistant", Duties: []string{"Track metrics", "Send reports"}, }, }, } content := formatter.FormatClockContext(clock, robot) // Verify time context assert.Contains(t, content, "Current Time Context") assert.Contains(t, content, "2024") assert.Contains(t, content, "12") assert.Contains(t, content, "31") assert.Contains(t, content, "Tuesday") // Dec 31, 2024 is Tuesday // Verify robot identity assert.Contains(t, content, "Robot Identity") assert.Contains(t, content, "Sales Assistant") assert.Contains(t, content, "Track metrics") }) t.Run("handles nil clock", func(t *testing.T) { formatter := standard.NewInputFormatter() content := formatter.FormatClockContext(nil, nil) assert.Empty(t, content) }) t.Run("handles nil robot", func(t *testing.T) { formatter := standard.NewInputFormatter() clock := types.NewClockContext(time.Now(), "") content := formatter.FormatClockContext(clock, nil) // Should have time context but no robot identity assert.Contains(t, content, "Current Time Context") assert.NotContains(t, content, "Robot Identity") }) t.Run("includes time markers", func(t *testing.T) { formatter := standard.NewInputFormatter() // Create a weekend + month start clock context // Jan 1, 2028 is Saturday (weekend + month start) clock := types.NewClockContext( time.Date(2028, 1, 1, 10, 0, 0, 0, time.UTC), "UTC", ) content := formatter.FormatClockContext(clock, nil) assert.Contains(t, content, "Weekend") assert.Contains(t, content, "Month Start") }) } // ============================================================================ // Helper Functions // ============================================================================ // createTestRobot creates a test robot with specified inspiration agent // Includes available expert agents so the Inspiration Agent knows what resources are available // // Note: The agent IDs listed in Resources.Agents must exist in yao-dev-app/assistants/experts/ // Current available experts: data-analyst, summarizer, text-writer, web-reader func createTestRobot(t *testing.T, agentID string) *types.Robot { t.Helper() return &types.Robot{ MemberID: "test-robot-1", TeamID: "test-team-1", DisplayName: "Test Robot", Config: &types.Config{ Identity: &types.Identity{ Role: "Test Assistant", Duties: []string{"Testing", "Data Analysis", "Report Generation"}, }, Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseInspiration: agentID, }, // Available expert agents that can be delegated to // These IDs correspond to assistants in yao-dev-app/assistants/experts/ Agents: []string{ "experts.data-analyst", // Data analysis and insights "experts.summarizer", // Content summarization "experts.text-writer", // Report and document generation "experts.web-reader", // Web content extraction }, }, // Knowledge base collections (if any) KB: &types.KB{ Collections: []string{"test-knowledge"}, }, }, } } // createTestExecution creates a test execution for a robot func createTestExecution(robot *types.Robot, trigger types.TriggerType) *types.Execution { exec := &types.Execution{ ID: "test-exec-1", MemberID: robot.MemberID, TeamID: robot.TeamID, TriggerType: trigger, StartTime: time.Now(), Status: types.ExecRunning, Phase: types.PhaseInspiration, Input: &types.TriggerInput{ Clock: types.NewClockContext(time.Now(), ""), }, } exec.SetRobot(robot) return exec } ================================================ FILE: agent/robot/executor/standard/learning.go ================================================ package standard import ( robottypes "github.com/yaoapp/yao/agent/robot/types" ) // RunLearning executes P5: Learning phase // Extracts learnings and saves to knowledge base // // Input: // - Execution summary (all phases) // // Output: // - LearningEntry list with extracted knowledge // // Learning Types: // - LearnExecution: Execution patterns // - LearnTask: Task-specific insights // - LearnError: Error patterns for improvement // // TODO: Implement real learning extraction func (e *Executor) RunLearning(ctx *robottypes.Context, exec *robottypes.Execution, _ interface{}) error { // Get robot for locale robot := exec.GetRobot() // Update UI field with i18n locale := getEffectiveLocale(robot, exec.Input) e.updateUIFields(ctx, exec, "", getLocalizedMessage(locale, "learning_from_exec")) e.simulateStreamDelay() exec.Learning = []robottypes.LearningEntry{ { Type: robottypes.LearnExecution, Content: "Completed daily tasks successfully", }, } return nil } ================================================ FILE: agent/robot/executor/standard/log.go ================================================ package standard import ( "encoding/json" "fmt" "strings" kunlog "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/agent/robot/logger" robottypes "github.com/yaoapp/yao/agent/robot/types" ) var log = logger.New("exec") type execLogger struct { robot *robottypes.Robot execID string } func newExecLogger(robot *robottypes.Robot, execID string) *execLogger { return &execLogger{robot: robot, execID: execID} } func (l *execLogger) robotID() string { if l.robot != nil { return l.robot.MemberID } return "" } func (l *execLogger) connector() string { if l.robot != nil { return l.robot.LanguageModel } return "" } // --------------------------------------------------------------------------- // P2: Task Overview // --------------------------------------------------------------------------- func (l *execLogger) logTaskOverview(tasks []robottypes.Task) { if config.IsDevelopment() { l.devTaskOverview(tasks) } kunlog.With(kunlog.F{ "robot_id": l.robotID(), "execution_id": l.execID, "phase": "tasks", "task_count": len(tasks), "language_model": l.connector(), }).Info("P2 task overview: %d tasks generated", len(tasks)) } func (l *execLogger) devTaskOverview(tasks []robottypes.Task) { w := logger.Gray h := logger.BoldCyan v := logger.White r := logger.Reset var sb strings.Builder sb.WriteString(fmt.Sprintf("\n%s%s%s\n", h, strings.Repeat("═", 60), r)) sb.WriteString(fmt.Sprintf("%s TASK OVERVIEW%s\n", h, r)) sb.WriteString(fmt.Sprintf("%s%s%s\n", h, strings.Repeat("─", 60), r)) sb.WriteString(fmt.Sprintf("%s Robot: %s%s%s\n", w, v, l.robotID(), r)) sb.WriteString(fmt.Sprintf("%s Exec: %s%s%s\n", w, v, l.execID, r)) if l.connector() != "" { sb.WriteString(fmt.Sprintf("%s Model: %s%s%s\n", w, v, l.connector(), r)) } sb.WriteString(fmt.Sprintf("%s%s%s\n", w, strings.Repeat("─", 60), r)) for i, t := range tasks { desc := t.Description if desc == "" && len(t.Messages) > 0 { if s, ok := t.Messages[0].GetContentAsString(); ok { desc = s } } desc = truncate(desc, 72) sb.WriteString(fmt.Sprintf("%s #%d %s%s%s [%s:%s]\n", w, i+1, v, t.ID, r, t.ExecutorType, t.ExecutorID)) sb.WriteString(fmt.Sprintf("%s %s%s\n", w, desc, r)) } sb.WriteString(fmt.Sprintf("%s%s%s\n", w, strings.Repeat("─", 60), r)) sb.WriteString(fmt.Sprintf("%s Total: %s%d tasks%s\n", w, v, len(tasks), r)) sb.WriteString(fmt.Sprintf("%s%s%s\n", h, strings.Repeat("═", 60), r)) logger.Raw(sb.String()) } // --------------------------------------------------------------------------- // P3: Task Input // --------------------------------------------------------------------------- func (l *execLogger) logTaskInput(task *robottypes.Task, prompt string) { if config.IsDevelopment() { l.devTaskInput(task, prompt) } kunlog.With(kunlog.F{ "robot_id": l.robotID(), "execution_id": l.execID, "task_id": task.ID, "executor_type": string(task.ExecutorType), "executor_id": task.ExecutorID, "prompt_len": len(prompt), "language_model": l.connector(), }).Info("Task input: %s [%s]", task.ID, task.ExecutorID) } func (l *execLogger) devTaskInput(task *robottypes.Task, prompt string) { w := logger.Gray v := logger.White r := logger.Reset var sb strings.Builder sb.WriteString(fmt.Sprintf("%s ▶ Task %s%s%s [%s:%s] Prompt: %d chars%s\n", w, v, task.ID, w, task.ExecutorType, task.ExecutorID, len(prompt), r)) logger.Raw(sb.String()) } // --------------------------------------------------------------------------- // P3: Task Output // --------------------------------------------------------------------------- func (l *execLogger) logTaskOutput(task *robottypes.Task, result *robottypes.TaskResult) { if config.IsDevelopment() { l.devTaskOutput(task, result) } fields := kunlog.F{ "robot_id": l.robotID(), "execution_id": l.execID, "task_id": result.TaskID, "success": result.Success, "duration_ms": result.Duration, "language_model": l.connector(), } if result.Output != nil { fields["output_type"] = fmt.Sprintf("%T", result.Output) fields["output_len"] = outputLen(result.Output) } if result.Error != "" { fields["error"] = result.Error } if result.Success { kunlog.With(fields).Info("Task completed: %s (%dms)", result.TaskID, result.Duration) } else { kunlog.With(fields).Warn("Task failed: %s (%dms) %s", result.TaskID, result.Duration, result.Error) } } func (l *execLogger) devTaskOutput(task *robottypes.Task, result *robottypes.TaskResult) { w := logger.Gray v := logger.White g := logger.BoldGreen rd := logger.BoldRed r := logger.Reset var sb strings.Builder if result.Success { sb.WriteString(fmt.Sprintf("%s ✓ %s%s%s completed %s(%dms)%s\n", g, v, result.TaskID, g, w, result.Duration, r)) out := outputSummary(result.Output) if len(out) > 120 { out = out[:120] + "..." } sb.WriteString(fmt.Sprintf("%s Output: %s%s%s\n", w, v, out, r)) } else { sb.WriteString(fmt.Sprintf("%s ✗ %s%s%s failed %s(%dms)%s\n", rd, v, result.TaskID, rd, w, result.Duration, r)) sb.WriteString(fmt.Sprintf("%s Error: %s%s%s\n", w, logger.Red, result.Error, r)) } logger.Raw(sb.String()) } // --------------------------------------------------------------------------- // Agent Call // --------------------------------------------------------------------------- func (l *execLogger) logAgentCall(agentID string, result *CallResult) { if result == nil { return } if config.IsDevelopment() { l.devAgentCall(agentID, result) } fields := kunlog.F{ "robot_id": l.robotID(), "execution_id": l.execID, "agent_id": agentID, "content_len": len(result.Content), "language_model": l.connector(), } if result.Next != nil { fields["next_type"] = fmt.Sprintf("%T", result.Next) fields["next_len"] = outputLen(result.Next) } kunlog.With(fields).Info("Agent call: %s (content=%d, next=%T)", agentID, len(result.Content), result.Next) } func (l *execLogger) devAgentCall(agentID string, result *CallResult) { w := logger.Gray v := logger.White c := logger.Cyan r := logger.Reset nextInfo := "—" if result.Next != nil { nextInfo = fmt.Sprintf("%T (len=%d)", result.Next, outputLen(result.Next)) } var sb strings.Builder sb.WriteString(fmt.Sprintf("%s → Agent(%s%s%s) Content: %s%d%s chars Next: %s%s%s\n", c, v, agentID, c, v, len(result.Content), w, v, nextInfo, r)) logger.Raw(sb.String()) } // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- func (l *execLogger) prefix() string { if l.connector() != "" { return fmt.Sprintf("[robot:%s|exec:%s|model:%s]", l.robotID(), l.execID, l.connector()) } return fmt.Sprintf("[robot:%s|exec:%s]", l.robotID(), l.execID) } func truncate(s string, maxLen int) string { if len(s) <= maxLen { return s } return s[:maxLen] + "..." } func indentText(s string, prefix string) string { lines := strings.Split(s, "\n") for i, line := range lines { lines[i] = prefix + line } return strings.Join(lines, "\n") } func outputSummary(v interface{}) string { if v == nil { return "" } switch val := v.(type) { case string: if len(val) > 500 { return fmt.Sprintf("string(len=%d) %s...", len(val), val[:500]) } return fmt.Sprintf("string(len=%d) %s", len(val), val) case map[string]interface{}: keys := make([]string, 0, len(val)) for k := range val { keys = append(keys, k) } return fmt.Sprintf("map{%s}", strings.Join(keys, ", ")) default: raw, err := json.Marshal(v) if err != nil { return fmt.Sprintf("%T(marshal-error)", v) } s := string(raw) if len(s) > 500 { return fmt.Sprintf("%T(len=%d) %s...", v, len(s), s[:500]) } return fmt.Sprintf("%T %s", v, s) } } func outputLen(v interface{}) int { if v == nil { return 0 } switch val := v.(type) { case string: return len(val) default: raw, err := json.Marshal(v) if err != nil { return 0 } return len(raw) } } ================================================ FILE: agent/robot/executor/standard/resume_test.go ================================================ package standard_test import ( "context" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/executor/standard" executortypes "github.com/yaoapp/yao/agent/robot/executor/types" "github.com/yaoapp/yao/agent/robot/store" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ============================================================================ // Resume method tests (R1-R10) // ============================================================================ func TestResume(t *testing.T) { // R1: Resume with empty execID returns error t.Run("R1: Resume with empty execID returns error", func(t *testing.T) { e := standard.New() ctx := types.NewContext(context.Background(), testAuth()) err := e.Resume(ctx, "", "some reply") require.Error(t, err) assert.Contains(t, err.Error(), "empty") }) // R2: Resume with non-existent execID returns error (requires DB) t.Run("R2: Resume with non-existent execID returns error", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) e := standard.New() ctx := types.NewContext(context.Background(), testAuth()) err := e.Resume(ctx, "non-existent-exec-id-12345", "reply") require.Error(t, err) assert.Contains(t, err.Error(), "execution not found") }) // R3: Resume with execution not in waiting status returns error (requires DB) t.Run("R3: Resume with execution not in waiting status returns error", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) robot := createResumeTestRobot(t) exec := createResumeTestExecution(robot) exec.Status = types.ExecRunning // Not waiting record := store.FromExecution(exec) require.NoError(t, store.NewExecutionStore().Save(ctx.Context, record)) // Save robot to robot store for Resume to load robotRecord := store.FromRobot(robot) require.NoError(t, store.NewRobotStore().Save(ctx.Context, robotRecord)) e := standard.New() err := e.Resume(ctx, exec.ID, "reply") require.Error(t, err) assert.Contains(t, err.Error(), "not in waiting status") }) // R4: Verify Resume loads execution from store (requires DB) t.Run("R4: Resume loads execution from store", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) robot := createResumeTestRobot(t) exec := createSuspendedResumeTestExecution(robot) execStore := store.NewExecutionStore() robotStore := store.NewRobotStore() record := store.FromExecution(exec) require.NoError(t, execStore.Save(ctx.Context, record)) robotRecord := store.FromRobot(robot) require.NoError(t, robotStore.Save(ctx.Context, robotRecord)) e := standard.New() err := e.Resume(ctx, exec.ID, "User provided answer") require.NoError(t, err) // Verify execution was loaded and completed loaded, err := execStore.Get(ctx.Context, exec.ID) require.NoError(t, err) require.NotNil(t, loaded) assert.Equal(t, types.ExecCompleted, loaded.Status) }) // R5: Resume restores robot from execution record (requires DB) t.Run("R5: Resume restores robot from execution record", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) robot := createResumeTestRobot(t) exec := createSuspendedResumeTestExecution(robot) execStore := store.NewExecutionStore() robotStore := store.NewRobotStore() record := store.FromExecution(exec) require.NoError(t, execStore.Save(ctx.Context, record)) robotRecord := store.FromRobot(robot) require.NoError(t, robotStore.Save(ctx.Context, robotRecord)) e := standard.New() err := e.Resume(ctx, exec.ID, "Answer for the question") require.NoError(t, err) // If we get here without "robot not found", Resume successfully restored robot }) // R6: Resume with __skip__ reply marks task as skipped (requires DB) t.Run("R6: Resume with __skip__ reply marks task as skipped", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) robot := createResumeTestRobot(t) exec := createSuspendedResumeTestExecution(robot) // Ensure we have a task at index 0 that is waiting exec.Tasks[0].Status = types.TaskWaitingInput exec.ResumeContext = &types.ResumeContext{ TaskIndex: 0, PreviousResults: []types.TaskResult{}, } execStore := store.NewExecutionStore() robotStore := store.NewRobotStore() record := store.FromExecution(exec) require.NoError(t, execStore.Save(ctx.Context, record)) robotRecord := store.FromRobot(robot) require.NoError(t, robotStore.Save(ctx.Context, robotRecord)) e := standard.New() err := e.Resume(ctx, exec.ID, "__skip__") require.NoError(t, err) loaded, err := execStore.Get(ctx.Context, exec.ID) require.NoError(t, err) require.NotNil(t, loaded) require.Len(t, loaded.Tasks, 1) assert.Equal(t, types.TaskSkipped, loaded.Tasks[0].Status) }) // R7: Resume sends ErrExecutionSuspended when execution suspends again (requires DB) t.Run("R7: Resume sends ErrExecutionSuspended when execution suspends again", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) // Use robot-need-input assistant that suspends ctx := types.NewContext(context.Background(), testAuth()) robot := createResumeNeedInputRobot(t) exec := createSuspendedResumeNeedInputExecution(robot) execStore := store.NewExecutionStore() robotStore := store.NewRobotStore() record := store.FromExecution(exec) require.NoError(t, execStore.Save(ctx.Context, record)) robotRecord := store.FromRobot(robot) require.NoError(t, robotStore.Save(ctx.Context, robotRecord)) e := standard.New() err := e.Resume(ctx, exec.ID, "some reply") // May return ErrExecutionSuspended if assistant suspends again if err != nil { assert.ErrorIs(t, err, types.ErrExecutionSuspended) } }) // R8: Resume increments exec counter t.Run("R8: Resume increments exec counter", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) robot := createResumeTestRobot(t) exec := createSuspendedResumeTestExecution(robot) execStore := store.NewExecutionStore() robotStore := store.NewRobotStore() record := store.FromExecution(exec) require.NoError(t, execStore.Save(ctx.Context, record)) robotRecord := store.FromRobot(robot) require.NoError(t, robotStore.Save(ctx.Context, robotRecord)) e := standard.New() e.Reset() before := e.CurrentCount() err := e.Resume(ctx, exec.ID, "answer") after := e.CurrentCount() require.NoError(t, err) // During Resume, currentCount was incremented; after completion it's decremented assert.Equal(t, before, after, "currentCount should be back to original after Resume completes") }) // R9: Resume decrements exec counter on completion t.Run("R9: Resume decrements exec counter on completion", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) robot := createResumeTestRobot(t) exec := createSuspendedResumeTestExecution(robot) execStore := store.NewExecutionStore() robotStore := store.NewRobotStore() record := store.FromExecution(exec) require.NoError(t, execStore.Save(ctx.Context, record)) robotRecord := store.FromRobot(robot) require.NoError(t, robotStore.Save(ctx.Context, robotRecord)) e := standard.New() e.Reset() err := e.Resume(ctx, exec.ID, "reply") require.NoError(t, err) // After Resume completes, currentCount should be 0 (no leak) assert.Equal(t, 0, e.CurrentCount()) }) // R10: Resume with nil context returns error t.Run("R10: Resume with nil context returns error", func(t *testing.T) { e := standard.NewWithConfig(executortypes.Config{SkipPersistence: true}) err := e.Resume(nil, "some-exec-id", "reply") require.Error(t, err) assert.Contains(t, err.Error(), "context") }) } // ============================================================================ // Helpers for Resume tests // ============================================================================ func createResumeTestRobot(t *testing.T) *types.Robot { t.Helper() return &types.Robot{ MemberID: "test-robot-resume", TeamID: "test-team-1", DisplayName: "Resume Test Robot", SystemPrompt: "You are a helpful assistant.", Config: &types.Config{ Identity: &types.Identity{ Role: "Test", Duties: []string{"Execute tasks"}, }, Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseDelivery: "robot.delivery", types.PhaseLearning: "robot.learning", }, Agents: []string{"experts.text-writer"}, }, Quota: &types.Quota{Max: 5}, }, } } func createResumeTestExecution(robot *types.Robot) *types.Execution { exec := &types.Execution{ ID: "test-exec-resume-1", MemberID: robot.MemberID, TeamID: robot.TeamID, TriggerType: types.TriggerClock, StartTime: time.Now(), Status: types.ExecRunning, Phase: types.PhaseRun, Goals: &types.Goals{Content: "## Goals\n\n1. Test resume"}, Tasks: []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write 'hello'"}, }, Order: 0, Status: types.TaskPending, }, }, ChatID: "robot_test-robot-resume_test-exec-resume-1", } exec.SetRobot(robot) return exec } func createSuspendedResumeTestExecution(robot *types.Robot) *types.Execution { exec := createResumeTestExecution(robot) exec.Status = types.ExecWaiting exec.WaitingTaskID = "task-001" exec.WaitingQuestion = "What should we do?" now := time.Now() exec.WaitingSince = &now exec.ResumeContext = &types.ResumeContext{ TaskIndex: 0, PreviousResults: []types.TaskResult{}, } exec.Tasks[0].Status = types.TaskWaitingInput return exec } func createResumeNeedInputRobot(t *testing.T) *types.Robot { t.Helper() return &types.Robot{ MemberID: "test-robot-resume-need-input", TeamID: "test-team-1", DisplayName: "Resume Need Input Robot", SystemPrompt: "You are a helpful assistant.", Config: &types.Config{ Identity: &types.Identity{ Role: "Test", Duties: []string{"Execute tasks"}, }, Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseDelivery: "robot.delivery", types.PhaseLearning: "robot.learning", }, Agents: []string{"tests.robot-need-input"}, }, Quota: &types.Quota{Max: 5}, }, } } func createSuspendedResumeNeedInputExecution(robot *types.Robot) *types.Execution { exec := &types.Execution{ ID: "test-exec-resume-need-input-1", MemberID: robot.MemberID, TeamID: robot.TeamID, TriggerType: types.TriggerClock, StartTime: time.Now(), Status: types.ExecWaiting, Phase: types.PhaseRun, Goals: &types.Goals{Content: "## Goals\n\n1. Test need input"}, Tasks: []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "tests.robot-need-input", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Need input test"}, }, Order: 0, Status: types.TaskWaitingInput, }, }, ChatID: "robot_test-robot-resume-need-input_test-exec-resume-need-input-1", WaitingTaskID: "task-001", WaitingQuestion: "What period?", ResumeContext: &types.ResumeContext{ TaskIndex: 0, PreviousResults: []types.TaskResult{}, }, } now := time.Now() exec.WaitingSince = &now exec.SetRobot(robot) return exec } ================================================ FILE: agent/robot/executor/standard/run.go ================================================ package standard import ( "fmt" "time" robotevents "github.com/yaoapp/yao/agent/robot/events" robottypes "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/event" ) // RunConfig configures P3 execution behavior type RunConfig struct { // ContinueOnFailure continues to next task even if current task fails. // V2 default: true — the Robot is an orchestrator, not a judge. // Failed tasks are recorded and evaluated by the Delivery Agent. ContinueOnFailure bool } // DefaultRunConfig returns the default P3 configuration func DefaultRunConfig() *RunConfig { return &RunConfig{ ContinueOnFailure: true, } } // RunExecution executes P3: Run phase // Executes each task using the appropriate executor (Assistant, MCP, Process). // // V2 simplified flow: single call per task, no validation loop. // Success is determined by whether the call itself succeeds (no error). // The Delivery Agent (P4) evaluates overall quality using expected_output. // // Supports resume: if exec.ResumeContext is set, execution starts from the // suspended task index with previously completed results restored. // // Returns ErrExecutionSuspended if a task signals it needs human input. func (e *Executor) RunExecution(ctx *robottypes.Context, exec *robottypes.Execution, data interface{}) error { robot := exec.GetRobot() if robot == nil { return fmt.Errorf("robot not found in execution") } if len(exec.Tasks) == 0 { return fmt.Errorf("no tasks to execute") } // Get run configuration from data or use default var config *RunConfig if cfg, ok := data.(*RunConfig); ok && cfg != nil { config = cfg } else { config = DefaultRunConfig() } // Determine locale for UI messages locale := getEffectiveLocale(robot, exec.Input) // Determine start index and restore results from resume context startIndex := 0 if exec.ResumeContext != nil { startIndex = exec.ResumeContext.TaskIndex exec.Results = exec.ResumeContext.PreviousResults } else { exec.Results = make([]robottypes.TaskResult, 0, len(exec.Tasks)) } // Create task runner with execution-level chatID (§8.4) runner := NewRunner(ctx, robot, config, exec.ChatID, exec.ID) // Execute tasks sequentially from startIndex for i := startIndex; i < len(exec.Tasks); i++ { task := &exec.Tasks[i] // Update current state for tracking exec.Current = &robottypes.CurrentState{ Task: task, TaskIndex: i, Progress: fmt.Sprintf("%d/%d tasks", i+1, len(exec.Tasks)), } // Update UI field with current task description (i18n) taskName := formatTaskProgressName(task, i, len(exec.Tasks), locale) e.updateUIFields(ctx, exec, "", taskName) // Mark task as running task.Status = robottypes.TaskRunning now := time.Now() task.StartTime = &now // Persist running state to database e.updateTasksState(ctx, exec) // Build task context with previous results taskCtx := runner.BuildTaskContext(exec, i) // Execute task (single call, no validation loop) result := runner.ExecuteTask(task, taskCtx) // Task needs human input — suspend execution without recording a half-result if result.NeedInput { return e.Suspend(ctx, exec, i, result.InputQuestion) } // Update task status based on result endTime := time.Now() task.EndTime = &endTime if result.Success { task.Status = robottypes.TaskCompleted event.Push(ctx.Context, robotevents.TaskCompleted, robotevents.TaskPayload{ ExecutionID: exec.ID, MemberID: exec.MemberID, TeamID: exec.TeamID, TaskID: task.ID, ChatID: exec.ChatID, }) } else { task.Status = robottypes.TaskFailed event.Push(ctx.Context, robotevents.TaskFailed, robotevents.TaskPayload{ ExecutionID: exec.ID, MemberID: exec.MemberID, TeamID: exec.TeamID, TaskID: task.ID, Error: result.Error, ChatID: exec.ChatID, }) } // Store result exec.Results = append(exec.Results, *result) // Persist completed/failed state to database e.updateTasksState(ctx, exec) // Check if we should continue on failure if !result.Success && !config.ContinueOnFailure { // Mark remaining tasks as skipped for j := i + 1; j < len(exec.Tasks); j++ { exec.Tasks[j].Status = robottypes.TaskSkipped } // Persist skipped state to database e.updateTasksState(ctx, exec) return fmt.Errorf("task %s failed: %s", task.ID, result.Error) } } // Clear current state and resume context after successful completion exec.Current = nil exec.ResumeContext = nil return nil } // formatTaskProgressName formats a progress name for the current task (used for UI with i18n) func formatTaskProgressName(task *robottypes.Task, index int, total int, locale string) string { taskPrefix := getLocalizedMessage(locale, "task_prefix") prefix := fmt.Sprintf("%s %d/%d: ", taskPrefix, index+1, total) // Priority 1: Use Description field if available if task.Description != "" { desc := task.Description if len(desc) > 80 { desc = desc[:80] + "..." } return prefix + desc } // Priority 2: Try to get description from first message if len(task.Messages) > 0 { if content, ok := task.Messages[0].GetContentAsString(); ok && content != "" { // Truncate if too long if len(content) > 80 { content = content[:80] + "..." } return prefix + content } } // Fallback to executor info return prefix + string(task.ExecutorType) + ":" + task.ExecutorID } ================================================ FILE: agent/robot/executor/standard/run_test.go ================================================ package standard_test import ( "context" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ============================================================================ // P3 Run Phase Tests - RunExecution // ============================================================================ func TestRunExecutionBasic(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("executes single task successfully", func(t *testing.T) { robot := createRunTestRobot(t) exec := createRunTestExecution(robot) // Pre-built task (simulating P2 output) exec.Tasks = []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write a short greeting message for a company newsletter. Keep it under 50 words."}, }, ExpectedOutput: "A friendly greeting message suitable for a newsletter", Order: 0, Status: types.TaskPending, }, } e := standard.New() err := e.RunExecution(ctx, exec, nil) require.NoError(t, err) require.Len(t, exec.Results, 1) result := exec.Results[0] assert.Equal(t, "task-001", result.TaskID) assert.True(t, result.Success, "task should succeed") assert.NotNil(t, result.Output, "should have output") assert.Greater(t, result.Duration, int64(0), "should have duration") // Task status should be updated assert.Equal(t, types.TaskCompleted, exec.Tasks[0].Status) assert.NotNil(t, exec.Tasks[0].StartTime) assert.NotNil(t, exec.Tasks[0].EndTime) }) t.Run("executes multiple tasks in order", func(t *testing.T) { robot := createRunTestRobot(t) exec := createRunTestExecution(robot) // Multiple tasks that depend on each other exec.Tasks = []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.data-analyst", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Analyze this data: Sales Q1: $100K, Q2: $150K, Q3: $120K, Q4: $180K. Calculate the total and average."}, }, ExpectedOutput: "JSON with total and average sales figures", ValidationRules: []string{ "output must be valid JSON", }, Order: 0, Status: types.TaskPending, }, { ID: "task-002", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.summarizer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Summarize the key findings from the previous analysis in 2-3 sentences."}, }, ExpectedOutput: "A brief summary of the sales analysis", Order: 1, Status: types.TaskPending, }, } e := standard.New() err := e.RunExecution(ctx, exec, nil) require.NoError(t, err) require.Len(t, exec.Results, 2) // Both tasks should complete assert.True(t, exec.Results[0].Success, "first task should succeed") assert.True(t, exec.Results[1].Success, "second task should succeed") // Second task should have access to first task's result (via context) assert.Equal(t, types.TaskCompleted, exec.Tasks[0].Status) assert.Equal(t, types.TaskCompleted, exec.Tasks[1].Status) t.Logf("Task 1 output: %v", exec.Results[0].Output) t.Logf("Task 2 output: %v", exec.Results[1].Output) }) t.Run("passes previous results as context to subsequent tasks", func(t *testing.T) { robot := createRunTestRobot(t) exec := createRunTestExecution(robot) // First task generates data, second task uses it exec.Tasks = []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Generate a list of 3 product names for a tech company. Output as JSON array."}, }, ExpectedOutput: "JSON array with 3 product names", Order: 0, Status: types.TaskPending, }, { ID: "task-002", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Using the product names from the previous task, write a one-line tagline for each product."}, }, ExpectedOutput: "Taglines for each product", Order: 1, Status: types.TaskPending, }, } e := standard.New() err := e.RunExecution(ctx, exec, nil) require.NoError(t, err) require.Len(t, exec.Results, 2) // Both should succeed assert.True(t, exec.Results[0].Success) assert.True(t, exec.Results[1].Success) // Second task output should reference products from first task t.Logf("Task 1 (products): %v", exec.Results[0].Output) t.Logf("Task 2 (taglines): %v", exec.Results[1].Output) }) } func TestRunExecutionTaskStatus(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("updates task status during execution", func(t *testing.T) { robot := createRunTestRobot(t) exec := createRunTestExecution(robot) exec.Tasks = []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Say 'Hello World'"}, }, Order: 0, Status: types.TaskPending, }, } // Verify initial status assert.Equal(t, types.TaskPending, exec.Tasks[0].Status) e := standard.New() err := e.RunExecution(ctx, exec, nil) require.NoError(t, err) // Verify final status assert.Equal(t, types.TaskCompleted, exec.Tasks[0].Status) assert.NotNil(t, exec.Tasks[0].StartTime) assert.NotNil(t, exec.Tasks[0].EndTime) }) t.Run("marks remaining tasks as skipped on failure with ContinueOnFailure=false", func(t *testing.T) { robot := createRunTestRobot(t) exec := createRunTestExecution(robot) exec.Tasks = []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "non.existent.assistant.xyz123", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "This will fail"}, }, Order: 0, Status: types.TaskPending, }, { ID: "task-002", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write another greeting"}, }, Order: 1, Status: types.TaskPending, }, { ID: "task-003", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write yet another greeting"}, }, Order: 2, Status: types.TaskPending, }, } config := &standard.RunConfig{ContinueOnFailure: false} e := standard.New() err := e.RunExecution(ctx, exec, config) assert.Error(t, err) assert.Contains(t, err.Error(), "task-001") assert.Equal(t, types.TaskFailed, exec.Tasks[0].Status) assert.Equal(t, types.TaskSkipped, exec.Tasks[1].Status) assert.Equal(t, types.TaskSkipped, exec.Tasks[2].Status) }) t.Run("continues on failure with default V2 config", func(t *testing.T) { robot := createRunTestRobot(t) exec := createRunTestExecution(robot) exec.Tasks = []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "non.existent.assistant.xyz123", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "This will fail"}, }, Order: 0, Status: types.TaskPending, }, { ID: "task-002", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Say hello"}, }, Order: 1, Status: types.TaskPending, }, } e := standard.New() err := e.RunExecution(ctx, exec, nil) require.NoError(t, err, "V2 default ContinueOnFailure=true should not return error") assert.Equal(t, types.TaskFailed, exec.Tasks[0].Status) assert.Equal(t, types.TaskCompleted, exec.Tasks[1].Status) assert.Len(t, exec.Results, 2) assert.False(t, exec.Results[0].Success) assert.True(t, exec.Results[1].Success) }) } func TestRunExecutionErrorHandling(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("returns error when robot is nil", func(t *testing.T) { exec := &types.Execution{ ID: "test-exec-1", TriggerType: types.TriggerClock, Tasks: []types.Task{ {ID: "task-001", ExecutorID: "test"}, }, } // Don't set robot e := standard.New() err := e.RunExecution(ctx, exec, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "robot not found") }) t.Run("returns error when no tasks", func(t *testing.T) { robot := createRunTestRobot(t) exec := createRunTestExecution(robot) exec.Tasks = []types.Task{} // Empty e := standard.New() err := e.RunExecution(ctx, exec, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "no tasks") }) t.Run("records failure for non-existent assistant", func(t *testing.T) { robot := createRunTestRobot(t) exec := createRunTestExecution(robot) exec.Tasks = []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "non.existent.agent", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test"}, }, Order: 0, Status: types.TaskPending, }, } e := standard.New() err := e.RunExecution(ctx, exec, nil) // V2 default ContinueOnFailure=true, so no error is returned assert.NoError(t, err) assert.Equal(t, types.TaskFailed, exec.Tasks[0].Status) assert.Len(t, exec.Results, 1) assert.False(t, exec.Results[0].Success) assert.NotEmpty(t, exec.Results[0].Error) }) } func TestRunExecutionContinueOnFailure(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("stops on first failure when ContinueOnFailure is false", func(t *testing.T) { robot := createRunTestRobot(t) exec := createRunTestExecution(robot) exec.Tasks = []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "non.existent.assistant.xyz123", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "This will fail"}, }, Order: 0, Status: types.TaskPending, }, { ID: "task-002", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write a greeting"}, }, Order: 1, Status: types.TaskPending, }, } config := &standard.RunConfig{ContinueOnFailure: false} e := standard.New() err := e.RunExecution(ctx, exec, config) assert.Error(t, err) assert.Contains(t, err.Error(), "task-001") assert.Len(t, exec.Results, 1) assert.Equal(t, types.TaskFailed, exec.Tasks[0].Status) assert.Equal(t, types.TaskSkipped, exec.Tasks[1].Status) }) t.Run("continues execution when ContinueOnFailure is true", func(t *testing.T) { robot := createRunTestRobot(t) exec := createRunTestExecution(robot) // First task will fail, but second should still execute exec.Tasks = []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "non.existent.assistant.xyz123", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "This will fail"}, }, Order: 0, Status: types.TaskPending, }, { ID: "task-002", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write a short greeting message"}, }, ExpectedOutput: "A greeting message", Order: 1, Status: types.TaskPending, }, { ID: "task-003", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write a farewell message"}, }, ExpectedOutput: "A farewell message", Order: 2, Status: types.TaskPending, }, } // V2 default ContinueOnFailure=true e := standard.New() err := e.RunExecution(ctx, exec, nil) // Should NOT return error when ContinueOnFailure is true assert.NoError(t, err) // All tasks should have results assert.Len(t, exec.Results, 3) // First task failed assert.Equal(t, types.TaskFailed, exec.Tasks[0].Status) assert.False(t, exec.Results[0].Success) // Second and third tasks should have executed and completed assert.Equal(t, types.TaskCompleted, exec.Tasks[1].Status) assert.True(t, exec.Results[1].Success) assert.Equal(t, types.TaskCompleted, exec.Tasks[2].Status) assert.True(t, exec.Results[2].Success) t.Logf("Task 1 (failed): %v", exec.Results[0].Error) t.Logf("Task 2 (success): %v", exec.Results[1].Output) t.Logf("Task 3 (success): %v", exec.Results[2].Output) }) t.Run("multiple failures with ContinueOnFailure", func(t *testing.T) { robot := createRunTestRobot(t) exec := createRunTestExecution(robot) // Mix of failing and succeeding tasks exec.Tasks = []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "non.existent.assistant.1", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Fail 1"}, }, Order: 0, Status: types.TaskPending, }, { ID: "task-002", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Say hello"}, }, Order: 1, Status: types.TaskPending, }, { ID: "task-003", ExecutorType: types.ExecutorAssistant, ExecutorID: "non.existent.assistant.2", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Fail 2"}, }, Order: 2, Status: types.TaskPending, }, { ID: "task-004", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Say goodbye"}, }, Order: 3, Status: types.TaskPending, }, } // V2 default ContinueOnFailure=true e := standard.New() err := e.RunExecution(ctx, exec, nil) assert.NoError(t, err) assert.Len(t, exec.Results, 4) // Check status pattern: fail, success, fail, success assert.Equal(t, types.TaskFailed, exec.Tasks[0].Status) assert.Equal(t, types.TaskCompleted, exec.Tasks[1].Status) assert.Equal(t, types.TaskFailed, exec.Tasks[2].Status) assert.Equal(t, types.TaskCompleted, exec.Tasks[3].Status) // Count successes and failures successCount := 0 failCount := 0 for _, result := range exec.Results { if result.Success { successCount++ } else { failCount++ } } assert.Equal(t, 2, successCount) assert.Equal(t, 2, failCount) }) } func TestRunExecutionNoValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("V2 runner does not set Validation on results", func(t *testing.T) { robot := createRunTestRobot(t) exec := createRunTestExecution(robot) exec.Tasks = []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.data-analyst", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Return a JSON object with fields: name (string), count (number). Example: {\"name\": \"test\", \"count\": 5}"}, }, ExpectedOutput: "JSON object with name and count fields", Order: 0, Status: types.TaskPending, }, } e := standard.New() err := e.RunExecution(ctx, exec, nil) require.NoError(t, err) require.Len(t, exec.Results, 1) result := exec.Results[0] assert.True(t, result.Success, "task should succeed if assistant call returns") assert.NotNil(t, result.Output, "output should be present") assert.Nil(t, result.Validation, "V2 runner does not run validation") t.Logf("Output: %v", result.Output) }) } // ============================================================================ // Helper Functions // ============================================================================ // createRunTestRobot creates a test robot for P3 run tests func createRunTestRobot(t *testing.T) *types.Robot { t.Helper() return &types.Robot{ MemberID: "test-robot-run", TeamID: "test-team-1", DisplayName: "Test Robot for Run", SystemPrompt: "You are a helpful assistant.", Config: &types.Config{ Identity: &types.Identity{ Role: "Test Assistant", Duties: []string{"Execute tasks", "Generate content"}, }, Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseRun: "robot.validation", "validation": "robot.validation", // For semantic validation agent }, Agents: []string{ "experts.data-analyst", "experts.summarizer", "experts.text-writer", }, }, }, } } // createRunTestExecution creates a test execution for P3 run tests func createRunTestExecution(robot *types.Robot) *types.Execution { exec := &types.Execution{ ID: "test-exec-run-1", MemberID: robot.MemberID, TeamID: robot.TeamID, TriggerType: types.TriggerClock, StartTime: time.Now(), Status: types.ExecRunning, Phase: types.PhaseRun, Goals: &types.Goals{ Content: "## Goals\n\n1. Execute test tasks", }, } exec.SetRobot(robot) return exec } ================================================ FILE: agent/robot/executor/standard/runner.go ================================================ package standard import ( "encoding/json" "fmt" "strings" "time" "github.com/yaoapp/gou/mcp" "github.com/yaoapp/gou/process" agentcontext "github.com/yaoapp/yao/agent/context" robottypes "github.com/yaoapp/yao/agent/robot/types" ) // Runner handles execution of individual tasks type Runner struct { ctx *robottypes.Context robot *robottypes.Robot config *RunConfig chatID string // execution-level chatID for conversation persistence (§8.4) log *execLogger } // NewRunner creates a new task runner func NewRunner(ctx *robottypes.Context, robot *robottypes.Robot, config *RunConfig, chatID string, execID string) *Runner { return &Runner{ ctx: ctx, robot: robot, config: config, chatID: chatID, log: newExecLogger(robot, execID), } } // RunnerContext provides context for task execution type RunnerContext struct { // PreviousResults contains results from previously executed tasks PreviousResults []robottypes.TaskResult // Goals contains the goals from P1 (for context) Goals *robottypes.Goals // SystemPrompt is the robot's system prompt SystemPrompt string } // BuildTaskContext builds context for a task including previous results func (r *Runner) BuildTaskContext(exec *robottypes.Execution, taskIndex int) *RunnerContext { ctx := &RunnerContext{ Goals: exec.Goals, SystemPrompt: r.robot.SystemPrompt, } // Include results from previous tasks (with bounds check) if taskIndex > 0 && len(exec.Results) > 0 { endIndex := taskIndex if endIndex > len(exec.Results) { endIndex = len(exec.Results) } ctx.PreviousResults = exec.Results[:endIndex] } return ctx } // ExecuteTask executes a single task (V2 simplified: single call, no validation loop). // Success is determined purely by whether the call itself succeeds without error. // Quality evaluation is deferred to the Delivery Agent (P4) using ExpectedOutput. func (r *Runner) ExecuteTask(task *robottypes.Task, taskCtx *RunnerContext) *robottypes.TaskResult { startTime := time.Now() result := &robottypes.TaskResult{ TaskID: task.ID, } // For non-assistant tasks (MCP, Process), single-call execution if task.ExecutorType != robottypes.ExecutorAssistant { output, err := r.executeNonAssistantTask(task, taskCtx) if err != nil { result.Success = false result.Error = fmt.Sprintf("execution failed: %s", err.Error()) result.Duration = time.Since(startTime).Milliseconds() r.log.logTaskOutput(task, result) return result } result.Output = output result.Success = true result.Duration = time.Since(startTime).Milliseconds() r.log.logTaskOutput(task, result) return result } // For assistant tasks, single call via conversation output, callResult, err := r.executeAssistantTask(task, taskCtx) if err != nil { result.Success = false result.Error = err.Error() result.Duration = time.Since(startTime).Milliseconds() r.log.logTaskOutput(task, result) return result } result.Output = output result.Success = true result.Duration = time.Since(startTime).Milliseconds() // Check if assistant signals it needs human input (V2 suspend protocol) if needInput, question := detectNeedMoreInfo(callResult); needInput { result.NeedInput = true result.InputQuestion = question } r.log.logTaskOutput(task, result) return result } // executeNonAssistantTask executes MCP or Process tasks (single-call, no multi-turn) func (r *Runner) executeNonAssistantTask(task *robottypes.Task, taskCtx *RunnerContext) (interface{}, error) { switch task.ExecutorType { case robottypes.ExecutorMCP: return r.ExecuteMCPTask(task, taskCtx) case robottypes.ExecutorProcess: return r.ExecuteProcessTask(task, taskCtx) default: return nil, fmt.Errorf("unsupported executor type: %s (expected mcp or process)", task.ExecutorType) } } // executeAssistantTask executes an assistant task with a single conversation turn. // Returns the extracted output, the raw CallResult (for need_input detection), and any error. func (r *Runner) executeAssistantTask(task *robottypes.Task, taskCtx *RunnerContext) (interface{}, *CallResult, error) { caller := NewAgentCaller() caller.log = r.log caller.Connector = r.robot.LanguageModel caller.ChatID = r.chatID messages := r.BuildAssistantMessages(task, taskCtx) input := r.FormatMessagesAsText(messages) if strings.TrimSpace(input) == "" { return nil, nil, fmt.Errorf("no valid input messages for task %s", task.ID) } if taskCtx.SystemPrompt != "" { input = "## Context\n\n" + taskCtx.SystemPrompt + "\n\n## Task\n\n" + input } r.log.logTaskInput(task, input) result, err := caller.CallWithMessages(r.ctx, task.ExecutorID, input) if err != nil { return nil, nil, fmt.Errorf("assistant call failed: %w", err) } output := r.extractOutput(result) return output, result, nil } // detectNeedMoreInfo checks if the assistant's response signals it needs human input. // The protocol: Next hook returns {data: {status: "need_input", question: "..."}}. // Also handles the unwrapped form {status: "need_input", question: "..."} for robustness. func detectNeedMoreInfo(result *CallResult) (bool, string) { if result == nil || result.Next == nil { return false, "" } m, ok := result.Next.(map[string]interface{}) if !ok { return false, "" } // Unwrap "data" envelope if present (Next hook standard: {data: {status: ...}}) if data, ok := m["data"].(map[string]interface{}); ok { m = data } status, _ := m["status"].(string) if status != "need_input" { return false, "" } question, _ := m["question"].(string) if question == "" { question = result.GetText() } return true, question } // extractOutput extracts the output from a CallResult // Priority: Next hook data > LLM Completion content // Next is the agent's formal A2A output (could be string, map, array, number, etc.) // Content is the raw LLM completion text (fallback only when Next is absent) func (r *Runner) extractOutput(result *CallResult) interface{} { if result == nil { return nil } if result.Next != nil { return result.Next } if result.Content != "" { return result.Content } return nil } // ExecuteMCPTask executes a task using an MCP tool // Requires task.MCPServer and task.MCPTool fields to be set // executor_id is the combined form: "mcp_server.mcp_tool" (e.g., "ark.image.text2img.generate") func (r *Runner) ExecuteMCPTask(task *robottypes.Task, taskCtx *RunnerContext) (interface{}, error) { // Validate MCP-specific fields if task.MCPServer == "" || task.MCPTool == "" { return nil, fmt.Errorf("MCP task requires mcp_server and mcp_tool fields (executor_id: %s)", task.ExecutorID) } // Get MCP client client, err := mcp.Select(task.MCPServer) if err != nil { return nil, fmt.Errorf("MCP server not found: %s: %w", task.MCPServer, err) } // Build arguments map from task.Args args := make(map[string]interface{}) if len(task.Args) > 0 { // First argument should be a map of tool arguments if argsMap, ok := task.Args[0].(map[string]interface{}); ok { args = argsMap } else { // If not a map, try to convert single argument args["input"] = task.Args[0] } } // Call MCP tool result, err := client.CallTool(r.ctx.Context, task.MCPTool, args) if err != nil { return nil, fmt.Errorf("MCP tool call failed (%s.%s): %w", task.MCPServer, task.MCPTool, err) } return result, nil } // ExecuteProcessTask executes a task using a Yao process // ExecutorID is the process name (e.g., "models.user.Find", "scripts.myScript.Run") func (r *Runner) ExecuteProcessTask(task *robottypes.Task, taskCtx *RunnerContext) (interface{}, error) { // Create process with task arguments proc, err := process.Of(task.ExecutorID, task.Args...) if err != nil { return nil, fmt.Errorf("process creation failed: %w", err) } // Set context for timeout and cancellation proc.Context = r.ctx.Context // Execute the process if err := proc.Execute(); err != nil { return nil, fmt.Errorf("process execution failed: %w", err) } defer proc.Release() // Return the result return proc.Value(), nil } // BuildAssistantMessages builds messages for an assistant task func (r *Runner) BuildAssistantMessages(task *robottypes.Task, taskCtx *RunnerContext) []agentcontext.Message { messages := make([]agentcontext.Message, 0) // Add context from previous tasks if available if len(taskCtx.PreviousResults) > 0 { contextMsg := r.FormatPreviousResultsAsContext(taskCtx.PreviousResults) if contextMsg != "" { messages = append(messages, agentcontext.Message{ Role: agentcontext.RoleUser, Content: contextMsg, }) } } // Add task messages messages = append(messages, task.Messages...) return messages } // FormatMessagesAsText converts messages to a single text string func (r *Runner) FormatMessagesAsText(messages []agentcontext.Message) string { var result string for _, msg := range messages { switch content := msg.Content.(type) { case string: result += content + "\n\n" case []interface{}: // Handle multi-part content (e.g., text + images) for _, part := range content { if textPart, ok := part.(map[string]interface{}); ok { if text, ok := textPart["text"].(string); ok { result += text + "\n\n" } } } default: // Try JSON marshaling as fallback if content != nil { if jsonBytes, err := json.Marshal(content); err == nil { result += string(jsonBytes) + "\n\n" } } } } return result } // FormatPreviousResultsAsContext formats previous task results as context func (r *Runner) FormatPreviousResultsAsContext(results []robottypes.TaskResult) string { if len(results) == 0 { return "" } var sb strings.Builder sb.WriteString("## Previous Task Results\n\n") sb.WriteString("The following tasks have been completed. Use their results as needed:\n\n") for _, result := range results { sb.WriteString(fmt.Sprintf("### Task: %s\n", result.TaskID)) if result.Success { sb.WriteString("- Status: ✓ Success\n") } else { sb.WriteString("- Status: ✗ Failed\n") } if result.Output != nil { outputJSON, err := json.MarshalIndent(result.Output, "", " ") if err == nil { sb.WriteString(fmt.Sprintf("- Output:\n```json\n%s\n```\n", string(outputJSON))) } else { sb.WriteString(fmt.Sprintf("- Output: %v\n", result.Output)) } } sb.WriteString("\n") } return sb.String() } ================================================ FILE: agent/robot/executor/standard/runner_test.go ================================================ package standard_test import ( "context" "testing" "github.com/stretchr/testify/assert" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ============================================================================ // Runner Tests - V2 Simplified Execution (single call, no validation loop) // ============================================================================ func TestRunnerExecuteTask(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("executes assistant task successfully", func(t *testing.T) { robot := createRunnerTestRobot(t) config := standard.DefaultRunConfig() runner := standard.NewRunner(ctx, robot, config, "", "test") task := &types.Task{ ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write a haiku about coding. Format: three lines with 5-7-5 syllables."}, }, ExpectedOutput: "A haiku poem about coding", Status: types.TaskPending, } taskCtx := &standard.RunnerContext{ SystemPrompt: robot.SystemPrompt, } result := runner.ExecuteTask(task, taskCtx) assert.True(t, result.Success, "task should succeed") assert.NotNil(t, result.Output, "output should not be nil") assert.Empty(t, result.Error, "error should be empty on success") assert.Greater(t, result.Duration, int64(0), "duration should be positive") t.Logf("Output: %v", result.Output) }) t.Run("returns success without validation for assistant tasks", func(t *testing.T) { robot := createRunnerTestRobot(t) config := standard.DefaultRunConfig() runner := standard.NewRunner(ctx, robot, config, "", "test") task := &types.Task{ ID: "task-002", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.data-analyst", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Return a JSON object with exactly these fields: status (string 'ok'), count (number greater than 0)."}, }, ExpectedOutput: "JSON with status='ok' and count>0", Status: types.TaskPending, } taskCtx := &standard.RunnerContext{ SystemPrompt: robot.SystemPrompt, } result := runner.ExecuteTask(task, taskCtx) // V2: success is determined by the call succeeding, not by validation assert.True(t, result.Success, "task should succeed if assistant call returns") assert.NotNil(t, result.Output, "output should not be nil") assert.Nil(t, result.Validation, "V2 does not set Validation in runner") t.Logf("Success: %v, Output: %v", result.Success, result.Output) }) t.Run("handles empty messages gracefully", func(t *testing.T) { robot := createRunnerTestRobot(t) config := standard.DefaultRunConfig() runner := standard.NewRunner(ctx, robot, config, "", "test") task := &types.Task{ ID: "task-003", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{}, Status: types.TaskPending, } taskCtx := &standard.RunnerContext{ SystemPrompt: robot.SystemPrompt, } result := runner.ExecuteTask(task, taskCtx) assert.False(t, result.Success, "task should fail with empty messages") assert.NotEmpty(t, result.Error, "error should describe the failure") t.Logf("Error: %s", result.Error) }) } func TestRunnerBuildTaskContext(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("includes previous results in context", func(t *testing.T) { robot := createRunnerTestRobot(t) config := standard.DefaultRunConfig() runner := standard.NewRunner(ctx, robot, config, "", "test") exec := &types.Execution{ ID: "test-exec", MemberID: robot.MemberID, TeamID: robot.TeamID, Goals: &types.Goals{ Content: "Test goals", }, Results: []types.TaskResult{ { TaskID: "task-001", Success: true, Output: map[string]interface{}{"data": "previous result"}, }, { TaskID: "task-002", Success: true, Output: "Another result", }, }, } exec.SetRobot(robot) // Build context for task at index 2 (should include results 0 and 1) taskCtx := runner.BuildTaskContext(exec, 2) assert.NotNil(t, taskCtx) assert.Len(t, taskCtx.PreviousResults, 2) assert.Equal(t, "task-001", taskCtx.PreviousResults[0].TaskID) assert.Equal(t, "task-002", taskCtx.PreviousResults[1].TaskID) assert.Equal(t, robot.SystemPrompt, taskCtx.SystemPrompt) }) t.Run("handles first task with no previous results", func(t *testing.T) { robot := createRunnerTestRobot(t) config := standard.DefaultRunConfig() runner := standard.NewRunner(ctx, robot, config, "", "test") exec := &types.Execution{ ID: "test-exec", MemberID: robot.MemberID, TeamID: robot.TeamID, Goals: &types.Goals{ Content: "Test goals", }, Results: []types.TaskResult{}, } exec.SetRobot(robot) taskCtx := runner.BuildTaskContext(exec, 0) assert.NotNil(t, taskCtx) assert.Empty(t, taskCtx.PreviousResults) }) t.Run("handles bounds check for task index", func(t *testing.T) { robot := createRunnerTestRobot(t) config := standard.DefaultRunConfig() runner := standard.NewRunner(ctx, robot, config, "", "test") exec := &types.Execution{ ID: "test-exec", MemberID: robot.MemberID, TeamID: robot.TeamID, Results: []types.TaskResult{ {TaskID: "task-001", Success: true}, }, } exec.SetRobot(robot) // Task index 5, but only 1 result exists taskCtx := runner.BuildTaskContext(exec, 5) assert.NotNil(t, taskCtx) assert.Len(t, taskCtx.PreviousResults, 1) // Should only include available results }) } func TestRunnerFormatPreviousResultsAsContext(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("formats previous results as markdown", func(t *testing.T) { robot := createRunnerTestRobot(t) config := standard.DefaultRunConfig() runner := standard.NewRunner(ctx, robot, config, "", "test") results := []types.TaskResult{ { TaskID: "task-001", Success: true, Output: map[string]interface{}{"key": "value", "count": 42}, }, { TaskID: "task-002", Success: false, Output: "Partial result", Error: "Validation failed", }, } formatted := runner.FormatPreviousResultsAsContext(results) assert.Contains(t, formatted, "## Previous Task Results") assert.Contains(t, formatted, "task-001") assert.Contains(t, formatted, "task-002") assert.Contains(t, formatted, "Success") assert.Contains(t, formatted, "Failed") assert.Contains(t, formatted, "key") assert.Contains(t, formatted, "value") t.Logf("Formatted context:\n%s", formatted) }) t.Run("returns empty string for no results", func(t *testing.T) { robot := createRunnerTestRobot(t) config := standard.DefaultRunConfig() runner := standard.NewRunner(ctx, robot, config, "", "test") formatted := runner.FormatPreviousResultsAsContext([]types.TaskResult{}) assert.Empty(t, formatted) }) } func TestRunnerBuildAssistantMessages(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("builds messages with task content", func(t *testing.T) { robot := createRunnerTestRobot(t) config := standard.DefaultRunConfig() runner := standard.NewRunner(ctx, robot, config, "", "test") task := &types.Task{ ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write a greeting"}, }, } taskCtx := &standard.RunnerContext{ SystemPrompt: "You are helpful", } messages := runner.BuildAssistantMessages(task, taskCtx) assert.NotEmpty(t, messages) // Should contain task message found := false for _, msg := range messages { if content, ok := msg.Content.(string); ok && content == "Write a greeting" { found = true break } } assert.True(t, found, "should contain task message") }) t.Run("includes previous results in messages", func(t *testing.T) { robot := createRunnerTestRobot(t) config := standard.DefaultRunConfig() runner := standard.NewRunner(ctx, robot, config, "", "test") task := &types.Task{ ID: "task-002", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Continue from previous"}, }, } taskCtx := &standard.RunnerContext{ PreviousResults: []types.TaskResult{ {TaskID: "task-001", Success: true, Output: "Previous output"}, }, SystemPrompt: "You are helpful", } messages := runner.BuildAssistantMessages(task, taskCtx) assert.NotEmpty(t, messages) // Should have context message with previous results formatted := runner.FormatMessagesAsText(messages) assert.Contains(t, formatted, "Previous Task Results") }) } func TestRunnerFormatMessagesAsText(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("formats string content", func(t *testing.T) { robot := createRunnerTestRobot(t) config := standard.DefaultRunConfig() runner := standard.NewRunner(ctx, robot, config, "", "test") messages := []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Hello"}, {Role: agentcontext.RoleUser, Content: "World"}, } text := runner.FormatMessagesAsText(messages) assert.Contains(t, text, "Hello") assert.Contains(t, text, "World") }) t.Run("handles multipart content", func(t *testing.T) { robot := createRunnerTestRobot(t) config := standard.DefaultRunConfig() runner := standard.NewRunner(ctx, robot, config, "", "test") messages := []agentcontext.Message{ { Role: agentcontext.RoleUser, Content: []interface{}{ map[string]interface{}{"type": "text", "text": "Part 1"}, map[string]interface{}{"type": "text", "text": "Part 2"}, }, }, } text := runner.FormatMessagesAsText(messages) assert.Contains(t, text, "Part 1") assert.Contains(t, text, "Part 2") }) t.Run("handles map content via JSON", func(t *testing.T) { robot := createRunnerTestRobot(t) config := standard.DefaultRunConfig() runner := standard.NewRunner(ctx, robot, config, "", "test") messages := []agentcontext.Message{ { Role: agentcontext.RoleUser, Content: map[string]interface{}{"key": "value"}, }, } text := runner.FormatMessagesAsText(messages) assert.Contains(t, text, "key") assert.Contains(t, text, "value") }) } // ============================================================================ // Non-Assistant Task Tests (MCP, Process) // ============================================================================ func TestRunnerExecuteNonAssistantTask(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) t.Run("executes unsupported type returns error", func(t *testing.T) { ctx := types.NewContext(context.Background(), testAuth()) robot := createRunnerTestRobot(t) config := standard.DefaultRunConfig() runner := standard.NewRunner(ctx, robot, config, "", "test") task := &types.Task{ ID: "task-unknown", ExecutorType: "unsupported", ExecutorID: "anything", Status: types.TaskPending, } taskCtx := &standard.RunnerContext{} result := runner.ExecuteTask(task, taskCtx) assert.False(t, result.Success, "unsupported executor type should fail") assert.Contains(t, result.Error, "unsupported executor type") assert.Nil(t, result.Validation, "V2 does not set Validation in runner") }) } // ============================================================================ // Helper Functions // ============================================================================ // createRunnerTestRobot creates a test robot for runner tests func createRunnerTestRobot(t *testing.T) *types.Robot { t.Helper() return &types.Robot{ MemberID: "test-robot-runner", TeamID: "test-team-1", DisplayName: "Test Robot for Runner", SystemPrompt: "You are a helpful assistant. Follow instructions carefully and provide clear responses.", Config: &types.Config{ Identity: &types.Identity{ Role: "Test Assistant", Duties: []string{"Execute tasks", "Generate content"}, }, Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseRun: "robot.run", }, Agents: []string{ "experts.data-analyst", "experts.summarizer", "experts.text-writer", }, }, }, } } // Note: createRunnerTestExecution is available if needed for future tests // that require a full Execution object instead of just RunnerContext ================================================ FILE: agent/robot/executor/standard/suspend_resume_test.go ================================================ package standard_test import ( "context" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/executor/standard" executortypes "github.com/yaoapp/yao/agent/robot/executor/types" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ============================================================================ // RunExecution with ResumeContext tests // ============================================================================ func TestRunExecutionResumeContext(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("resumes from task index with previous results", func(t *testing.T) { robot := createRunTestRobot(t) exec := createRunTestExecution(robot) exec.Tasks = []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write 'hello'"}, }, Order: 0, Status: types.TaskCompleted, }, { ID: "task-002", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write 'world'"}, }, Order: 1, Status: types.TaskPending, }, { ID: "task-003", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write '!'"}, }, Order: 2, Status: types.TaskPending, }, } // Simulate resume: task-001 already completed, resume from task-002 previousResult := types.TaskResult{ TaskID: "task-001", Success: true, Output: "hello", Duration: 100, } exec.ResumeContext = &types.ResumeContext{ TaskIndex: 1, PreviousResults: []types.TaskResult{previousResult}, } e := standard.New() err := e.RunExecution(ctx, exec, nil) require.NoError(t, err) // Should have 3 results: 1 from previous + 2 new require.Len(t, exec.Results, 3) assert.Equal(t, "task-001", exec.Results[0].TaskID) assert.True(t, exec.Results[0].Success) assert.Equal(t, "hello", exec.Results[0].Output) assert.Equal(t, "task-002", exec.Results[1].TaskID) assert.True(t, exec.Results[1].Success) assert.Equal(t, "task-003", exec.Results[2].TaskID) assert.True(t, exec.Results[2].Success) // ResumeContext should be cleared after completion assert.Nil(t, exec.ResumeContext) // Only task-002 and task-003 should have been executed (check status) assert.Equal(t, types.TaskCompleted, exec.Tasks[1].Status) assert.Equal(t, types.TaskCompleted, exec.Tasks[2].Status) }) t.Run("resumes from last task", func(t *testing.T) { robot := createRunTestRobot(t) exec := createRunTestExecution(robot) exec.Tasks = []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write 'hello'"}, }, Order: 0, Status: types.TaskCompleted, }, { ID: "task-002", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write 'world'"}, }, Order: 1, Status: types.TaskWaitingInput, }, } // Resume from the last task exec.ResumeContext = &types.ResumeContext{ TaskIndex: 1, PreviousResults: []types.TaskResult{ {TaskID: "task-001", Success: true, Output: "hello", Duration: 100}, }, } e := standard.New() err := e.RunExecution(ctx, exec, nil) require.NoError(t, err) require.Len(t, exec.Results, 2) assert.True(t, exec.Results[1].Success) assert.Equal(t, types.TaskCompleted, exec.Tasks[1].Status) }) } // ============================================================================ // Suspend method tests (using Executor directly) // ============================================================================ func TestSuspendExecution(t *testing.T) { t.Run("suspend sets waiting fields and returns ErrExecutionSuspended", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot-suspend", TeamID: "test-team-1", DisplayName: "Suspend Test Robot", } exec := &types.Execution{ ID: "exec-suspend-001", MemberID: robot.MemberID, TeamID: robot.TeamID, Status: types.ExecRunning, Phase: types.PhaseRun, Tasks: []types.Task{ {ID: "task-001", Status: types.TaskRunning}, {ID: "task-002", Status: types.TaskPending}, }, Results: []types.TaskResult{}, } exec.SetRobot(robot) e := standard.NewWithConfig(executortypes.Config{SkipPersistence: true}) err := e.Suspend( types.NewContext(context.Background(), nil), exec, 0, "What time range?", ) assert.ErrorIs(t, err, types.ErrExecutionSuspended) assert.Equal(t, types.ExecWaiting, exec.Status) assert.Equal(t, "task-001", exec.WaitingTaskID) assert.Equal(t, "What time range?", exec.WaitingQuestion) assert.NotNil(t, exec.WaitingSince) assert.NotNil(t, exec.ResumeContext) assert.Equal(t, 0, exec.ResumeContext.TaskIndex) assert.Equal(t, types.TaskWaitingInput, exec.Tasks[0].Status) }) t.Run("suspend with out of range taskIndex is safe", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot-suspend-2", TeamID: "test-team-1", } exec := &types.Execution{ ID: "exec-suspend-002", MemberID: robot.MemberID, TeamID: robot.TeamID, Status: types.ExecRunning, Tasks: []types.Task{}, Results: []types.TaskResult{}, } exec.SetRobot(robot) e := standard.NewWithConfig(executortypes.Config{SkipPersistence: true}) err := e.Suspend( types.NewContext(context.Background(), nil), exec, 5, "some question", ) assert.ErrorIs(t, err, types.ErrExecutionSuspended) assert.Equal(t, types.ExecWaiting, exec.Status) assert.Empty(t, exec.WaitingTaskID) }) } // ============================================================================ // ExecuteWithControl handles ErrExecutionSuspended // ============================================================================ func TestExecuteWithControlSuspend(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) t.Run("returns ErrExecutionSuspended without marking as failed", func(t *testing.T) { // This test requires a robot-need-input assistant that returns need_input. // Since we don't have one yet (Stage 6), we test the suspend path indirectly // by verifying that when RunExecution returns ErrExecutionSuspended, // ExecuteWithControl propagates it correctly. // // Full E2E test with real assistant will be in Stage 6. robot := &types.Robot{ MemberID: "test-robot-suspend-exec", TeamID: "test-team-1", DisplayName: "Suspend Exec Test", Config: &types.Config{ Identity: &types.Identity{ Role: "Test", }, Resources: &types.Resources{ Phases: map[types.Phase]string{}, Agents: []string{"experts.text-writer"}, }, Quota: &types.Quota{Max: 5}, }, } ctx := types.NewContext(context.Background(), testAuth()) e := standard.New() // Execute normally (no need_input expected from text-writer) exec, err := e.Execute(ctx, robot, types.TriggerHuman, "Write a greeting") if err == types.ErrExecutionSuspended { // If somehow suspended, verify state assert.Equal(t, types.ExecWaiting, exec.Status) assert.NotEmpty(t, exec.WaitingQuestion) } else { // Normal completion assert.NoError(t, err) assert.NotNil(t, exec) } }) } // ============================================================================ // ResumeContext data structure tests // ============================================================================ func TestResumeContext(t *testing.T) { t.Run("stores task index and previous results", func(t *testing.T) { rc := &types.ResumeContext{ TaskIndex: 2, PreviousResults: []types.TaskResult{ {TaskID: "t1", Success: true, Output: "out1"}, {TaskID: "t2", Success: false, Error: "some error"}, }, } assert.Equal(t, 2, rc.TaskIndex) assert.Len(t, rc.PreviousResults, 2) assert.True(t, rc.PreviousResults[0].Success) assert.False(t, rc.PreviousResults[1].Success) }) } // ============================================================================ // NeedInput in TaskResult // ============================================================================ func TestTaskResultNeedInput(t *testing.T) { t.Run("NeedInput fields are populated correctly", func(t *testing.T) { result := types.TaskResult{ TaskID: "task-001", Success: true, Output: "some output", NeedInput: true, InputQuestion: "What time range?", } assert.True(t, result.NeedInput) assert.Equal(t, "What time range?", result.InputQuestion) }) } // ============================================================================ // Execution status transitions for suspend/resume // ============================================================================ func TestExecutionStatusTransitions(t *testing.T) { t.Run("ExecWaiting is a valid status", func(t *testing.T) { exec := &types.Execution{ ID: "exec-001", Status: types.ExecWaiting, } assert.Equal(t, types.ExecStatus("waiting"), exec.Status) }) t.Run("TaskWaitingInput is a valid task status", func(t *testing.T) { task := types.Task{ ID: "task-001", Status: types.TaskWaitingInput, } assert.Equal(t, types.TaskStatus("waiting_input"), task.Status) }) t.Run("Execution V2 fields are accessible", func(t *testing.T) { now := time.Now() exec := &types.Execution{ ID: "exec-v2-001", ChatID: "robot_member1_exec001", WaitingTaskID: "task-002", WaitingQuestion: "What period?", WaitingSince: &now, ResumeContext: &types.ResumeContext{ TaskIndex: 1, PreviousResults: []types.TaskResult{ {TaskID: "task-001", Success: true}, }, }, } assert.Equal(t, "robot_member1_exec001", exec.ChatID) assert.Equal(t, "task-002", exec.WaitingTaskID) assert.Equal(t, "What period?", exec.WaitingQuestion) assert.NotNil(t, exec.WaitingSince) assert.NotNil(t, exec.ResumeContext) assert.Equal(t, 1, exec.ResumeContext.TaskIndex) }) } ================================================ FILE: agent/robot/executor/standard/suspend_test.go ================================================ package standard import ( "testing" "github.com/stretchr/testify/assert" ) // ============================================================================ // detectNeedMoreInfo unit tests (internal — tests unexported function) // ============================================================================ func TestDetectNeedMoreInfo(t *testing.T) { t.Run("returns false for nil result", func(t *testing.T) { needInput, question := detectNeedMoreInfo(nil) assert.False(t, needInput) assert.Empty(t, question) }) t.Run("returns false for nil Next", func(t *testing.T) { result := &CallResult{Content: "some text"} needInput, question := detectNeedMoreInfo(result) assert.False(t, needInput) assert.Empty(t, question) }) t.Run("returns false for non-map Next", func(t *testing.T) { result := &CallResult{Next: "just a string"} needInput, question := detectNeedMoreInfo(result) assert.False(t, needInput) assert.Empty(t, question) }) t.Run("returns false when status is not need_input", func(t *testing.T) { result := &CallResult{ Next: map[string]interface{}{ "status": "ok", "content": "everything is fine", }, } needInput, question := detectNeedMoreInfo(result) assert.False(t, needInput) assert.Empty(t, question) }) t.Run("returns true with question from Next", func(t *testing.T) { result := &CallResult{ Next: map[string]interface{}{ "status": "need_input", "question": "What time range should I use?", }, } needInput, question := detectNeedMoreInfo(result) assert.True(t, needInput) assert.Equal(t, "What time range should I use?", question) }) t.Run("falls back to GetText when question is empty", func(t *testing.T) { result := &CallResult{ Content: "I need more information about the time range.", Next: map[string]interface{}{ "status": "need_input", }, } needInput, question := detectNeedMoreInfo(result) assert.True(t, needInput) assert.Equal(t, "I need more information about the time range.", question) }) t.Run("returns true with empty question when both are empty", func(t *testing.T) { result := &CallResult{ Next: map[string]interface{}{ "status": "need_input", }, } needInput, question := detectNeedMoreInfo(result) assert.True(t, needInput) assert.Empty(t, question) }) t.Run("unwraps data envelope from Next hook", func(t *testing.T) { result := &CallResult{ Next: map[string]interface{}{ "data": map[string]interface{}{ "status": "need_input", "question": "Which database should I query?", }, }, } needInput, question := detectNeedMoreInfo(result) assert.True(t, needInput) assert.Equal(t, "Which database should I query?", question) }) } ================================================ FILE: agent/robot/executor/standard/tasks.go ================================================ package standard import ( "fmt" agentcontext "github.com/yaoapp/yao/agent/context" robottypes "github.com/yaoapp/yao/agent/robot/types" ) // RunTasks executes P2: Tasks phase // Calls the Tasks Agent to break down goals into executable tasks // // Input: // - Goals (from P1) with markdown content // - Available resources (Agents, MCP tools, KB, DB) // // Output: // - List of Task objects with executor assignments, expected outputs, and validation rules func (e *Executor) RunTasks(ctx *robottypes.Context, exec *robottypes.Execution, _ interface{}) error { // §18.2: confirming phase may have already populated Tasks — skip regeneration if len(exec.Tasks) > 0 { return nil } // Get robot for resources robot := exec.GetRobot() if robot == nil { return fmt.Errorf("robot not found in execution") } // Update UI field with i18n locale := getEffectiveLocale(robot, exec.Input) e.updateUIFields(ctx, exec, "", getLocalizedMessage(locale, "breaking_down_tasks")) // Validate: Goals must exist (from P1) if exec.Goals == nil || exec.Goals.Content == "" { return fmt.Errorf("goals not available for task planning") } // Get agent ID for tasks phase agentID := "__yao.tasks" // default if robot.Config != nil && robot.Config.Resources != nil { agentID = robot.Config.Resources.GetPhaseAgent(robottypes.PhaseTasks) } // Build prompt with goals and available resources formatter := NewInputFormatter() userContent := formatter.FormatGoals(exec.Goals, robot) if userContent == "" { return fmt.Errorf("tasks agent (%s) received empty input for task planning", agentID) } // Call agent caller := NewAgentCaller() caller.log = newExecLogger(robot, exec.ID) caller.Connector = robot.LanguageModel result, err := caller.CallWithMessages(ctx, agentID, userContent) if err != nil { return fmt.Errorf("tasks agent (%s) call failed: %w", agentID, err) } // Parse response as JSON // Tasks Agent returns: { "tasks": [...] } data, err := result.GetJSON() if err != nil { return fmt.Errorf("tasks agent (%s) returned invalid JSON: %w", agentID, err) } // Extract tasks array tasksData, ok := data["tasks"].([]interface{}) if !ok || len(tasksData) == 0 { return fmt.Errorf("tasks agent (%s) returned no tasks", agentID) } // Parse tasks tasks, err := ParseTasks(tasksData) if err != nil { return fmt.Errorf("tasks agent (%s) returned invalid task structure: %w", agentID, err) } // Validate tasks if err := ValidateTasks(tasks); err != nil { return fmt.Errorf("tasks validation failed: %w", err) } exec.Tasks = tasks // Log task overview for developer observability el := newExecLogger(robot, exec.ID) el.logTaskOverview(tasks) return nil } // ParseTasks converts raw JSON array to []Task // Tasks are sorted by Order field after parsing func ParseTasks(data []interface{}) ([]robottypes.Task, error) { tasks := make([]robottypes.Task, 0, len(data)) for i, item := range data { taskMap, ok := item.(map[string]interface{}) if !ok { return nil, fmt.Errorf("task %d is not a valid object", i) } task, err := ParseTask(taskMap, i) if err != nil { return nil, fmt.Errorf("task %d: %w", i, err) } tasks = append(tasks, *task) } // Sort tasks by Order field to ensure correct execution sequence SortTasksByOrder(tasks) return tasks, nil } // ParseTask converts a map to Task struct func ParseTask(data map[string]interface{}, index int) (*robottypes.Task, error) { task := &robottypes.Task{ Status: robottypes.TaskPending, Order: index, } // Required: id if id, ok := data["id"].(string); ok && id != "" { task.ID = id } else { task.ID = fmt.Sprintf("task-%03d", index+1) } // Required: executor_type if execType, ok := data["executor_type"].(string); ok { task.ExecutorType = ParseExecutorType(execType) } else { return nil, fmt.Errorf("missing executor_type") } // Required: executor_id if execID, ok := data["executor_id"].(string); ok && execID != "" { task.ExecutorID = execID } else { return nil, fmt.Errorf("missing executor_id") } // Optional: goal_ref if goalRef, ok := data["goal_ref"].(string); ok { task.GoalRef = goalRef } // Optional: source (default to auto) if source, ok := data["source"].(string); ok { task.Source = robottypes.TaskSource(source) } else { task.Source = robottypes.TaskSourceAuto } // Optional: order (override default) if order, ok := data["order"].(float64); ok { task.Order = int(order) } // Optional: messages (task instructions) if messages, ok := data["messages"].([]interface{}); ok { task.Messages = ParseMessages(messages) } // Optional: description - save to Description field and convert to message if no messages if desc, ok := data["description"].(string); ok && desc != "" { task.Description = desc // Also convert to Messages for execution if no explicit messages provided if len(task.Messages) == 0 { task.Messages = []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: desc}, } } } // Optional: args if args, ok := data["args"].([]interface{}); ok { task.Args = make([]any, len(args)) copy(task.Args, args) } // MCP-specific fields (required when executor_type is "mcp") if mcpServer, ok := data["mcp_server"].(string); ok { task.MCPServer = mcpServer } if mcpTool, ok := data["mcp_tool"].(string); ok { task.MCPTool = mcpTool } // Optional: expected_output (for P3 validation) if expectedOutput, ok := data["expected_output"].(string); ok { task.ExpectedOutput = expectedOutput } // Optional: validation_rules (for P3 validation) if rules, ok := data["validation_rules"].([]interface{}); ok { task.ValidationRules = make([]string, 0, len(rules)) for _, r := range rules { if s, ok := r.(string); ok { task.ValidationRules = append(task.ValidationRules, s) } } } return task, nil } // ParseMessages converts raw message array to []Message func ParseMessages(data []interface{}) []agentcontext.Message { messages := make([]agentcontext.Message, 0, len(data)) for _, item := range data { msgMap, ok := item.(map[string]interface{}) if !ok { continue } msg := agentcontext.Message{} // Role if role, ok := msgMap["role"].(string); ok { msg.Role = agentcontext.MessageRole(role) } else { msg.Role = agentcontext.RoleUser } // Content if content, ok := msgMap["content"].(string); ok { msg.Content = content } else if content, ok := msgMap["content"]; ok { // Handle non-string content (multimodal) msg.Content = content } if msg.Content != nil { messages = append(messages, msg) } } return messages } // ParseExecutorType converts string to ExecutorType func ParseExecutorType(s string) robottypes.ExecutorType { switch s { case "agent", "assistant": return robottypes.ExecutorAssistant case "mcp": return robottypes.ExecutorMCP case "process": return robottypes.ExecutorProcess default: return robottypes.ExecutorAssistant // default to assistant } } // ValidateTasks validates the task list func ValidateTasks(tasks []robottypes.Task) error { if len(tasks) == 0 { return fmt.Errorf("no tasks generated") } seenIDs := make(map[string]bool) for i, task := range tasks { // Check unique ID if seenIDs[task.ID] { return fmt.Errorf("task %d: duplicate task ID '%s'", i, task.ID) } seenIDs[task.ID] = true // Check executor if task.ExecutorID == "" { return fmt.Errorf("task %d (%s): missing executor_id", i, task.ID) } // Check messages or description if len(task.Messages) == 0 { return fmt.Errorf("task %d (%s): missing messages or description", i, task.ID) } // Note: Executor existence is NOT validated here // - ValidateExecutorExists() can be called separately if needed // - Unknown executors will fail at P3 runtime with clear error message // - This allows flexibility for dynamically registered executors // Note: Validation rules are optional // - P3 can still do basic validation without explicit rules } return nil } // ValidateTasksWithResources validates tasks and checks executor existence // Returns a list of warnings for unknown executors (does not fail) func ValidateTasksWithResources(tasks []robottypes.Task, robot *robottypes.Robot) (warnings []string, err error) { // First do basic validation if err := ValidateTasks(tasks); err != nil { return nil, err } // Then check executor existence (warnings only) for _, task := range tasks { if !ValidateExecutorExists(task.ExecutorID, task.ExecutorType, robot) { warnings = append(warnings, fmt.Sprintf( "task %s: executor '%s' (%s) not found in available resources", task.ID, task.ExecutorID, task.ExecutorType, )) } } return warnings, nil } // IsValidExecutorType checks if the executor type is valid func IsValidExecutorType(t robottypes.ExecutorType) bool { switch t { case robottypes.ExecutorAssistant, robottypes.ExecutorMCP, robottypes.ExecutorProcess: return true default: return false } } // SortTasksByOrder sorts tasks by their Order field (ascending) // This ensures tasks are executed in the correct sequence regardless of // the order they appear in the LLM response func SortTasksByOrder(tasks []robottypes.Task) { for i := 0; i < len(tasks)-1; i++ { for j := i + 1; j < len(tasks); j++ { if tasks[j].Order < tasks[i].Order { tasks[i], tasks[j] = tasks[j], tasks[i] } } } } // ValidateExecutorExists checks if the executor ID exists in available resources // This is an optional validation - tasks with unknown executors will still be created // but may fail during P3 execution // For MCP tasks, pass mcpServer as the second parameter (executorID is ignored for MCP) func ValidateExecutorExists(executorID string, executorType robottypes.ExecutorType, robot *robottypes.Robot) bool { if robot == nil || robot.Config == nil || robot.Config.Resources == nil { return true // Skip validation if no resources configured } switch executorType { case robottypes.ExecutorAssistant: for _, agent := range robot.Config.Resources.Agents { if agent == executorID { return true } } return false case robottypes.ExecutorMCP: // For MCP, executorID can be either: // 1. The mcp_server value (new format) // 2. The combined mcp_server.mcp_tool format (for display) // We validate against mcp_server (the MCP server/client ID) for _, mcp := range robot.Config.Resources.MCP { if mcp.ID == executorID { return true } } return false case robottypes.ExecutorProcess: // Process executors are not validated against resources // They are validated at runtime by the Yao process system return true } return false } // ValidateMCPTask validates MCP task fields // Returns an error if mcp_server or mcp_tool is missing for MCP tasks func ValidateMCPTask(task *robottypes.Task) error { if task.ExecutorType != robottypes.ExecutorMCP { return nil } if task.MCPServer == "" { return fmt.Errorf("MCP task %s: mcp_server field is required", task.ID) } if task.MCPTool == "" { return fmt.Errorf("MCP task %s: mcp_tool field is required", task.ID) } return nil } ================================================ FILE: agent/robot/executor/standard/tasks_test.go ================================================ package standard_test import ( "context" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ============================================================================ // P2 Tasks Phase Tests // ============================================================================ func TestRunTasksBasic(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("generates tasks from goals (clock trigger)", func(t *testing.T) { // Create robot with tasks agent configured robot := createTasksTestRobot(t, "robot.tasks") // Create execution with goals (from P1) exec := createTasksTestExecution(robot, types.TriggerClock) exec.Goals = &types.Goals{ Content: `## Goals 1. [High] Analyze Q4 sales data and identify top performing products - Reason: Need to prepare quarterly report 2. [Normal] Generate a summary report for management - Reason: Weekly review meeting tomorrow`, } // Run tasks phase e := standard.New() err := e.RunTasks(ctx, exec, nil) require.NoError(t, err) require.NotNil(t, exec.Tasks) assert.NotEmpty(t, exec.Tasks) // Verify task structure for i, task := range exec.Tasks { t.Logf("Task %d: ID=%s, ExecutorType=%s, ExecutorID=%s", i, task.ID, task.ExecutorType, task.ExecutorID) assert.NotEmpty(t, task.ID, "task should have ID") assert.NotEmpty(t, task.ExecutorID, "task should have executor ID") assert.NotEmpty(t, task.Messages, "task should have messages") } }) t.Run("includes expected output and validation rules", func(t *testing.T) { robot := createTasksTestRobot(t, "robot.tasks") exec := createTasksTestExecution(robot, types.TriggerClock) exec.Goals = &types.Goals{ Content: `## Goals 1. [High] Fetch latest news about AI developments - Reason: Stay updated on industry trends 2. [Normal] Summarize the key findings - Reason: Share with team`, } e := standard.New() err := e.RunTasks(ctx, exec, nil) require.NoError(t, err) require.NotEmpty(t, exec.Tasks) // Check that at least one task has validation info hasValidationInfo := false for _, task := range exec.Tasks { if task.ExpectedOutput != "" || len(task.ValidationRules) > 0 { hasValidationInfo = true t.Logf("Task %s has validation: expected_output=%q, rules=%v", task.ID, task.ExpectedOutput, task.ValidationRules) } } // Note: LLM might not always include validation rules, so we just log t.Logf("Tasks have validation info: %v", hasValidationInfo) }) } func TestRunTasksHumanTrigger(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("generates tasks from human-triggered goals", func(t *testing.T) { robot := createTasksTestRobot(t, "robot.tasks") exec := createTasksTestExecution(robot, types.TriggerHuman) // Goals from human request (P1 output) exec.Goals = &types.Goals{ Content: `## Goals 1. [High] Research competitor pricing strategies - Reason: User requested competitive analysis 2. [Normal] Create comparison report - Reason: User needs data for presentation`, } e := standard.New() err := e.RunTasks(ctx, exec, nil) require.NoError(t, err) require.NotEmpty(t, exec.Tasks) // Tasks should relate to the goals for _, task := range exec.Tasks { t.Logf("Task: %s -> %s", task.ID, task.ExecutorID) } }) } func TestRunTasksWithExpertAgents(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("assigns appropriate expert agents to tasks", func(t *testing.T) { robot := createTasksTestRobot(t, "robot.tasks") exec := createTasksTestExecution(robot, types.TriggerClock) // Goals that require different expert agents exec.Goals = &types.Goals{ Content: `## Goals 1. [High] Analyze sales data from database - Reason: Quarterly review needed - Requires: Data analysis capabilities 2. [Normal] Write executive summary report - Reason: Management presentation - Requires: Text generation capabilities 3. [Low] Summarize key findings - Reason: Quick reference for team - Requires: Summarization capabilities`, } e := standard.New() err := e.RunTasks(ctx, exec, nil) require.NoError(t, err) require.NotEmpty(t, exec.Tasks) // Log assigned executors executorCounts := make(map[string]int) for _, task := range exec.Tasks { executorCounts[task.ExecutorID]++ t.Logf("Task %s assigned to: %s (%s)", task.ID, task.ExecutorID, task.ExecutorType) } // Verify different executors were assigned (not all to same agent) t.Logf("Executor distribution: %v", executorCounts) }) } func TestRunTasksErrorHandling(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("returns error when robot is nil", func(t *testing.T) { exec := &types.Execution{ ID: "test-exec-1", TriggerType: types.TriggerClock, } // Don't set robot e := standard.New() err := e.RunTasks(ctx, exec, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "robot not found") }) t.Run("returns error when goals not available", func(t *testing.T) { robot := createTasksTestRobot(t, "robot.tasks") exec := createTasksTestExecution(robot, types.TriggerClock) exec.Goals = nil // No goals e := standard.New() err := e.RunTasks(ctx, exec, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "goals not available") }) t.Run("returns error when goals content is empty", func(t *testing.T) { robot := createTasksTestRobot(t, "robot.tasks") exec := createTasksTestExecution(robot, types.TriggerClock) exec.Goals = &types.Goals{Content: ""} // Empty content e := standard.New() err := e.RunTasks(ctx, exec, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "goals not available") }) t.Run("returns error when agent not found", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-robot-1", TeamID: "test-team-1", Config: &types.Config{ Identity: &types.Identity{Role: "Test"}, Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseTasks: "non.existent.agent", }, }, }, } exec := createTasksTestExecution(robot, types.TriggerClock) exec.Goals = &types.Goals{Content: "Test goals"} e := standard.New() err := e.RunTasks(ctx, exec, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "call failed") }) } // ============================================================================ // ParseTasks Unit Tests // ============================================================================ func TestParseTasks(t *testing.T) { t.Run("parses valid tasks array", func(t *testing.T) { data := []interface{}{ map[string]interface{}{ "id": "task-001", "goal_ref": "Goal 1", "executor_type": "agent", "executor_id": "experts.data-analyst", "messages": []interface{}{ map[string]interface{}{ "role": "user", "content": "Analyze sales data", }, }, "expected_output": "JSON with sales metrics", "validation_rules": []interface{}{ // Natural language rules (matched by validator) "output must be valid JSON", "must contain 'total_sales'", // Structured rule: check field type `{"type": "type", "path": "product_rankings", "value": "array"}`, }, "order": float64(0), }, map[string]interface{}{ "id": "task-002", "goal_ref": "Goal 1", "executor_type": "agent", "executor_id": "experts.text-writer", "description": "Generate report from analysis", "order": float64(1), }, } tasks, err := standard.ParseTasks(data) require.NoError(t, err) require.Len(t, tasks, 2) // First task assert.Equal(t, "task-001", tasks[0].ID) assert.Equal(t, "Goal 1", tasks[0].GoalRef) assert.Equal(t, types.ExecutorAssistant, tasks[0].ExecutorType) assert.Equal(t, "experts.data-analyst", tasks[0].ExecutorID) assert.Len(t, tasks[0].Messages, 1) assert.Equal(t, "JSON with sales metrics", tasks[0].ExpectedOutput) assert.Len(t, tasks[0].ValidationRules, 3) assert.Equal(t, 0, tasks[0].Order) // Second task assert.Equal(t, "task-002", tasks[1].ID) assert.Equal(t, "experts.text-writer", tasks[1].ExecutorID) assert.Equal(t, "Generate report from analysis", tasks[1].Description) // description saved to field assert.Len(t, tasks[1].Messages, 1) // description also converted to message assert.Equal(t, 1, tasks[1].Order) }) t.Run("generates ID if missing", func(t *testing.T) { data := []interface{}{ map[string]interface{}{ "executor_type": "agent", "executor_id": "experts.summarizer", "description": "Summarize content", }, } tasks, err := standard.ParseTasks(data) require.NoError(t, err) require.Len(t, tasks, 1) assert.Equal(t, "task-001", tasks[0].ID) }) t.Run("saves description to field and preserves explicit messages", func(t *testing.T) { data := []interface{}{ map[string]interface{}{ "id": "task-001", "executor_type": "agent", "executor_id": "experts.summarizer", "description": "Task description for UI", "messages": []interface{}{ map[string]interface{}{ "role": "user", "content": "Explicit message content", }, }, }, } tasks, err := standard.ParseTasks(data) require.NoError(t, err) require.Len(t, tasks, 1) // Description should be saved to field assert.Equal(t, "Task description for UI", tasks[0].Description) // Explicit messages should be preserved (not overwritten by description) assert.Len(t, tasks[0].Messages, 1) content, ok := tasks[0].Messages[0].GetContentAsString() assert.True(t, ok) assert.Equal(t, "Explicit message content", content) }) t.Run("converts description to message when no messages provided", func(t *testing.T) { data := []interface{}{ map[string]interface{}{ "id": "task-001", "executor_type": "agent", "executor_id": "experts.summarizer", "description": "Only description, no messages", }, } tasks, err := standard.ParseTasks(data) require.NoError(t, err) require.Len(t, tasks, 1) // Description should be saved to field assert.Equal(t, "Only description, no messages", tasks[0].Description) // Description should also be converted to message for execution assert.Len(t, tasks[0].Messages, 1) content, ok := tasks[0].Messages[0].GetContentAsString() assert.True(t, ok) assert.Equal(t, "Only description, no messages", content) }) t.Run("returns error for missing executor_type", func(t *testing.T) { data := []interface{}{ map[string]interface{}{ "id": "task-001", "executor_id": "experts.summarizer", "description": "Summarize content", }, } _, err := standard.ParseTasks(data) assert.Error(t, err) assert.Contains(t, err.Error(), "missing executor_type") }) t.Run("returns error for missing executor_id", func(t *testing.T) { data := []interface{}{ map[string]interface{}{ "id": "task-001", "executor_type": "agent", "description": "Summarize content", }, } _, err := standard.ParseTasks(data) assert.Error(t, err) assert.Contains(t, err.Error(), "missing executor_id") }) t.Run("handles different executor types", func(t *testing.T) { data := []interface{}{ map[string]interface{}{ "executor_type": "agent", "executor_id": "test-agent", "description": "Agent task", }, map[string]interface{}{ "executor_type": "assistant", "executor_id": "test-assistant", "description": "Assistant task", }, map[string]interface{}{ "executor_type": "mcp", "executor_id": "test-mcp", "description": "MCP task", }, map[string]interface{}{ "executor_type": "process", "executor_id": "test-process", "description": "Process task", }, } tasks, err := standard.ParseTasks(data) require.NoError(t, err) require.Len(t, tasks, 4) assert.Equal(t, types.ExecutorAssistant, tasks[0].ExecutorType) assert.Equal(t, types.ExecutorAssistant, tasks[1].ExecutorType) // assistant -> ExecutorAssistant assert.Equal(t, types.ExecutorMCP, tasks[2].ExecutorType) assert.Equal(t, types.ExecutorProcess, tasks[3].ExecutorType) }) } func TestValidateTasks(t *testing.T) { t.Run("validates valid tasks", func(t *testing.T) { tasks := []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.data-analyst", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Analyze data"}, }, }, { ID: "task-002", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.text-writer", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write report"}, }, }, } err := standard.ValidateTasks(tasks) assert.NoError(t, err) }) t.Run("returns error for empty tasks", func(t *testing.T) { tasks := []types.Task{} err := standard.ValidateTasks(tasks) assert.Error(t, err) assert.Contains(t, err.Error(), "no tasks generated") }) t.Run("returns error for duplicate IDs", func(t *testing.T) { tasks := []types.Task{ { ID: "task-001", ExecutorID: "agent-1", Messages: []agentcontext.Message{{Content: "test"}}, }, { ID: "task-001", // duplicate ExecutorID: "agent-2", Messages: []agentcontext.Message{{Content: "test"}}, }, } err := standard.ValidateTasks(tasks) assert.Error(t, err) assert.Contains(t, err.Error(), "duplicate task ID") }) t.Run("returns error for missing executor_id", func(t *testing.T) { tasks := []types.Task{ { ID: "task-001", ExecutorID: "", // missing Messages: []agentcontext.Message{{Content: "test"}}, }, } err := standard.ValidateTasks(tasks) assert.Error(t, err) assert.Contains(t, err.Error(), "missing executor_id") }) t.Run("returns error for missing messages", func(t *testing.T) { tasks := []types.Task{ { ID: "task-001", ExecutorID: "agent-1", Messages: []agentcontext.Message{}, // empty }, } err := standard.ValidateTasks(tasks) assert.Error(t, err) assert.Contains(t, err.Error(), "missing messages") }) } func TestParseExecutorType(t *testing.T) { t.Run("parses agent", func(t *testing.T) { assert.Equal(t, types.ExecutorAssistant, standard.ParseExecutorType("agent")) }) t.Run("parses assistant", func(t *testing.T) { assert.Equal(t, types.ExecutorAssistant, standard.ParseExecutorType("assistant")) }) t.Run("parses mcp", func(t *testing.T) { assert.Equal(t, types.ExecutorMCP, standard.ParseExecutorType("mcp")) }) t.Run("parses process", func(t *testing.T) { assert.Equal(t, types.ExecutorProcess, standard.ParseExecutorType("process")) }) t.Run("defaults to assistant for unknown", func(t *testing.T) { assert.Equal(t, types.ExecutorAssistant, standard.ParseExecutorType("unknown")) assert.Equal(t, types.ExecutorAssistant, standard.ParseExecutorType("")) }) } func TestIsValidExecutorType(t *testing.T) { t.Run("valid executor types", func(t *testing.T) { assert.True(t, standard.IsValidExecutorType(types.ExecutorAssistant)) assert.True(t, standard.IsValidExecutorType(types.ExecutorMCP)) assert.True(t, standard.IsValidExecutorType(types.ExecutorProcess)) }) t.Run("invalid executor types", func(t *testing.T) { assert.False(t, standard.IsValidExecutorType(types.ExecutorType("invalid"))) assert.False(t, standard.IsValidExecutorType(types.ExecutorType(""))) }) } func TestSortTasksByOrder(t *testing.T) { t.Run("sorts tasks by order", func(t *testing.T) { tasks := []types.Task{ {ID: "task-c", Order: 2}, {ID: "task-a", Order: 0}, {ID: "task-b", Order: 1}, } standard.SortTasksByOrder(tasks) assert.Equal(t, "task-a", tasks[0].ID) assert.Equal(t, "task-b", tasks[1].ID) assert.Equal(t, "task-c", tasks[2].ID) }) t.Run("handles already sorted tasks", func(t *testing.T) { tasks := []types.Task{ {ID: "task-a", Order: 0}, {ID: "task-b", Order: 1}, {ID: "task-c", Order: 2}, } standard.SortTasksByOrder(tasks) assert.Equal(t, "task-a", tasks[0].ID) assert.Equal(t, "task-b", tasks[1].ID) assert.Equal(t, "task-c", tasks[2].ID) }) t.Run("handles single task", func(t *testing.T) { tasks := []types.Task{ {ID: "task-a", Order: 0}, } standard.SortTasksByOrder(tasks) assert.Len(t, tasks, 1) assert.Equal(t, "task-a", tasks[0].ID) }) t.Run("handles empty tasks", func(t *testing.T) { tasks := []types.Task{} standard.SortTasksByOrder(tasks) assert.Empty(t, tasks) }) } func TestValidateExecutorExists(t *testing.T) { t.Run("returns true for existing agent", func(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Resources: &types.Resources{ Agents: []string{"experts.data-analyst", "experts.text-writer"}, }, }, } assert.True(t, standard.ValidateExecutorExists("experts.data-analyst", types.ExecutorAssistant, robot)) assert.True(t, standard.ValidateExecutorExists("experts.text-writer", types.ExecutorAssistant, robot)) }) t.Run("returns false for non-existing agent", func(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Resources: &types.Resources{ Agents: []string{"experts.data-analyst"}, }, }, } assert.False(t, standard.ValidateExecutorExists("experts.unknown", types.ExecutorAssistant, robot)) }) t.Run("returns true for existing MCP", func(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Resources: &types.Resources{ MCP: []types.MCPConfig{ {ID: "database"}, {ID: "email"}, }, }, }, } assert.True(t, standard.ValidateExecutorExists("database", types.ExecutorMCP, robot)) assert.True(t, standard.ValidateExecutorExists("email", types.ExecutorMCP, robot)) }) t.Run("returns false for non-existing MCP", func(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Resources: &types.Resources{ MCP: []types.MCPConfig{ {ID: "database"}, }, }, }, } assert.False(t, standard.ValidateExecutorExists("unknown", types.ExecutorMCP, robot)) }) t.Run("returns true for process (not validated)", func(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Resources: &types.Resources{}, }, } assert.True(t, standard.ValidateExecutorExists("models.user.Find", types.ExecutorProcess, robot)) }) t.Run("returns true when robot is nil", func(t *testing.T) { assert.True(t, standard.ValidateExecutorExists("any", types.ExecutorAssistant, nil)) }) t.Run("returns true when resources is nil", func(t *testing.T) { robot := &types.Robot{ Config: &types.Config{}, } assert.True(t, standard.ValidateExecutorExists("any", types.ExecutorAssistant, robot)) }) } func TestValidateTasksWithResources(t *testing.T) { t.Run("returns no warnings for valid tasks", func(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Resources: &types.Resources{ Agents: []string{"experts.data-analyst", "experts.text-writer"}, }, }, } tasks := []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.data-analyst", Messages: []agentcontext.Message{{Content: "test"}}, }, } warnings, err := standard.ValidateTasksWithResources(tasks, robot) assert.NoError(t, err) assert.Empty(t, warnings) }) t.Run("returns warnings for unknown executor", func(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Resources: &types.Resources{ Agents: []string{"experts.data-analyst"}, }, }, } tasks := []types.Task{ { ID: "task-001", ExecutorType: types.ExecutorAssistant, ExecutorID: "experts.unknown", Messages: []agentcontext.Message{{Content: "test"}}, }, } warnings, err := standard.ValidateTasksWithResources(tasks, robot) assert.NoError(t, err) assert.Len(t, warnings, 1) assert.Contains(t, warnings[0], "experts.unknown") assert.Contains(t, warnings[0], "not found") }) t.Run("returns error for invalid tasks", func(t *testing.T) { robot := &types.Robot{} tasks := []types.Task{} // empty _, err := standard.ValidateTasksWithResources(tasks, robot) assert.Error(t, err) assert.Contains(t, err.Error(), "no tasks generated") }) } // ============================================================================ // InputFormatter Tests for P2 // ============================================================================ func TestInputFormatterFormatGoalsForTasks(t *testing.T) { formatter := standard.NewInputFormatter() t.Run("formats goals with resources", func(t *testing.T) { goals := &types.Goals{ Content: "## Goals\n\n1. [High] Analyze data\n2. [Normal] Write report", } robot := &types.Robot{ MemberID: "test-robot", Config: &types.Config{ Resources: &types.Resources{ Agents: []string{"experts.data-analyst", "experts.text-writer"}, }, }, } content := formatter.FormatGoals(goals, robot) assert.Contains(t, content, "## Goals") assert.Contains(t, content, "[High] Analyze data") assert.Contains(t, content, "## Available Resources") assert.Contains(t, content, "experts.data-analyst") assert.Contains(t, content, "experts.text-writer") }) t.Run("formats goals without robot", func(t *testing.T) { goals := &types.Goals{ Content: "## Goals\n\n1. Test goal", } content := formatter.FormatGoals(goals, nil) assert.Contains(t, content, "## Goals") assert.Contains(t, content, "Test goal") assert.NotContains(t, content, "## Available Resources") }) t.Run("formats goals with delivery target", func(t *testing.T) { goals := &types.Goals{ Content: "## Goals\n\n1. Generate weekly report", Delivery: &types.DeliveryTarget{ Type: types.DeliveryEmail, Recipients: []string{"team@example.com", "manager@example.com"}, Format: "markdown", Template: "weekly-report", }, } robot := &types.Robot{ MemberID: "test-robot", Config: &types.Config{ Resources: &types.Resources{ Agents: []string{"experts.text-writer"}, }, }, } content := formatter.FormatGoals(goals, robot) assert.Contains(t, content, "## Goals") assert.Contains(t, content, "## Delivery Target") assert.Contains(t, content, "email") assert.Contains(t, content, "team@example.com") assert.Contains(t, content, "manager@example.com") assert.Contains(t, content, "markdown") assert.Contains(t, content, "weekly-report") assert.Contains(t, content, "Design tasks to produce output suitable") }) t.Run("formats goals without delivery target", func(t *testing.T) { goals := &types.Goals{ Content: "## Goals\n\n1. Test goal", Delivery: nil, } content := formatter.FormatGoals(goals, nil) assert.Contains(t, content, "## Goals") assert.NotContains(t, content, "## Delivery Target") }) } // ============================================================================ // Helper Functions // ============================================================================ // createTasksTestRobot creates a test robot with specified tasks agent // Includes available expert agents for task assignment func createTasksTestRobot(t *testing.T, agentID string) *types.Robot { t.Helper() return &types.Robot{ MemberID: "test-robot-1", TeamID: "test-team-1", DisplayName: "Test Robot", Config: &types.Config{ Identity: &types.Identity{ Role: "Test Assistant", Duties: []string{"Testing", "Data Analysis", "Report Generation"}, }, Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseTasks: agentID, }, // Available expert agents that can be assigned to tasks Agents: []string{ "experts.data-analyst", "experts.summarizer", "experts.text-writer", "experts.web-reader", }, }, }, } } // createTasksTestExecution creates a test execution for tasks phase func createTasksTestExecution(robot *types.Robot, trigger types.TriggerType) *types.Execution { exec := &types.Execution{ ID: "test-exec-tasks-1", MemberID: robot.MemberID, TeamID: robot.TeamID, TriggerType: trigger, StartTime: time.Now(), Status: types.ExecRunning, Phase: types.PhaseTasks, } exec.SetRobot(robot) return exec } // Note: testAuth is defined in goals_test.go in the same package ================================================ FILE: agent/robot/executor/standard/ui_fields_test.go ================================================ package standard import ( "testing" "github.com/stretchr/testify/assert" agentcontext "github.com/yaoapp/yao/agent/context" robottypes "github.com/yaoapp/yao/agent/robot/types" ) // ============================================================================ // getEffectiveLocale Tests // ============================================================================ func TestGetEffectiveLocale(t *testing.T) { t.Run("returns_input_locale_when_provided", func(t *testing.T) { robot := &robottypes.Robot{ Config: &robottypes.Config{ DefaultLocale: "en", }, } input := &robottypes.TriggerInput{ Locale: "zh", } locale := getEffectiveLocale(robot, input) assert.Equal(t, "zh", locale) }) t.Run("returns_robot_default_locale_when_input_locale_empty", func(t *testing.T) { robot := &robottypes.Robot{ Config: &robottypes.Config{ DefaultLocale: "zh", }, } input := &robottypes.TriggerInput{ Locale: "", } locale := getEffectiveLocale(robot, input) assert.Equal(t, "zh", locale) }) t.Run("returns_system_default_when_no_locale_configured", func(t *testing.T) { robot := &robottypes.Robot{ Config: &robottypes.Config{}, } input := &robottypes.TriggerInput{} locale := getEffectiveLocale(robot, input) assert.Equal(t, "en", locale) }) t.Run("returns_system_default_when_robot_config_nil", func(t *testing.T) { robot := &robottypes.Robot{} input := &robottypes.TriggerInput{} locale := getEffectiveLocale(robot, input) assert.Equal(t, "en", locale) }) t.Run("returns_system_default_when_robot_nil", func(t *testing.T) { input := &robottypes.TriggerInput{} locale := getEffectiveLocale(nil, input) assert.Equal(t, "en", locale) }) t.Run("returns_system_default_when_input_nil", func(t *testing.T) { robot := &robottypes.Robot{} locale := getEffectiveLocale(robot, nil) assert.Equal(t, "en", locale) }) } // ============================================================================ // getLocalizedMessage Tests // ============================================================================ func TestGetLocalizedMessage(t *testing.T) { t.Run("returns_english_message_for_en_locale", func(t *testing.T) { msg := getLocalizedMessage("en", "preparing") assert.Equal(t, "Preparing...", msg) }) t.Run("returns_chinese_message_for_zh_locale", func(t *testing.T) { msg := getLocalizedMessage("zh", "preparing") assert.Equal(t, "准备中...", msg) }) t.Run("returns_english_fallback_for_unknown_locale", func(t *testing.T) { msg := getLocalizedMessage("fr", "preparing") assert.Equal(t, "Preparing...", msg) }) t.Run("returns_key_for_unknown_message", func(t *testing.T) { msg := getLocalizedMessage("en", "unknown_key") assert.Equal(t, "unknown_key", msg) }) t.Run("all_english_messages_exist", func(t *testing.T) { keys := []string{ "preparing", "starting", "scheduled_execution", "event_prefix", "event_triggered", "analyzing_context", "planning_goals", "breaking_down_tasks", "generating_delivery", "sending_delivery", "learning_from_exec", "completed", "failed_prefix", "task_prefix", // Phase names for failure messages "phase_inspiration", "phase_goals", "phase_tasks", "phase_run", "phase_delivery", "phase_learning", } for _, key := range keys { msg := getLocalizedMessage("en", key) assert.NotEqual(t, key, msg, "English message should exist for key: %s", key) } }) t.Run("all_chinese_messages_exist", func(t *testing.T) { keys := []string{ "preparing", "starting", "scheduled_execution", "event_prefix", "event_triggered", "analyzing_context", "planning_goals", "breaking_down_tasks", "generating_delivery", "sending_delivery", "learning_from_exec", "completed", "failed_prefix", "task_prefix", // Phase names for failure messages "phase_inspiration", "phase_goals", "phase_tasks", "phase_run", "phase_delivery", "phase_learning", } for _, key := range keys { msg := getLocalizedMessage("zh", key) assert.NotEqual(t, key, msg, "Chinese message should exist for key: %s", key) } }) t.Run("failure_message_is_concise", func(t *testing.T) { // Test that failure messages use phase names, not full error text enFailure := getLocalizedMessage("en", "failed_prefix") + getLocalizedMessage("en", "phase_inspiration") assert.Equal(t, "Failed at inspiration", enFailure) zhFailure := getLocalizedMessage("zh", "failed_prefix") + getLocalizedMessage("zh", "phase_inspiration") assert.Equal(t, "失败于灵感阶段", zhFailure) }) } // ============================================================================ // initUIFields Tests // ============================================================================ func TestInitUIFields(t *testing.T) { executor := New() t.Run("human_trigger_extracts_name_from_message", func(t *testing.T) { robot := &robottypes.Robot{ Config: &robottypes.Config{DefaultLocale: "en"}, } input := &robottypes.TriggerInput{ Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Please analyze the sales data"}, }, } name, currentTaskName := executor.initUIFields(robottypes.TriggerHuman, input, robot) assert.Equal(t, "Please analyze the sales data", name) assert.Equal(t, "Starting...", currentTaskName) }) t.Run("human_trigger_truncates_long_message", func(t *testing.T) { robot := &robottypes.Robot{ Config: &robottypes.Config{DefaultLocale: "en"}, } longMessage := "This is a very long message that exceeds one hundred characters and should be truncated with an ellipsis at the end" input := &robottypes.TriggerInput{ Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: longMessage}, }, } name, _ := executor.initUIFields(robottypes.TriggerHuman, input, robot) assert.LessOrEqual(t, len(name), 103) // 100 chars + "..." assert.True(t, len(name) > 100 || name == longMessage[:100]+"...") }) t.Run("clock_trigger_uses_scheduled_execution", func(t *testing.T) { robot := &robottypes.Robot{ Config: &robottypes.Config{DefaultLocale: "en"}, } input := &robottypes.TriggerInput{} name, currentTaskName := executor.initUIFields(robottypes.TriggerClock, input, robot) assert.Equal(t, "Scheduled execution", name) assert.Equal(t, "Starting...", currentTaskName) }) t.Run("clock_trigger_chinese_locale", func(t *testing.T) { robot := &robottypes.Robot{ Config: &robottypes.Config{DefaultLocale: "zh"}, } input := &robottypes.TriggerInput{} name, currentTaskName := executor.initUIFields(robottypes.TriggerClock, input, robot) assert.Equal(t, "定时执行", name) assert.Equal(t, "启动中...", currentTaskName) }) t.Run("event_trigger_with_event_type", func(t *testing.T) { robot := &robottypes.Robot{ Config: &robottypes.Config{DefaultLocale: "en"}, } input := &robottypes.TriggerInput{ EventType: "lead.created", } name, currentTaskName := executor.initUIFields(robottypes.TriggerEvent, input, robot) assert.Equal(t, "Event: lead.created", name) assert.Equal(t, "Starting...", currentTaskName) }) t.Run("event_trigger_without_event_type", func(t *testing.T) { robot := &robottypes.Robot{ Config: &robottypes.Config{DefaultLocale: "en"}, } input := &robottypes.TriggerInput{} name, _ := executor.initUIFields(robottypes.TriggerEvent, input, robot) assert.Equal(t, "Event triggered", name) }) t.Run("event_trigger_chinese_locale", func(t *testing.T) { robot := &robottypes.Robot{ Config: &robottypes.Config{DefaultLocale: "zh"}, } input := &robottypes.TriggerInput{ EventType: "order.placed", } name, _ := executor.initUIFields(robottypes.TriggerEvent, input, robot) assert.Equal(t, "事件: order.placed", name) }) t.Run("input_locale_overrides_robot_default", func(t *testing.T) { robot := &robottypes.Robot{ Config: &robottypes.Config{DefaultLocale: "en"}, } input := &robottypes.TriggerInput{ Locale: "zh", } name, currentTaskName := executor.initUIFields(robottypes.TriggerClock, input, robot) assert.Equal(t, "定时执行", name) assert.Equal(t, "启动中...", currentTaskName) }) } // ============================================================================ // extractGoalName Tests // ============================================================================ func TestExtractGoalName(t *testing.T) { t.Run("extracts_first_line_from_content", func(t *testing.T) { goals := &robottypes.Goals{ Content: "Generate monthly sales report\nAnalyze trends\nSend to stakeholders", } name := extractGoalName(goals) assert.Equal(t, "Generate monthly sales report", name) }) t.Run("returns_empty_for_nil_goals", func(t *testing.T) { name := extractGoalName(nil) assert.Equal(t, "", name) }) t.Run("returns_empty_for_empty_content", func(t *testing.T) { goals := &robottypes.Goals{ Content: "", } name := extractGoalName(goals) assert.Equal(t, "", name) }) t.Run("truncates_long_first_line", func(t *testing.T) { longLine := "This is an extremely long goal description that exceeds one hundred and fifty characters and should be truncated with an ellipsis at the end to keep the display manageable" goals := &robottypes.Goals{ Content: longLine, } name := extractGoalName(goals) assert.LessOrEqual(t, len(name), 153) // 150 chars + "..." }) t.Run("handles_single_line_content", func(t *testing.T) { goals := &robottypes.Goals{ Content: "Single line goal", } name := extractGoalName(goals) assert.Equal(t, "Single line goal", name) }) t.Run("handles_carriage_return", func(t *testing.T) { goals := &robottypes.Goals{ Content: "First goal\r\nSecond goal", } name := extractGoalName(goals) assert.Equal(t, "First goal", name) }) t.Run("skips_markdown_h1_header", func(t *testing.T) { goals := &robottypes.Goals{ Content: "# Goals\nSystem optimization and monitoring", } name := extractGoalName(goals) assert.Equal(t, "System optimization and monitoring", name) }) t.Run("skips_markdown_h2_header", func(t *testing.T) { goals := &robottypes.Goals{ Content: "## Goals\n\nPerform system maintenance tasks", } name := extractGoalName(goals) assert.Equal(t, "Perform system maintenance tasks", name) }) t.Run("skips_multiple_markdown_headers", func(t *testing.T) { goals := &robottypes.Goals{ Content: "## Goals\n### 1. [High] First Goal\nActual description here", } name := extractGoalName(goals) assert.Equal(t, "Actual description here", name) }) t.Run("strips_bold_formatting", func(t *testing.T) { goals := &robottypes.Goals{ Content: "**Important** task to complete", } name := extractGoalName(goals) assert.Equal(t, "Important task to complete", name) }) t.Run("strips_italic_formatting", func(t *testing.T) { goals := &robottypes.Goals{ Content: "*Urgent* system update needed", } name := extractGoalName(goals) assert.Equal(t, "Urgent system update needed", name) }) t.Run("strips_inline_code", func(t *testing.T) { goals := &robottypes.Goals{ Content: "Run `npm install` command", } name := extractGoalName(goals) assert.Equal(t, "Run npm install command", name) }) t.Run("skips_empty_lines", func(t *testing.T) { goals := &robottypes.Goals{ Content: "\n\n\nFirst real content\nSecond line", } name := extractGoalName(goals) assert.Equal(t, "First real content", name) }) t.Run("fallback_to_header_content_if_only_headers", func(t *testing.T) { goals := &robottypes.Goals{ Content: "## Goals\n### Tasks", } name := extractGoalName(goals) assert.Equal(t, "Goals", name) }) t.Run("skips_horizontal_rules", func(t *testing.T) { goals := &robottypes.Goals{ Content: "---\nActual content here", } name := extractGoalName(goals) assert.Equal(t, "Actual content here", name) }) t.Run("handles_complex_markdown_content", func(t *testing.T) { goals := &robottypes.Goals{ Content: "## Goals\n\n### 1. [High] System Maintenance\n**Description**: Perform system optimization based on diagnostic results\n**Reason**: Time-sensitive maintenance", } name := extractGoalName(goals) assert.Equal(t, "Description: Perform system optimization based on diagnostic results", name) }) } // ============================================================================ // stripMarkdownFormatting Tests // ============================================================================ func TestStripMarkdownFormatting(t *testing.T) { t.Run("strips_bold", func(t *testing.T) { result := stripMarkdownFormatting("**bold text**") assert.Equal(t, "bold text", result) }) t.Run("strips_italic_asterisk", func(t *testing.T) { result := stripMarkdownFormatting("*italic text*") assert.Equal(t, "italic text", result) }) t.Run("strips_italic_underscore", func(t *testing.T) { result := stripMarkdownFormatting("_italic text_") assert.Equal(t, "italic text", result) }) t.Run("strips_inline_code", func(t *testing.T) { result := stripMarkdownFormatting("`code`") assert.Equal(t, "code", result) }) t.Run("strips_link_syntax", func(t *testing.T) { result := stripMarkdownFormatting("[link text](https://example.com)") assert.Equal(t, "link text", result) }) t.Run("preserves_plain_text", func(t *testing.T) { result := stripMarkdownFormatting("plain text without formatting") assert.Equal(t, "plain text without formatting", result) }) t.Run("handles_mixed_formatting", func(t *testing.T) { result := stripMarkdownFormatting("**bold** and *italic* and `code`") assert.Equal(t, "bold and italic and code", result) }) } // ============================================================================ // formatTaskProgressName Tests // ============================================================================ func TestFormatTaskProgressName(t *testing.T) { t.Run("prioritizes_description_field_over_messages", func(t *testing.T) { task := &robottypes.Task{ ID: "task-001", Description: "High-level task description for UI", ExecutorType: robottypes.ExecutorAssistant, ExecutorID: "analyst", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Detailed message content for execution"}, }, } name := formatTaskProgressName(task, 0, 3, "en") // Should use Description field, NOT the message content assert.Equal(t, "Task 1/3: High-level task description for UI", name) }) t.Run("falls_back_to_message_when_no_description", func(t *testing.T) { task := &robottypes.Task{ ID: "task-001", Description: "", // Empty description ExecutorType: robottypes.ExecutorAssistant, ExecutorID: "analyst", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Analyze sales data"}, }, } name := formatTaskProgressName(task, 0, 3, "en") assert.Equal(t, "Task 1/3: Analyze sales data", name) }) t.Run("formats_with_chinese_locale", func(t *testing.T) { task := &robottypes.Task{ ID: "task-001", ExecutorType: robottypes.ExecutorAssistant, ExecutorID: "analyst", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "分析销售数据"}, }, } name := formatTaskProgressName(task, 1, 5, "zh") assert.Equal(t, "任务 2/5: 分析销售数据", name) }) t.Run("truncates_long_description_field", func(t *testing.T) { longDesc := "This is a very long task description that should be truncated because it exceeds 80 characters which is the maximum length allowed" task := &robottypes.Task{ ID: "task-001", Description: longDesc, ExecutorType: robottypes.ExecutorAssistant, ExecutorID: "analyst", Messages: []agentcontext.Message{}, } name := formatTaskProgressName(task, 0, 1, "en") // Should be "Task 1/1: " (11 chars) + truncated content (83 chars max with "...") assert.Contains(t, name, "...") assert.LessOrEqual(t, len(name), 100) }) t.Run("truncates_long_message_content", func(t *testing.T) { longContent := "This is a very long message content that should be truncated because it exceeds 80 characters which is the maximum length allowed" task := &robottypes.Task{ ID: "task-001", Description: "", // No description, will use message ExecutorType: robottypes.ExecutorAssistant, ExecutorID: "analyst", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: longContent}, }, } name := formatTaskProgressName(task, 0, 1, "en") // Should be "Task 1/1: " (11 chars) + truncated content (83 chars max with "...") assert.Contains(t, name, "...") assert.LessOrEqual(t, len(name), 100) }) t.Run("fallback_to_executor_info_when_no_messages", func(t *testing.T) { task := &robottypes.Task{ ID: "task-001", ExecutorType: robottypes.ExecutorMCP, ExecutorID: "calculator", Messages: []agentcontext.Message{}, } name := formatTaskProgressName(task, 2, 4, "en") assert.Equal(t, "Task 3/4: mcp:calculator", name) }) } ================================================ FILE: agent/robot/executor/standard/validator.go ================================================ package standard import ( "encoding/json" "fmt" "strings" "github.com/yaoapp/gou/process" robottypes "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/assert" ) // ValidatorConfig configures validation behavior (decoupled from RunConfig) type ValidatorConfig struct { // ValidationThreshold is the minimum score to pass validation (default: 0.6) ValidationThreshold float64 } // DefaultValidatorConfig returns the default validator configuration func DefaultValidatorConfig() *ValidatorConfig { return &ValidatorConfig{ ValidationThreshold: 0.6, } } // Validator handles task result validation using a two-layer approach: // 1. Rule-based validation: Uses yao/assert for deterministic rules (type, contains, regex, json_path) // 2. Semantic validation: Calls Validation Agent for semantic understanding (ExpectedOutput) type Validator struct { ctx *robottypes.Context robot *robottypes.Robot config *ValidatorConfig asserter *assert.Asserter } // NewValidator creates a new task validator func NewValidator(ctx *robottypes.Context, robot *robottypes.Robot, config *ValidatorConfig) *Validator { if config == nil { config = DefaultValidatorConfig() } v := &Validator{ ctx: ctx, robot: robot, config: config, asserter: assert.New(), } // Configure asserter with robot-specific implementations v.asserter.WithAgentValidator(&robotAgentValidator{v: v}) v.asserter.WithScriptRunner(&robotScriptRunner{ctx: ctx}) return v } // Validate validates task output using two-layer validation (without multi-turn context) // Equivalent to ValidateWithContext(task, output, nil) // Use ValidateWithContext when you have a CallResult for better multi-turn support func (v *Validator) Validate(task *robottypes.Task, output interface{}) *robottypes.ValidationResult { return v.ValidateWithContext(task, output, nil) } // ValidateWithContext validates task output and determines execution state for multi-turn conversation. // It extends basic validation with: // - Complete: whether expected result is obtained // - NeedReply: whether to continue conversation // - ReplyContent: content for next turn // // Parameters: // - task: the task being executed // - output: the output from assistant/mcp/process // - callResult: the full call result (for detecting assistant's need for more info) func (v *Validator) ValidateWithContext(task *robottypes.Task, output interface{}, callResult *CallResult) *robottypes.ValidationResult { // If no validation rules and no expected output, return passed and complete if task.ExpectedOutput == "" && len(task.ValidationRules) == 0 { return &robottypes.ValidationResult{ Passed: true, Score: 1.0, Complete: v.hasValidOutput(output), } } result := &robottypes.ValidationResult{ Passed: true, Score: 1.0, } // Layer 1: Rule-based validation (using yao/assert) if len(task.ValidationRules) > 0 { ruleResult := v.validateRules(task.ValidationRules, output) if !ruleResult.Passed { // Rule validation failed - check if we should retry with feedback ruleResult.Complete = false ruleResult.NeedReply, ruleResult.ReplyContent = v.checkNeedReplyOnFailure(task, ruleResult) return ruleResult } // Merge rule validation results result.Issues = append(result.Issues, ruleResult.Issues...) result.Suggestions = append(result.Suggestions, ruleResult.Suggestions...) } // Layer 2: Semantic validation (using Validation Agent) // Only run if ExpectedOutput is set or there are agent-type rules if task.ExpectedOutput != "" || v.hasAgentRules(task.ValidationRules) { semanticResult := v.validateSemantic(task, output) result = v.mergeResults(result, semanticResult) } // Determine execution state result.Complete = v.isComplete(task, output, result) result.NeedReply, result.ReplyContent = v.checkNeedReply(task, output, callResult, result) return result } // hasValidOutput checks if output is non-empty and valid func (v *Validator) hasValidOutput(output interface{}) bool { if output == nil { return false } switch o := output.(type) { case string: return strings.TrimSpace(o) != "" case []interface{}: return len(o) > 0 case map[string]interface{}: return len(o) > 0 default: return true } } // isComplete determines if the expected result has been obtained func (v *Validator) isComplete(task *robottypes.Task, output interface{}, result *robottypes.ValidationResult) bool { // If validation failed, not complete if !result.Passed { return false } // Must have valid output if !v.hasValidOutput(output) { return false } // If score is below threshold, consider incomplete if result.Score < v.config.ValidationThreshold { return false } return true } // checkNeedReply determines if conversation should continue and generates reply content func (v *Validator) checkNeedReply(task *robottypes.Task, output interface{}, callResult *CallResult, result *robottypes.ValidationResult) (bool, string) { // If already complete, no need to reply if result.Complete { return false, "" } // Scenario 1: Assistant explicitly asks for more information if callResult != nil { text := callResult.GetText() if v.detectNeedMoreInfo(text) { return true, v.generateClarificationReply(task, text) } } // Scenario 2: Validation passed but output is incomplete/empty if result.Passed && !v.hasValidOutput(output) { return true, "Please continue and provide the complete result as specified in the task." } // Scenario 3: Validation failed with suggestions - can retry with feedback if !result.Passed && len(result.Suggestions) > 0 { return true, v.generateFeedbackReply(result) } // Scenario 4: Low confidence score - ask for improvement if result.Passed && result.Score < v.config.ValidationThreshold { return true, fmt.Sprintf("The result is partially correct (score: %.2f), but needs improvement. Please refine your response to better match the expected output: %s", result.Score, task.ExpectedOutput) } // No need to continue return false, "" } // checkNeedReplyOnFailure handles the case when rule validation fails func (v *Validator) checkNeedReplyOnFailure(task *robottypes.Task, result *robottypes.ValidationResult) (bool, string) { // If there are suggestions, we can try to fix if len(result.Suggestions) > 0 { return true, v.generateFeedbackReply(result) } // If there are issues, provide feedback if len(result.Issues) > 0 { var sb strings.Builder sb.WriteString("Your response did not pass validation. Please fix the following issues:\n\n") for _, issue := range result.Issues { sb.WriteString(fmt.Sprintf("- %s\n", issue)) } sb.WriteString(fmt.Sprintf("\nExpected output: %s", task.ExpectedOutput)) return true, sb.String() } return false, "" } // detectNeedMoreInfo checks if assistant's response indicates need for more information func (v *Validator) detectNeedMoreInfo(text string) bool { if text == "" { return false } textLower := strings.ToLower(text) keywords := []string{ "need more information", "please clarify", "could you provide", "can you specify", "what is the", "which one", "please provide", "i need to know", "could you tell me", "what do you mean", } for _, kw := range keywords { if strings.Contains(textLower, kw) { return true } } // Check for question marks at the end (likely asking for clarification) // Note: We require 2+ question marks to avoid false positives from rhetorical questions // or questions that are part of the output (e.g., "How can I help you?") // Single questions are often just conversational and don't need clarification trimmed := strings.TrimSpace(text) if strings.HasSuffix(trimmed, "?") { if strings.Count(text, "?") >= 2 { return true } } return false } // generateClarificationReply generates a reply when assistant asks for clarification func (v *Validator) generateClarificationReply(task *robottypes.Task, assistantText string) string { var sb strings.Builder sb.WriteString("Please proceed with the task based on the available information.\n\n") if task.ExpectedOutput != "" { sb.WriteString(fmt.Sprintf("**Expected Output**: %s\n\n", task.ExpectedOutput)) } sb.WriteString("If you need to make assumptions, please state them clearly and proceed with the most reasonable interpretation.") return sb.String() } // generateFeedbackReply generates a reply with validation feedback func (v *Validator) generateFeedbackReply(result *robottypes.ValidationResult) string { var sb strings.Builder sb.WriteString("## Validation Feedback\n\n") sb.WriteString("Your previous response needs improvement. Please address the following:\n\n") if len(result.Issues) > 0 { sb.WriteString("### Issues\n") for _, issue := range result.Issues { sb.WriteString(fmt.Sprintf("- %s\n", issue)) } sb.WriteString("\n") } if len(result.Suggestions) > 0 { sb.WriteString("### Suggestions\n") for _, suggestion := range result.Suggestions { sb.WriteString(fmt.Sprintf("- %s\n", suggestion)) } sb.WriteString("\n") } sb.WriteString("Please provide an improved response that addresses these points.") return sb.String() } // validateRules validates output against rule-based assertions func (v *Validator) validateRules(rules []string, output interface{}) *robottypes.ValidationResult { result := &robottypes.ValidationResult{ Passed: true, Score: 1.0, } // Parse rules into assertions assertions := v.parseRules(rules) if len(assertions) == 0 { return result } // Run assertions passed, message := v.asserter.Validate(assertions, output) if !passed { result.Passed = false result.Score = 0 result.Issues = append(result.Issues, message) } return result } // parseRules converts validation rules (strings or JSON) to assertions // Supports: // - Simple string rules: "output must be valid JSON" (converted to type check) // - JSON assertion objects: {"type": "contains", "value": "success"} func (v *Validator) parseRules(rules []string) []*assert.Assertion { var assertions []*assert.Assertion for _, rule := range rules { // Try to parse as JSON assertion if strings.HasPrefix(rule, "{") { var assertionMap map[string]interface{} if err := json.Unmarshal([]byte(rule), &assertionMap); err == nil { parsed := assert.ParseAssertions(assertionMap) assertions = append(assertions, parsed...) continue } } // Convert common string rules to assertions assertion := v.convertStringRule(rule) if assertion != nil { assertions = append(assertions, assertion) } } return assertions } // convertStringRule converts a human-readable rule string to an assertion // Examples: // - "output must be valid JSON" -> {"type": "type", "value": "object"} // - "must contain 'success'" -> {"type": "contains", "value": "success"} // - "count > 0" -> (passed to semantic validation) func (v *Validator) convertStringRule(rule string) *assert.Assertion { ruleLower := strings.ToLower(rule) // JSON type check if strings.Contains(ruleLower, "valid json") || strings.Contains(ruleLower, "json object") { return &assert.Assertion{ Type: "type", Value: "object", Message: rule, } } // Array type check if strings.Contains(ruleLower, "json array") || strings.Contains(ruleLower, "must be array") { return &assert.Assertion{ Type: "type", Value: "array", Message: rule, } } // Contains check if strings.Contains(ruleLower, "contain") { // Extract the value in quotes if start := strings.Index(rule, "'"); start != -1 { if end := strings.Index(rule[start+1:], "'"); end != -1 { value := rule[start+1 : start+1+end] return &assert.Assertion{ Type: "contains", Value: value, Message: rule, } } } if start := strings.Index(rule, "\""); start != -1 { if end := strings.Index(rule[start+1:], "\""); end != -1 { value := rule[start+1 : start+1+end] return &assert.Assertion{ Type: "contains", Value: value, Message: rule, } } } } // Not empty check - use regex to match at least one character if strings.Contains(ruleLower, "not empty") || strings.Contains(ruleLower, "non-empty") { return &assert.Assertion{ Type: "regex", Value: ".+", Message: rule, } } // For other rules, return nil (will be handled by semantic validation) return nil } // hasAgentRules checks if any rule requires agent-based validation func (v *Validator) hasAgentRules(rules []string) bool { for _, rule := range rules { if strings.HasPrefix(rule, "{") { var assertionMap map[string]interface{} if err := json.Unmarshal([]byte(rule), &assertionMap); err == nil { if assertionMap["type"] == "agent" { return true } } } } return false } // validateSemantic performs semantic validation using the Validation Agent func (v *Validator) validateSemantic(task *robottypes.Task, output interface{}) *robottypes.ValidationResult { // Get validation agent ID validationAgentID := "__yao.validation" // default if v.robot.Config != nil && v.robot.Config.Resources != nil { if customID, ok := v.robot.Config.Resources.Phases["validation"]; ok && customID != "" { validationAgentID = customID } } // Build validation prompt validationPrompt := v.BuildSemanticPrompt(task, output) // Call validation agent caller := NewAgentCaller() caller.Connector = v.robot.LanguageModel result, err := caller.CallWithMessages(v.ctx, validationAgentID, validationPrompt) if err != nil { return &robottypes.ValidationResult{ Passed: false, Score: 0, Issues: []string{fmt.Sprintf("Validation agent error: %s", err.Error())}, } } return v.ParseAgentResult(result) } // BuildSemanticPrompt builds the prompt for semantic validation // Format matches the Validation Agent's expected input structure: // 1. Task: task definition with expected_output and validation_rules // 2. Result: actual output from task execution // 3. Success Criteria: overall criteria (optional) func (v *Validator) BuildSemanticPrompt(task *robottypes.Task, output interface{}) string { var sb strings.Builder // Section 1: Task (matches Agent's expected "Task" input) sb.WriteString("## Task\n\n") sb.WriteString(fmt.Sprintf("**Task ID**: %s\n", task.ID)) sb.WriteString(fmt.Sprintf("**Executor**: %s (%s)\n\n", task.ExecutorID, task.ExecutorType)) // Task description (instructions) if len(task.Messages) > 0 { sb.WriteString("**Instructions**:\n") for _, msg := range task.Messages { if content, ok := msg.Content.(string); ok { sb.WriteString(content + "\n") } } sb.WriteString("\n") } // Expected output (primary criterion for semantic validation) if task.ExpectedOutput != "" { sb.WriteString(fmt.Sprintf("**expected_output**: %s\n\n", task.ExpectedOutput)) } // Validation rules semanticRules := v.getSemanticRules(task.ValidationRules) if len(semanticRules) > 0 { sb.WriteString("**validation_rules**:\n") for _, rule := range semanticRules { sb.WriteString(fmt.Sprintf("- %s\n", rule)) } sb.WriteString("\n") } // Section 2: Result (matches Agent's expected "Result" input) sb.WriteString("## Result\n\n") if output != nil { outputJSON, err := json.MarshalIndent(output, "", " ") if err == nil { sb.WriteString(fmt.Sprintf("```json\n%s\n```\n", string(outputJSON))) } else { sb.WriteString(fmt.Sprintf("%v\n", output)) } } else { sb.WriteString("(no output)\n") } // Section 3: Success Criteria (optional, from goals if available) // Note: This could be extended to include criteria from exec.Goals if needed sb.WriteString("\n## Success Criteria\n\n") if task.ExpectedOutput != "" { sb.WriteString(fmt.Sprintf("The task should produce: %s\n", task.ExpectedOutput)) } else { sb.WriteString("Complete the task successfully with valid output.\n") } return sb.String() } // getSemanticRules returns rules that need semantic validation (not convertible to assertions) func (v *Validator) getSemanticRules(rules []string) []string { var semanticRules []string for _, rule := range rules { // Skip JSON assertions (already handled) if strings.HasPrefix(rule, "{") { continue } // Skip rules that were converted to assertions if v.convertStringRule(rule) == nil { semanticRules = append(semanticRules, rule) } } return semanticRules } // ParseAgentResult parses the validation agent's response func (v *Validator) ParseAgentResult(result *CallResult) *robottypes.ValidationResult { validation := &robottypes.ValidationResult{ Passed: false, Score: 0, } // Try to parse as JSON data, err := result.GetJSON() if err != nil { // If not JSON, try to interpret the text response text := result.GetText() if text != "" { validation.Details = text // Simple heuristic: check for positive keywords textLower := strings.ToLower(text) positiveKeywords := []string{"passed", "valid", "correct", "success"} for _, keyword := range positiveKeywords { if strings.Contains(textLower, keyword) { validation.Passed = true validation.Score = 0.8 break } } } return validation } // Parse JSON fields if passed, ok := data["passed"].(bool); ok { validation.Passed = passed } if score, ok := data["score"].(float64); ok { validation.Score = score } if issues, ok := data["issues"].([]interface{}); ok { for _, issue := range issues { if s, ok := issue.(string); ok { validation.Issues = append(validation.Issues, s) } } } if suggestions, ok := data["suggestions"].([]interface{}); ok { for _, suggestion := range suggestions { if s, ok := suggestion.(string); ok { validation.Suggestions = append(validation.Suggestions, s) } } } if details, ok := data["details"].(string); ok { validation.Details = details } return validation } // mergeResults merges rule-based and semantic validation results func (v *Validator) mergeResults(ruleResult, semanticResult *robottypes.ValidationResult) *robottypes.ValidationResult { // If either failed, the overall result is failed if !ruleResult.Passed || !semanticResult.Passed { return &robottypes.ValidationResult{ Passed: false, Score: min(ruleResult.Score, semanticResult.Score), Issues: append(ruleResult.Issues, semanticResult.Issues...), Suggestions: append(ruleResult.Suggestions, semanticResult.Suggestions...), Details: semanticResult.Details, } } // Both passed return &robottypes.ValidationResult{ Passed: true, Score: (ruleResult.Score + semanticResult.Score) / 2, Issues: append(ruleResult.Issues, semanticResult.Issues...), Suggestions: append(ruleResult.Suggestions, semanticResult.Suggestions...), Details: semanticResult.Details, } } // ============================================================================ // Robot-specific implementations of assert interfaces // ============================================================================ // robotAgentValidator implements assert.AgentValidator for robot package type robotAgentValidator struct { v *Validator } // Validate validates output using an agent func (av *robotAgentValidator) Validate(agentID string, output, input, criteria interface{}, options *assert.AssertionOptions) *assert.Result { result := &assert.Result{} // Build validation request validationInput := map[string]interface{}{ "output": output, "input": input, } if criteria != nil { validationInput["criteria"] = criteria } inputJSON, err := json.Marshal(validationInput) if err != nil { result.Passed = false result.Message = fmt.Sprintf("failed to marshal validation input: %s", err.Error()) return result } // Call agent caller := NewAgentCaller() caller.Connector = av.v.robot.LanguageModel callResult, err := caller.CallWithMessages(av.v.ctx, agentID, string(inputJSON)) if err != nil { result.Passed = false result.Message = fmt.Sprintf("agent validation error: %s", err.Error()) return result } // Parse response data, err := callResult.GetJSON() if err != nil { result.Passed = false result.Message = "agent returned invalid response format" return result } if passed, ok := data["passed"].(bool); ok { result.Passed = passed } if reason, ok := data["reason"].(string); ok { result.Message = reason } result.Expected = data return result } // robotScriptRunner implements assert.ScriptRunner for robot package type robotScriptRunner struct { ctx *robottypes.Context } // Run runs an assertion script using Yao process func (r *robotScriptRunner) Run(scriptName string, output, input, expected interface{}) (bool, string, error) { // Build script arguments args := []interface{}{output, input, expected} // Create and run the process proc, err := process.Of(scriptName, args...) if err != nil { return false, "", fmt.Errorf("failed to create process: %w", err) } // Set context for timeout and cancellation support if r.ctx != nil { proc.Context = r.ctx.Context } if err := proc.Execute(); err != nil { return false, "", fmt.Errorf("script execution failed: %w", err) } defer proc.Release() // Parse result - expected format: bool or { "pass": bool, "message": string } res := proc.Value() switch v := res.(type) { case bool: if v { return true, "script assertion passed", nil } return false, "script assertion failed", nil case map[string]interface{}: passed := false message := "" if pass, ok := v["pass"].(bool); ok { passed = pass } if msg, ok := v["message"].(string); ok { message = msg } return passed, message, nil default: return false, fmt.Sprintf("script returned unexpected type: %T", res), nil } } ================================================ FILE: agent/robot/executor/standard/validator_test.go ================================================ package standard_test import ( "context" "testing" "github.com/stretchr/testify/assert" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ============================================================================ // Validator Tests - Two-Layer Validation System // ============================================================================ func TestValidatorValidateWithContext(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("validates with no rules - passes with valid output", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ExpectedOutput: "", ValidationRules: []string{}, } result := validator.ValidateWithContext(task, "Some output", nil) assert.True(t, result.Passed) assert.True(t, result.Complete) assert.False(t, result.NeedReply) assert.Equal(t, 1.0, result.Score) }) t.Run("validates with no rules - incomplete with empty output", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ExpectedOutput: "", ValidationRules: []string{}, } result := validator.ValidateWithContext(task, "", nil) assert.True(t, result.Passed) assert.False(t, result.Complete) // Empty output = not complete }) t.Run("validates with rule-based validation - passes", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ValidationRules: []string{ `{"type": "contains", "value": "hello"}`, }, } result := validator.ValidateWithContext(task, "hello world", nil) assert.True(t, result.Passed) assert.True(t, result.Complete) assert.False(t, result.NeedReply) }) t.Run("validates with rule-based validation - fails", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ValidationRules: []string{ `{"type": "contains", "value": "expected_string"}`, }, } result := validator.ValidateWithContext(task, "actual output without expected", nil) assert.False(t, result.Passed) assert.False(t, result.Complete) assert.True(t, result.NeedReply) // Should suggest retry assert.NotEmpty(t, result.ReplyContent) assert.NotEmpty(t, result.Issues) }) t.Run("validates with semantic validation", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ExpectedOutput: "A professional greeting message", } result := validator.ValidateWithContext(task, "Dear Sir/Madam, I hope this message finds you well.", nil) // Semantic validation should pass for this appropriate output t.Logf("Validation result: passed=%v, complete=%v, score=%.2f", result.Passed, result.Complete, result.Score) t.Logf("Issues: %v", result.Issues) t.Logf("Suggestions: %v", result.Suggestions) // The semantic validator should recognize this as appropriate assert.NotNil(t, result) }) } func TestValidatorIsComplete(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("complete when passed with valid output", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ExpectedOutput: "", ValidationRules: []string{}, } result := validator.ValidateWithContext(task, "Valid output", nil) assert.True(t, result.Complete) }) t.Run("not complete when passed but empty output", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ExpectedOutput: "", ValidationRules: []string{}, } result := validator.ValidateWithContext(task, "", nil) assert.False(t, result.Complete) }) t.Run("not complete when validation failed", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ValidationRules: []string{ `{"type": "contains", "value": "MUST_CONTAIN_THIS"}`, }, } result := validator.ValidateWithContext(task, "output without required string", nil) assert.False(t, result.Passed) assert.False(t, result.Complete) }) t.Run("not complete when score below threshold", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() config.ValidationThreshold = 0.9 // High threshold validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ExpectedOutput: "A very specific output format that's hard to match exactly", } // This output might get a lower score due to semantic mismatch result := validator.ValidateWithContext(task, "Some generic output", nil) // If score is below threshold, should not be complete if result.Passed && result.Score < config.ValidationThreshold { assert.False(t, result.Complete) } }) } func TestValidatorCheckNeedReply(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("no reply needed when complete", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ExpectedOutput: "", ValidationRules: []string{}, } result := validator.ValidateWithContext(task, "Complete output", nil) assert.True(t, result.Complete) assert.False(t, result.NeedReply) assert.Empty(t, result.ReplyContent) }) t.Run("reply needed when validation failed with suggestions", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ValidationRules: []string{ `{"type": "type", "value": "object"}`, }, } // String output when object expected result := validator.ValidateWithContext(task, "not an object", nil) assert.False(t, result.Passed) assert.True(t, result.NeedReply) assert.NotEmpty(t, result.ReplyContent) // The reply should contain validation feedback about the issue assert.Contains(t, result.ReplyContent, "did not pass validation") }) t.Run("reply needed when output is empty but passed", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ExpectedOutput: "", ValidationRules: []string{}, } result := validator.ValidateWithContext(task, " ", nil) // Whitespace only // Passed (no rules) but not complete (empty output) assert.True(t, result.Passed) assert.False(t, result.Complete) // When passed but not complete (empty output), checkNeedReply may or may not // set NeedReply depending on the implementation details // Just verify the result is consistent t.Logf("NeedReply: %v, ReplyContent: %s", result.NeedReply, result.ReplyContent) }) } func TestValidatorConvertStringRule(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("converts 'valid JSON' rule", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ValidationRules: []string{ "output must be valid JSON", }, } // Valid JSON object result := validator.ValidateWithContext(task, map[string]interface{}{"key": "value"}, nil) assert.True(t, result.Passed) // Invalid (string is not an object) result2 := validator.ValidateWithContext(task, "not json", nil) assert.False(t, result2.Passed) }) t.Run("converts 'must contain' rule", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ValidationRules: []string{ "must contain 'success'", }, } result := validator.ValidateWithContext(task, "Operation was a success!", nil) assert.True(t, result.Passed) result2 := validator.ValidateWithContext(task, "Operation failed", nil) assert.False(t, result2.Passed) }) t.Run("converts 'not empty' rule", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ValidationRules: []string{ "output must not be empty", }, } result := validator.ValidateWithContext(task, "Some content", nil) assert.True(t, result.Passed) // Note: The "not empty" rule may be converted to semantic validation // rather than a rule-based assertion, so empty string might still pass // if semantic validation is lenient result2 := validator.ValidateWithContext(task, "", nil) t.Logf("Empty string validation: passed=%v, issues=%v", result2.Passed, result2.Issues) }) t.Run("converts 'json array' rule", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ValidationRules: []string{ "must be json array", }, } result := validator.ValidateWithContext(task, []interface{}{"a", "b", "c"}, nil) assert.True(t, result.Passed) result2 := validator.ValidateWithContext(task, map[string]interface{}{"key": "value"}, nil) assert.False(t, result2.Passed) }) } func TestValidatorParseRules(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("parses JSON assertion rules", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ValidationRules: []string{ `{"type": "equals", "value": "expected"}`, }, } result := validator.ValidateWithContext(task, "expected", nil) assert.True(t, result.Passed) result2 := validator.ValidateWithContext(task, "different", nil) assert.False(t, result2.Passed) }) t.Run("parses regex rules", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ValidationRules: []string{ `{"type": "regex", "value": "^[A-Z][a-z]+$"}`, }, } result := validator.ValidateWithContext(task, "Hello", nil) assert.True(t, result.Passed) result2 := validator.ValidateWithContext(task, "hello", nil) assert.False(t, result2.Passed) }) t.Run("parses json_path rules", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ValidationRules: []string{ `{"type": "json_path", "path": "data.count", "value": 42}`, }, } result := validator.ValidateWithContext(task, map[string]interface{}{ "data": map[string]interface{}{ "count": 42, }, }, nil) assert.True(t, result.Passed) result2 := validator.ValidateWithContext(task, map[string]interface{}{ "data": map[string]interface{}{ "count": 10, }, }, nil) assert.False(t, result2.Passed) }) t.Run("parses type rules with path", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ValidationRules: []string{ `{"type": "type", "path": "items", "value": "array"}`, }, } result := validator.ValidateWithContext(task, map[string]interface{}{ "items": []interface{}{"a", "b"}, }, nil) assert.True(t, result.Passed) result2 := validator.ValidateWithContext(task, map[string]interface{}{ "items": "not an array", }, nil) assert.False(t, result2.Passed) }) } func TestValidatorSemanticValidation(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("semantic validation with expected output", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ExpectedOutput: "A JSON object containing user information with name and email fields", } output := map[string]interface{}{ "name": "John Doe", "email": "john@example.com", } result := validator.ValidateWithContext(task, output, nil) t.Logf("Semantic validation: passed=%v, score=%.2f, complete=%v", result.Passed, result.Score, result.Complete) t.Logf("Details: %s", result.Details) // Should pass semantic validation assert.NotNil(t, result) }) t.Run("semantic validation with complex criteria", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ExpectedOutput: "A professional email with greeting, body, and signature", Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Write a professional email"}, }, } output := `Dear Mr. Smith, I hope this email finds you well. I am writing to follow up on our previous conversation regarding the project timeline. Please let me know if you have any questions. Best regards, John Doe` result := validator.ValidateWithContext(task, output, nil) t.Logf("Email validation: passed=%v, score=%.2f", result.Passed, result.Score) // Should recognize this as a valid professional email assert.NotNil(t, result) }) } func TestValidatorMergeResults(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), testAuth()) t.Run("both rule and semantic validation pass", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ExpectedOutput: "A greeting message", ValidationRules: []string{ `{"type": "contains", "value": "Hello"}`, }, } result := validator.ValidateWithContext(task, "Hello, how are you today?", nil) assert.True(t, result.Passed) assert.True(t, result.Complete) }) t.Run("rule passes but semantic fails", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ExpectedOutput: "A formal business letter with proper formatting", ValidationRules: []string{ `{"type": "contains", "value": "Hello"}`, // This will pass }, } // Contains "Hello" but not a formal business letter result := validator.ValidateWithContext(task, "Hello there buddy!", nil) // Rule passes, but semantic might not t.Logf("Merged result: passed=%v, score=%.2f", result.Passed, result.Score) }) t.Run("rule fails - semantic not run", func(t *testing.T) { robot := createValidatorTestRobot(t) config := standard.DefaultValidatorConfig() validator := standard.NewValidator(ctx, robot, config) task := &types.Task{ ID: "task-001", ExpectedOutput: "Some expected output", ValidationRules: []string{ `{"type": "contains", "value": "REQUIRED_STRING"}`, }, } result := validator.ValidateWithContext(task, "Output without required string", nil) // Should fail at rule level, semantic not needed assert.False(t, result.Passed) assert.False(t, result.Complete) }) } // ============================================================================ // Helper Functions // ============================================================================ // createValidatorTestRobot creates a test robot for validator tests func createValidatorTestRobot(t *testing.T) *types.Robot { t.Helper() return &types.Robot{ MemberID: "test-robot-validator", TeamID: "test-team-1", DisplayName: "Test Robot for Validator", SystemPrompt: "You are a helpful assistant.", Config: &types.Config{ Identity: &types.Identity{ Role: "Test Assistant", Duties: []string{"Validate outputs"}, }, Resources: &types.Resources{ Phases: map[types.Phase]string{ types.PhaseRun: "robot.validation", "validation": "robot.validation", // For semantic validation agent }, Agents: []string{ "experts.data-analyst", "experts.text-writer", }, }, }, } } ================================================ FILE: agent/robot/executor/types/helpers.go ================================================ package types import ( "time" robottypes "github.com/yaoapp/yao/agent/robot/types" ) // BuildTriggerInput builds TriggerInput from trigger data // Shared helper used by all executor implementations func BuildTriggerInput(trigger robottypes.TriggerType, data interface{}) *robottypes.TriggerInput { input := &robottypes.TriggerInput{} switch trigger { case robottypes.TriggerClock: input.Clock = robottypes.NewClockContext(time.Now(), "") case robottypes.TriggerHuman: if existing, ok := data.(*robottypes.TriggerInput); ok { return existing } if req, ok := data.(*robottypes.InterveneRequest); ok { input.Action = req.Action input.Messages = req.Messages } case robottypes.TriggerEvent: if req, ok := data.(*robottypes.EventRequest); ok { input.Source = robottypes.EventSource(req.Source) input.EventType = req.EventType input.Data = req.Data } } return input } ================================================ FILE: agent/robot/executor/types/types.go ================================================ package types import ( "time" robottypes "github.com/yaoapp/yao/agent/robot/types" ) // Executor defines the interface for robot phase execution // Different implementations provide different execution strategies: // - Standard: Real Agent calls with full phase execution // - DryRun: Plan-only mode, simulates execution without Agent calls // - Sandbox: Isolated execution with resource limits and safety controls type Executor interface { // ExecuteWithControl runs a robot through all applicable phases with execution control // ctx: Execution context with auth and logging // robot: Robot configuration and state // trigger: What triggered this execution (clock, human, event) // data: Trigger-specific data (human input, event payload, etc.) // execID: Pre-generated execution ID (empty string to auto-generate) // control: Optional execution control for pause/resume functionality // Returns: Execution record with all phase outputs ExecuteWithControl(ctx *robottypes.Context, robot *robottypes.Robot, trigger robottypes.TriggerType, data interface{}, execID string, control robottypes.ExecutionControl) (*robottypes.Execution, error) // ExecuteWithID runs a robot through all applicable phases with a pre-generated execution ID // This is a convenience wrapper around ExecuteWithControl without control ExecuteWithID(ctx *robottypes.Context, robot *robottypes.Robot, trigger robottypes.TriggerType, data interface{}, execID string) (*robottypes.Execution, error) // Execute runs a robot through all applicable phases (auto-generates execution ID) // This is a convenience wrapper around ExecuteWithControl Execute(ctx *robottypes.Context, robot *robottypes.Robot, trigger robottypes.TriggerType, data interface{}) (*robottypes.Execution, error) // Resume resumes a suspended execution with human-provided input. // Loads the execution from persistent storage, restores state from ResumeContext, // and continues from where it was suspended. // Returns ErrExecutionSuspended if the execution suspends again during resume. Resume(ctx *robottypes.Context, execID string, reply string) error // Metrics and control ExecCount() int // Total execution count CurrentCount() int // Currently running execution count Reset() // Reset counters (for testing) } // PhaseExecutor defines the interface for individual phase execution // Used internally by Executor implementations type PhaseExecutor interface { // RunInspiration executes P0: Inspiration phase RunInspiration(ctx *robottypes.Context, exec *robottypes.Execution, data interface{}) error // RunGoals executes P1: Goals phase RunGoals(ctx *robottypes.Context, exec *robottypes.Execution, data interface{}) error // RunTasks executes P2: Tasks phase RunTasks(ctx *robottypes.Context, exec *robottypes.Execution, data interface{}) error // RunExecution executes P3: Run phase (task execution) RunExecution(ctx *robottypes.Context, exec *robottypes.Execution, data interface{}) error // RunDelivery executes P4: Delivery phase RunDelivery(ctx *robottypes.Context, exec *robottypes.Execution, data interface{}) error // RunLearning executes P5: Learning phase RunLearning(ctx *robottypes.Context, exec *robottypes.Execution, data interface{}) error } // Config holds common executor configuration type Config struct { // SkipPersistence skips execution record persistence (for testing) SkipPersistence bool // OnPhaseStart callback when a phase starts OnPhaseStart func(phase robottypes.Phase) // OnPhaseEnd callback when a phase ends OnPhaseEnd func(phase robottypes.Phase) } // DryRunConfig holds dry-run specific configuration type DryRunConfig struct { Config // Delay simulates execution delay for each phase Delay time.Duration // OnStart callback on execution start OnStart func() // OnEnd callback on execution end OnEnd func() } // SandboxConfig holds sandbox specific configuration // // ⚠️ NOT IMPLEMENTED: These settings are placeholders for future // container-based isolation. True sandbox requires infrastructure support // (Docker/gVisor/Firecracker). Current implementation behaves like DryRun. type SandboxConfig struct { Config // MaxDuration limits total execution time MaxDuration time.Duration // MaxMemory limits memory usage (bytes) - requires container runtime MaxMemory int64 // AllowedAgents restricts which agents can be called AllowedAgents []string // AllowedTools restricts which tools can be used AllowedTools []string // NetworkAccess controls network access - requires container networking NetworkAccess bool // FileAccess controls file system access - requires container filesystem FileAccess bool } // Mode represents the executor mode type Mode string const ( ModeStandard Mode = "standard" // Real Agent execution (production) ModeDryRun Mode = "dryrun" // Simulated execution (testing/demo) ModeSandbox Mode = "sandbox" // Container-isolated execution (NOT IMPLEMENTED) ) // Setting holds executor settings from configuration type Setting struct { Mode Mode `json:"mode,omitempty" yaml:"mode,omitempty"` // Executor mode MaxDuration time.Duration `json:"max_duration,omitempty" yaml:"max_duration,omitempty"` // Max execution time MaxMemory int64 `json:"max_memory,omitempty" yaml:"max_memory,omitempty"` // Max memory (bytes) AllowedAgents []string `json:"allowed_agents,omitempty" yaml:"allowed_agents,omitempty"` // Allowed agent IDs NetworkAccess bool `json:"network_access,omitempty" yaml:"network_access,omitempty"` // Allow network FileAccess bool `json:"file_access,omitempty" yaml:"file_access,omitempty"` // Allow file system } // DefaultSetting returns default executor settings func DefaultSetting() *Setting { return &Setting{ Mode: ModeStandard, MaxDuration: 30 * time.Minute, MaxMemory: 512 * 1024 * 1024, // 512MB NetworkAccess: true, FileAccess: false, } } ================================================ FILE: agent/robot/logger/logger.go ================================================ package logger import ( "fmt" kunlog "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/config" ) const ( Reset = "\033[0m" Red = "\033[31m" Green = "\033[32m" Yellow = "\033[33m" Cyan = "\033[36m" White = "\033[37m" Gray = "\033[90m" BoldCyan = "\033[1;36m" BoldGreen = "\033[1;32m" BoldRed = "\033[1;31m" reset = Reset red = Red yellow = Yellow cyan = Cyan gray = Gray ) // Logger provides robot-level structured logging. All integration adapters, // dispatchers, event handlers, etc. share this implementation. // // Dev mode → colored stdout + kun/log.Trace (unified). // Prod mode → kun/log at matching level. type Logger struct { tag string } // New creates a Logger tagged with the given component name // (e.g. "telegram", "dispatcher", "message", "delivery"). func New(tag string) *Logger { return &Logger{tag: tag} } func (l *Logger) prefix() string { return fmt.Sprintf("[robot:%s]", l.tag) } func (l *Logger) Trace(format string, args ...interface{}) { msg := fmt.Sprintf(format, args...) if config.IsDevelopment() { fmt.Printf("%s → %s %s%s\n", gray, l.prefix(), msg, reset) } kunlog.Trace("%s %s", l.prefix(), msg) } func (l *Logger) Debug(format string, args ...interface{}) { msg := fmt.Sprintf(format, args...) if config.IsDevelopment() { fmt.Printf("%s • %s %s%s\n", gray, l.prefix(), msg, reset) } kunlog.Debug("%s %s", l.prefix(), msg) } func (l *Logger) Info(format string, args ...interface{}) { msg := fmt.Sprintf(format, args...) if config.IsDevelopment() { fmt.Printf("%s ℹ %s %s%s\n", cyan, l.prefix(), msg, reset) } kunlog.Info("%s %s", l.prefix(), msg) } func (l *Logger) Warn(format string, args ...interface{}) { msg := fmt.Sprintf(format, args...) if config.IsDevelopment() { fmt.Printf("%s ⚠ %s %s%s\n", yellow, l.prefix(), msg, reset) } kunlog.Warn("%s %s", l.prefix(), msg) } func (l *Logger) Error(format string, args ...interface{}) { msg := fmt.Sprintf(format, args...) if config.IsDevelopment() { fmt.Printf("%s ✗ %s %s%s\n", red, l.prefix(), msg, reset) } kunlog.Error("%s %s", l.prefix(), msg) } // IsDev returns true when running in development mode. func IsDev() bool { return config.IsDevelopment() } // Raw writes pre-formatted text directly to stdout in dev mode only. // Use for rich multi-line output (box-style logs, tables, etc.) // that should bypass the standard single-line prefix format. func Raw(s string) { if config.IsDevelopment() { fmt.Print(s) } } ================================================ FILE: agent/robot/manager/integration_clock_test.go ================================================ package manager_test // Integration tests for Clock trigger modes // Tests all three clock modes: times, interval, daemon // Includes timezone handling and day-of-week filtering import ( "context" "encoding/json" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" "github.com/yaoapp/yao/agent/robot/executor" "github.com/yaoapp/yao/agent/robot/manager" "github.com/yaoapp/yao/agent/robot/pool" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // createClockTestManager creates a manager with mock executor for clock tests func createClockTestManager(t *testing.T, tickInterval time.Duration, workerSize, queueSize int) (*manager.Manager, *executor.DryRunExecutor) { exec := executor.NewDryRunWithDelay(0) config := &manager.Config{ TickInterval: tickInterval, PoolConfig: &pool.Config{WorkerSize: workerSize, QueueSize: queueSize}, Executor: exec, } m := manager.NewWithConfig(config) return m, exec } // ==================== Times Mode Tests ==================== // TestIntegrationClockTimesMode tests the times mode clock trigger func TestIntegrationClockTimesMode(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("triggers at configured time", func(t *testing.T) { // Clean up before each subtest to ensure isolation cleanupIntegrationRobots(t) setupClockTestRobot(t, "robot_integ_clock_times1", "team_integ_clock", map[string]interface{}{ "mode": "times", "times": []string{"09:00", "14:00", "17:00"}, "days": []string{"Mon", "Tue", "Wed", "Thu", "Fri"}, "tz": "Asia/Shanghai", }) m, exec := createClockTestManager(t, 100*time.Millisecond, 3, 20) err := m.Start() require.NoError(t, err) defer m.Stop() // Verify robot is loaded into cache robot := m.Cache().Get("robot_integ_clock_times1") require.NotNil(t, robot, "Robot should be loaded into cache") exec.Reset() // Trigger at 09:00 on Wednesday loc, _ := time.LoadLocation("Asia/Shanghai") now := time.Date(2025, 1, 15, 9, 0, 0, 0, loc) // Wednesday 09:00 ctx := types.NewContext(context.Background(), nil) err = m.Tick(ctx, now) assert.NoError(t, err) time.Sleep(300 * time.Millisecond) assert.GreaterOrEqual(t, exec.ExecCount(), 1, "Should trigger at 09:00") }) t.Run("does not trigger at non-configured time", func(t *testing.T) { // Clean up before each subtest to ensure isolation cleanupIntegrationRobots(t) setupClockTestRobot(t, "robot_integ_clock_times2", "team_integ_clock", map[string]interface{}{ "mode": "times", "times": []string{"09:00", "14:00"}, "days": []string{"Mon", "Tue", "Wed", "Thu", "Fri"}, "tz": "Asia/Shanghai", }) m, exec := createClockTestManager(t, 100*time.Millisecond, 3, 20) err := m.Start() require.NoError(t, err) defer m.Stop() exec.Reset() // Trigger at 10:30 (not configured) loc, _ := time.LoadLocation("Asia/Shanghai") now := time.Date(2025, 1, 15, 10, 30, 0, 0, loc) // Wednesday 10:30 ctx := types.NewContext(context.Background(), nil) err = m.Tick(ctx, now) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) assert.Equal(t, 0, exec.ExecCount(), "Should not trigger at non-configured time") }) t.Run("does not trigger on non-configured day", func(t *testing.T) { // Clean up before each subtest to ensure isolation cleanupIntegrationRobots(t) setupClockTestRobot(t, "robot_integ_clock_times3", "team_integ_clock", map[string]interface{}{ "mode": "times", "times": []string{"09:00"}, "days": []string{"Mon", "Tue", "Wed", "Thu", "Fri"}, // Weekdays only "tz": "Asia/Shanghai", }) m, exec := createClockTestManager(t, 100*time.Millisecond, 3, 20) err := m.Start() require.NoError(t, err) defer m.Stop() exec.Reset() // Trigger at 09:00 on Saturday (not configured) loc, _ := time.LoadLocation("Asia/Shanghai") now := time.Date(2025, 1, 18, 9, 0, 0, 0, loc) // Saturday 09:00 ctx := types.NewContext(context.Background(), nil) err = m.Tick(ctx, now) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) assert.Equal(t, 0, exec.ExecCount(), "Should not trigger on Saturday") }) t.Run("wildcard days matches all days", func(t *testing.T) { // Clean up before each subtest to ensure isolation cleanupIntegrationRobots(t) setupClockTestRobot(t, "robot_integ_clock_times4", "team_integ_clock", map[string]interface{}{ "mode": "times", "times": []string{"09:00"}, "days": []string{"*"}, // All days "tz": "Asia/Shanghai", }) m, exec := createClockTestManager(t, 100*time.Millisecond, 3, 20) err := m.Start() require.NoError(t, err) defer m.Stop() exec.Reset() // Trigger at 09:00 on Saturday loc, _ := time.LoadLocation("Asia/Shanghai") now := time.Date(2025, 1, 18, 9, 0, 0, 0, loc) // Saturday 09:00 ctx := types.NewContext(context.Background(), nil) err = m.Tick(ctx, now) assert.NoError(t, err) time.Sleep(300 * time.Millisecond) assert.GreaterOrEqual(t, exec.ExecCount(), 1, "Should trigger on Saturday with wildcard days") }) t.Run("dedup prevents double trigger in same minute", func(t *testing.T) { // Clean up before each subtest to ensure isolation cleanupIntegrationRobots(t) setupClockTestRobot(t, "robot_integ_clock_times5", "team_integ_clock", map[string]interface{}{ "mode": "times", "times": []string{"09:00"}, "days": []string{"*"}, "tz": "Asia/Shanghai", }) m, exec := createClockTestManager(t, 100*time.Millisecond, 3, 20) err := m.Start() require.NoError(t, err) defer m.Stop() exec.Reset() loc, _ := time.LoadLocation("Asia/Shanghai") ctx := types.NewContext(context.Background(), nil) // First tick at 09:00:00 now1 := time.Date(2025, 1, 15, 9, 0, 0, 0, loc) err = m.Tick(ctx, now1) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) firstCount := exec.ExecCount() assert.GreaterOrEqual(t, firstCount, 1, "First tick should trigger") // Second tick at 09:00:30 (same minute) now2 := time.Date(2025, 1, 15, 9, 0, 30, 0, loc) err = m.Tick(ctx, now2) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) // Should not trigger again in same minute assert.Equal(t, firstCount, exec.ExecCount(), "Should not trigger twice in same minute") }) } // ==================== Interval Mode Tests ==================== // TestIntegrationClockIntervalMode tests the interval mode clock trigger func TestIntegrationClockIntervalMode(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("triggers on first run", func(t *testing.T) { // Clean up before each subtest to ensure isolation cleanupIntegrationRobots(t) setupClockTestRobot(t, "robot_integ_clock_interval1", "team_integ_clock", map[string]interface{}{ "mode": "interval", "every": "30m", }) m, exec := createClockTestManager(t, 100*time.Millisecond, 3, 20) err := m.Start() require.NoError(t, err) defer m.Stop() exec.Reset() ctx := types.NewContext(context.Background(), nil) now := time.Now() err = m.Tick(ctx, now) assert.NoError(t, err) time.Sleep(300 * time.Millisecond) assert.GreaterOrEqual(t, exec.ExecCount(), 1, "Should trigger on first run") }) t.Run("triggers after interval passed", func(t *testing.T) { // Clean up before each subtest to ensure isolation cleanupIntegrationRobots(t) setupClockTestRobot(t, "robot_integ_clock_interval2", "team_integ_clock", map[string]interface{}{ "mode": "interval", "every": "100ms", // Short interval for testing }) m, exec := createClockTestManager(t, 50*time.Millisecond, 3, 20) err := m.Start() require.NoError(t, err) defer m.Stop() exec.Reset() ctx := types.NewContext(context.Background(), nil) // First tick now1 := time.Now() err = m.Tick(ctx, now1) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) firstCount := exec.ExecCount() assert.GreaterOrEqual(t, firstCount, 1, "First tick should trigger") // Wait for interval to pass time.Sleep(150 * time.Millisecond) // Second tick after interval now2 := time.Now() err = m.Tick(ctx, now2) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) // Should have triggered again assert.Greater(t, exec.ExecCount(), firstCount, "Should trigger again after interval") }) t.Run("does not trigger before interval passed", func(t *testing.T) { // Clean up before each subtest to ensure isolation cleanupIntegrationRobots(t) setupClockTestRobot(t, "robot_integ_clock_interval3", "team_integ_clock", map[string]interface{}{ "mode": "interval", "every": "1h", // Long interval }) m, exec := createClockTestManager(t, 100*time.Millisecond, 3, 20) err := m.Start() require.NoError(t, err) defer m.Stop() exec.Reset() ctx := types.NewContext(context.Background(), nil) // First tick now1 := time.Now() err = m.Tick(ctx, now1) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) firstCount := exec.ExecCount() assert.GreaterOrEqual(t, firstCount, 1, "First tick should trigger") // Second tick immediately (interval not passed) now2 := now1.Add(1 * time.Minute) // Only 1 minute later err = m.Tick(ctx, now2) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) // Should not trigger again assert.Equal(t, firstCount, exec.ExecCount(), "Should not trigger before interval") }) } // ==================== Daemon Mode Tests ==================== // TestIntegrationClockDaemonMode tests the daemon mode clock trigger func TestIntegrationClockDaemonMode(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("triggers when robot can run", func(t *testing.T) { setupClockTestRobot(t, "robot_integ_clock_daemon1", "team_integ_clock", map[string]interface{}{ "mode": "daemon", "timeout": "5m", }) m, exec := createClockTestManager(t, 100*time.Millisecond, 3, 20) err := m.Start() require.NoError(t, err) defer m.Stop() exec.Reset() ctx := types.NewContext(context.Background(), nil) err = m.Tick(ctx, time.Now()) assert.NoError(t, err) time.Sleep(300 * time.Millisecond) assert.GreaterOrEqual(t, exec.ExecCount(), 1, "Daemon should trigger when idle") }) t.Run("respects quota limit", func(t *testing.T) { // Create daemon robot with Max=1 setupClockTestRobotWithQuota(t, "robot_integ_clock_daemon2", "team_integ_clock", map[string]interface{}{ "mode": "daemon", "timeout": "5m", }, 1, 5, 5) // Max=1, Queue=5 m, exec := createClockTestManager(t, 50*time.Millisecond, 5, 20) err := m.Start() require.NoError(t, err) defer m.Stop() exec.Reset() ctx := types.NewContext(context.Background(), nil) // Trigger multiple times rapidly for i := 0; i < 5; i++ { err = m.Tick(ctx, time.Now()) assert.NoError(t, err) time.Sleep(60 * time.Millisecond) } // Robot should respect quota (Max=1) robot := m.Cache().Get("robot_integ_clock_daemon2") assert.NotNil(t, robot) // Running count should be at most Max assert.LessOrEqual(t, robot.RunningCount(), 1, "Should respect quota limit") }) } // ==================== Timezone Tests ==================== // TestIntegrationClockTimezone tests timezone handling func TestIntegrationClockTimezone(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("respects robot timezone", func(t *testing.T) { // Robot configured for Asia/Shanghai (UTC+8) setupClockTestRobot(t, "robot_integ_clock_tz1", "team_integ_clock", map[string]interface{}{ "mode": "times", "times": []string{"09:00"}, "days": []string{"*"}, "tz": "Asia/Shanghai", }) m, exec := createClockTestManager(t, 100*time.Millisecond, 3, 20) err := m.Start() require.NoError(t, err) defer m.Stop() exec.Reset() ctx := types.NewContext(context.Background(), nil) // 09:00 in Shanghai = 01:00 UTC shanghai, _ := time.LoadLocation("Asia/Shanghai") shanghaiTime := time.Date(2025, 1, 15, 9, 0, 0, 0, shanghai) err = m.Tick(ctx, shanghaiTime) assert.NoError(t, err) time.Sleep(300 * time.Millisecond) assert.GreaterOrEqual(t, exec.ExecCount(), 1, "Should trigger at 09:00 Shanghai time") }) t.Run("different timezone same UTC time", func(t *testing.T) { // Robot 1: Asia/Shanghai at 09:00 (UTC+8) = 01:00 UTC setupClockTestRobot(t, "robot_integ_clock_tz2", "team_integ_clock", map[string]interface{}{ "mode": "times", "times": []string{"09:00"}, "days": []string{"*"}, "tz": "Asia/Shanghai", }) // Robot 2: America/New_York at 09:00 (UTC-5) = 14:00 UTC setupClockTestRobot(t, "robot_integ_clock_tz3", "team_integ_clock", map[string]interface{}{ "mode": "times", "times": []string{"09:00"}, "days": []string{"*"}, "tz": "America/New_York", }) m, exec := createClockTestManager(t, 100*time.Millisecond, 3, 20) err := m.Start() require.NoError(t, err) defer m.Stop() exec.Reset() ctx := types.NewContext(context.Background(), nil) // Test at 01:00 UTC (09:00 Shanghai) utcTime := time.Date(2025, 1, 15, 1, 0, 0, 0, time.UTC) err = m.Tick(ctx, utcTime) assert.NoError(t, err) time.Sleep(300 * time.Millisecond) // Only Shanghai robot should trigger execCount := exec.ExecCount() assert.GreaterOrEqual(t, execCount, 1, "Shanghai robot should trigger") // New York robot should not trigger (it's 20:00 in NY) }) } // ==================== Edge Cases ==================== // TestIntegrationClockEdgeCases tests edge cases in clock triggering func TestIntegrationClockEdgeCases(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("robot with clock disabled is skipped", func(t *testing.T) { // Create robot with clock trigger disabled m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{"role": "Clock Disabled Robot"}, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": false}, }, "clock": map[string]interface{}{ "mode": "times", "times": []string{"09:00"}, "tz": "Asia/Shanghai", }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_integ_clock_disabled", "team_id": "team_integ_clock", "member_type": "robot", "display_name": "Clock Disabled Robot", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) require.NoError(t, err) mgr, exec := createClockTestManager(t, 100*time.Millisecond, 3, 20) err = mgr.Start() require.NoError(t, err) defer mgr.Stop() exec.Reset() // Trigger at matching time loc, _ := time.LoadLocation("Asia/Shanghai") now := time.Date(2025, 1, 15, 9, 0, 0, 0, loc) ctx := types.NewContext(context.Background(), nil) err = mgr.Tick(ctx, now) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) assert.Equal(t, 0, exec.ExecCount(), "Clock disabled robot should not trigger") }) t.Run("paused robot is skipped", func(t *testing.T) { // Create paused robot m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{"role": "Paused Robot"}, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, }, "clock": map[string]interface{}{ "mode": "times", "times": []string{"09:00"}, "tz": "Asia/Shanghai", }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_integ_clock_paused", "team_id": "team_integ_clock", "member_type": "robot", "display_name": "Paused Robot", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "paused", // Paused status "robot_config": string(configJSON), }, }) require.NoError(t, err) mgr, exec := createClockTestManager(t, 100*time.Millisecond, 3, 20) err = mgr.Start() require.NoError(t, err) defer mgr.Stop() exec.Reset() // Trigger at matching time loc, _ := time.LoadLocation("Asia/Shanghai") now := time.Date(2025, 1, 15, 9, 0, 0, 0, loc) ctx := types.NewContext(context.Background(), nil) err = mgr.Tick(ctx, now) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) assert.Equal(t, 0, exec.ExecCount(), "Paused robot should not trigger") }) t.Run("robot without clock config is skipped", func(t *testing.T) { // Create robot without clock config m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{"role": "No Clock Robot"}, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, }, // No clock config } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_integ_clock_noconfig", "team_id": "team_integ_clock", "member_type": "robot", "display_name": "No Clock Config Robot", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) require.NoError(t, err) mgr, exec := createClockTestManager(t, 100*time.Millisecond, 3, 20) err = mgr.Start() require.NoError(t, err) defer mgr.Stop() exec.Reset() ctx := types.NewContext(context.Background(), nil) err = mgr.Tick(ctx, time.Now()) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) assert.Equal(t, 0, exec.ExecCount(), "Robot without clock config should not trigger") }) } // ==================== Test Data Setup Helpers ==================== // setupClockTestRobot creates a robot with specified clock config func setupClockTestRobot(t *testing.T, memberID, teamID string, clockConfig map[string]interface{}) { setupClockTestRobotWithQuota(t, memberID, teamID, clockConfig, 3, 20, 5) } // setupClockTestRobotWithQuota creates a robot with specified clock config and quota func setupClockTestRobotWithQuota(t *testing.T, memberID, teamID string, clockConfig map[string]interface{}, max, queue, priority int) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Clock Test Robot " + memberID, }, "quota": map[string]interface{}{ "max": max, "queue": queue, "priority": priority, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, }, "clock": clockConfig, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "Clock Test Robot " + memberID, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } ================================================ FILE: agent/robot/manager/integration_concurrent_test.go ================================================ package manager_test // Integration tests for concurrent execution and quota enforcement // Tests the two-level concurrency model: // 1. Global pool limit (worker count) // 2. Per-robot quota limit (Quota.Max, Quota.Queue) import ( "context" "encoding/json" "fmt" "sync" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" "github.com/yaoapp/yao/agent/robot/executor" "github.com/yaoapp/yao/agent/robot/manager" "github.com/yaoapp/yao/agent/robot/pool" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ==================== Concurrent Execution Tests ==================== // TestIntegrationConcurrentExecution tests concurrent execution of multiple robots func TestIntegrationConcurrentExecution(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("multiple robots execute concurrently", func(t *testing.T) { // Create 5 robots for i := 0; i < 5; i++ { memberID := "robot_integ_conc_multi_" + string(rune('A'+i)) setupConcurrentTestRobot(t, memberID, "team_integ_conc", 3, 20) } // Track concurrent execution count var maxConcurrent int32 var currentConcurrent int32 exec := executor.NewDryRunWithCallbacks(100*time.Millisecond, func() { curr := atomic.AddInt32(¤tConcurrent, 1) for { old := atomic.LoadInt32(&maxConcurrent) if curr <= old || atomic.CompareAndSwapInt32(&maxConcurrent, old, curr) { break } } }, func() { atomic.AddInt32(¤tConcurrent, -1) }, ) config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 5, QueueSize: 50}, } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() // Verify robots are loaded into cache for i := 0; i < 5; i++ { memberID := "robot_integ_conc_multi_" + string(rune('A'+i)) robot := m.Cache().Get(memberID) require.NotNil(t, robot, "Robot %s should be loaded into cache", memberID) } ctx := types.NewContext(context.Background(), nil) // Trigger all robots simultaneously var wg sync.WaitGroup for i := 0; i < 5; i++ { wg.Add(1) memberID := "robot_integ_conc_multi_" + string(rune('A'+i)) go func(id string) { defer wg.Done() m.TriggerManual(ctx, id, types.TriggerClock, nil) }(memberID) } wg.Wait() // Wait for all executions time.Sleep(500 * time.Millisecond) // Should have achieved concurrent execution assert.GreaterOrEqual(t, int(maxConcurrent), 2, "Should achieve concurrent execution") assert.GreaterOrEqual(t, exec.ExecCount(), 5, "All robots should execute") }) t.Run("same robot multiple triggers", func(t *testing.T) { setupConcurrentTestRobot(t, "robot_integ_conc_same", "team_integ_conc", 3, 20) exec := executor.NewDryRunWithDelay(50 * time.Millisecond) config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 5, QueueSize: 50}, } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Trigger same robot multiple times for i := 0; i < 5; i++ { _, err := m.TriggerManual(ctx, "robot_integ_conc_same", types.TriggerClock, nil) assert.NoError(t, err) } // Wait for all executions time.Sleep(800 * time.Millisecond) // All 5 should eventually execute assert.GreaterOrEqual(t, exec.ExecCount(), 5, "All triggers should execute") }) } // ==================== Quota Enforcement Tests ==================== // TestIntegrationQuotaEnforcement tests per-robot quota limits func TestIntegrationQuotaEnforcement(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("respects Quota.Max limit", func(t *testing.T) { // Create robot with Max=2 setupConcurrentTestRobot(t, "robot_integ_quota_max", "team_integ_quota", 2, 20) // Track max concurrent for this robot var maxConcurrent int32 var currentConcurrent int32 exec := executor.NewDryRunWithCallbacks(200*time.Millisecond, func() { curr := atomic.AddInt32(¤tConcurrent, 1) for { old := atomic.LoadInt32(&maxConcurrent) if curr <= old || atomic.CompareAndSwapInt32(&maxConcurrent, old, curr) { break } } }, func() { atomic.AddInt32(¤tConcurrent, -1) }, ) config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 10, QueueSize: 50}, // Many workers } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Submit 10 jobs for the same robot for i := 0; i < 10; i++ { m.TriggerManual(ctx, "robot_integ_quota_max", types.TriggerClock, nil) } // Wait a bit for concurrent execution time.Sleep(300 * time.Millisecond) // Max concurrent should not exceed Quota.Max (2) assert.LessOrEqual(t, int(maxConcurrent), 2, "Should not exceed Quota.Max") // Wait for all to complete time.Sleep(1500 * time.Millisecond) // All should eventually execute assert.GreaterOrEqual(t, exec.ExecCount(), 10, "All jobs should eventually execute") }) t.Run("respects Quota.Queue limit", func(t *testing.T) { // Create robot with Max=1, Queue=3 setupConcurrentTestRobot(t, "robot_integ_quota_queue", "team_integ_quota", 1, 3) exec := executor.NewDryRunWithDelay(300 * time.Millisecond) // Slow execution config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 10, QueueSize: 100}, } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Submit many jobs - some should be rejected due to queue limit successCount := 0 for i := 0; i < 20; i++ { _, err := m.TriggerManual(ctx, "robot_integ_quota_queue", types.TriggerClock, nil) if err == nil { successCount++ } } // Should accept at most Max + Queue = 1 + 3 = 4 jobs assert.LessOrEqual(t, successCount, 4, "Should respect queue limit") assert.GreaterOrEqual(t, successCount, 1, "Should accept at least 1 job") }) t.Run("different robots have independent quotas", func(t *testing.T) { // Robot A: Max=1 setupConcurrentTestRobot(t, "robot_integ_quota_A", "team_integ_quota", 1, 10) // Robot B: Max=3 setupConcurrentTestRobot(t, "robot_integ_quota_B", "team_integ_quota", 3, 10) var concurrentA int32 var concurrentB int32 var maxA int32 var maxB int32 // Custom executor that tracks per-robot concurrency exec := &trackingExecutor{ delay: 150 * time.Millisecond, onStart: func(robot *types.Robot) { if robot.MemberID == "robot_integ_quota_A" { curr := atomic.AddInt32(&concurrentA, 1) for { old := atomic.LoadInt32(&maxA) if curr <= old || atomic.CompareAndSwapInt32(&maxA, old, curr) { break } } } else { curr := atomic.AddInt32(&concurrentB, 1) for { old := atomic.LoadInt32(&maxB) if curr <= old || atomic.CompareAndSwapInt32(&maxB, old, curr) { break } } } }, onEnd: func(robot *types.Robot) { if robot.MemberID == "robot_integ_quota_A" { atomic.AddInt32(&concurrentA, -1) } else { atomic.AddInt32(&concurrentB, -1) } }, } config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 10, QueueSize: 50}, } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Submit 5 jobs for each robot for i := 0; i < 5; i++ { m.TriggerManual(ctx, "robot_integ_quota_A", types.TriggerClock, nil) m.TriggerManual(ctx, "robot_integ_quota_B", types.TriggerClock, nil) } // Wait a bit time.Sleep(300 * time.Millisecond) // Robot A should have max 1 concurrent assert.LessOrEqual(t, int(maxA), 1, "Robot A should respect its quota") // Robot B should have max 3 concurrent assert.LessOrEqual(t, int(maxB), 3, "Robot B should respect its quota") // Wait for completion time.Sleep(1 * time.Second) }) } // ==================== Global Pool Limit Tests ==================== // TestIntegrationGlobalPoolLimit tests global worker pool limits func TestIntegrationGlobalPoolLimit(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("respects global worker limit", func(t *testing.T) { // Create 10 robots with high quotas for i := 0; i < 10; i++ { memberID := "robot_integ_pool_limit_" + string(rune('A'+i)) setupConcurrentTestRobot(t, memberID, "team_integ_pool", 5, 20) } var maxConcurrent int32 var currentConcurrent int32 exec := executor.NewDryRunWithCallbacks(200*time.Millisecond, func() { curr := atomic.AddInt32(¤tConcurrent, 1) for { old := atomic.LoadInt32(&maxConcurrent) if curr <= old || atomic.CompareAndSwapInt32(&maxConcurrent, old, curr) { break } } }, func() { atomic.AddInt32(¤tConcurrent, -1) }, ) config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 3, QueueSize: 100}, // Only 3 workers } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Trigger all 10 robots for i := 0; i < 10; i++ { memberID := "robot_integ_pool_limit_" + string(rune('A'+i)) m.TriggerManual(ctx, memberID, types.TriggerClock, nil) } // Wait a bit time.Sleep(300 * time.Millisecond) // Max concurrent should not exceed worker limit (3) assert.LessOrEqual(t, int(maxConcurrent), 3, "Should not exceed worker limit") // Wait for all to complete time.Sleep(1 * time.Second) // All 10 should execute assert.GreaterOrEqual(t, exec.ExecCount(), 10, "All robots should execute") }) t.Run("respects global queue limit", func(t *testing.T) { // Create robots for i := 0; i < 20; i++ { memberID := "robot_integ_pool_queue_" + string(rune('A'+i%26)) setupConcurrentTestRobot(t, memberID, "team_integ_pool", 5, 20) } exec := executor.NewDryRunWithDelay(500 * time.Millisecond) // Slow execution config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 1, QueueSize: 5}, // Small queue } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Try to submit many jobs successCount := 0 for i := 0; i < 20; i++ { memberID := "robot_integ_pool_queue_" + string(rune('A'+i%26)) _, err := m.TriggerManual(ctx, memberID, types.TriggerClock, nil) if err == nil { successCount++ } } // Should respect global queue limit // Max = WorkerSize + QueueSize = 1 + 5 = 6 assert.LessOrEqual(t, successCount, 6, "Should respect global queue limit") }) } // ==================== Priority Tests ==================== // TestIntegrationPriorityExecution tests priority-based execution order func TestIntegrationPriorityExecution(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("higher priority executes first", func(t *testing.T) { // Create robots with different priorities setupConcurrentTestRobotWithPriority(t, "robot_integ_prio_low", "team_integ_prio", 2, 10, 1) setupConcurrentTestRobotWithPriority(t, "robot_integ_prio_med", "team_integ_prio", 2, 10, 5) setupConcurrentTestRobotWithPriority(t, "robot_integ_prio_high", "team_integ_prio", 2, 10, 10) executionOrder := make([]string, 0) var mu sync.Mutex exec := &trackingExecutor{ delay: 50 * time.Millisecond, onStart: func(robot *types.Robot) { mu.Lock() executionOrder = append(executionOrder, robot.MemberID) mu.Unlock() }, onEnd: func(robot *types.Robot) {}, } config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 1, QueueSize: 50}, // Single worker for ordering } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Submit in low-to-high priority order _, err = m.TriggerManual(ctx, "robot_integ_prio_low", types.TriggerClock, nil) assert.NoError(t, err) _, err = m.TriggerManual(ctx, "robot_integ_prio_med", types.TriggerClock, nil) assert.NoError(t, err) _, err = m.TriggerManual(ctx, "robot_integ_prio_high", types.TriggerClock, nil) assert.NoError(t, err) // Wait for all to complete time.Sleep(500 * time.Millisecond) // Verify execution order (high priority should be first or early) mu.Lock() order := executionOrder mu.Unlock() assert.Len(t, order, 3, "All 3 robots should execute") // Note: First job may already be picked up before others are queued // So we just verify all executed }) t.Run("human trigger has higher priority than clock", func(t *testing.T) { setupConcurrentTestRobotAllTriggers(t, "robot_integ_prio_trigger", "team_integ_prio", 2, 10, 5) executionOrder := make([]types.TriggerType, 0) var mu sync.Mutex exec := &triggerTrackingExecutor{ delay: 50 * time.Millisecond, onStart: func(trigger types.TriggerType) { mu.Lock() executionOrder = append(executionOrder, trigger) mu.Unlock() }, } config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 1, QueueSize: 50}, } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Submit clock first, then human _, err = m.TriggerManual(ctx, "robot_integ_prio_trigger", types.TriggerClock, nil) assert.NoError(t, err) _, err = m.TriggerManual(ctx, "robot_integ_prio_trigger", types.TriggerHuman, nil) assert.NoError(t, err) // Wait for execution time.Sleep(300 * time.Millisecond) mu.Lock() order := executionOrder mu.Unlock() assert.Len(t, order, 2, "Both triggers should execute") }) } // ==================== Helper Types ==================== // trackingExecutor tracks execution per robot type trackingExecutor struct { delay time.Duration onStart func(robot *types.Robot) onEnd func(robot *types.Robot) count int32 } func (e *trackingExecutor) Execute(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}) (*types.Execution, error) { return e.ExecuteWithControl(ctx, robot, trigger, data, "", nil) } func (e *trackingExecutor) ExecuteWithID(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}, execID string) (*types.Execution, error) { return e.ExecuteWithControl(ctx, robot, trigger, data, execID, nil) } func (e *trackingExecutor) ExecuteWithControl(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}, execID string, control types.ExecutionControl) (*types.Execution, error) { if robot == nil { return nil, types.ErrRobotNotFound } // Use provided execID or generate unique ID for each execution to properly track quota if execID == "" { execID = fmt.Sprintf("exec_%d", time.Now().UnixNano()) } exec := &types.Execution{ ID: execID, MemberID: robot.MemberID, TeamID: robot.TeamID, TriggerType: trigger, StartTime: time.Now(), Status: types.ExecPending, } if !robot.TryAcquireSlot(exec) { return nil, types.ErrQuotaExceeded } defer robot.RemoveExecution(exec.ID) if e.onStart != nil { e.onStart(robot) } exec.Status = types.ExecRunning time.Sleep(e.delay) if e.onEnd != nil { e.onEnd(robot) } exec.Status = types.ExecCompleted now := time.Now() exec.EndTime = &now atomic.AddInt32(&e.count, 1) return exec, nil } func (e *trackingExecutor) ExecCount() int { return int(atomic.LoadInt32(&e.count)) } func (e *trackingExecutor) CurrentCount() int { return 0 } func (e *trackingExecutor) Resume(ctx *types.Context, execID string, reply string) error { return fmt.Errorf("resume not supported in tracking executor") } func (e *trackingExecutor) Reset() { atomic.StoreInt32(&e.count, 0) } // triggerTrackingExecutor tracks execution by trigger type type triggerTrackingExecutor struct { delay time.Duration onStart func(trigger types.TriggerType) count int32 } func (e *triggerTrackingExecutor) Execute(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}) (*types.Execution, error) { return e.ExecuteWithControl(ctx, robot, trigger, data, "", nil) } func (e *triggerTrackingExecutor) ExecuteWithID(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}, execID string) (*types.Execution, error) { return e.ExecuteWithControl(ctx, robot, trigger, data, execID, nil) } func (e *triggerTrackingExecutor) ExecuteWithControl(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}, execID string, control types.ExecutionControl) (*types.Execution, error) { if robot == nil { return nil, types.ErrRobotNotFound } // Use provided execID or generate unique ID for each execution to properly track quota if execID == "" { execID = fmt.Sprintf("exec_trigger_%s_%d", string(trigger), time.Now().UnixNano()) } exec := &types.Execution{ ID: execID, MemberID: robot.MemberID, TeamID: robot.TeamID, TriggerType: trigger, StartTime: time.Now(), Status: types.ExecPending, } if !robot.TryAcquireSlot(exec) { return nil, types.ErrQuotaExceeded } defer robot.RemoveExecution(exec.ID) if e.onStart != nil { e.onStart(trigger) } exec.Status = types.ExecRunning time.Sleep(e.delay) exec.Status = types.ExecCompleted now := time.Now() exec.EndTime = &now atomic.AddInt32(&e.count, 1) return exec, nil } func (e *triggerTrackingExecutor) ExecCount() int { return int(atomic.LoadInt32(&e.count)) } func (e *triggerTrackingExecutor) CurrentCount() int { return 0 } func (e *triggerTrackingExecutor) Resume(ctx *types.Context, execID string, reply string) error { return fmt.Errorf("resume not supported in trigger tracking executor") } func (e *triggerTrackingExecutor) Reset() { atomic.StoreInt32(&e.count, 0) } // ==================== Test Data Setup Helpers ==================== // setupConcurrentTestRobot creates a robot for concurrency testing func setupConcurrentTestRobot(t *testing.T, memberID, teamID string, max, queue int) { setupConcurrentTestRobotWithPriority(t, memberID, teamID, max, queue, 5) } // setupConcurrentTestRobotWithPriority creates a robot with specified priority func setupConcurrentTestRobotWithPriority(t *testing.T, memberID, teamID string, max, queue, priority int) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Concurrent Test Robot " + memberID, }, "quota": map[string]interface{}{ "max": max, "queue": queue, "priority": priority, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, "intervene": map[string]interface{}{"enabled": true}, "event": map[string]interface{}{"enabled": true}, }, "clock": map[string]interface{}{ "mode": "times", "times": []string{"09:00"}, "tz": "Asia/Shanghai", }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "Concurrent Test Robot " + memberID, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } // setupConcurrentTestRobotAllTriggers creates a robot with all triggers enabled func setupConcurrentTestRobotAllTriggers(t *testing.T, memberID, teamID string, max, queue, priority int) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "All Triggers Test Robot", }, "quota": map[string]interface{}{ "max": max, "queue": queue, "priority": priority, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, "intervene": map[string]interface{}{"enabled": true}, "event": map[string]interface{}{"enabled": true}, }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "All Triggers Robot " + memberID, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } ================================================ FILE: agent/robot/manager/integration_control_test.go ================================================ package manager_test // Integration tests for execution control (Pause/Resume/Stop) // Tests Manager's execution control methods and ExecutionController import ( "context" "encoding/json" "fmt" "sync" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/manager" "github.com/yaoapp/yao/agent/robot/pool" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ==================== Pause/Resume Tests ==================== // TestIntegrationExecutionPauseResume tests pausing and resuming executions func TestIntegrationExecutionPauseResume(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("pause and resume execution", func(t *testing.T) { setupControlTestRobot(t, "robot_integ_ctrl_pause", "team_integ_ctrl") // Use slow executor to have time to pause exec := &slowExecutor{delay: 500 * time.Millisecond} config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 3, QueueSize: 20}, } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() // Verify robot is loaded into cache robot := m.Cache().Get("robot_integ_ctrl_pause") require.NotNil(t, robot, "Robot should be loaded into cache") ctx := types.NewContext(context.Background(), nil) // Trigger execution req := &types.InterveneRequest{ MemberID: "robot_integ_ctrl_pause", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test task"}, }, } result, err := m.Intervene(ctx, req) require.NoError(t, err) execID := result.ExecutionID // Wait for execution to be tracked time.Sleep(100 * time.Millisecond) // Pause execution err = m.PauseExecution(ctx, execID) assert.NoError(t, err) // Verify paused status, err := m.GetExecutionStatus(execID) assert.NoError(t, err) assert.True(t, status.IsPaused(), "Execution should be paused") // Resume execution err = m.ResumeExecution(ctx, execID) assert.NoError(t, err) // Verify resumed status, err = m.GetExecutionStatus(execID) assert.NoError(t, err) assert.False(t, status.IsPaused(), "Execution should be resumed") }) t.Run("pause non-existent execution", func(t *testing.T) { m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) err = m.PauseExecution(ctx, "nonexistent_exec") assert.Error(t, err) assert.Contains(t, err.Error(), "not found") }) t.Run("resume non-paused execution", func(t *testing.T) { setupControlTestRobot(t, "robot_integ_ctrl_resume", "team_integ_ctrl") exec := &slowExecutor{delay: 500 * time.Millisecond} config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 3, QueueSize: 20}, } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Trigger execution req := &types.InterveneRequest{ MemberID: "robot_integ_ctrl_resume", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test task"}, }, } result, err := m.Intervene(ctx, req) require.NoError(t, err) execID := result.ExecutionID // Wait for execution to be tracked time.Sleep(100 * time.Millisecond) // Resume without pausing first - should be safe err = m.ResumeExecution(ctx, execID) // May or may not error depending on implementation // The important thing is it doesn't panic }) } // ==================== Stop Tests ==================== // TestIntegrationExecutionStop tests stopping executions func TestIntegrationExecutionStop(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("stop execution", func(t *testing.T) { setupControlTestRobot(t, "robot_integ_ctrl_stop", "team_integ_ctrl") exec := &slowExecutor{delay: 1 * time.Second} config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 3, QueueSize: 20}, } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Trigger execution req := &types.InterveneRequest{ MemberID: "robot_integ_ctrl_stop", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test task"}, }, } result, err := m.Intervene(ctx, req) require.NoError(t, err) execID := result.ExecutionID // Wait for execution to be tracked time.Sleep(100 * time.Millisecond) // Stop execution err = m.StopExecution(ctx, execID) assert.NoError(t, err) // Execution should be removed from tracking _, err = m.GetExecutionStatus(execID) assert.Error(t, err) assert.Contains(t, err.Error(), "not found") }) t.Run("stop non-existent execution", func(t *testing.T) { m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) err = m.StopExecution(ctx, "nonexistent_exec") assert.Error(t, err) assert.Contains(t, err.Error(), "not found") }) } // ==================== List Executions Tests ==================== // TestIntegrationListExecutions tests listing executions func TestIntegrationListExecutions(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("list all executions", func(t *testing.T) { setupControlTestRobot(t, "robot_integ_ctrl_list1", "team_integ_ctrl") setupControlTestRobot(t, "robot_integ_ctrl_list2", "team_integ_ctrl") exec := &slowExecutor{delay: 500 * time.Millisecond} config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 5, QueueSize: 20}, } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Trigger multiple executions execIDs := make([]string, 0) for _, memberID := range []string{"robot_integ_ctrl_list1", "robot_integ_ctrl_list2"} { req := &types.InterveneRequest{ MemberID: memberID, Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test task"}, }, } result, err := m.Intervene(ctx, req) require.NoError(t, err) execIDs = append(execIDs, result.ExecutionID) } // Wait for executions to be tracked time.Sleep(100 * time.Millisecond) // List all executions execs := m.ListExecutions() assert.GreaterOrEqual(t, len(execs), 2, "Should have at least 2 executions") // Verify our executions are in the list foundCount := 0 for _, e := range execs { for _, id := range execIDs { if e.ID == id { foundCount++ } } } assert.Equal(t, 2, foundCount, "Both executions should be in list") }) t.Run("list executions by member", func(t *testing.T) { setupControlTestRobot(t, "robot_integ_ctrl_member1", "team_integ_ctrl") setupControlTestRobot(t, "robot_integ_ctrl_member2", "team_integ_ctrl") exec := &slowExecutor{delay: 500 * time.Millisecond} config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 5, QueueSize: 20}, } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Trigger 3 executions for robot 1 for i := 0; i < 3; i++ { req := &types.InterveneRequest{ MemberID: "robot_integ_ctrl_member1", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test task"}, }, } _, err := m.Intervene(ctx, req) require.NoError(t, err) } // Trigger 2 executions for robot 2 for i := 0; i < 2; i++ { req := &types.InterveneRequest{ MemberID: "robot_integ_ctrl_member2", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test task"}, }, } _, err := m.Intervene(ctx, req) require.NoError(t, err) } // Wait for executions to be tracked time.Sleep(100 * time.Millisecond) // List executions for robot 1 execs1 := m.ListExecutionsByMember("robot_integ_ctrl_member1") assert.GreaterOrEqual(t, len(execs1), 1, "Robot 1 should have executions") // List executions for robot 2 execs2 := m.ListExecutionsByMember("robot_integ_ctrl_member2") assert.GreaterOrEqual(t, len(execs2), 1, "Robot 2 should have executions") // Verify member IDs for _, e := range execs1 { assert.Equal(t, "robot_integ_ctrl_member1", e.MemberID) } for _, e := range execs2 { assert.Equal(t, "robot_integ_ctrl_member2", e.MemberID) } }) } // ==================== Multiple Control Operations Tests ==================== // TestIntegrationMultipleControlOperations tests sequences of control operations func TestIntegrationMultipleControlOperations(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("pause-resume-pause-stop sequence", func(t *testing.T) { setupControlTestRobot(t, "robot_integ_ctrl_seq", "team_integ_ctrl") exec := &slowExecutor{delay: 2 * time.Second} config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 3, QueueSize: 20}, } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Trigger execution req := &types.InterveneRequest{ MemberID: "robot_integ_ctrl_seq", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test task"}, }, } result, err := m.Intervene(ctx, req) require.NoError(t, err) execID := result.ExecutionID // Wait for tracking time.Sleep(100 * time.Millisecond) // Pause err = m.PauseExecution(ctx, execID) assert.NoError(t, err) status, _ := m.GetExecutionStatus(execID) assert.True(t, status.IsPaused()) // Resume err = m.ResumeExecution(ctx, execID) assert.NoError(t, err) status, _ = m.GetExecutionStatus(execID) assert.False(t, status.IsPaused()) // Pause again err = m.PauseExecution(ctx, execID) assert.NoError(t, err) status, _ = m.GetExecutionStatus(execID) assert.True(t, status.IsPaused()) // Stop err = m.StopExecution(ctx, execID) assert.NoError(t, err) _, err = m.GetExecutionStatus(execID) assert.Error(t, err) // Should be removed }) t.Run("concurrent control operations", func(t *testing.T) { setupControlTestRobot(t, "robot_integ_ctrl_conc", "team_integ_ctrl") exec := &slowExecutor{delay: 1 * time.Second} config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 3, QueueSize: 20}, } m := manager.NewWithConfig(config) m.Pool().SetExecutor(exec) err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Trigger execution req := &types.InterveneRequest{ MemberID: "robot_integ_ctrl_conc", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test task"}, }, } result, err := m.Intervene(ctx, req) require.NoError(t, err) execID := result.ExecutionID // Wait for tracking time.Sleep(100 * time.Millisecond) // Concurrent pause/resume operations should not panic var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(2) go func() { defer wg.Done() m.PauseExecution(ctx, execID) }() go func() { defer wg.Done() m.ResumeExecution(ctx, execID) }() } // Wait with timeout done := make(chan struct{}) go func() { wg.Wait() close(done) }() select { case <-done: // Success - no deadlock case <-time.After(5 * time.Second): t.Fatal("Concurrent control operations caused deadlock") } }) } // ==================== Helper Types ==================== // slowExecutor is an executor with configurable delay type slowExecutor struct { delay time.Duration count int32 current int32 } func (e *slowExecutor) Execute(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}) (*types.Execution, error) { return e.ExecuteWithControl(ctx, robot, trigger, data, "", nil) } func (e *slowExecutor) ExecuteWithID(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}, execID string) (*types.Execution, error) { return e.ExecuteWithControl(ctx, robot, trigger, data, execID, nil) } func (e *slowExecutor) ExecuteWithControl(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}, execID string, control types.ExecutionControl) (*types.Execution, error) { if robot == nil { return nil, types.ErrRobotNotFound } // Use provided execID or generate one if execID == "" { execID = "exec_slow_" + robot.MemberID } exec := &types.Execution{ ID: execID, MemberID: robot.MemberID, TeamID: robot.TeamID, TriggerType: trigger, StartTime: time.Now(), Status: types.ExecPending, } if !robot.TryAcquireSlot(exec) { return nil, types.ErrQuotaExceeded } defer robot.RemoveExecution(exec.ID) atomic.AddInt32(&e.current, 1) defer atomic.AddInt32(&e.current, -1) exec.Status = types.ExecRunning time.Sleep(e.delay) exec.Status = types.ExecCompleted now := time.Now() exec.EndTime = &now atomic.AddInt32(&e.count, 1) return exec, nil } func (e *slowExecutor) ExecCount() int { return int(atomic.LoadInt32(&e.count)) } func (e *slowExecutor) CurrentCount() int { return int(atomic.LoadInt32(&e.current)) } func (e *slowExecutor) Resume(ctx *types.Context, execID string, reply string) error { return fmt.Errorf("resume not supported in slow executor") } func (e *slowExecutor) Reset() { atomic.StoreInt32(&e.count, 0) atomic.StoreInt32(&e.current, 0) } // ==================== Test Data Setup Helpers ==================== // setupControlTestRobot creates a robot for control testing func setupControlTestRobot(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Control Test Robot", "duties": []string{"Test execution control"}, }, "quota": map[string]interface{}{ "max": 5, "queue": 20, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, "intervene": map[string]interface{}{"enabled": true}, "event": map[string]interface{}{"enabled": true}, }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "Control Test Robot " + memberID, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } ================================================ FILE: agent/robot/manager/integration_event_test.go ================================================ package manager_test // Integration tests for Event triggers // Tests Manager.HandleEvent() with various event types and scenarios import ( "context" "encoding/json" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" "github.com/yaoapp/yao/agent/robot/manager" "github.com/yaoapp/yao/agent/robot/pool" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ==================== Event Trigger Tests ==================== // TestIntegrationEventTrigger tests event trigger flow func TestIntegrationEventTrigger(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("webhook event success", func(t *testing.T) { setupEventTestRobot(t, "robot_integ_event_webhook", "team_integ_event") config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 3, QueueSize: 20}, } m := manager.NewWithConfig(config) err := m.Start() require.NoError(t, err) defer m.Stop() // Verify robot is loaded into cache robot := m.Cache().Get("robot_integ_event_webhook") require.NotNil(t, robot, "Robot should be loaded into cache") ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "robot_integ_event_webhook", Source: "webhook", EventType: "lead.created", Data: map[string]interface{}{ "name": "John Doe", "email": "john@example.com", "company": "Acme Corp", }, } result, err := m.HandleEvent(ctx, req) assert.NoError(t, err) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) assert.Equal(t, types.ExecPending, result.Status) assert.Contains(t, result.Message, "webhook") assert.Contains(t, result.Message, "lead.created") // Wait for execution time.Sleep(500 * time.Millisecond) // Verify execution completed assert.GreaterOrEqual(t, m.Executor().ExecCount(), 1) }) t.Run("database event success", func(t *testing.T) { setupEventTestRobot(t, "robot_integ_event_db", "team_integ_event") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "robot_integ_event_db", Source: "database", EventType: "order.paid", Data: map[string]interface{}{ "order_id": "ORD-12345", "amount": 1500.00, "customer": "customer_001", }, } result, err := m.HandleEvent(ctx, req) assert.NoError(t, err) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) }) t.Run("event with complex data", func(t *testing.T) { setupEventTestRobot(t, "robot_integ_event_complex", "team_integ_event") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "robot_integ_event_complex", Source: "webhook", EventType: "crm.contact.updated", Data: map[string]interface{}{ "contact": map[string]interface{}{ "id": "contact_001", "name": "Jane Smith", "email": "jane@example.com", "tags": []string{"vip", "enterprise"}, }, "changes": map[string]interface{}{ "old_status": "active", "new_status": "premium", }, "timestamp": time.Now().Unix(), }, } result, err := m.HandleEvent(ctx, req) assert.NoError(t, err) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) }) } // TestIntegrationEventTriggerErrors tests error cases for event triggers func TestIntegrationEventTriggerErrors(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("robot not found", func(t *testing.T) { m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "robot_nonexistent", Source: "webhook", EventType: "test.event", } _, err = m.HandleEvent(ctx, req) assert.Error(t, err) assert.Equal(t, types.ErrRobotNotFound, err) }) t.Run("robot paused", func(t *testing.T) { setupEventTestRobotPaused(t, "robot_integ_event_paused", "team_integ_event") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "robot_integ_event_paused", Source: "webhook", EventType: "test.event", } _, err = m.HandleEvent(ctx, req) assert.Error(t, err) assert.Equal(t, types.ErrRobotPaused, err) }) t.Run("event trigger disabled", func(t *testing.T) { setupEventTestRobotDisabled(t, "robot_integ_event_disabled", "team_integ_event") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "robot_integ_event_disabled", Source: "webhook", EventType: "test.event", } _, err = m.HandleEvent(ctx, req) assert.Error(t, err) assert.Equal(t, types.ErrTriggerDisabled, err) }) t.Run("invalid request - empty member_id", func(t *testing.T) { m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "", // Empty Source: "webhook", EventType: "test.event", } _, err = m.HandleEvent(ctx, req) assert.Error(t, err) assert.Contains(t, err.Error(), "member_id") }) t.Run("invalid request - empty source", func(t *testing.T) { setupEventTestRobot(t, "robot_integ_event_nosource", "team_integ_event") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "robot_integ_event_nosource", Source: "", // Empty EventType: "test.event", } _, err = m.HandleEvent(ctx, req) assert.Error(t, err) assert.Contains(t, err.Error(), "source") }) t.Run("invalid request - empty event_type", func(t *testing.T) { setupEventTestRobot(t, "robot_integ_event_notype", "team_integ_event") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "robot_integ_event_notype", Source: "webhook", EventType: "", // Empty } _, err = m.HandleEvent(ctx, req) assert.Error(t, err) assert.Contains(t, err.Error(), "event_type") }) t.Run("manager not started", func(t *testing.T) { m := manager.New() // Don't start ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "robot_test", Source: "webhook", EventType: "test.event", } _, err := m.HandleEvent(ctx, req) assert.Error(t, err) assert.Contains(t, err.Error(), "not started") }) } // TestIntegrationEventTriggerTypes tests various event types func TestIntegrationEventTriggerTypes(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) // Common event types to test eventTypes := []struct { name string eventType string data map[string]interface{} }{ { name: "lead.created", eventType: "lead.created", data: map[string]interface{}{"name": "John", "email": "john@example.com"}, }, { name: "order.paid", eventType: "order.paid", data: map[string]interface{}{"order_id": "ORD-001", "amount": 100.0}, }, { name: "customer.signup", eventType: "customer.signup", data: map[string]interface{}{"customer_id": "cust_001", "plan": "premium"}, }, { name: "ticket.created", eventType: "ticket.created", data: map[string]interface{}{"ticket_id": "TKT-001", "priority": "high"}, }, { name: "inventory.low", eventType: "inventory.low", data: map[string]interface{}{"product_id": "PRD-001", "quantity": 5}, }, } for _, tc := range eventTypes { t.Run(tc.name, func(t *testing.T) { memberID := "robot_integ_event_type_" + tc.name setupEventTestRobot(t, memberID, "team_integ_event") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: memberID, Source: "webhook", EventType: tc.eventType, Data: tc.data, } result, err := m.HandleEvent(ctx, req) assert.NoError(t, err, "Event type %s should succeed", tc.eventType) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) assert.Equal(t, types.ExecPending, result.Status) }) } } // TestIntegrationEventTriggerSources tests different event sources func TestIntegrationEventTriggerSources(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) sources := []string{"webhook", "database", "api", "scheduler", "internal"} for _, source := range sources { t.Run("source_"+source, func(t *testing.T) { memberID := "robot_integ_event_src_" + source setupEventTestRobot(t, memberID, "team_integ_event") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: memberID, Source: source, EventType: "test.event", Data: map[string]interface{}{"source": source}, } result, err := m.HandleEvent(ctx, req) assert.NoError(t, err, "Source %s should succeed", source) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) }) } } // TestIntegrationEventTriggerWithEmptyData tests event with empty or nil data func TestIntegrationEventTriggerWithEmptyData(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("nil data", func(t *testing.T) { setupEventTestRobot(t, "robot_integ_event_nildata", "team_integ_event") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "robot_integ_event_nildata", Source: "webhook", EventType: "ping", Data: nil, // Nil data } result, err := m.HandleEvent(ctx, req) assert.NoError(t, err) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) }) t.Run("empty data map", func(t *testing.T) { setupEventTestRobot(t, "robot_integ_event_emptydata", "team_integ_event") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "robot_integ_event_emptydata", Source: "webhook", EventType: "heartbeat", Data: map[string]interface{}{}, // Empty map } result, err := m.HandleEvent(ctx, req) assert.NoError(t, err) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) }) } // ==================== Test Data Setup Helpers ==================== // setupEventTestRobot creates a robot with event trigger enabled func setupEventTestRobot(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Event Test Robot", "duties": []string{"Handle event triggers"}, }, "quota": map[string]interface{}{ "max": 5, "queue": 20, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": false}, "intervene": map[string]interface{}{"enabled": false}, "event": map[string]interface{}{"enabled": true}, }, "events": []map[string]interface{}{ { "type": "webhook", "source": "/webhook/events", }, { "type": "database", "source": "orders", }, }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "Event Test Robot " + memberID, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } // setupEventTestRobotPaused creates a paused robot func setupEventTestRobotPaused(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{"role": "Paused Event Robot"}, "triggers": map[string]interface{}{ "event": map[string]interface{}{"enabled": true}, }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "Paused Event Robot " + memberID, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "paused", // Paused "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } // setupEventTestRobotDisabled creates a robot with event trigger disabled func setupEventTestRobotDisabled(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{"role": "Event Disabled Robot"}, "triggers": map[string]interface{}{ "event": map[string]interface{}{"enabled": false}, // Disabled }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "Event Disabled Robot " + memberID, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } ================================================ FILE: agent/robot/manager/integration_human_test.go ================================================ package manager_test // Integration tests for Human intervention triggers // Tests Manager.Intervene() with various actions and scenarios import ( "context" "encoding/json" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/manager" "github.com/yaoapp/yao/agent/robot/pool" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ==================== Human Intervention Tests ==================== // TestIntegrationHumanIntervention tests human intervention trigger flow func TestIntegrationHumanIntervention(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("task.add action success", func(t *testing.T) { setupInterveneTestRobot(t, "robot_integ_human_add", "team_integ_human") config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 3, QueueSize: 20}, } m := manager.NewWithConfig(config) err := m.Start() require.NoError(t, err) defer m.Stop() // Verify robot is loaded into cache robot := m.Cache().Get("robot_integ_human_add") require.NotNil(t, robot, "Robot should be loaded into cache") ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ TeamID: "team_integ_human", MemberID: "robot_integ_human_add", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Add a new task: analyze sales data"}, }, } result, err := m.Intervene(ctx, req) assert.NoError(t, err) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) assert.Equal(t, types.ExecPending, result.Status) assert.Contains(t, result.Message, "task.add") // Wait for execution time.Sleep(500 * time.Millisecond) // Verify execution completed assert.GreaterOrEqual(t, m.Executor().ExecCount(), 1) }) t.Run("goal.adjust action success", func(t *testing.T) { setupInterveneTestRobot(t, "robot_integ_human_goal", "team_integ_human") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ TeamID: "team_integ_human", MemberID: "robot_integ_human_goal", Action: types.ActionGoalAdjust, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Focus on high-priority customers only"}, }, } result, err := m.Intervene(ctx, req) assert.NoError(t, err) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) }) t.Run("instruct action success", func(t *testing.T) { setupInterveneTestRobot(t, "robot_integ_human_instruct", "team_integ_human") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ TeamID: "team_integ_human", MemberID: "robot_integ_human_instruct", Action: types.ActionInstruct, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Generate a weekly report"}, }, } result, err := m.Intervene(ctx, req) assert.NoError(t, err) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) }) } // TestIntegrationHumanInterventionErrors tests error cases for human intervention func TestIntegrationHumanInterventionErrors(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("robot not found", func(t *testing.T) { m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "robot_nonexistent", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test"}, }, } _, err = m.Intervene(ctx, req) assert.Error(t, err) assert.Equal(t, types.ErrRobotNotFound, err) }) t.Run("robot paused", func(t *testing.T) { setupInterveneTestRobotPaused(t, "robot_integ_human_paused", "team_integ_human") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "robot_integ_human_paused", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test"}, }, } _, err = m.Intervene(ctx, req) assert.Error(t, err) assert.Equal(t, types.ErrRobotPaused, err) }) t.Run("intervene trigger disabled", func(t *testing.T) { setupInterveneTestRobotDisabled(t, "robot_integ_human_disabled", "team_integ_human") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "robot_integ_human_disabled", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test"}, }, } _, err = m.Intervene(ctx, req) assert.Error(t, err) assert.Equal(t, types.ErrTriggerDisabled, err) }) t.Run("invalid request - empty member_id", func(t *testing.T) { m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "", // Empty Action: types.ActionTaskAdd, } _, err = m.Intervene(ctx, req) assert.Error(t, err) assert.Contains(t, err.Error(), "member_id") }) t.Run("invalid request - empty action", func(t *testing.T) { setupInterveneTestRobot(t, "robot_integ_human_noaction", "team_integ_human") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "robot_integ_human_noaction", Action: "", // Empty action } _, err = m.Intervene(ctx, req) assert.Error(t, err) assert.Contains(t, err.Error(), "action") }) t.Run("manager not started", func(t *testing.T) { m := manager.New() // Don't start ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "robot_test", Action: types.ActionTaskAdd, } _, err := m.Intervene(ctx, req) assert.Error(t, err) assert.Contains(t, err.Error(), "not started") }) } // TestIntegrationHumanInterventionMultimodal tests multimodal input support func TestIntegrationHumanInterventionMultimodal(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("text message", func(t *testing.T) { setupInterveneTestRobot(t, "robot_integ_human_text", "team_integ_human") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "robot_integ_human_text", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ { Role: agentcontext.RoleUser, Content: "Analyze the quarterly sales report", }, }, } result, err := m.Intervene(ctx, req) assert.NoError(t, err) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) }) t.Run("message with image reference", func(t *testing.T) { setupInterveneTestRobot(t, "robot_integ_human_image", "team_integ_human") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "robot_integ_human_image", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ { Role: agentcontext.RoleUser, Content: []interface{}{ map[string]interface{}{ "type": "text", "text": "Analyze this chart", }, map[string]interface{}{ "type": "image_url", "image_url": map[string]interface{}{ "url": "https://example.com/chart.png", }, }, }, }, }, } result, err := m.Intervene(ctx, req) assert.NoError(t, err) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) }) t.Run("multiple messages", func(t *testing.T) { setupInterveneTestRobot(t, "robot_integ_human_multi", "team_integ_human") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "robot_integ_human_multi", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "First, check the sales data"}, {Role: agentcontext.RoleUser, Content: "Then, prepare a summary report"}, }, } result, err := m.Intervene(ctx, req) assert.NoError(t, err) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) }) } // TestIntegrationHumanInterventionAllActions tests all intervention actions func TestIntegrationHumanInterventionAllActions(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) // Test all defined actions actions := []types.InterventionAction{ types.ActionTaskAdd, types.ActionTaskCancel, types.ActionTaskUpdate, types.ActionGoalAdjust, types.ActionGoalAdd, types.ActionGoalComplete, types.ActionGoalCancel, types.ActionInstruct, // Note: plan.add, plan.remove, plan.update are handled differently } for _, action := range actions { t.Run(string(action), func(t *testing.T) { memberID := "robot_integ_action_" + string(action) setupInterveneTestRobot(t, memberID, "team_integ_human") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: memberID, Action: action, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test action: " + string(action)}, }, } result, err := m.Intervene(ctx, req) assert.NoError(t, err, "Action %s should succeed", action) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) assert.Equal(t, types.ExecPending, result.Status) }) } } // TestIntegrationHumanInterventionPlanAdd tests plan.add action (deferred execution) func TestIntegrationHumanInterventionPlanAdd(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("plan.add with future time", func(t *testing.T) { setupInterveneTestRobot(t, "robot_integ_human_plan", "team_integ_human") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) planTime := time.Now().Add(1 * time.Hour) req := &types.InterveneRequest{ MemberID: "robot_integ_human_plan", Action: types.ActionPlanAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Send weekly report"}, }, PlanTime: &planTime, } result, err := m.Intervene(ctx, req) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, types.ExecPending, result.Status) assert.Contains(t, result.Message, "Planned") // Note: Plan queue not implemented yet, so execution is deferred }) } // ==================== Test Data Setup Helpers ==================== // setupInterveneTestRobot creates a robot with intervene trigger enabled func setupInterveneTestRobot(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Intervene Test Robot", "duties": []string{"Handle human interventions"}, }, "quota": map[string]interface{}{ "max": 5, "queue": 20, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": false}, "intervene": map[string]interface{}{"enabled": true}, "event": map[string]interface{}{"enabled": false}, }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "Intervene Test Robot " + memberID, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } // setupInterveneTestRobotPaused creates a paused robot func setupInterveneTestRobotPaused(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{"role": "Paused Robot"}, "triggers": map[string]interface{}{ "intervene": map[string]interface{}{"enabled": true}, }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "Paused Robot " + memberID, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "paused", // Paused "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } // setupInterveneTestRobotDisabled creates a robot with intervene trigger disabled func setupInterveneTestRobotDisabled(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{"role": "Intervene Disabled Robot"}, "triggers": map[string]interface{}{ "intervene": map[string]interface{}{"enabled": false}, // Disabled }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "Intervene Disabled Robot " + memberID, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } ================================================ FILE: agent/robot/manager/integration_test.go ================================================ package manager_test // Integration tests for the Robot Agent scheduling system // These tests verify the complete end-to-end flow: // Trigger → Manager → Cache → Pool → Worker → Executor → Job // // Test Structure: // - integration_test.go: Core scheduling flow tests // - integration_clock_test.go: Clock trigger mode tests (times/interval/daemon) // - integration_human_test.go: Human intervention trigger tests // - integration_event_test.go: Event trigger tests // - integration_concurrent_test.go: Concurrent execution & quota tests // - integration_control_test.go: Pause/Resume/Stop tests // // Test Data: // All tests use real database records in __yao.member table // Test robot IDs are prefixed with "robot_integ_" for easy cleanup import ( "context" "encoding/json" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" "github.com/yaoapp/yao/agent/robot/executor" "github.com/yaoapp/yao/agent/robot/manager" "github.com/yaoapp/yao/agent/robot/pool" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // ==================== Core Scheduling Flow Tests ==================== // TestIntegrationSchedulingFlow tests the complete scheduling flow: // Create robot → Start manager → Trigger → Verify execution func TestIntegrationSchedulingFlow(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("complete clock trigger flow", func(t *testing.T) { // Setup: Create a robot with times mode clock config setupIntegrationRobotTimes(t, "robot_integ_flow_clock", "team_integ_flow") // Create manager with fast tick interval for testing config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 5, QueueSize: 50}, } m := manager.NewWithConfig(config) // Start manager err := m.Start() require.NoError(t, err) defer m.Stop() // Verify robot is loaded into cache robot := m.Cache().Get("robot_integ_flow_clock") require.NotNil(t, robot, "Robot should be loaded into cache") assert.Equal(t, "robot_integ_flow_clock", robot.MemberID) assert.Equal(t, types.RobotIdle, robot.Status) // Simulate clock trigger at matching time (03:33 on Wednesday) loc, _ := time.LoadLocation("Asia/Shanghai") triggerTime := time.Date(2025, 1, 15, 3, 33, 0, 0, loc) // Wednesday 03:33 ctx := types.NewContext(context.Background(), nil) err = m.Tick(ctx, triggerTime) assert.NoError(t, err) // Wait for execution to complete time.Sleep(500 * time.Millisecond) // Verify execution happened execCount := m.Executor().ExecCount() assert.GreaterOrEqual(t, execCount, 1, "Should have at least 1 execution") }) t.Run("robot loaded from database", func(t *testing.T) { // Setup: Create multiple robots setupIntegrationRobotTimes(t, "robot_integ_flow_db1", "team_integ_flow") setupIntegrationRobotInterval(t, "robot_integ_flow_db2", "team_integ_flow") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() // Verify both robots are in cache robot1 := m.Cache().Get("robot_integ_flow_db1") robot2 := m.Cache().Get("robot_integ_flow_db2") assert.NotNil(t, robot1, "Robot 1 should be loaded") assert.NotNil(t, robot2, "Robot 2 should be loaded") // Verify config is parsed correctly assert.NotNil(t, robot1.Config) assert.NotNil(t, robot1.Config.Clock) assert.Equal(t, types.ClockTimes, robot1.Config.Clock.Mode) assert.NotNil(t, robot2.Config) assert.NotNil(t, robot2.Config.Clock) assert.Equal(t, types.ClockInterval, robot2.Config.Clock.Mode) }) t.Run("inactive robot not loaded", func(t *testing.T) { // Setup: Create an inactive robot setupIntegrationRobotInactive(t, "robot_integ_flow_inactive", "team_integ_flow") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() // Inactive robot should not be in cache robot := m.Cache().Get("robot_integ_flow_inactive") assert.Nil(t, robot, "Inactive robot should not be loaded") }) t.Run("robot with autonomous_mode=false not loaded", func(t *testing.T) { // Setup: Create a robot with autonomous_mode=false setupIntegrationRobotNonAutonomous(t, "robot_integ_flow_nonauto", "team_integ_flow") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() // Non-autonomous robot should not be in cache robot := m.Cache().Get("robot_integ_flow_nonauto") assert.Nil(t, robot, "Non-autonomous robot should not be loaded") }) } // TestIntegrationJobSubmission tests job submission to pool and execution func TestIntegrationJobSubmission(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("job submitted to pool and executed", func(t *testing.T) { setupIntegrationRobotTimes(t, "robot_integ_submit", "team_integ_submit") config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 3, QueueSize: 20}, } m := manager.NewWithConfig(config) err := m.Start() require.NoError(t, err) defer m.Stop() // Manually trigger execution ctx := types.NewContext(context.Background(), nil) execID, err := m.TriggerManual(ctx, "robot_integ_submit", types.TriggerClock, nil) assert.NoError(t, err) assert.NotEmpty(t, execID, "Should return execution ID") // Wait for execution time.Sleep(500 * time.Millisecond) // Verify execution completed assert.GreaterOrEqual(t, m.Executor().ExecCount(), 1) }) t.Run("multiple jobs queued and executed in order", func(t *testing.T) { setupIntegrationRobotHighQuota(t, "robot_integ_queue", "team_integ_submit") config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 2, QueueSize: 50}, } m := manager.NewWithConfig(config) err := m.Start() require.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Submit multiple jobs execIDs := make([]string, 5) for i := 0; i < 5; i++ { execID, err := m.TriggerManual(ctx, "robot_integ_queue", types.TriggerClock, nil) assert.NoError(t, err) execIDs[i] = execID } // All should have valid IDs for i, id := range execIDs { assert.NotEmpty(t, id, "Execution %d should have valid ID", i) } // Wait for all to complete (longer wait for slow execution) time.Sleep(2 * time.Second) // All jobs should have executed execCount := m.Executor().ExecCount() assert.GreaterOrEqual(t, execCount, 5, "Expected at least 5 executions, got %d", execCount) }) } // TestIntegrationPhaseProgression tests that execution progresses through all phases func TestIntegrationPhaseProgression(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("clock trigger executes all phases P0-P5", func(t *testing.T) { cleanupIntegrationRobots(t) setupIntegrationRobotTimes(t, "robot_integ_phases_clock", "team_integ_phases") // Track phases executed phasesExecuted := make([]types.Phase, 0) exec := executor.NewDryRunWithConfig(executor.DryRunConfig{ Config: executor.Config{ OnPhaseStart: func(phase types.Phase) { phasesExecuted = append(phasesExecuted, phase) }, }, }) config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 2, QueueSize: 20}, Executor: exec, } m := manager.NewWithConfig(config) err := m.Start() require.NoError(t, err) // Trigger execution ctx := types.NewContext(context.Background(), nil) _, err = m.TriggerManual(ctx, "robot_integ_phases_clock", types.TriggerClock, nil) assert.NoError(t, err) // Wait for execution time.Sleep(500 * time.Millisecond) // Stop manager before asserting to prevent ticker from triggering extra executions m.Stop() // Verify all 6 phases executed (P0-P5) assert.Len(t, phasesExecuted, 6, "Should execute all 6 phases for clock trigger") assert.Equal(t, types.PhaseInspiration, phasesExecuted[0], "Should start with P0") assert.Equal(t, types.PhaseLearning, phasesExecuted[5], "Should end with P5") }) t.Run("human trigger skips P0 and executes P1-P5", func(t *testing.T) { cleanupIntegrationRobots(t) setupIntegrationRobotIntervene(t, "robot_integ_phases_human", "team_integ_phases") // Track phases executed phasesExecuted := make([]types.Phase, 0) exec := executor.NewDryRunWithConfig(executor.DryRunConfig{ Config: executor.Config{ OnPhaseStart: func(phase types.Phase) { phasesExecuted = append(phasesExecuted, phase) }, }, }) config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 2, QueueSize: 20}, Executor: exec, } m := manager.NewWithConfig(config) err := m.Start() require.NoError(t, err) // Trigger execution via human trigger ctx := types.NewContext(context.Background(), nil) _, err = m.TriggerManual(ctx, "robot_integ_phases_human", types.TriggerHuman, nil) assert.NoError(t, err) // Wait for execution time.Sleep(500 * time.Millisecond) // Stop manager before asserting to prevent ticker from triggering extra executions m.Stop() // Verify 5 phases executed (P1-P5, skipping P0) assert.Len(t, phasesExecuted, 5, "Should execute 5 phases for human trigger") assert.Equal(t, types.PhaseGoals, phasesExecuted[0], "Should start with P1 (Goals)") assert.Equal(t, types.PhaseLearning, phasesExecuted[4], "Should end with P5") }) } // TestIntegrationCacheRefresh tests that cache refresh works correctly func TestIntegrationCacheRefresh(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupIntegrationRobots(t) defer cleanupIntegrationRobots(t) t.Run("cache refresh loads new robots", func(t *testing.T) { // Start with one robot setupIntegrationRobotTimes(t, "robot_integ_refresh1", "team_integ_refresh") m := manager.New() err := m.Start() require.NoError(t, err) defer m.Stop() // Verify first robot is loaded robot1 := m.Cache().Get("robot_integ_refresh1") assert.NotNil(t, robot1) // Add another robot to database setupIntegrationRobotTimes(t, "robot_integ_refresh2", "team_integ_refresh") // Manually refresh cache ctx := types.NewContext(context.Background(), nil) err = m.Cache().Load(ctx) assert.NoError(t, err) // Verify new robot is now in cache robot2 := m.Cache().Get("robot_integ_refresh2") assert.NotNil(t, robot2, "New robot should be loaded after refresh") }) } // ==================== Test Data Setup Helpers ==================== // setupIntegrationRobotTimes creates a robot with times mode clock config func setupIntegrationRobotTimes(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Integration Test Robot (Times)", "duties": []string{"Test scheduling"}, }, "quota": map[string]interface{}{ "max": 3, "queue": 20, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, "intervene": map[string]interface{}{"enabled": true}, "event": map[string]interface{}{"enabled": true}, }, "clock": map[string]interface{}{ "mode": "times", "times": []string{"03:33"}, "days": []string{"Mon", "Tue", "Wed", "Thu", "Fri"}, "tz": "Asia/Shanghai", "timeout": "30m", }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "Test Robot " + memberID, "system_prompt": "You are an integration test robot.", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } // setupIntegrationRobotInterval creates a robot with interval mode clock config func setupIntegrationRobotInterval(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Integration Test Robot (Interval)", }, "quota": map[string]interface{}{ "max": 2, "queue": 10, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, }, "clock": map[string]interface{}{ "mode": "interval", "every": "30m", "timeout": "10m", }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "Test Robot " + memberID, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } // setupIntegrationRobotHighQuota creates a robot with high quota for queue tests func setupIntegrationRobotHighQuota(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Integration Test Robot (High Quota)", }, "quota": map[string]interface{}{ "max": 10, "queue": 50, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, }, "clock": map[string]interface{}{ "mode": "times", "times": []string{"03:33"}, "tz": "Asia/Shanghai", }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "Test Robot " + memberID, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } // setupIntegrationRobotIntervene creates a robot with intervene trigger enabled func setupIntegrationRobotIntervene(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Integration Test Robot (Intervene)", }, "quota": map[string]interface{}{ "max": 5, "queue": 20, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": false}, "intervene": map[string]interface{}{"enabled": true}, }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "Test Robot " + memberID, "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } // setupIntegrationRobotInactive creates an inactive robot (should not be loaded) func setupIntegrationRobotInactive(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Inactive Robot", }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "Inactive Robot " + memberID, "status": "inactive", // Inactive status "role_id": "member", "autonomous_mode": true, "robot_status": "paused", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } // setupIntegrationRobotNonAutonomous creates a robot with autonomous_mode=false func setupIntegrationRobotNonAutonomous(t *testing.T, memberID, teamID string) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() robotConfig := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Non-Autonomous Robot", }, } configJSON, _ := json.Marshal(robotConfig) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": memberID, "team_id": teamID, "member_type": "robot", "display_name": "Non-Autonomous Robot " + memberID, "status": "active", "role_id": "member", "autonomous_mode": false, // Not autonomous "robot_status": "idle", "robot_config": string(configJSON), }, }) if err != nil { t.Fatalf("Failed to insert %s: %v", memberID, err) } } // cleanupIntegrationRobots removes all integration test robots func cleanupIntegrationRobots(t *testing.T) { qb := capsule.Query() m := model.Select("__yao.member") tableName := m.MetaData.Table.Name // Delete all robots with member_id starting with "robot_integ_" // Using LIKE pattern for cleanup _, err := qb.Table(tableName).Where("member_id", "like", "robot_integ_%").Delete() if err != nil { // Log but don't fail - cleanup errors are not critical t.Logf("Warning: cleanup error: %v", err) } } ================================================ FILE: agent/robot/manager/interact.go ================================================ package manager import ( "encoding/json" "fmt" "strings" "time" "github.com/yaoapp/kun/log" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" robotevents "github.com/yaoapp/yao/agent/robot/events" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/pool" "github.com/yaoapp/yao/agent/robot/store" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/robot/utils" "github.com/yaoapp/yao/event" ) // executeResume resumes a suspended execution using the Manager's shared executor. // This avoids creating orphan Executor instances with independent counters. func (m *Manager) executeResume(ctx *types.Context, execID, reply string) error { return m.executor.Resume(types.NewContext(ctx.Context, ctx.Auth), execID, reply) } // InteractRequest represents a unified interaction with a robot (Manager layer). type InteractRequest struct { ExecutionID string `json:"execution_id,omitempty"` TaskID string `json:"task_id,omitempty"` Source types.InteractSource `json:"source,omitempty"` Message string `json:"message"` Action string `json:"action,omitempty"` } // InteractResponse is the result of an interaction. type InteractResponse struct { ExecutionID string `json:"execution_id,omitempty"` Status string `json:"status"` Message string `json:"message,omitempty"` ChatID string `json:"chat_id,omitempty"` Reply string `json:"reply,omitempty"` WaitForMore bool `json:"wait_for_more,omitempty"` } // CancelExecution cancels a waiting/confirming execution. func (m *Manager) CancelExecution(ctx *types.Context, execID string) error { m.mu.RLock() if !m.started { m.mu.RUnlock() return fmt.Errorf("manager not started") } m.mu.RUnlock() execStore := store.NewExecutionStore() record, err := execStore.Get(ctx.Context, execID) if err != nil { return fmt.Errorf("execution not found: %s", execID) } if record == nil { return fmt.Errorf("execution not found: %s", execID) } if record.Status != types.ExecWaiting && record.Status != types.ExecConfirming { return fmt.Errorf("execution %s is in status %s, only waiting/confirming can be cancelled", execID, record.Status) } if err := execStore.UpdateStatus(ctx.Context, execID, types.ExecCancelled, "cancelled by user"); err != nil { return fmt.Errorf("failed to cancel execution: %w", err) } m.execController.Untrack(execID) if robot := m.cache.Get(record.MemberID); robot != nil { robot.RemoveExecution(execID) } event.Push(ctx.Context, robotevents.ExecCancelled, robotevents.ExecPayload{ ExecutionID: execID, MemberID: record.MemberID, TeamID: record.TeamID, Status: string(types.ExecCancelled), ChatID: record.ChatID, }) return nil } // HandleInteract processes all human-robot interactions through a unified entry point. // // Routing logic (§16.37): // - No execution_id: new interaction → createConfirmingExecution → Host Agent (assign) // - execution_id with status=confirming: Host Agent (assign) → processHostAction // - execution_id with status=waiting: Host Agent (clarify) → processHostAction // - execution_id with status=running: Host Agent (guide) → processHostAction func (m *Manager) HandleInteract(ctx *types.Context, memberID string, req *InteractRequest) (*InteractResponse, error) { m.mu.RLock() if !m.started { m.mu.RUnlock() return nil, fmt.Errorf("manager not started") } m.mu.RUnlock() if memberID == "" { return nil, fmt.Errorf("member_id is required") } if req == nil || req.Message == "" { return nil, fmt.Errorf("message is required") } robot, _, err := m.getOrLoadRobot(ctx, memberID) if err != nil { return nil, fmt.Errorf("robot not found: %w", err) } execStore := store.NewExecutionStore() // No execution_id → create a new confirming execution if req.ExecutionID == "" { return m.handleNewInteraction(ctx, robot, req, execStore) } // Existing execution_id → load and route by status record, err := execStore.Get(ctx.Context, req.ExecutionID) if err != nil { return nil, fmt.Errorf("execution not found: %s", req.ExecutionID) } switch record.Status { case types.ExecConfirming: return m.handleConfirmingInteraction(ctx, robot, record, req, execStore) case types.ExecWaiting: return m.handleWaitingInteraction(ctx, robot, record, req, execStore) case types.ExecRunning: if record.WaitingTaskID == "" { return &InteractResponse{ ExecutionID: record.ExecutionID, Status: "rejected", Message: "Execution is running and not waiting for input", }, nil } return m.handleRunningInteraction(ctx, robot, record, req, execStore) default: return nil, fmt.Errorf("execution %s is in status %s, cannot interact", req.ExecutionID, record.Status) } } // handleNewInteraction creates a confirming execution and calls Host Agent with "assign" scenario. func (m *Manager) handleNewInteraction(ctx *types.Context, robot *types.Robot, req *InteractRequest, execStore *store.ExecutionStore) (*InteractResponse, error) { exec, chatID, err := m.createConfirmingExecution(ctx, robot, req, execStore) if err != nil { return nil, fmt.Errorf("failed to create confirming execution: %w", err) } hostOutput, err := m.callHostAgentForScenario(ctx, robot, "assign", req.Message, nil, chatID) if err != nil { log.Warn("Host Agent call failed, using direct assign: %v", err) return m.directAssign(ctx, robot, exec, req, execStore) } resp, err := m.processHostAction(ctx, robot, exec, hostOutput, execStore) if err != nil { return nil, err } resp.ExecutionID = exec.ExecutionID resp.ChatID = chatID return resp, nil } // handleConfirmingInteraction continues a confirming flow with Host Agent. func (m *Manager) handleConfirmingInteraction(ctx *types.Context, robot *types.Robot, record *store.ExecutionRecord, req *InteractRequest, execStore *store.ExecutionStore) (*InteractResponse, error) { hostCtx := m.buildHostContext(robot, record, nil) hostOutput, err := m.callHostAgentForScenario(ctx, robot, "assign", req.Message, hostCtx, record.ChatID) if err != nil { log.Warn("Host Agent call failed during confirming: %v", err) return &InteractResponse{ ExecutionID: record.ExecutionID, Status: "error", Message: fmt.Sprintf("Host Agent failed: %v", err), }, nil } resp, err := m.processHostAction(ctx, robot, record, hostOutput, execStore) if err != nil { return nil, err } resp.ExecutionID = record.ExecutionID resp.ChatID = record.ChatID return resp, nil } // handleWaitingInteraction processes input for a waiting (suspended) execution. func (m *Manager) handleWaitingInteraction(ctx *types.Context, robot *types.Robot, record *store.ExecutionRecord, req *InteractRequest, execStore *store.ExecutionStore) (*InteractResponse, error) { waitingTask := m.findWaitingTask(record) hostCtx := m.buildHostContext(robot, record, waitingTask) hostOutput, err := m.callHostAgentForScenario(ctx, robot, "clarify", req.Message, hostCtx, record.ChatID) if err != nil { log.Warn("Host Agent call failed during clarify, falling back to direct resume: %v", err) return m.directResume(ctx, record, req) } resp, err := m.processHostAction(ctx, robot, record, hostOutput, execStore) if err != nil { return nil, err } resp.ExecutionID = record.ExecutionID resp.ChatID = record.ChatID return resp, nil } // handleRunningInteraction allows guidance for a running execution. func (m *Manager) handleRunningInteraction(ctx *types.Context, robot *types.Robot, record *store.ExecutionRecord, req *InteractRequest, execStore *store.ExecutionStore) (*InteractResponse, error) { hostCtx := m.buildHostContext(robot, record, nil) hostOutput, err := m.callHostAgentForScenario(ctx, robot, "guide", req.Message, hostCtx, record.ChatID) if err != nil { return &InteractResponse{ ExecutionID: record.ExecutionID, Status: "acknowledged", Message: "Guidance noted (Host Agent unavailable)", }, nil } resp, err := m.processHostAction(ctx, robot, record, hostOutput, execStore) if err != nil { return nil, err } resp.ExecutionID = record.ExecutionID resp.ChatID = record.ChatID return resp, nil } // ==================== Helper Methods ==================== // createConfirmingExecution creates a new execution in "confirming" status. func (m *Manager) createConfirmingExecution(ctx *types.Context, robot *types.Robot, req *InteractRequest, execStore *store.ExecutionStore) (*store.ExecutionRecord, string, error) { execID := pool.GenerateExecID() chatID := fmt.Sprintf("robot_%s_%s", robot.MemberID, execID) now := time.Now() record := &store.ExecutionRecord{ ExecutionID: execID, MemberID: robot.MemberID, TeamID: robot.TeamID, TriggerType: types.TriggerHuman, Status: types.ExecConfirming, Phase: types.PhaseGoals, ChatID: chatID, Input: &types.TriggerInput{ Action: types.ActionTaskAdd, Messages: []agentcontext.Message{{Role: "user", Content: req.Message}}, UserID: ctx.UserID(), }, StartTime: &now, } if err := execStore.Save(ctx.Context, record); err != nil { return nil, "", fmt.Errorf("failed to save confirming execution: %w", err) } return record, chatID, nil } // buildHostContext builds the HostContext for Host Agent calls. func (m *Manager) buildHostContext(robot *types.Robot, record *store.ExecutionRecord, waitingTask *types.Task) *types.HostContext { hostCtx := &types.HostContext{ RobotStatus: m.buildRobotStatusSnapshot(robot), } if record.Goals != nil { hostCtx.Goals = record.Goals } if len(record.Tasks) > 0 { hostCtx.Tasks = record.Tasks } if waitingTask != nil { hostCtx.CurrentTask = waitingTask } if record.WaitingQuestion != "" { hostCtx.AgentReply = record.WaitingQuestion } return hostCtx } // buildRobotStatusSnapshot builds a status snapshot for the Host Agent. func (m *Manager) buildRobotStatusSnapshot(robot *types.Robot) *types.RobotStatusSnapshot { if robot == nil { return nil } snapshot := &types.RobotStatusSnapshot{ MemberID: robot.MemberID, Status: robot.Status, ActiveCount: robot.ActiveCount(), WaitingCount: robot.WaitingCount(), MaxQuota: robot.MaxQuota(), ActiveExecs: robot.ListExecutionBriefs(), } if m.pool != nil { snapshot.QueuedCount = m.pool.QueueSize() } return snapshot } // findWaitingTask finds the task that is currently waiting for input. func (m *Manager) findWaitingTask(record *store.ExecutionRecord) *types.Task { if record.WaitingTaskID == "" { return nil } for i := range record.Tasks { if record.Tasks[i].ID == record.WaitingTaskID { return &record.Tasks[i] } } return nil } // callHostAgentForScenario calls the Host Agent with a given scenario. func (m *Manager) callHostAgentForScenario(ctx *types.Context, robot *types.Robot, scenario string, message string, hostCtx *types.HostContext, chatID string) (*types.HostOutput, error) { agentID := "" if robot.Config != nil && robot.Config.Resources != nil { agentID = robot.Config.Resources.GetPhaseAgent(types.PhaseHost) } if agentID == "" { return nil, fmt.Errorf("no Host Agent configured for robot %s", robot.MemberID) } return m.callHostAgent(ctx, agentID, &types.HostInput{ Scenario: scenario, Messages: []agentcontext.Message{{Role: "user", Content: message}}, Context: hostCtx, }, chatID) } // callHostAgent calls the Host Agent assistant and parses output. func (m *Manager) callHostAgent(ctx *types.Context, agentID string, input *types.HostInput, chatID string) (*types.HostOutput, error) { inputJSON, err := json.Marshal(input) if err != nil { return nil, fmt.Errorf("failed to marshal host input: %w", err) } caller := standard.NewConversationCaller(chatID) result, err := caller.CallWithMessages(ctx, agentID, string(inputJSON)) if err != nil { return nil, fmt.Errorf("host agent (%s) call failed: %w", agentID, err) } return m.parseHostAgentResult(result) } // parseHostAgentResult inspects the agent result to determine if it is an action // decision (JSON with "action" field) or a conversational reply (natural language). func (m *Manager) parseHostAgentResult(result *standard.CallResult) (*types.HostOutput, error) { data, err := result.GetJSON() if err == nil { output := &types.HostOutput{} raw, _ := json.Marshal(data) if err := json.Unmarshal(raw, output); err == nil && output.Action != "" { return output, nil } } return &types.HostOutput{ Reply: result.GetText(), WaitForMore: true, }, nil } // processHostAction processes the output from Host Agent and takes the appropriate action. func (m *Manager) processHostAction(ctx *types.Context, robot *types.Robot, record *store.ExecutionRecord, output *types.HostOutput, execStore *store.ExecutionStore) (*InteractResponse, error) { resp := &InteractResponse{ Reply: output.Reply, WaitForMore: output.WaitForMore, } if output.WaitForMore { resp.Status = "waiting_for_more" resp.Message = output.Reply return resp, nil } switch output.Action { case types.HostActionConfirm: if err := m.advanceExecution(ctx, robot, record, execStore); err != nil { return nil, fmt.Errorf("failed to advance execution: %w", err) } resp.Status = "confirmed" resp.Message = "Execution confirmed and started" case types.HostActionAdjust: if err := m.adjustExecution(ctx, record, output.ActionData, execStore); err != nil { return nil, fmt.Errorf("failed to adjust execution: %w", err) } resp.Status = "adjusted" resp.Message = "Execution plan adjusted" case types.HostActionAddTask: if err := m.injectTask(ctx, record, output.ActionData, execStore); err != nil { return nil, fmt.Errorf("failed to inject task: %w", err) } resp.Status = "task_added" resp.Message = "New task injected" case types.HostActionSkip: if err := m.skipWaitingTask(ctx, record, execStore); err != nil { return nil, fmt.Errorf("failed to skip task: %w", err) } resp.Status = "task_skipped" resp.Message = "Waiting task skipped" case types.HostActionInjectCtx: if err := m.resumeWithContext(ctx, record, output.ActionData, execStore); err != nil { if err == types.ErrExecutionSuspended { resp.Status = "waiting" resp.Message = "Execution suspended again" return resp, nil } return nil, fmt.Errorf("failed to resume with context: %w", err) } resp.Status = "resumed" resp.Message = "Execution resumed with additional context" case types.HostActionCancel: if err := m.CancelExecution(ctx, record.ExecutionID); err != nil { return nil, fmt.Errorf("failed to cancel execution: %w", err) } resp.Status = "cancelled" resp.Message = "Execution cancelled" default: resp.Status = "acknowledged" resp.Message = output.Reply } return resp, nil } // advanceExecution moves a confirming execution to running. func (m *Manager) advanceExecution(ctx *types.Context, robot *types.Robot, record *store.ExecutionRecord, execStore *store.ExecutionStore) error { if err := execStore.UpdateStatus(ctx.Context, record.ExecutionID, types.ExecRunning, ""); err != nil { return err } ctrlExec := m.execController.Track(record.ExecutionID, record.MemberID, record.TeamID) execCtx := types.NewContext(ctrlExec.Context(), ctx.Auth) triggerInput := record.Input _, err := m.pool.SubmitWithID(execCtx, robot, types.TriggerHuman, triggerInput, record.ExecutionID, ctrlExec) if err != nil { m.execController.Untrack(record.ExecutionID) return fmt.Errorf("failed to submit execution to pool: %w", err) } return nil } // adjustExecution adjusts goals/tasks based on Host Agent output. func (m *Manager) adjustExecution(ctx *types.Context, record *store.ExecutionRecord, actionData interface{}, execStore *store.ExecutionStore) error { if actionData == nil { return nil } data, ok := actionData.(map[string]interface{}) if !ok { raw, err := json.Marshal(actionData) if err != nil { return nil } json.Unmarshal(raw, &data) } if goalsContent, ok := data["goals"].(string); ok && goalsContent != "" { record.Goals = &types.Goals{Content: goalsContent} } if tasksRaw, ok := data["tasks"]; ok { raw, _ := json.Marshal(tasksRaw) var tasks []types.Task if err := json.Unmarshal(raw, &tasks); err == nil { record.Tasks = tasks } } return execStore.Save(ctx.Context, record) } // injectTask adds a new task to the execution's task list. func (m *Manager) injectTask(ctx *types.Context, record *store.ExecutionRecord, actionData interface{}, execStore *store.ExecutionStore) error { if actionData == nil { return fmt.Errorf("task data is required") } raw, err := json.Marshal(actionData) if err != nil { return fmt.Errorf("invalid task data: %w", err) } var newTask types.Task if err := json.Unmarshal(raw, &newTask); err != nil { return fmt.Errorf("failed to parse task: %w", err) } if newTask.ID == "" { newTask.ID = fmt.Sprintf("injected-%s", utils.NewID()[:8]) } newTask.Status = types.TaskPending record.Tasks = append(record.Tasks, newTask) return execStore.Save(ctx.Context, record) } // skipWaitingTask skips the currently waiting task and resumes execution. func (m *Manager) skipWaitingTask(ctx *types.Context, record *store.ExecutionRecord, execStore *store.ExecutionStore) error { if record.WaitingTaskID == "" { return fmt.Errorf("no task is waiting") } for i := range record.Tasks { if record.Tasks[i].ID == record.WaitingTaskID { record.Tasks[i].Status = types.TaskSkipped break } } err := m.executeResume(ctx, record.ExecutionID, "__skip__") if err != nil && err != types.ErrExecutionSuspended { return fmt.Errorf("failed to resume after skip: %w", err) } return nil } // resumeWithContext injects context and resumes the waiting execution. func (m *Manager) resumeWithContext(ctx *types.Context, record *store.ExecutionRecord, actionData interface{}, execStore *store.ExecutionStore) error { reply := "" if actionData != nil { if s, ok := actionData.(string); ok { reply = s } else if data, ok := actionData.(map[string]interface{}); ok { if r, ok := data["reply"].(string); ok { reply = r } else { raw, _ := json.Marshal(data) reply = string(raw) } } } return m.executeResume(ctx, record.ExecutionID, reply) } // directAssign is the fallback when Host Agent is unavailable: directly start execution. func (m *Manager) directAssign(ctx *types.Context, robot *types.Robot, record *store.ExecutionRecord, req *InteractRequest, execStore *store.ExecutionStore) (*InteractResponse, error) { if err := m.advanceExecution(ctx, robot, record, execStore); err != nil { return nil, fmt.Errorf("direct assign failed: %w", err) } return &InteractResponse{ ExecutionID: record.ExecutionID, Status: "confirmed", Message: "Execution started (direct assign)", ChatID: record.ChatID, }, nil } // directResume is the fallback when Host Agent is unavailable: directly resume. func (m *Manager) directResume(ctx *types.Context, record *store.ExecutionRecord, req *InteractRequest) (*InteractResponse, error) { err := m.executeResume(ctx, record.ExecutionID, req.Message) if err != nil { if err == types.ErrExecutionSuspended { return &InteractResponse{ ExecutionID: record.ExecutionID, Status: "waiting", Message: "Execution suspended again: needs more input", ChatID: record.ChatID, }, nil } return nil, fmt.Errorf("failed to resume execution: %w", err) } return &InteractResponse{ ExecutionID: record.ExecutionID, Status: "resumed", Message: "Execution resumed and completed successfully", ChatID: record.ChatID, }, nil } // ==================== Streaming Interact ==================== // HandleInteractStream is the streaming version of HandleInteract. // It streams Host Agent text tokens via streamFn while still returning the final InteractResponse. func (m *Manager) HandleInteractStream(ctx *types.Context, memberID string, req *InteractRequest, streamFn standard.StreamCallback) (*InteractResponse, error) { m.mu.RLock() if !m.started { m.mu.RUnlock() return nil, fmt.Errorf("manager not started") } m.mu.RUnlock() if memberID == "" { return nil, fmt.Errorf("member_id is required") } if req == nil || req.Message == "" { return nil, fmt.Errorf("message is required") } robot, _, err := m.getOrLoadRobot(ctx, memberID) if err != nil { return nil, fmt.Errorf("robot not found: %w", err) } execStore := store.NewExecutionStore() if req.ExecutionID == "" { return m.handleNewInteractionStream(ctx, robot, req, execStore, streamFn) } record, err := execStore.Get(ctx.Context, req.ExecutionID) if err != nil { return nil, fmt.Errorf("execution not found: %s", req.ExecutionID) } switch record.Status { case types.ExecConfirming: return m.handleConfirmingInteractionStream(ctx, robot, record, req, execStore, streamFn) case types.ExecWaiting: return m.handleWaitingInteractionStream(ctx, robot, record, req, execStore, streamFn) case types.ExecRunning: if record.WaitingTaskID == "" { return &InteractResponse{ ExecutionID: record.ExecutionID, Status: "rejected", Message: "Execution is running and not waiting for input", }, nil } return m.handleRunningInteractionStream(ctx, robot, record, req, execStore, streamFn) default: return nil, fmt.Errorf("execution %s is in status %s, cannot interact", req.ExecutionID, record.Status) } } func (m *Manager) handleNewInteractionStream(ctx *types.Context, robot *types.Robot, req *InteractRequest, execStore *store.ExecutionStore, streamFn standard.StreamCallback) (*InteractResponse, error) { exec, chatID, err := m.createConfirmingExecution(ctx, robot, req, execStore) if err != nil { return nil, fmt.Errorf("failed to create confirming execution: %w", err) } hostOutput, err := m.callHostAgentForScenarioStream(ctx, robot, "assign", req.Message, nil, chatID, streamFn) if err != nil { log.Warn("Host Agent call failed, using direct assign: %v", err) return m.directAssign(ctx, robot, exec, req, execStore) } resp, err := m.processHostAction(ctx, robot, exec, hostOutput, execStore) if err != nil { return nil, err } resp.ExecutionID = exec.ExecutionID resp.ChatID = chatID return resp, nil } func (m *Manager) handleConfirmingInteractionStream(ctx *types.Context, robot *types.Robot, record *store.ExecutionRecord, req *InteractRequest, execStore *store.ExecutionStore, streamFn standard.StreamCallback) (*InteractResponse, error) { hostCtx := m.buildHostContext(robot, record, nil) hostOutput, err := m.callHostAgentForScenarioStream(ctx, robot, "assign", req.Message, hostCtx, record.ChatID, streamFn) if err != nil { log.Warn("Host Agent call failed during confirming: %v", err) return &InteractResponse{ ExecutionID: record.ExecutionID, Status: "error", Message: fmt.Sprintf("Host Agent failed: %v", err), }, nil } resp, err := m.processHostAction(ctx, robot, record, hostOutput, execStore) if err != nil { return nil, err } resp.ExecutionID = record.ExecutionID resp.ChatID = record.ChatID return resp, nil } func (m *Manager) handleWaitingInteractionStream(ctx *types.Context, robot *types.Robot, record *store.ExecutionRecord, req *InteractRequest, execStore *store.ExecutionStore, streamFn standard.StreamCallback) (*InteractResponse, error) { waitingTask := m.findWaitingTask(record) hostCtx := m.buildHostContext(robot, record, waitingTask) hostOutput, err := m.callHostAgentForScenarioStream(ctx, robot, "clarify", req.Message, hostCtx, record.ChatID, streamFn) if err != nil { log.Warn("Host Agent call failed during clarify, falling back to direct resume: %v", err) return m.directResume(ctx, record, req) } resp, err := m.processHostAction(ctx, robot, record, hostOutput, execStore) if err != nil { return nil, err } resp.ExecutionID = record.ExecutionID resp.ChatID = record.ChatID return resp, nil } func (m *Manager) handleRunningInteractionStream(ctx *types.Context, robot *types.Robot, record *store.ExecutionRecord, req *InteractRequest, execStore *store.ExecutionStore, streamFn standard.StreamCallback) (*InteractResponse, error) { hostCtx := m.buildHostContext(robot, record, nil) hostOutput, err := m.callHostAgentForScenarioStream(ctx, robot, "guide", req.Message, hostCtx, record.ChatID, streamFn) if err != nil { return &InteractResponse{ ExecutionID: record.ExecutionID, Status: "acknowledged", Message: "Guidance noted (Host Agent unavailable)", }, nil } resp, err := m.processHostAction(ctx, robot, record, hostOutput, execStore) if err != nil { return nil, err } resp.ExecutionID = record.ExecutionID resp.ChatID = record.ChatID return resp, nil } func (m *Manager) callHostAgentForScenarioStream(ctx *types.Context, robot *types.Robot, scenario string, msg string, hostCtx *types.HostContext, chatID string, streamFn standard.StreamCallback) (*types.HostOutput, error) { agentID := "" if robot.Config != nil && robot.Config.Resources != nil { agentID = robot.Config.Resources.GetPhaseAgent(types.PhaseHost) } if agentID == "" { return nil, fmt.Errorf("no Host Agent configured for robot %s", robot.MemberID) } return m.callHostAgentStream(ctx, agentID, &types.HostInput{ Scenario: scenario, Messages: []agentcontext.Message{{Role: "user", Content: msg}}, Context: hostCtx, }, chatID, streamFn) } func (m *Manager) callHostAgentStream(ctx *types.Context, agentID string, input *types.HostInput, chatID string, streamFn standard.StreamCallback) (*types.HostOutput, error) { inputJSON, err := json.Marshal(input) if err != nil { return nil, fmt.Errorf("failed to marshal host input: %w", err) } caller := standard.NewConversationCaller(chatID) result, err := caller.CallWithMessagesStream(ctx, agentID, string(inputJSON), streamFn) if err != nil { return nil, fmt.Errorf("host agent (%s) call failed: %w", agentID, err) } return m.parseHostAgentResult(result) } // ==================== Raw Message Streaming (CUI Protocol) ==================== // HandleInteractStreamRaw is the CUI-protocol-aligned streaming version of HandleInteract. // It passes raw message.Message objects directly to the onMessage callback, preserving all // CUI protocol fields for direct SSE passthrough to the frontend. func (m *Manager) HandleInteractStreamRaw(ctx *types.Context, memberID string, req *InteractRequest, onMessage agentcontext.OnMessageFunc) (*InteractResponse, error) { m.mu.RLock() if !m.started { m.mu.RUnlock() return nil, fmt.Errorf("manager not started") } m.mu.RUnlock() if memberID == "" { return nil, fmt.Errorf("member_id is required") } if req == nil || req.Message == "" { return nil, fmt.Errorf("message is required") } robot, _, err := m.getOrLoadRobot(ctx, memberID) if err != nil { return nil, fmt.Errorf("robot not found: %w", err) } execStore := store.NewExecutionStore() if req.ExecutionID == "" { return m.handleNewInteractionStreamRaw(ctx, robot, req, execStore, onMessage) } record, err := execStore.Get(ctx.Context, req.ExecutionID) if err != nil { return nil, fmt.Errorf("execution not found: %s", req.ExecutionID) } switch record.Status { case types.ExecConfirming: return m.handleConfirmingInteractionStreamRaw(ctx, robot, record, req, execStore, onMessage) case types.ExecWaiting: return m.handleWaitingInteractionStreamRaw(ctx, robot, record, req, execStore, onMessage) case types.ExecRunning: if record.WaitingTaskID == "" { return &InteractResponse{ ExecutionID: record.ExecutionID, Status: "rejected", Message: "Execution is running and not waiting for input", }, nil } return m.handleRunningInteractionStreamRaw(ctx, robot, record, req, execStore, onMessage) default: return nil, fmt.Errorf("execution %s is in status %s, cannot interact", req.ExecutionID, record.Status) } } func (m *Manager) handleNewInteractionStreamRaw(ctx *types.Context, robot *types.Robot, req *InteractRequest, execStore *store.ExecutionStore, onMessage agentcontext.OnMessageFunc) (*InteractResponse, error) { exec, chatID, err := m.createConfirmingExecution(ctx, robot, req, execStore) if err != nil { return nil, fmt.Errorf("failed to create confirming execution: %w", err) } hostOutput, err := m.callHostAgentForScenarioStreamRaw(ctx, robot, "assign", req.Message, nil, chatID, onMessage) if err != nil { log.Warn("Host Agent call failed, using direct assign: %v", err) return m.directAssign(ctx, robot, exec, req, execStore) } resp, err := m.processHostAction(ctx, robot, exec, hostOutput, execStore) if err != nil { return nil, err } resp.ExecutionID = exec.ExecutionID resp.ChatID = chatID return resp, nil } func (m *Manager) handleConfirmingInteractionStreamRaw(ctx *types.Context, robot *types.Robot, record *store.ExecutionRecord, req *InteractRequest, execStore *store.ExecutionStore, onMessage agentcontext.OnMessageFunc) (*InteractResponse, error) { hostCtx := m.buildHostContext(robot, record, nil) hostOutput, err := m.callHostAgentForScenarioStreamRaw(ctx, robot, "assign", req.Message, hostCtx, record.ChatID, onMessage) if err != nil { log.Warn("Host Agent call failed during confirming: %v", err) return &InteractResponse{ ExecutionID: record.ExecutionID, Status: "error", Message: fmt.Sprintf("Host Agent failed: %v", err), }, nil } resp, err := m.processHostAction(ctx, robot, record, hostOutput, execStore) if err != nil { return nil, err } resp.ExecutionID = record.ExecutionID resp.ChatID = record.ChatID return resp, nil } func (m *Manager) handleWaitingInteractionStreamRaw(ctx *types.Context, robot *types.Robot, record *store.ExecutionRecord, req *InteractRequest, execStore *store.ExecutionStore, onMessage agentcontext.OnMessageFunc) (*InteractResponse, error) { waitingTask := m.findWaitingTask(record) hostCtx := m.buildHostContext(robot, record, waitingTask) hostOutput, err := m.callHostAgentForScenarioStreamRaw(ctx, robot, "clarify", req.Message, hostCtx, record.ChatID, onMessage) if err != nil { log.Warn("Host Agent call failed during clarify, falling back to direct resume: %v", err) return m.directResume(ctx, record, req) } resp, err := m.processHostAction(ctx, robot, record, hostOutput, execStore) if err != nil { return nil, err } resp.ExecutionID = record.ExecutionID resp.ChatID = record.ChatID return resp, nil } func (m *Manager) handleRunningInteractionStreamRaw(ctx *types.Context, robot *types.Robot, record *store.ExecutionRecord, req *InteractRequest, execStore *store.ExecutionStore, onMessage agentcontext.OnMessageFunc) (*InteractResponse, error) { hostCtx := m.buildHostContext(robot, record, nil) hostOutput, err := m.callHostAgentForScenarioStreamRaw(ctx, robot, "guide", req.Message, hostCtx, record.ChatID, onMessage) if err != nil { return &InteractResponse{ ExecutionID: record.ExecutionID, Status: "acknowledged", Message: "Guidance noted (Host Agent unavailable)", }, nil } resp, err := m.processHostAction(ctx, robot, record, hostOutput, execStore) if err != nil { return nil, err } resp.ExecutionID = record.ExecutionID resp.ChatID = record.ChatID return resp, nil } func (m *Manager) callHostAgentForScenarioStreamRaw(ctx *types.Context, robot *types.Robot, scenario string, msg string, hostCtx *types.HostContext, chatID string, onMessage agentcontext.OnMessageFunc) (*types.HostOutput, error) { agentID := "" if robot.Config != nil && robot.Config.Resources != nil { agentID = robot.Config.Resources.GetPhaseAgent(types.PhaseHost) } if agentID == "" { return nil, fmt.Errorf("no Host Agent configured for robot %s", robot.MemberID) } return m.callHostAgentStreamRaw(ctx, agentID, &types.HostInput{ Scenario: scenario, Messages: []agentcontext.Message{{Role: "user", Content: msg}}, Context: hostCtx, }, chatID, onMessage) } // callHostAgentStreamRaw calls the Host Agent with CUI raw message streaming. // It buffers text chunks that look like JSON output (starting with "{" or "```json") // so the frontend never sees raw decision JSON. If the final result is a decision, // the buffered chunks are discarded and a clean reply is sent instead. If the // result is a normal conversation turn, buffered chunks are flushed through. func (m *Manager) callHostAgentStreamRaw(ctx *types.Context, agentID string, input *types.HostInput, chatID string, onMessage agentcontext.OnMessageFunc) (*types.HostOutput, error) { inputJSON, err := json.Marshal(input) if err != nil { return nil, fmt.Errorf("failed to marshal host input: %w", err) } var ( bufferedChunks []*message.Message buffering bool accumulatedText string lastTextMsgID string ) wrappedOnMessage := func(msg *message.Message) int { if msg == nil { return onMessage(msg) } // Only intercept text type messages with delta content if msg.Type != message.TypeText || !msg.Delta { return onMessage(msg) } if msg.MessageID != "" { lastTextMsgID = msg.MessageID } // Extract the text content from this chunk chunkText := "" if msg.Props != nil { if c, ok := msg.Props["content"].(string); ok { chunkText = c } } accumulatedText += chunkText // Decide whether to buffer: check accumulated text so far trimmed := strings.TrimSpace(accumulatedText) if !buffering && len(trimmed) > 0 { if trimmed[0] == '{' || strings.HasPrefix(trimmed, "```") { buffering = true } } if buffering { bufferedChunks = append(bufferedChunks, msg) return 0 } return onMessage(msg) } caller := standard.NewConversationCaller(chatID) result, err := caller.CallWithMessagesStreamRaw(ctx, agentID, string(inputJSON), wrappedOnMessage) if err != nil { return nil, fmt.Errorf("host agent (%s) call failed: %w", agentID, err) } output, err := m.parseHostAgentResult(result) if err != nil { return nil, err } if output.Action != "" && lastTextMsgID != "" { // Decision detected — discard buffered JSON chunks, send reply text onMessage(&message.Message{ Type: message.TypeText, MessageID: lastTextMsgID, Props: map[string]interface{}{"content": output.Reply}, Delta: false, }) } else if len(bufferedChunks) > 0 { // Not a decision — flush all buffered chunks to the frontend for _, chunk := range bufferedChunks { if onMessage(chunk) != 0 { break } } } return output, nil } ================================================ FILE: agent/robot/manager/interact_helpers_test.go ================================================ package manager import ( "context" "encoding/json" "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/robot/cache" "github.com/yaoapp/yao/agent/robot/store" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // mockExecutor is a minimal Executor for unit testing type mockExecutor struct { resumeErr error } func (m *mockExecutor) ExecuteWithControl(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}, execID string, control types.ExecutionControl) (*types.Execution, error) { return nil, fmt.Errorf("not implemented") } func (m *mockExecutor) ExecuteWithID(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}, execID string) (*types.Execution, error) { return nil, fmt.Errorf("not implemented") } func (m *mockExecutor) Execute(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}) (*types.Execution, error) { return nil, fmt.Errorf("not implemented") } func (m *mockExecutor) Resume(ctx *types.Context, execID string, reply string) error { return m.resumeErr } func (m *mockExecutor) ExecCount() int { return 0 } func (m *mockExecutor) CurrentCount() int { return 0 } func (m *mockExecutor) Reset() {} // HL1: createConfirmingExecution func TestCreateConfirmingExecution(t *testing.T) { m := &Manager{} t.Run("creates record with correct fields", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) robot := &types.Robot{MemberID: "member-hl1", TeamID: "team-hl1"} req := &InteractRequest{Message: "do something"} execStore := store.NewExecutionStore() record, chatID, err := m.createConfirmingExecution(ctx, robot, req, execStore) require.NoError(t, err) assert.NotEmpty(t, record.ExecutionID) assert.Equal(t, "member-hl1", record.MemberID) assert.Equal(t, "team-hl1", record.TeamID) assert.Equal(t, types.ExecConfirming, record.Status) assert.Equal(t, types.TriggerHuman, record.TriggerType) assert.Equal(t, types.PhaseGoals, record.Phase) assert.Contains(t, chatID, "robot_member-hl1_") assert.Equal(t, chatID, record.ChatID) assert.NotNil(t, record.Input) assert.Equal(t, types.ActionTaskAdd, record.Input.Action) assert.Len(t, record.Input.Messages, 1) assert.Equal(t, "do something", record.Input.Messages[0].Content) assert.NotNil(t, record.StartTime) }) t.Run("UserID empty when auth is nil", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) robot := &types.Robot{MemberID: "member-hl1b", TeamID: "team-hl1b"} req := &InteractRequest{Message: "test"} execStore := store.NewExecutionStore() record, _, err := m.createConfirmingExecution(ctx, robot, req, execStore) require.NoError(t, err) assert.Empty(t, record.Input.UserID) }) } // HL2-HL4: adjustExecution func TestAdjustExecution(t *testing.T) { m := &Manager{} t.Run("adjusts goals from string", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) record := &store.ExecutionRecord{ ExecutionID: "exec-hl2", MemberID: "member-hl2", TriggerType: types.TriggerHuman, Status: types.ExecPending, Phase: types.PhaseInspiration, } execStore := store.NewExecutionStore() require.NoError(t, execStore.Save(ctx.Context, record)) actionData := map[string]interface{}{"goals": "updated goals content"} err := m.adjustExecution(ctx, record, actionData, execStore) require.NoError(t, err) require.NotNil(t, record.Goals) assert.Equal(t, "updated goals content", record.Goals.Content) }) t.Run("adjusts tasks from array", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) record := &store.ExecutionRecord{ ExecutionID: "exec-hl3", MemberID: "member-hl3", TriggerType: types.TriggerHuman, Status: types.ExecPending, Phase: types.PhaseInspiration, } execStore := store.NewExecutionStore() require.NoError(t, execStore.Save(ctx.Context, record)) tasks := []map[string]interface{}{ {"id": "t1", "name": "Task 1"}, {"id": "t2", "name": "Task 2"}, } actionData := map[string]interface{}{"tasks": tasks} err := m.adjustExecution(ctx, record, actionData, execStore) require.NoError(t, err) assert.Len(t, record.Tasks, 2) }) t.Run("nil action data is noop", func(t *testing.T) { ctx := types.NewContext(context.Background(), nil) record := &store.ExecutionRecord{} execStore := store.NewExecutionStore() err := m.adjustExecution(ctx, record, nil, execStore) require.NoError(t, err) assert.Nil(t, record.Goals) }) t.Run("non-map action data handled gracefully", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) record := &store.ExecutionRecord{ ExecutionID: "exec-hl4", MemberID: "member-hl4", TriggerType: types.TriggerHuman, Status: types.ExecPending, Phase: types.PhaseInspiration, } execStore := store.NewExecutionStore() require.NoError(t, execStore.Save(ctx.Context, record)) err := m.adjustExecution(ctx, record, "not a map", execStore) require.NoError(t, err) }) } // HL5-HL6: injectTask func TestInjectTask(t *testing.T) { m := &Manager{} t.Run("appends new task with auto-generated ID", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) record := &store.ExecutionRecord{ ExecutionID: "exec-hl5", MemberID: "member-hl5", TriggerType: types.TriggerHuman, Status: types.ExecPending, Phase: types.PhaseInspiration, } execStore := store.NewExecutionStore() require.NoError(t, execStore.Save(ctx.Context, record)) taskData := map[string]interface{}{"name": "New Task"} err := m.injectTask(ctx, record, taskData, execStore) require.NoError(t, err) require.Len(t, record.Tasks, 1) assert.Contains(t, record.Tasks[0].ID, "injected-") assert.Equal(t, types.TaskPending, record.Tasks[0].Status) }) t.Run("preserves existing tasks", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) record := &store.ExecutionRecord{ ExecutionID: "exec-hl6", MemberID: "member-hl6", TriggerType: types.TriggerHuman, Status: types.ExecPending, Phase: types.PhaseInspiration, Tasks: []types.Task{ {ID: "existing-1", Description: "Existing"}, }, } execStore := store.NewExecutionStore() require.NoError(t, execStore.Save(ctx.Context, record)) taskData := map[string]interface{}{"name": "Added Task"} err := m.injectTask(ctx, record, taskData, execStore) require.NoError(t, err) assert.Len(t, record.Tasks, 2) assert.Equal(t, "existing-1", record.Tasks[0].ID) }) t.Run("nil action data returns error", func(t *testing.T) { ctx := types.NewContext(context.Background(), nil) record := &store.ExecutionRecord{} execStore := store.NewExecutionStore() err := m.injectTask(ctx, record, nil, execStore) assert.Error(t, err) assert.Contains(t, err.Error(), "task data is required") }) t.Run("respects provided task ID", func(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) ctx := types.NewContext(context.Background(), nil) record := &store.ExecutionRecord{ ExecutionID: "exec-hl6b", MemberID: "member-hl6b", TriggerType: types.TriggerHuman, Status: types.ExecPending, Phase: types.PhaseInspiration, } execStore := store.NewExecutionStore() require.NoError(t, execStore.Save(ctx.Context, record)) taskData := map[string]interface{}{"id": "custom-id", "name": "Custom"} err := m.injectTask(ctx, record, taskData, execStore) require.NoError(t, err) assert.Equal(t, "custom-id", record.Tasks[0].ID) }) } // HL7: callHostAgentForScenario func TestCallHostAgentForScenario(t *testing.T) { m := &Manager{} t.Run("no host agent returns error", func(t *testing.T) { ctx := types.NewContext(context.Background(), nil) robot := &types.Robot{MemberID: "member-hl7"} _, err := m.callHostAgentForScenario(ctx, robot, "assign", "test", nil, "chat-1") assert.Error(t, err) assert.Contains(t, err.Error(), "no Host Agent configured") }) t.Run("robot with nil config returns error", func(t *testing.T) { ctx := types.NewContext(context.Background(), nil) robot := &types.Robot{MemberID: "member-hl7b", Config: nil} _, err := m.callHostAgentForScenario(ctx, robot, "assign", "test", nil, "chat-1") assert.Error(t, err) assert.Contains(t, err.Error(), "no Host Agent configured") }) } // HL8: directAssign (needs pool — tested in processHostAction) // HL9-HL10: directResume (needs executor — tested in processHostAction) // Updated buildRobotStatusSnapshot tests func TestBuildRobotStatusSnapshotV2(t *testing.T) { m := &Manager{} t.Run("nil robot returns nil", func(t *testing.T) { snap := m.buildRobotStatusSnapshot(nil) assert.Nil(t, snap) }) t.Run("populates MemberID and Status", func(t *testing.T) { robot := &types.Robot{ MemberID: "member-snap", Status: types.RobotWorking, } snap := m.buildRobotStatusSnapshot(robot) require.NotNil(t, snap) assert.Equal(t, "member-snap", snap.MemberID) assert.Equal(t, types.RobotWorking, snap.Status) }) t.Run("uses ActiveCount and WaitingCount", func(t *testing.T) { robot := &types.Robot{MemberID: "member-snap2"} exec1 := &types.Execution{ID: "e1", Status: types.ExecRunning} exec2 := &types.Execution{ID: "e2", Status: types.ExecWaiting} robot.AddExecution(exec1) robot.AddExecution(exec2) snap := m.buildRobotStatusSnapshot(robot) require.NotNil(t, snap) assert.Equal(t, 1, snap.ActiveCount) assert.Equal(t, 1, snap.WaitingCount) }) t.Run("populates ActiveExecs briefs", func(t *testing.T) { robot := &types.Robot{MemberID: "member-snap3"} exec := &types.Execution{ID: "e-brief", Status: types.ExecRunning, Name: "Test Exec"} robot.AddExecution(exec) snap := m.buildRobotStatusSnapshot(robot) require.NotNil(t, snap) require.Len(t, snap.ActiveExecs, 1) assert.Equal(t, "e-brief", snap.ActiveExecs[0].ID) }) t.Run("uses robot MaxQuota", func(t *testing.T) { robot := &types.Robot{ MemberID: "member-snap4", Config: &types.Config{Quota: &types.Quota{Max: 7}}, } snap := m.buildRobotStatusSnapshot(robot) require.NotNil(t, snap) assert.Equal(t, 7, snap.MaxQuota) }) } // Test processHostAction — adjust branch func TestProcessHostActionAdjust(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) m := &Manager{} ctx := types.NewContext(context.Background(), nil) robot := &types.Robot{MemberID: "member-pa-adj"} t.Run("adjust with goals", func(t *testing.T) { record := &store.ExecutionRecord{ ExecutionID: "exec-pa2", MemberID: "member-pa-adj", TriggerType: types.TriggerHuman, Status: types.ExecConfirming, Phase: types.PhaseInspiration, } execStore := store.NewExecutionStore() require.NoError(t, execStore.Save(ctx.Context, record)) output := &types.HostOutput{ Reply: "Plan adjusted", Action: types.HostActionAdjust, ActionData: map[string]interface{}{"goals": "new goals"}, } resp, err := m.processHostAction(ctx, robot, record, output, execStore) require.NoError(t, err) assert.Equal(t, "adjusted", resp.Status) require.NotNil(t, record.Goals) assert.Equal(t, "new goals", record.Goals.Content) }) t.Run("adjust with tasks", func(t *testing.T) { record := &store.ExecutionRecord{ ExecutionID: "exec-pa3", MemberID: "member-pa-adj", TriggerType: types.TriggerHuman, Status: types.ExecConfirming, Phase: types.PhaseInspiration, } execStore := store.NewExecutionStore() require.NoError(t, execStore.Save(ctx.Context, record)) tasksJSON := []map[string]interface{}{{"id": "t1", "name": "Adjusted Task"}} output := &types.HostOutput{ Reply: "Tasks updated", Action: types.HostActionAdjust, ActionData: map[string]interface{}{"tasks": tasksJSON}, } resp, err := m.processHostAction(ctx, robot, record, output, execStore) require.NoError(t, err) assert.Equal(t, "adjusted", resp.Status) assert.Len(t, record.Tasks, 1) }) t.Run("adjust with nil data is noop", func(t *testing.T) { record := &store.ExecutionRecord{ ExecutionID: "exec-pa4", MemberID: "member-pa-adj", TriggerType: types.TriggerHuman, Status: types.ExecConfirming, Phase: types.PhaseInspiration, } execStore := store.NewExecutionStore() require.NoError(t, execStore.Save(ctx.Context, record)) output := &types.HostOutput{ Reply: "No changes", Action: types.HostActionAdjust, } resp, err := m.processHostAction(ctx, robot, record, output, execStore) require.NoError(t, err) assert.Equal(t, "adjusted", resp.Status) }) } // Test processHostAction — add_task branch func TestProcessHostActionAddTask(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) m := &Manager{} ctx := types.NewContext(context.Background(), nil) robot := &types.Robot{MemberID: "member-pa-at"} t.Run("add task success", func(t *testing.T) { record := &store.ExecutionRecord{ ExecutionID: "exec-pa5", MemberID: "member-pa-at", TriggerType: types.TriggerHuman, Status: types.ExecConfirming, Phase: types.PhaseInspiration, } execStore := store.NewExecutionStore() require.NoError(t, execStore.Save(ctx.Context, record)) output := &types.HostOutput{ Reply: "Task added", Action: types.HostActionAddTask, ActionData: map[string]interface{}{"name": "New task"}, } resp, err := m.processHostAction(ctx, robot, record, output, execStore) require.NoError(t, err) assert.Equal(t, "task_added", resp.Status) assert.Len(t, record.Tasks, 1) }) t.Run("add task nil data returns error", func(t *testing.T) { record := &store.ExecutionRecord{ ExecutionID: "exec-pa6", MemberID: "member-pa-at", } execStore := store.NewExecutionStore() output := &types.HostOutput{ Reply: "Add task", Action: types.HostActionAddTask, } _, err := m.processHostAction(ctx, robot, record, output, execStore) assert.Error(t, err) assert.Contains(t, err.Error(), "task data is required") }) } // Test processHostAction — skip branch func TestProcessHostActionSkip(t *testing.T) { m := &Manager{} ctx := types.NewContext(context.Background(), nil) robot := &types.Robot{MemberID: "member-pa-skip"} t.Run("skip without waiting task returns error", func(t *testing.T) { record := &store.ExecutionRecord{ ExecutionID: "exec-pa8", MemberID: "member-pa-skip", } execStore := store.NewExecutionStore() output := &types.HostOutput{ Reply: "Skip it", Action: types.HostActionSkip, } _, err := m.processHostAction(ctx, robot, record, output, execStore) assert.Error(t, err) assert.Contains(t, err.Error(), "no task is waiting") }) } // Test processHostAction — wait_for_more and default func TestProcessHostActionWaitForMoreAndDefault(t *testing.T) { m := &Manager{} ctx := types.NewContext(context.Background(), nil) robot := &types.Robot{MemberID: "member-pa-wfm"} t.Run("wait_for_more", func(t *testing.T) { record := &store.ExecutionRecord{} execStore := store.NewExecutionStore() output := &types.HostOutput{ Reply: "More details please", WaitForMore: true, } resp, err := m.processHostAction(ctx, robot, record, output, execStore) require.NoError(t, err) assert.Equal(t, "waiting_for_more", resp.Status) assert.Equal(t, "More details please", resp.Reply) assert.True(t, resp.WaitForMore) }) t.Run("unknown action returns acknowledged", func(t *testing.T) { record := &store.ExecutionRecord{} execStore := store.NewExecutionStore() output := &types.HostOutput{ Reply: "OK", Action: "unknown_action", } resp, err := m.processHostAction(ctx, robot, record, output, execStore) require.NoError(t, err) assert.Equal(t, "acknowledged", resp.Status) assert.Equal(t, "OK", resp.Message) }) } // Test processHostAction — cancel branch func TestProcessHostActionCancel(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) t.Run("cancel waiting execution", func(t *testing.T) { // Cannot fully test without a started manager; verify the error path m := &Manager{started: false} ctx := types.NewContext(context.Background(), nil) robot := &types.Robot{MemberID: "member-pa-cancel"} record := &store.ExecutionRecord{ ExecutionID: "exec-pa11", MemberID: "member-pa-cancel", } execStore := store.NewExecutionStore() output := &types.HostOutput{ Reply: "Cancel it", Action: types.HostActionCancel, } _, err := m.processHostAction(ctx, robot, record, output, execStore) assert.Error(t, err) assert.Contains(t, err.Error(), "manager not started") }) } // Test HandleInteract validation func TestHandleInteractValidationExtended(t *testing.T) { t.Run("manager not started returns error", func(t *testing.T) { m := &Manager{started: false} _, err := m.HandleInteract(types.NewContext(context.Background(), nil), "member-1", &InteractRequest{Message: "test"}) assert.Error(t, err) assert.Contains(t, err.Error(), "manager not started") }) t.Run("empty member_id returns error", func(t *testing.T) { m := &Manager{started: true} _, err := m.HandleInteract(types.NewContext(context.Background(), nil), "", &InteractRequest{Message: "test"}) assert.Error(t, err) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("nil request returns error", func(t *testing.T) { m := &Manager{started: true} _, err := m.HandleInteract(types.NewContext(context.Background(), nil), "member-1", nil) assert.Error(t, err) assert.Contains(t, err.Error(), "message is required") }) t.Run("empty message returns error", func(t *testing.T) { m := &Manager{started: true} _, err := m.HandleInteract(types.NewContext(context.Background(), nil), "member-1", &InteractRequest{}) assert.Error(t, err) assert.Contains(t, err.Error(), "message is required") }) t.Run("non-interactable status returns error", func(t *testing.T) { if testing.Short() { t.Skip("Requires database and cache") } testutils.Prepare(t) defer testutils.Clean(t) // Would require a full Manager with cache — tested via E2E }) } // Test CancelExecution validation func TestCancelExecutionValidationExtended(t *testing.T) { t.Run("manager not started", func(t *testing.T) { m := &Manager{started: false} err := m.CancelExecution(types.NewContext(context.Background(), nil), "exec-1") assert.Error(t, err) assert.Contains(t, err.Error(), "manager not started") }) } // Test buildHostContext JSON output func TestBuildHostContextJSON(t *testing.T) { m := &Manager{} robot := &types.Robot{MemberID: "member-ctx"} record := &store.ExecutionRecord{ Goals: &types.Goals{Content: "test goals"}, Tasks: []types.Task{{ID: "t1"}}, WaitingQuestion: "What time?", } waitingTask := &types.Task{ID: "t1", Status: types.TaskWaitingInput} hostCtx := m.buildHostContext(robot, record, waitingTask) require.NotNil(t, hostCtx) data, err := json.Marshal(hostCtx) require.NoError(t, err) var parsed map[string]interface{} err = json.Unmarshal(data, &parsed) require.NoError(t, err) // Goals is a struct, not a plain string goalsRaw, ok := parsed["goals"] require.True(t, ok) goalsMap, ok := goalsRaw.(map[string]interface{}) require.True(t, ok, "Goals should be a JSON object, not a string") assert.Equal(t, "test goals", goalsMap["content"]) assert.Equal(t, "What time?", parsed["agent_reply"]) } // ==================== processHostAction -- confirm branch (PA1) ==================== func TestProcessHostActionConfirmRequiresPool(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) m := &Manager{started: false} ctx := types.NewContext(context.Background(), nil) robot := &types.Robot{MemberID: "member-pa1"} record := &store.ExecutionRecord{ ExecutionID: "exec-pa1", MemberID: "member-pa1", Status: types.ExecConfirming, } execStore := store.NewExecutionStore() output := &types.HostOutput{ Reply: "Confirmed", Action: types.HostActionConfirm, } assert.Panics(t, func() { m.processHostAction(ctx, robot, record, output, execStore) }, "should panic because pool/executor are nil") } // ==================== processHostAction -- inject_ctx branch (PA9-PA10) ==================== func TestProcessHostActionInjectCtx(t *testing.T) { t.Run("nil executor panics", func(t *testing.T) { m := &Manager{} ctx := types.NewContext(context.Background(), nil) robot := &types.Robot{MemberID: "member-pa9"} record := &store.ExecutionRecord{ ExecutionID: "exec-pa9", MemberID: "member-pa9", Status: types.ExecWaiting, } execStore := store.NewExecutionStore() output := &types.HostOutput{ Reply: "Here's context", Action: types.HostActionInjectCtx, ActionData: "additional context data", } assert.Panics(t, func() { m.processHostAction(ctx, robot, record, output, execStore) }) }) t.Run("with mock executor delegates resume", func(t *testing.T) { mockExec := &mockExecutor{resumeErr: fmt.Errorf("mock error")} m := &Manager{executor: mockExec} ctx := types.NewContext(context.Background(), nil) robot := &types.Robot{MemberID: "member-pa10"} record := &store.ExecutionRecord{ ExecutionID: "exec-pa10", MemberID: "member-pa10", } execStore := store.NewExecutionStore() output := &types.HostOutput{ Reply: "Resume with data", Action: types.HostActionInjectCtx, ActionData: map[string]interface{}{"reply": "detailed info"}, } _, err := m.processHostAction(ctx, robot, record, output, execStore) assert.Error(t, err) assert.Contains(t, err.Error(), "mock error") }) t.Run("ErrExecutionSuspended returns waiting status", func(t *testing.T) { mockExec := &mockExecutor{resumeErr: types.ErrExecutionSuspended} m := &Manager{executor: mockExec} ctx := types.NewContext(context.Background(), nil) robot := &types.Robot{MemberID: "member-pa10b"} record := &store.ExecutionRecord{ ExecutionID: "exec-pa10b", MemberID: "member-pa10b", } execStore := store.NewExecutionStore() output := &types.HostOutput{ Reply: "Resume", Action: types.HostActionInjectCtx, ActionData: "context", } resp, err := m.processHostAction(ctx, robot, record, output, execStore) require.NoError(t, err) assert.Equal(t, "waiting", resp.Status) }) } // ==================== HandleInteract routing (HI5-HI8) ==================== func TestHandleInteractRouting(t *testing.T) { t.Run("HI5: non-existent execution_id returns error", func(t *testing.T) { if testing.Short() { t.Skip("Requires database and cache") } testutils.Prepare(t) defer testutils.Clean(t) m := &Manager{started: true, cache: cache.New()} ctx := types.NewContext(context.Background(), nil) _, err := m.HandleInteract(ctx, "member-hi5", &InteractRequest{ ExecutionID: "nonexistent-exec", Message: "test", }) assert.Error(t, err) }) t.Run("HI6: non-existent robot returns error", func(t *testing.T) { if testing.Short() { t.Skip("Requires database and cache") } testutils.Prepare(t) defer testutils.Clean(t) m := &Manager{started: true, cache: cache.New()} ctx := types.NewContext(context.Background(), nil) _, err := m.HandleInteract(ctx, "nonexistent-robot", &InteractRequest{ Message: "test", }) assert.Error(t, err) assert.Contains(t, err.Error(), "robot not found") }) } // ==================== CancelExecution validation (CE2-CE5) ==================== func TestCancelExecutionStatusValidation(t *testing.T) { if testing.Short() { t.Skip("Requires database") } testutils.Prepare(t) defer testutils.Clean(t) t.Run("CE2: non-existent execution returns error", func(t *testing.T) { m := &Manager{started: true} ctx := types.NewContext(context.Background(), nil) err := m.CancelExecution(ctx, "nonexistent-exec") assert.Error(t, err) assert.Contains(t, err.Error(), "execution not found") }) t.Run("CE3: running execution cannot be cancelled", func(t *testing.T) { m := &Manager{started: true} ctx := types.NewContext(context.Background(), nil) execStore := store.NewExecutionStore() record := &store.ExecutionRecord{ ExecutionID: "exec-ce3", MemberID: "member-ce3", Status: types.ExecRunning, TriggerType: types.TriggerHuman, Phase: types.PhaseInspiration, } require.NoError(t, execStore.Save(ctx.Context, record)) err := m.CancelExecution(ctx, "exec-ce3") assert.Error(t, err) assert.Contains(t, err.Error(), "only waiting/confirming can be cancelled") }) t.Run("CE4: completed execution cannot be cancelled", func(t *testing.T) { m := &Manager{started: true} ctx := types.NewContext(context.Background(), nil) execStore := store.NewExecutionStore() record := &store.ExecutionRecord{ ExecutionID: "exec-ce4", MemberID: "member-ce4", Status: types.ExecCompleted, TriggerType: types.TriggerHuman, Phase: types.PhaseInspiration, } require.NoError(t, execStore.Save(ctx.Context, record)) err := m.CancelExecution(ctx, "exec-ce4") assert.Error(t, err) assert.Contains(t, err.Error(), "only waiting/confirming can be cancelled") }) } // ==================== InteractRequest/InteractResponse struct validation ==================== func TestInteractRequestStructFields(t *testing.T) { req := &InteractRequest{ ExecutionID: "exec-1", TaskID: "task-1", Source: types.InteractSourceUI, Message: "do something", Action: "confirm", } assert.Equal(t, "exec-1", req.ExecutionID) assert.Equal(t, "task-1", req.TaskID) assert.Equal(t, types.InteractSourceUI, req.Source) assert.Equal(t, "do something", req.Message) assert.Equal(t, "confirm", req.Action) } func TestInteractResponseStructFields(t *testing.T) { resp := &InteractResponse{ ExecutionID: "exec-1", Status: "confirmed", Message: "Done", ChatID: "chat-1", Reply: "I'll do it", WaitForMore: true, } assert.Equal(t, "exec-1", resp.ExecutionID) assert.Equal(t, "confirmed", resp.Status) assert.Equal(t, "Done", resp.Message) assert.Equal(t, "chat-1", resp.ChatID) assert.Equal(t, "I'll do it", resp.Reply) assert.True(t, resp.WaitForMore) } // ==================== executeResume helper ==================== func TestExecuteResumeNilExecutor(t *testing.T) { m := &Manager{} ctx := types.NewContext(context.Background(), nil) assert.Panics(t, func() { _ = m.executeResume(ctx, "exec-test", "reply") }) } func TestExecuteResumeWithMock(t *testing.T) { t.Run("delegates to executor Resume", func(t *testing.T) { mockExec := &mockExecutor{resumeErr: nil} m := &Manager{executor: mockExec} ctx := types.NewContext(context.Background(), nil) err := m.executeResume(ctx, "exec-test", "reply") assert.NoError(t, err) }) t.Run("propagates error", func(t *testing.T) { mockExec := &mockExecutor{resumeErr: fmt.Errorf("resume failed")} m := &Manager{executor: mockExec} ctx := types.NewContext(context.Background(), nil) err := m.executeResume(ctx, "exec-test", "reply") assert.Error(t, err) assert.Contains(t, err.Error(), "resume failed") }) t.Run("propagates ErrExecutionSuspended", func(t *testing.T) { mockExec := &mockExecutor{resumeErr: types.ErrExecutionSuspended} m := &Manager{executor: mockExec} ctx := types.NewContext(context.Background(), nil) err := m.executeResume(ctx, "exec-test", "reply") assert.Equal(t, types.ErrExecutionSuspended, err) }) } // ==================== skipWaitingTask and directResume with mock ==================== func TestSkipWaitingTaskWithMock(t *testing.T) { t.Run("no waiting task returns error", func(t *testing.T) { mockExec := &mockExecutor{} m := &Manager{executor: mockExec} ctx := types.NewContext(context.Background(), nil) record := &store.ExecutionRecord{ ExecutionID: "exec-skip", } execStore := store.NewExecutionStore() err := m.skipWaitingTask(ctx, record, execStore) assert.Error(t, err) assert.Contains(t, err.Error(), "no task is waiting") }) t.Run("marks waiting task as skipped and resumes", func(t *testing.T) { mockExec := &mockExecutor{resumeErr: nil} m := &Manager{executor: mockExec} ctx := types.NewContext(context.Background(), nil) record := &store.ExecutionRecord{ ExecutionID: "exec-skip2", WaitingTaskID: "task-w", Tasks: []types.Task{ {ID: "task-w", Status: types.TaskWaitingInput}, }, } execStore := store.NewExecutionStore() err := m.skipWaitingTask(ctx, record, execStore) assert.NoError(t, err) assert.Equal(t, types.TaskSkipped, record.Tasks[0].Status) }) } func TestDirectResumeWithMock(t *testing.T) { t.Run("successful resume", func(t *testing.T) { mockExec := &mockExecutor{resumeErr: nil} m := &Manager{executor: mockExec} ctx := types.NewContext(context.Background(), nil) record := &store.ExecutionRecord{ ExecutionID: "exec-dr", ChatID: "chat-dr", } req := &InteractRequest{Message: "continue"} resp, err := m.directResume(ctx, record, req) require.NoError(t, err) assert.Equal(t, "resumed", resp.Status) assert.Equal(t, "exec-dr", resp.ExecutionID) assert.Equal(t, "chat-dr", resp.ChatID) }) t.Run("suspended again", func(t *testing.T) { mockExec := &mockExecutor{resumeErr: types.ErrExecutionSuspended} m := &Manager{executor: mockExec} ctx := types.NewContext(context.Background(), nil) record := &store.ExecutionRecord{ ExecutionID: "exec-dr2", ChatID: "chat-dr2", } req := &InteractRequest{Message: "continue"} resp, err := m.directResume(ctx, record, req) require.NoError(t, err) assert.Equal(t, "waiting", resp.Status) }) t.Run("error propagated", func(t *testing.T) { mockExec := &mockExecutor{resumeErr: fmt.Errorf("resume failed")} m := &Manager{executor: mockExec} ctx := types.NewContext(context.Background(), nil) record := &store.ExecutionRecord{ ExecutionID: "exec-dr3", } req := &InteractRequest{Message: "continue"} _, err := m.directResume(ctx, record, req) assert.Error(t, err) assert.Contains(t, err.Error(), "resume failed") }) } ================================================ FILE: agent/robot/manager/interact_test.go ================================================ package manager import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/robot/executor/standard" "github.com/yaoapp/yao/agent/robot/store" "github.com/yaoapp/yao/agent/robot/types" ) func TestBuildRobotStatusSnapshot(t *testing.T) { m := &Manager{} t.Run("nil robot returns nil", func(t *testing.T) { snap := m.buildRobotStatusSnapshot(nil) assert.Nil(t, snap) }) t.Run("robot with quota", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-member", Config: &types.Config{ Quota: &types.Quota{Max: 5}, }, } snap := m.buildRobotStatusSnapshot(robot) require.NotNil(t, snap) assert.Equal(t, 5, snap.MaxQuota) }) t.Run("robot without quota uses default", func(t *testing.T) { robot := &types.Robot{ MemberID: "test-member", } snap := m.buildRobotStatusSnapshot(robot) require.NotNil(t, snap) assert.Equal(t, 2, snap.MaxQuota) // robot.MaxQuota() returns 2 for nil config }) } func TestFindWaitingTask(t *testing.T) { m := &Manager{} t.Run("returns nil when no waiting task id", func(t *testing.T) { record := &store.ExecutionRecord{ Tasks: []types.Task{ {ID: "task-1"}, }, } task := m.findWaitingTask(record) assert.Nil(t, task) }) t.Run("finds matching task", func(t *testing.T) { record := &store.ExecutionRecord{ WaitingTaskID: "task-2", Tasks: []types.Task{ {ID: "task-1"}, {ID: "task-2", Status: types.TaskWaitingInput}, {ID: "task-3"}, }, } task := m.findWaitingTask(record) require.NotNil(t, task) assert.Equal(t, "task-2", task.ID) }) t.Run("returns nil when task not found", func(t *testing.T) { record := &store.ExecutionRecord{ WaitingTaskID: "nonexistent", Tasks: []types.Task{ {ID: "task-1"}, }, } task := m.findWaitingTask(record) assert.Nil(t, task) }) } func TestBuildHostContext(t *testing.T) { m := &Manager{} t.Run("builds context with goals and tasks", func(t *testing.T) { robot := &types.Robot{MemberID: "test"} record := &store.ExecutionRecord{ Goals: &types.Goals{Content: "test goals"}, Tasks: []types.Task{ {ID: "task-1"}, }, WaitingQuestion: "What is the answer?", } waitingTask := &types.Task{ID: "task-1", Status: types.TaskWaitingInput} hostCtx := m.buildHostContext(robot, record, waitingTask) require.NotNil(t, hostCtx) assert.NotNil(t, hostCtx.Goals) assert.Equal(t, "test goals", hostCtx.Goals.Content) assert.Len(t, hostCtx.Tasks, 1) assert.NotNil(t, hostCtx.CurrentTask) assert.Equal(t, "What is the answer?", hostCtx.AgentReply) }) t.Run("builds context without optional fields", func(t *testing.T) { robot := &types.Robot{MemberID: "test"} record := &store.ExecutionRecord{} hostCtx := m.buildHostContext(robot, record, nil) require.NotNil(t, hostCtx) assert.Nil(t, hostCtx.Goals) assert.Nil(t, hostCtx.Tasks) assert.Nil(t, hostCtx.CurrentTask) assert.Empty(t, hostCtx.AgentReply) }) } func TestProcessHostAction(t *testing.T) { m := &Manager{} t.Run("wait_for_more returns waiting status", func(t *testing.T) { output := &types.HostOutput{ Reply: "Please provide more details", WaitForMore: true, } record := &store.ExecutionRecord{} robot := &types.Robot{} execStore := store.NewExecutionStore() resp, err := m.processHostAction(types.NewContext(nil, nil), robot, record, output, execStore) require.NoError(t, err) assert.Equal(t, "waiting_for_more", resp.Status) assert.Equal(t, "Please provide more details", resp.Reply) assert.True(t, resp.WaitForMore) }) t.Run("unknown action returns acknowledged", func(t *testing.T) { output := &types.HostOutput{ Reply: "Got it", Action: "unknown_action", } record := &store.ExecutionRecord{} robot := &types.Robot{} execStore := store.NewExecutionStore() resp, err := m.processHostAction(types.NewContext(nil, nil), robot, record, output, execStore) require.NoError(t, err) assert.Equal(t, "acknowledged", resp.Status) }) } func TestHandleInteractValidation(t *testing.T) { m := &Manager{started: true} t.Run("empty member_id returns error", func(t *testing.T) { _, err := m.HandleInteract(types.NewContext(nil, nil), "", &InteractRequest{Message: "test"}) assert.Error(t, err) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("nil request returns error", func(t *testing.T) { _, err := m.HandleInteract(types.NewContext(nil, nil), "member-1", nil) assert.Error(t, err) assert.Contains(t, err.Error(), "message is required") }) t.Run("empty message returns error", func(t *testing.T) { _, err := m.HandleInteract(types.NewContext(nil, nil), "member-1", &InteractRequest{}) assert.Error(t, err) assert.Contains(t, err.Error(), "message is required") }) t.Run("manager not started returns error", func(t *testing.T) { m2 := &Manager{started: false} _, err := m2.HandleInteract(types.NewContext(nil, nil), "member-1", &InteractRequest{Message: "test"}) assert.Error(t, err) assert.Contains(t, err.Error(), "manager not started") }) } func TestCancelExecutionValidation(t *testing.T) { m := &Manager{started: false} t.Run("manager not started returns error", func(t *testing.T) { err := m.CancelExecution(types.NewContext(nil, nil), "exec-1") assert.Error(t, err) assert.Contains(t, err.Error(), "manager not started") }) } func TestParseHostAgentResult(t *testing.T) { m := &Manager{} t.Run("plain text returns WaitForMore", func(t *testing.T) { result := &standard.CallResult{Content: "I understand your request. Shall I proceed?"} output, err := m.parseHostAgentResult(result) require.NoError(t, err) assert.True(t, output.WaitForMore, "plain text should set WaitForMore=true") assert.Equal(t, "I understand your request. Shall I proceed?", output.Reply) assert.Empty(t, string(output.Action), "plain text should have no action") }) t.Run("JSON with action returns action", func(t *testing.T) { result := &standard.CallResult{ Content: `{"reply":"Task confirmed","action":"confirm","wait_for_more":false}`, } output, err := m.parseHostAgentResult(result) require.NoError(t, err) assert.False(t, output.WaitForMore) assert.Equal(t, types.HostActionConfirm, output.Action) assert.Equal(t, "Task confirmed", output.Reply) }) t.Run("JSON without action returns WaitForMore", func(t *testing.T) { result := &standard.CallResult{ Content: `{"reply":"Let me think about this","some_field":"value"}`, } output, err := m.parseHostAgentResult(result) require.NoError(t, err) assert.True(t, output.WaitForMore, "JSON without action should set WaitForMore=true") assert.NotEmpty(t, output.Reply) }) t.Run("JSON with adjust action and action_data", func(t *testing.T) { result := &standard.CallResult{ Content: `{"reply":"Plan adjusted","action":"adjust","action_data":{"goals":"new goals"}}`, } output, err := m.parseHostAgentResult(result) require.NoError(t, err) assert.False(t, output.WaitForMore) assert.Equal(t, types.HostActionAdjust, output.Action) assert.NotNil(t, output.ActionData) }) t.Run("malformed JSON returns WaitForMore", func(t *testing.T) { result := &standard.CallResult{Content: `{invalid json`} output, err := m.parseHostAgentResult(result) require.NoError(t, err) assert.True(t, output.WaitForMore) assert.Equal(t, `{invalid json`, output.Reply) }) t.Run("empty content returns WaitForMore", func(t *testing.T) { result := &standard.CallResult{Content: ""} output, err := m.parseHostAgentResult(result) require.NoError(t, err) assert.True(t, output.WaitForMore) }) } ================================================ FILE: agent/robot/manager/manager.go ================================================ package manager import ( "context" "fmt" "sync" "time" "github.com/yaoapp/yao/agent/robot/cache" "github.com/yaoapp/yao/agent/robot/executor" "github.com/yaoapp/yao/agent/robot/pool" "github.com/yaoapp/yao/agent/robot/trigger" "github.com/yaoapp/yao/agent/robot/types" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) // Default configuration values const ( DefaultTickInterval = time.Minute // default tick interval for clock checking ) // Config holds manager configuration type Config struct { TickInterval time.Duration // how often to check clock triggers (default: 1 minute) PoolConfig *pool.Config // worker pool configuration Executor types.Executor // optional: custom executor (default: real executor) } // DefaultConfig returns default manager configuration func DefaultConfig() *Config { return &Config{ TickInterval: DefaultTickInterval, PoolConfig: pool.DefaultConfig(), } } // Manager implements types.Manager interface // Orchestrates the robot scheduling system: Cache -> Dedup -> Pool -> Executor type Manager struct { config *Config cache *cache.Cache pool *pool.Pool executor types.Executor // Execution control for pause/resume/stop execController *trigger.ExecutionController // Ticker for clock trigger checking ticker *time.Ticker tickerDone chan struct{} // State started bool mu sync.RWMutex // Context for background operations ctx context.Context cancel context.CancelFunc } // New creates a new manager instance with default configuration func New() *Manager { return NewWithConfig(nil) } // NewWithConfig creates a new manager instance with custom configuration func NewWithConfig(config *Config) *Manager { if config == nil { config = DefaultConfig() } // Apply defaults for zero values if config.TickInterval <= 0 { config.TickInterval = DefaultTickInterval } // Create components c := cache.New() p := pool.NewWithConfig(config.PoolConfig) ec := trigger.NewExecutionController() // Use custom executor if provided, otherwise create default var e types.Executor if config.Executor != nil { e = config.Executor } else { e = executor.New() } // Wire up pool with executor p.SetExecutor(e) // Create shared executor instances for each mode // These are reused across all executions to maintain accurate counters dryRunExecutor := executor.NewDryRun() // Set executor factory for mode-specific executors p.SetExecutorFactory(func(mode types.ExecutorMode) types.Executor { switch mode { case types.ExecutorDryRun: return dryRunExecutor case types.ExecutorSandbox: // Sandbox not implemented, fall back to DryRun return dryRunExecutor default: // Standard mode or empty - use the configured executor return e } }) return &Manager{ config: config, cache: c, pool: p, executor: e, execController: ec, } } // Start starts the manager // 1. Load robots into cache // 2. Start worker pool // 3. Start clock ticker goroutine func (m *Manager) Start() error { m.mu.Lock() defer m.mu.Unlock() if m.started { return fmt.Errorf("manager already started") } // Create background context m.ctx, m.cancel = context.WithCancel(context.Background()) // Load robots into cache ctx := types.NewContext(m.ctx, nil) if err := m.cache.Load(ctx); err != nil { return fmt.Errorf("failed to load robots: %w", err) } // Set completion callback to clean up ExecutionController when execution finishes m.pool.SetOnComplete(func(execID, memberID string, status types.ExecStatus) { // Remove from ExecutionController (cleans up in-memory tracking) m.execController.Untrack(execID) // Remove from robot's in-memory execution list if robot := m.cache.Get(memberID); robot != nil { robot.RemoveExecution(execID) } }) // Start worker pool if err := m.pool.Start(); err != nil { return fmt.Errorf("failed to start pool: %w", err) } // Start clock ticker m.ticker = time.NewTicker(m.config.TickInterval) m.tickerDone = make(chan struct{}) go m.tickerLoop() // Start cache auto-refresh (every hour) m.cache.StartAutoRefresh(ctx, nil) m.started = true return nil } // Stop stops the manager gracefully // 1. Stop clock ticker // 2. Stop cache auto-refresh // 3. Stop worker pool (waits for running jobs) func (m *Manager) Stop() error { m.mu.Lock() if !m.started { m.mu.Unlock() return nil } m.started = false m.mu.Unlock() // Stop ticker if m.tickerDone != nil { close(m.tickerDone) } // Stop cache auto-refresh m.cache.StopAutoRefresh() // Stop pool (waits for running jobs) if err := m.pool.Stop(); err != nil { return fmt.Errorf("failed to stop pool: %w", err) } // Cancel background context if m.cancel != nil { m.cancel() } return nil } // tickerLoop is the main ticker goroutine func (m *Manager) tickerLoop() { for { select { case <-m.tickerDone: m.ticker.Stop() return case now := <-m.ticker.C: // Perform tick - context is created per-robot in Tick() _ = m.Tick(m.ctx, now) } } } // Tick processes a clock tick // 1. Get all cached robots // 2. For each robot with clock trigger enabled // 3. Check if should execute based on clock config // 4. Submit to pool with robot's own identity func (m *Manager) Tick(parentCtx context.Context, now time.Time) error { m.mu.RLock() if !m.started { m.mu.RUnlock() return nil } m.mu.RUnlock() // Get all cached robots robots := m.cache.ListAll() for _, robot := range robots { // Skip if robot is not active if robot.Status == types.RobotPaused || robot.Status == types.RobotError || robot.Status == types.RobotMaintenance { continue } // Skip if clock trigger is disabled if robot.Config == nil || robot.Config.Triggers == nil { continue } if !robot.Config.Triggers.IsEnabled(types.TriggerClock) { continue } // Skip if no clock config if robot.Config.Clock == nil { continue } // Check if should trigger based on clock config if !m.shouldTrigger(robot, now) { continue } // TODO: dedup check (Phase 11.1) // result, err := m.dedup.Check(ctx, robot.MemberID, types.TriggerClock) // if err != nil || result == types.DedupSkip { // continue // } // Pre-generate execution ID and track for pause/resume/stop // We need to track BEFORE submit so we can pass the cancellable context to the executor execID := pool.GenerateExecID() ctrlExec := m.execController.Track(execID, robot.MemberID, robot.TeamID) // Create context with robot's own identity and cancellable context // Clock-triggered executions run as the robot itself robotAuth := m.buildRobotAuth(robot) execCtx := types.NewContext(ctrlExec.Context(), robotAuth) // Create clock context for P0 inspiration clockCtx := types.NewClockContext(now, robot.Config.Clock.TZ) // Submit to pool with the cancellable context and execution control _, err := m.pool.SubmitWithID(execCtx, robot, types.TriggerClock, clockCtx, execID, ctrlExec) if err != nil { // If submission failed, untrack the execution m.execController.Untrack(execID) // Log error but continue with other robots // In production, this would be logged properly continue } // Update robot's last run time robot.LastRun = now } return nil } // buildRobotAuth creates AuthorizedInfo for a robot's own identity // Used when robot executes autonomously (clock trigger) func (m *Manager) buildRobotAuth(robot *types.Robot) *oauthtypes.AuthorizedInfo { return &oauthtypes.AuthorizedInfo{ UserID: robot.MemberID, TeamID: robot.TeamID, // ClientID could be set to a special "robot-agent" identifier if needed ClientID: "robot-agent", } } // shouldTrigger checks if a robot should be triggered based on its clock config func (m *Manager) shouldTrigger(robot *types.Robot, now time.Time) bool { clock := robot.Config.Clock if clock == nil { return false } // Get time in robot's timezone loc := clock.GetLocation() localNow := now.In(loc) switch clock.Mode { case types.ClockTimes: return m.shouldTriggerTimes(robot, clock, localNow) case types.ClockInterval: return m.shouldTriggerInterval(robot, clock, localNow) case types.ClockDaemon: return m.shouldTriggerDaemon(robot, clock, localNow) default: return false } } // shouldTriggerTimes checks if current time matches any configured times // times mode: run at specific times (e.g., ["09:00", "14:00", "17:00"]) func (m *Manager) shouldTriggerTimes(robot *types.Robot, clock *types.Clock, now time.Time) bool { // Check day of week first if !m.matchesDay(clock, now) { return false } // Check if current time matches any configured time currentTime := now.Format("15:04") for _, t := range clock.Times { if t == currentTime { // Check if already triggered in this minute if !robot.LastRun.IsZero() { lastRunInLoc := robot.LastRun.In(now.Location()) if lastRunInLoc.Format("15:04") == currentTime && lastRunInLoc.Day() == now.Day() { return false // Already triggered this minute today } } return true } } return false } // shouldTriggerInterval checks if enough time has passed since last run // interval mode: run every X duration (e.g., "30m", "2h") func (m *Manager) shouldTriggerInterval(robot *types.Robot, clock *types.Clock, now time.Time) bool { interval, err := time.ParseDuration(clock.Every) if err != nil { return false } // First run if never executed if robot.LastRun.IsZero() { return true } // Check if interval has passed return now.Sub(robot.LastRun) >= interval } // shouldTriggerDaemon checks if robot can restart immediately after last run // daemon mode: restart immediately after each run completes func (m *Manager) shouldTriggerDaemon(robot *types.Robot, clock *types.Clock, now time.Time) bool { // Daemon mode: trigger if not currently running // CanRun() checks if robot has available execution slots return robot.CanRun() } // matchesDay checks if current day matches the configured days func (m *Manager) matchesDay(clock *types.Clock, now time.Time) bool { // Empty days or ["*"] means all days if len(clock.Days) == 0 { return true } for _, day := range clock.Days { if day == "*" { return true } // Match day name (Mon, Tue, Wed, Thu, Fri, Sat, Sun) // or full name (Monday, Tuesday, etc.) weekday := now.Weekday().String() shortDay := weekday[:3] // Mon, Tue, etc. if day == weekday || day == shortDay { return true } } return false } // TriggerManual manually triggers a robot execution (for testing or API calls) // This bypasses clock checking and directly submits to pool // For non-autonomous robots: lazy-loads from DB, executes, then unloads func (m *Manager) TriggerManual(ctx *types.Context, memberID string, trigger types.TriggerType, data interface{}) (string, error) { m.mu.RLock() if !m.started { m.mu.RUnlock() return "", fmt.Errorf("manager not started") } m.mu.RUnlock() // Get robot from cache, or lazy-load if not found robot, lazyLoaded, err := m.getOrLoadRobot(ctx, memberID) if err != nil { return "", err } // Check robot status if robot.Status == types.RobotPaused { return "", types.ErrRobotPaused } // Check if trigger type is enabled if robot.Config != nil && robot.Config.Triggers != nil { if !robot.Config.Triggers.IsEnabled(trigger) { return "", types.ErrTriggerDisabled } } // Pre-generate execution ID and track for pause/resume/stop // We need to track BEFORE submit so we can pass the cancellable context to the executor execID := pool.GenerateExecID() ctrlExec := m.execController.Track(execID, memberID, robot.TeamID) // Create a new context with the cancellable context from ExecutionController // This allows Stop() to propagate cancellation to the executor execCtx := types.NewContext(ctrlExec.Context(), ctx.Auth) // Submit to pool with the cancellable context and execution control // The control interface allows executor to check pause state and wait if paused _, err = m.pool.SubmitWithID(execCtx, robot, trigger, data, execID, ctrlExec) if err != nil { // If submission failed, untrack the execution m.execController.Untrack(execID) // If lazy-loaded and submission failed, remove from cache if lazyLoaded { m.cache.Remove(memberID) } return "", err } // For lazy-loaded robots, schedule cleanup after execution completes if lazyLoaded { m.scheduleCleanup(robot) } return execID, nil } // ==================== Human Intervention & Event Triggers ==================== // Intervene processes a human intervention request // Human intervention skips P0 (inspiration) and goes directly to P1 (goals) // For non-autonomous robots: lazy-loads from DB, executes, then unloads func (m *Manager) Intervene(ctx *types.Context, req *types.InterveneRequest) (*types.ExecutionResult, error) { m.mu.RLock() if !m.started { m.mu.RUnlock() return nil, fmt.Errorf("manager not started") } m.mu.RUnlock() // Validate request if err := trigger.ValidateIntervention(req); err != nil { return nil, err } // Get robot from cache, or lazy-load if not found robot, lazyLoaded, err := m.getOrLoadRobot(ctx, req.MemberID) if err != nil { return nil, err } // Check robot status if robot.Status == types.RobotPaused { return nil, types.ErrRobotPaused } // Check if human trigger is enabled if robot.Config != nil && robot.Config.Triggers != nil { if !robot.Config.Triggers.IsEnabled(types.TriggerHuman) { return nil, types.ErrTriggerDisabled } } // Build trigger input triggerInput := &types.TriggerInput{ Action: req.Action, Messages: req.Messages, UserID: ctx.UserID(), } // Handle plan.add action - schedule for later if req.Action == types.ActionPlanAdd && req.PlanTime != nil { // If lazy-loaded but not executing, remove immediately if lazyLoaded { m.cache.Remove(req.MemberID) } // TODO: Add to plan queue (Phase 11.3) return &types.ExecutionResult{ Status: types.ExecPending, Message: fmt.Sprintf("Planned for %s (plan queue not implemented yet)", req.PlanTime.Format(time.RFC3339)), }, nil } // Determine executor mode: request > robot config > default executorMode := m.resolveExecutorMode(req.ExecutorMode, robot) // Submit to pool with executor mode execID, err := m.pool.SubmitWithMode(ctx, robot, types.TriggerHuman, triggerInput, executorMode) if err != nil { // If lazy-loaded and submission failed, remove from cache if lazyLoaded { m.cache.Remove(req.MemberID) } return nil, err } // Track execution for pause/resume/stop m.execController.Track(execID, req.MemberID, req.TeamID) // For lazy-loaded robots, schedule cleanup after execution completes if lazyLoaded { m.scheduleCleanup(robot) } return &types.ExecutionResult{ ExecutionID: execID, Status: types.ExecPending, Message: fmt.Sprintf("Human intervention (%s) submitted", req.Action), }, nil } // HandleEvent processes an event trigger request // Event trigger skips P0 (inspiration) and goes directly to P1 (goals) // For non-autonomous robots: lazy-loads from DB, executes, then unloads func (m *Manager) HandleEvent(ctx *types.Context, req *types.EventRequest) (*types.ExecutionResult, error) { m.mu.RLock() if !m.started { m.mu.RUnlock() return nil, fmt.Errorf("manager not started") } m.mu.RUnlock() // Validate request if err := trigger.ValidateEvent(req); err != nil { return nil, err } // Get robot from cache, or lazy-load if not found robot, lazyLoaded, err := m.getOrLoadRobot(ctx, req.MemberID) if err != nil { return nil, err } // Check robot status if robot.Status == types.RobotPaused { return nil, types.ErrRobotPaused } // Check if event trigger is enabled if robot.Config != nil && robot.Config.Triggers != nil { if !robot.Config.Triggers.IsEnabled(types.TriggerEvent) { return nil, types.ErrTriggerDisabled } } // Build trigger input triggerInput := trigger.BuildEventInput(req) // Determine executor mode: request > robot config > default executorMode := m.resolveExecutorMode(req.ExecutorMode, robot) // Submit to pool with executor mode execID, err := m.pool.SubmitWithMode(ctx, robot, types.TriggerEvent, triggerInput, executorMode) if err != nil { // If lazy-loaded and submission failed, remove from cache if lazyLoaded { m.cache.Remove(req.MemberID) } return nil, err } // Track execution for pause/resume/stop m.execController.Track(execID, req.MemberID, "") // For lazy-loaded robots, schedule cleanup after execution completes if lazyLoaded { m.scheduleCleanup(robot) } return &types.ExecutionResult{ ExecutionID: execID, Status: types.ExecPending, Message: fmt.Sprintf("Event trigger (%s: %s) submitted", req.Source, req.EventType), }, nil } // ==================== Execution Control ==================== // PauseExecution pauses a running execution func (m *Manager) PauseExecution(ctx *types.Context, execID string) error { // Get execution info before pausing exec := m.execController.Get(execID) if exec == nil { return fmt.Errorf("execution not found: %s", execID) } // Pause the execution if err := m.execController.Pause(execID); err != nil { return err } // Remove from robot's in-memory execution list (paused doesn't count as running) if robot := m.cache.Get(exec.MemberID); robot != nil { robot.RemoveExecution(execID) } return nil } // ResumeExecution resumes a paused execution func (m *Manager) ResumeExecution(ctx *types.Context, execID string) error { // Get execution info before resuming exec := m.execController.Get(execID) if exec == nil { return fmt.Errorf("execution not found: %s", execID) } // Resume the execution if err := m.execController.Resume(execID); err != nil { return err } // Add back to robot's in-memory execution list if robot := m.cache.Get(exec.MemberID); robot != nil { robot.AddExecution(&types.Execution{ ID: execID, MemberID: exec.MemberID, TeamID: exec.TeamID, Status: types.ExecRunning, }) } return nil } // StopExecution stops a running execution func (m *Manager) StopExecution(ctx *types.Context, execID string) error { // Get execution info before stopping exec := m.execController.Get(execID) if exec == nil { return fmt.Errorf("execution not found: %s", execID) } // Stop the execution if err := m.execController.Stop(execID); err != nil { return err } // Remove from robot's in-memory execution list if robot := m.cache.Get(exec.MemberID); robot != nil { robot.RemoveExecution(execID) } return nil } // GetExecutionStatus returns the status of an execution func (m *Manager) GetExecutionStatus(execID string) (*trigger.ControlledExecution, error) { exec := m.execController.Get(execID) if exec == nil { return nil, fmt.Errorf("execution not found: %s", execID) } return exec, nil } // ListExecutions returns all tracked executions func (m *Manager) ListExecutions() []*trigger.ControlledExecution { return m.execController.List() } // ListExecutionsByMember returns all executions for a specific robot func (m *Manager) ListExecutionsByMember(memberID string) []*trigger.ControlledExecution { return m.execController.ListByMember(memberID) } // ==================== Helper Methods ==================== // getOrLoadRobot gets a robot from cache, or lazy-loads from DB if not found // Returns: robot, wasLazyLoaded, error func (m *Manager) getOrLoadRobot(ctx *types.Context, memberID string) (*types.Robot, bool, error) { // Try cache first robot := m.cache.Get(memberID) if robot != nil { return robot, false, nil } // Not in cache - lazy load from database robot, err := m.cache.LoadByID(ctx, memberID) if err != nil { return nil, false, err } // Add to cache temporarily for execution tracking m.cache.Add(robot) // Return with lazyLoaded=true to indicate cleanup needed after execution return robot, true, nil } // scheduleCleanup schedules removal of a lazy-loaded robot after all executions complete // This runs in a goroutine that monitors the robot's execution count func (m *Manager) scheduleCleanup(robot *types.Robot) { go func() { memberID := robot.MemberID // Poll every 5 seconds to check if all executions are done ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() // Timeout after 24 hours to prevent memory leaks timeout := time.After(24 * time.Hour) for { select { case <-timeout: // Timeout - force cleanup m.cache.Remove(memberID) return case <-ticker.C: // Check if robot still exists in cache r := m.cache.Get(memberID) if r == nil { // Already removed return } // Check if all executions are done if r.RunningCount() == 0 { // Only remove if still non-autonomous // (user might have changed it during execution) if !r.AutonomousMode { m.cache.Remove(memberID) } return } } } }() } // resolveExecutorMode determines the executor mode to use // Priority: request > robot config > default (standard) func (m *Manager) resolveExecutorMode(requestMode types.ExecutorMode, robot *types.Robot) types.ExecutorMode { // Request mode takes precedence if requestMode != "" && requestMode.IsValid() { return requestMode } // Robot config mode if robot != nil && robot.Config != nil && robot.Config.Executor != nil { return robot.Config.Executor.GetMode() } // Default: standard return types.ExecutorStandard } // ==================== Getters for internal components ==================== // These are exposed for testing and advanced use cases // Cache returns the internal cache func (m *Manager) Cache() *cache.Cache { return m.cache } // Pool returns the internal pool func (m *Manager) Pool() *pool.Pool { return m.pool } // Executor returns the internal executor func (m *Manager) Executor() types.Executor { return m.executor } // IsStarted returns true if manager is started func (m *Manager) IsStarted() bool { m.mu.RLock() defer m.mu.RUnlock() return m.started } // Running returns number of currently running jobs func (m *Manager) Running() int { return m.pool.Running() } // Queued returns number of queued jobs func (m *Manager) Queued() int { return m.pool.Queued() } // CachedRobots returns number of cached robots func (m *Manager) CachedRobots() int { return m.cache.Count() } ================================================ FILE: agent/robot/manager/manager_test.go ================================================ package manager_test import ( "context" "encoding/json" "runtime" "testing" "time" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/manager" "github.com/yaoapp/yao/agent/robot/pool" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // TestManagerStartStop tests manager lifecycle func TestManagerStartStop(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobots(t) defer cleanupTestRobots(t) t.Run("start and stop manager", func(t *testing.T) { m := manager.New() // Should not be started assert.False(t, m.IsStarted()) // Start manager err := m.Start() assert.NoError(t, err) assert.True(t, m.IsStarted()) // Robots should be loaded assert.GreaterOrEqual(t, m.CachedRobots(), 2, "Should load at least 2 robots") // Stop manager err = m.Stop() assert.NoError(t, err) assert.False(t, m.IsStarted()) }) t.Run("double start should fail", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) // Second start should fail err = m.Start() assert.Error(t, err) assert.Contains(t, err.Error(), "already started") // Cleanup m.Stop() }) t.Run("stop without start should not panic", func(t *testing.T) { m := manager.New() assert.NotPanics(t, func() { err := m.Stop() assert.NoError(t, err) }) }) } // TestManagerTick tests the Tick function with different clock modes func TestManagerTick(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobotsWithClockConfig(t) defer cleanupTestRobots(t) t.Run("tick with times mode - matching time", func(t *testing.T) { // Create manager with short tick interval for testing config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 2, QueueSize: 10}, } m := manager.NewWithConfig(config) err := m.Start() assert.NoError(t, err) defer m.Stop() // Create a time that matches the configured time (09:00) loc, _ := time.LoadLocation("Asia/Shanghai") now := time.Date(2025, 1, 15, 9, 0, 0, 0, loc) // Wednesday 09:00 ctx := types.NewContext(context.Background(), nil) err = m.Tick(ctx, now) assert.NoError(t, err) // Wait for job to be processed time.Sleep(200 * time.Millisecond) // Check that job was submitted (may be queued or running) // Note: The executor stub completes quickly, so we check execution count execCount := m.Executor().ExecCount() assert.GreaterOrEqual(t, execCount, 1, "Should have executed at least 1 job") }) t.Run("tick with times mode - non-matching time", func(t *testing.T) { config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 2, QueueSize: 10}, } m := manager.NewWithConfig(config) err := m.Start() assert.NoError(t, err) defer m.Stop() // Reset executor count m.Executor().Reset() // Create a time that does NOT match (10:30) loc, _ := time.LoadLocation("Asia/Shanghai") now := time.Date(2025, 1, 15, 10, 30, 0, 0, loc) // Wednesday 10:30 ctx := types.NewContext(context.Background(), nil) err = m.Tick(ctx, now) assert.NoError(t, err) // Wait a bit time.Sleep(100 * time.Millisecond) // Should not have triggered (times mode robot only triggers at 09:00, 14:00) execCount := m.Executor().ExecCount() // Note: interval mode robot might trigger if enough time passed // We just verify the times mode robot didn't trigger assert.LessOrEqual(t, execCount, 1, "Times mode robot should not trigger at non-matching time") }) t.Run("tick with interval mode", func(t *testing.T) { config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 2, QueueSize: 10}, } m := manager.NewWithConfig(config) err := m.Start() assert.NoError(t, err) defer m.Stop() // Reset executor count m.Executor().Reset() // First tick - should trigger interval mode robot (first run) ctx := types.NewContext(context.Background(), nil) now := time.Now() err = m.Tick(ctx, now) assert.NoError(t, err) // Wait for execution time.Sleep(200 * time.Millisecond) // Should have at least 1 execution (interval robot first run) execCount := m.Executor().ExecCount() assert.GreaterOrEqual(t, execCount, 1, "Interval mode robot should trigger on first run") }) t.Run("tick skips paused robots", func(t *testing.T) { config := &manager.Config{ TickInterval: 100 * time.Millisecond, PoolConfig: &pool.Config{WorkerSize: 2, QueueSize: 10}, } m := manager.NewWithConfig(config) err := m.Start() assert.NoError(t, err) defer m.Stop() // Get the paused robot from cache pausedRobot := m.Cache().Get("robot_test_manager_paused") assert.NotNil(t, pausedRobot) assert.Equal(t, types.RobotPaused, pausedRobot.Status) // Reset executor count m.Executor().Reset() // Tick should skip paused robot ctx := types.NewContext(context.Background(), nil) loc, _ := time.LoadLocation("Asia/Shanghai") now := time.Date(2025, 1, 15, 9, 0, 0, 0, loc) err = m.Tick(ctx, now) assert.NoError(t, err) // The paused robot should not have been triggered // (we can't directly verify this, but we verify the tick completed) }) } // TestManagerTriggerManual tests manual triggering of robots func TestManagerTriggerManual(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobotsWithClockConfig(t) defer cleanupTestRobots(t) t.Run("trigger manual - success", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Manually trigger a robot execID, err := m.TriggerManual(ctx, "robot_test_manager_times", types.TriggerHuman, nil) assert.NoError(t, err) assert.NotEmpty(t, execID) // Wait for execution time.Sleep(200 * time.Millisecond) // Should have executed assert.GreaterOrEqual(t, m.Executor().ExecCount(), 1) }) t.Run("trigger manual - robot not found", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Try to trigger non-existent robot _, err = m.TriggerManual(ctx, "robot_nonexistent", types.TriggerHuman, nil) assert.Error(t, err) assert.Equal(t, types.ErrRobotNotFound, err) }) t.Run("trigger manual - robot paused", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Try to trigger paused robot _, err = m.TriggerManual(ctx, "robot_test_manager_paused", types.TriggerHuman, nil) assert.Error(t, err) assert.Equal(t, types.ErrRobotPaused, err) }) t.Run("trigger manual - manager not started", func(t *testing.T) { m := manager.New() // Don't start manager ctx := types.NewContext(context.Background(), nil) _, err := m.TriggerManual(ctx, "robot_test_manager_times", types.TriggerHuman, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "not started") }) } // TestManagerClockModes tests all three clock modes func TestManagerClockModes(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobotsWithClockConfig(t) defer cleanupTestRobots(t) t.Run("times mode - day matching", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() m.Executor().Reset() // Wednesday (configured day) loc, _ := time.LoadLocation("Asia/Shanghai") now := time.Date(2025, 1, 15, 9, 0, 0, 0, loc) // Wednesday 09:00 ctx := types.NewContext(context.Background(), nil) err = m.Tick(ctx, now) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) assert.GreaterOrEqual(t, m.Executor().ExecCount(), 1, "Should trigger on matching day") }) t.Run("times mode - day not matching", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() m.Executor().Reset() // Saturday (not configured) loc, _ := time.LoadLocation("Asia/Shanghai") now := time.Date(2025, 1, 18, 9, 0, 0, 0, loc) // Saturday 09:00 ctx := types.NewContext(context.Background(), nil) err = m.Tick(ctx, now) assert.NoError(t, err) time.Sleep(100 * time.Millisecond) // Times mode robot should not trigger on Saturday // Only interval/daemon robots might trigger }) t.Run("daemon mode - always triggers when idle", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() m.Executor().Reset() // Daemon robot should trigger whenever it can run ctx := types.NewContext(context.Background(), nil) now := time.Now() // First tick err = m.Tick(ctx, now) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) // Should have triggered daemon robot assert.GreaterOrEqual(t, m.Executor().ExecCount(), 1, "Daemon mode should trigger") }) } // TestManagerTimezoneDedup tests that times mode dedup works correctly across timezones // This specifically tests the bug fix where LastRun.Day() must be converted to the same // timezone as 'now' before comparison func TestManagerTimezoneDedup(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobotsWithClockConfig(t) defer cleanupTestRobots(t) t.Run("times mode - same minute same day should not trigger twice", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() m.Executor().Reset() // Use Asia/Shanghai timezone (UTC+8) loc, _ := time.LoadLocation("Asia/Shanghai") // Wednesday 09:00:00 in Shanghai now := time.Date(2025, 1, 15, 9, 0, 0, 0, loc) ctx := types.NewContext(context.Background(), nil) // First tick - should trigger err = m.Tick(ctx, now) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) firstCount := m.Executor().ExecCount() assert.GreaterOrEqual(t, firstCount, 1, "First tick should trigger") // Second tick at 09:00:30 (same minute) - should NOT trigger again now2 := time.Date(2025, 1, 15, 9, 0, 30, 0, loc) err = m.Tick(ctx, now2) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) // Count should remain the same (times robot should not trigger twice) // Note: daemon/interval robots may still trigger, so we check the delta secondCount := m.Executor().ExecCount() t.Logf("First count: %d, Second count: %d", firstCount, secondCount) // The times robot should not have triggered again in the same minute // Delta should be <= 2 (daemon always triggers, interval might trigger) // If delta > 2, it means times robot triggered twice (bug!) delta := secondCount - firstCount assert.LessOrEqual(t, delta, 2, "Times robot should not trigger twice in same minute (delta: %d)", delta) }) t.Run("times mode - different day should trigger again", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() m.Executor().Reset() loc, _ := time.LoadLocation("Asia/Shanghai") ctx := types.NewContext(context.Background(), nil) // Wednesday 09:00 now1 := time.Date(2025, 1, 15, 9, 0, 0, 0, loc) err = m.Tick(ctx, now1) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) firstCount := m.Executor().ExecCount() // Thursday 09:00 (next day, same time) now2 := time.Date(2025, 1, 16, 9, 0, 0, 0, loc) err = m.Tick(ctx, now2) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) // Should have triggered again on the new day secondCount := m.Executor().ExecCount() assert.Greater(t, secondCount, firstCount, "Should trigger on different day") }) t.Run("times mode - cross-timezone day boundary", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() m.Executor().Reset() // Robot is configured with Asia/Shanghai (UTC+8) // Test case: LastRun was set when it was Jan 15 in Shanghai // Now it's Jan 16 00:30 in Shanghai (still Jan 15 in UTC) // The comparison should use Shanghai timezone, not UTC loc, _ := time.LoadLocation("Asia/Shanghai") ctx := types.NewContext(context.Background(), nil) // First run: Jan 15, 09:00 Shanghai time now1 := time.Date(2025, 1, 15, 9, 0, 0, 0, loc) err = m.Tick(ctx, now1) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) // Second run: Jan 16, 09:00 Shanghai time // This is Jan 16 01:00 UTC, but should be treated as Jan 16 in Shanghai now2 := time.Date(2025, 1, 16, 9, 0, 0, 0, loc) err = m.Tick(ctx, now2) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) // Should have triggered on both days assert.GreaterOrEqual(t, m.Executor().ExecCount(), 2, "Should trigger on both days") }) t.Run("times mode - UTC vs local timezone comparison", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() m.Executor().Reset() // Test with explicit UTC time converted to Shanghai // This tests that LastRun stored in one timezone is correctly // compared when 'now' is in a different timezone shanghai, _ := time.LoadLocation("Asia/Shanghai") ctx := types.NewContext(context.Background(), nil) // Create a time that's Jan 15 23:30 UTC = Jan 16 07:30 Shanghai utcTime := time.Date(2025, 1, 15, 23, 30, 0, 0, time.UTC) shanghaiTime := utcTime.In(shanghai) t.Logf("UTC: %v, Shanghai: %v", utcTime, shanghaiTime) // The robot is configured for 09:00 Shanghai time // So Jan 16 09:00 Shanghai should trigger now := time.Date(2025, 1, 16, 9, 0, 0, 0, shanghai) err = m.Tick(ctx, now) assert.NoError(t, err) time.Sleep(200 * time.Millisecond) assert.GreaterOrEqual(t, m.Executor().ExecCount(), 1, "Should trigger at 09:00 Shanghai") }) } // TestManagerGoroutineLeak tests that manager doesn't leak goroutines func TestManagerGoroutineLeak(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobots(t) defer cleanupTestRobots(t) t.Run("start stop cycle should not leak goroutines", func(t *testing.T) { // Record initial goroutine count runtime.GC() time.Sleep(100 * time.Millisecond) initialGoroutines := runtime.NumGoroutine() // Start and stop multiple times for i := 0; i < 5; i++ { m := manager.New() err := m.Start() assert.NoError(t, err) // Do some ticks ctx := types.NewContext(context.Background(), nil) m.Tick(ctx, time.Now()) time.Sleep(50 * time.Millisecond) err = m.Stop() assert.NoError(t, err) } // Wait for cleanup time.Sleep(200 * time.Millisecond) runtime.GC() time.Sleep(100 * time.Millisecond) // Check goroutine count finalGoroutines := runtime.NumGoroutine() assert.LessOrEqual(t, finalGoroutines, initialGoroutines+2, "Should not leak goroutines (initial: %d, final: %d)", initialGoroutines, finalGoroutines) }) } // TestManagerComponents tests access to internal components func TestManagerComponents(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobots(t) defer cleanupTestRobots(t) m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() t.Run("cache access", func(t *testing.T) { cache := m.Cache() assert.NotNil(t, cache) robot := cache.Get("robot_test_sales_001") assert.NotNil(t, robot) }) t.Run("pool access", func(t *testing.T) { pool := m.Pool() assert.NotNil(t, pool) assert.True(t, pool.IsStarted()) }) t.Run("executor access", func(t *testing.T) { executor := m.Executor() assert.NotNil(t, executor) }) t.Run("running and queued counts", func(t *testing.T) { running := m.Running() queued := m.Queued() cached := m.CachedRobots() assert.GreaterOrEqual(t, running, 0) assert.GreaterOrEqual(t, queued, 0) assert.GreaterOrEqual(t, cached, 2) }) } // ==================== Test Data Setup ==================== // setupTestRobots creates basic test robots (same as cache tests) func setupTestRobots(t *testing.T) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() // Robot 1: Sales Bot robotConfig1 := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Sales Manager", "duties": []string{"Manage leads", "Follow up customers"}, }, "quota": map[string]interface{}{ "max": 3, "queue": 15, "priority": 7, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, }, "clock": map[string]interface{}{ "mode": "times", "times": []string{"09:00", "14:00"}, "days": []string{"Mon", "Tue", "Wed", "Thu", "Fri"}, "tz": "Asia/Shanghai", }, } config1JSON, _ := json.Marshal(robotConfig1) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_sales_001", "team_id": "team_test_cache_001", "member_type": "robot", "display_name": "Test Sales Bot", "system_prompt": "You are a professional sales manager assistant.", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(config1JSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_sales_001: %v", err) } // Robot 2: Support Bot robotConfig2 := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Customer Support", "duties": []string{"Answer questions", "Resolve issues"}, }, "quota": map[string]interface{}{ "max": 2, "queue": 10, "priority": 5, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, }, "clock": map[string]interface{}{ "mode": "interval", "every": "1h", }, } config2JSON, _ := json.Marshal(robotConfig2) err = qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_support_002", "team_id": "team_test_cache_001", "member_type": "robot", "display_name": "Test Support Bot", "system_prompt": "You are a helpful customer support assistant.", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(config2JSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_support_002: %v", err) } // Robot 3: Inactive robot err = qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_inactive_003", "team_id": "team_test_cache_001", "member_type": "robot", "display_name": "Test Inactive Bot", "status": "inactive", "role_id": "member", "autonomous_mode": true, "robot_status": "paused", }, }) if err != nil { t.Fatalf("Failed to insert robot_test_inactive_003: %v", err) } } // setupTestRobotsWithClockConfig creates robots with specific clock configurations func setupTestRobotsWithClockConfig(t *testing.T) { m := model.Select("__yao.member") tableName := m.MetaData.Table.Name qb := capsule.Query() // Robot 1: Times mode (09:00, 14:00 on weekdays) robotConfigTimes := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Times Mode Robot", }, "quota": map[string]interface{}{ "max": 2, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, }, "clock": map[string]interface{}{ "mode": "times", "times": []string{"09:00", "14:00"}, "days": []string{"Mon", "Tue", "Wed", "Thu", "Fri"}, "tz": "Asia/Shanghai", }, } configTimesJSON, _ := json.Marshal(robotConfigTimes) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_manager_times", "team_id": "team_test_manager", "member_type": "robot", "display_name": "Test Times Robot", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configTimesJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_manager_times: %v", err) } // Robot 2: Interval mode (every 30 minutes) robotConfigInterval := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Interval Mode Robot", }, "quota": map[string]interface{}{ "max": 2, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, }, "clock": map[string]interface{}{ "mode": "interval", "every": "30m", }, } configIntervalJSON, _ := json.Marshal(robotConfigInterval) err = qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_manager_interval", "team_id": "team_test_manager", "member_type": "robot", "display_name": "Test Interval Robot", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configIntervalJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_manager_interval: %v", err) } // Robot 3: Daemon mode robotConfigDaemon := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Daemon Mode Robot", }, "quota": map[string]interface{}{ "max": 2, }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, }, "clock": map[string]interface{}{ "mode": "daemon", "timeout": "5m", }, } configDaemonJSON, _ := json.Marshal(robotConfigDaemon) err = qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_manager_daemon", "team_id": "team_test_manager", "member_type": "robot", "display_name": "Test Daemon Robot", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configDaemonJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_manager_daemon: %v", err) } // Robot 4: Paused robot (should be skipped) robotConfigPaused := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Paused Robot", }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": true}, }, "clock": map[string]interface{}{ "mode": "times", "times": []string{"09:00"}, "tz": "Asia/Shanghai", }, } configPausedJSON, _ := json.Marshal(robotConfigPaused) err = qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_manager_paused", "team_id": "team_test_manager", "member_type": "robot", "display_name": "Test Paused Robot", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "paused", "robot_config": string(configPausedJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_manager_paused: %v", err) } // Robot 5: Clock disabled robot robotConfigDisabled := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Clock Disabled Robot", }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": false}, }, "clock": map[string]interface{}{ "mode": "times", "times": []string{"09:00"}, }, } configDisabledJSON, _ := json.Marshal(robotConfigDisabled) err = qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_manager_disabled", "team_id": "team_test_manager", "member_type": "robot", "display_name": "Test Clock Disabled Robot", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configDisabledJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_manager_disabled: %v", err) } } // ==================== Intervene Tests ==================== func TestManagerIntervene(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobotsWithInterveneConfig(t) defer cleanupTestRobots(t) t.Run("intervene success", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ TeamID: "team_test_manager", MemberID: "robot_test_manager_intervene", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Add a new task"}, }, } result, err := m.Intervene(ctx, req) assert.NoError(t, err) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) assert.Equal(t, types.ExecPending, result.Status) }) t.Run("intervene - manager not started", func(t *testing.T) { m := manager.New() // Don't start ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "robot_test_manager_intervene", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Add a new task"}, }, } _, err := m.Intervene(ctx, req) assert.Error(t, err) assert.Contains(t, err.Error(), "not started") }) t.Run("intervene - robot not found", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "non_existent_robot", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Add a new task"}, }, } _, err = m.Intervene(ctx, req) assert.Error(t, err) assert.Equal(t, types.ErrRobotNotFound, err) }) t.Run("intervene - robot paused", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "robot_test_manager_paused", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Add a new task"}, }, } _, err = m.Intervene(ctx, req) assert.Error(t, err) assert.Equal(t, types.ErrRobotPaused, err) }) t.Run("intervene - invalid request", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "", // Invalid: empty member_id Action: types.ActionTaskAdd, } _, err = m.Intervene(ctx, req) assert.Error(t, err) assert.Contains(t, err.Error(), "member_id") }) t.Run("intervene - trigger disabled", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "robot_test_manager_intervene_disabled", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Add a new task"}, }, } _, err = m.Intervene(ctx, req) assert.Error(t, err) assert.Equal(t, types.ErrTriggerDisabled, err) }) } // ==================== HandleEvent Tests ==================== func TestManagerHandleEvent(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobotsWithEventConfig(t) defer cleanupTestRobots(t) t.Run("handle event success", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "robot_test_manager_event", Source: "webhook", EventType: "lead.created", Data: map[string]interface{}{"name": "John", "email": "john@example.com"}, } result, err := m.HandleEvent(ctx, req) assert.NoError(t, err) assert.NotNil(t, result) assert.NotEmpty(t, result.ExecutionID) assert.Equal(t, types.ExecPending, result.Status) }) t.Run("handle event - manager not started", func(t *testing.T) { m := manager.New() // Don't start ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "robot_test_manager_event", Source: "webhook", EventType: "lead.created", } _, err := m.HandleEvent(ctx, req) assert.Error(t, err) assert.Contains(t, err.Error(), "not started") }) t.Run("handle event - robot not found", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "non_existent_robot", Source: "webhook", EventType: "lead.created", } _, err = m.HandleEvent(ctx, req) assert.Error(t, err) assert.Equal(t, types.ErrRobotNotFound, err) }) t.Run("handle event - invalid request", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "robot_test_manager_event", Source: "", // Invalid: empty source EventType: "lead.created", } _, err = m.HandleEvent(ctx, req) assert.Error(t, err) assert.Contains(t, err.Error(), "source") }) t.Run("handle event - trigger disabled", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) req := &types.EventRequest{ MemberID: "robot_test_manager_event_disabled", Source: "webhook", EventType: "lead.created", } _, err = m.HandleEvent(ctx, req) assert.Error(t, err) assert.Equal(t, types.ErrTriggerDisabled, err) }) } // ==================== Execution Control Tests ==================== func TestManagerExecutionControl(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobotsWithInterveneConfig(t) defer cleanupTestRobots(t) t.Run("pause and resume execution", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() // Trigger an execution ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "robot_test_manager_intervene", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test task"}, }, } result, err := m.Intervene(ctx, req) assert.NoError(t, err) execID := result.ExecutionID // Wait a bit for execution to be tracked time.Sleep(50 * time.Millisecond) // Pause err = m.PauseExecution(ctx, execID) assert.NoError(t, err) // Get status - should be paused status, err := m.GetExecutionStatus(execID) assert.NoError(t, err) assert.True(t, status.IsPaused()) // Resume err = m.ResumeExecution(ctx, execID) assert.NoError(t, err) // Get status - should not be paused status, err = m.GetExecutionStatus(execID) assert.NoError(t, err) assert.False(t, status.IsPaused()) }) t.Run("stop execution", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() // Trigger an execution ctx := types.NewContext(context.Background(), nil) req := &types.InterveneRequest{ MemberID: "robot_test_manager_intervene", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test task"}, }, } result, err := m.Intervene(ctx, req) assert.NoError(t, err) execID := result.ExecutionID // Wait a bit for execution to be tracked time.Sleep(50 * time.Millisecond) // Stop err = m.StopExecution(ctx, execID) assert.NoError(t, err) // Get status - should not be found (removed after stop) _, err = m.GetExecutionStatus(execID) assert.Error(t, err) assert.Contains(t, err.Error(), "not found") }) t.Run("list executions", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Track execution IDs var execIDs []string // Trigger multiple executions for i := 0; i < 3; i++ { req := &types.InterveneRequest{ MemberID: "robot_test_manager_intervene", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test task"}, }, } result, err := m.Intervene(ctx, req) assert.NoError(t, err) execIDs = append(execIDs, result.ExecutionID) } // Verify each execution was tracked (even if briefly) // Note: executions complete quickly with stub executor, so they may be removed // We just verify that we got valid execution IDs assert.Len(t, execIDs, 3) for _, id := range execIDs { assert.NotEmpty(t, id) } }) } // setupTestRobotsWithInterveneConfig creates test robots with intervene trigger enabled func setupTestRobotsWithInterveneConfig(t *testing.T) { // First setup the basic robots setupTestRobotsWithClockConfig(t) // Add robots for intervene tests qb := capsule.Query() m := model.Select("__yao.member") tableName := m.MetaData.Table.Name // Robot with intervene enabled robotConfigIntervene := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Intervene Test Robot", }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": false}, "intervene": map[string]interface{}{"enabled": true}, }, "quota": map[string]interface{}{ "max": 5, "queue": 10, }, } configInterveneJSON, _ := json.Marshal(robotConfigIntervene) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_manager_intervene", "team_id": "team_test_manager", "member_type": "robot", "display_name": "Test Intervene Robot", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configInterveneJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_manager_intervene: %v", err) } // Robot with intervene disabled robotConfigInterveneDisabled := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Intervene Disabled Robot", }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": false}, "intervene": map[string]interface{}{"enabled": false}, }, } configInterveneDisabledJSON, _ := json.Marshal(robotConfigInterveneDisabled) err = qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_manager_intervene_disabled", "team_id": "team_test_manager", "member_type": "robot", "display_name": "Test Intervene Disabled Robot", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configInterveneDisabledJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_manager_intervene_disabled: %v", err) } } // setupTestRobotsWithEventConfig creates test robots with event trigger enabled func setupTestRobotsWithEventConfig(t *testing.T) { // First setup the basic robots setupTestRobotsWithClockConfig(t) // Add robots for event tests qb := capsule.Query() m := model.Select("__yao.member") tableName := m.MetaData.Table.Name // Robot with event enabled robotConfigEvent := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Event Test Robot", }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": false}, "event": map[string]interface{}{"enabled": true}, }, "quota": map[string]interface{}{ "max": 5, "queue": 10, }, } configEventJSON, _ := json.Marshal(robotConfigEvent) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_manager_event", "team_id": "team_test_manager", "member_type": "robot", "display_name": "Test Event Robot", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configEventJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_manager_event: %v", err) } // Robot with event disabled robotConfigEventDisabled := map[string]interface{}{ "identity": map[string]interface{}{ "role": "Event Disabled Robot", }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": false}, "event": map[string]interface{}{"enabled": false}, }, } configEventDisabledJSON, _ := json.Marshal(robotConfigEventDisabled) err = qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_manager_event_disabled", "team_id": "team_test_manager", "member_type": "robot", "display_name": "Test Event Disabled Robot", "status": "active", "role_id": "member", "autonomous_mode": true, "robot_status": "idle", "robot_config": string(configEventDisabledJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_manager_event_disabled: %v", err) } } // ==================== Lazy Load Tests for Non-Autonomous Robots ==================== // TestManagerLazyLoadNonAutonomous tests that non-autonomous robots are lazy-loaded on demand // and automatically cleaned up after execution completes func TestManagerLazyLoadNonAutonomous(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) setupTestRobotsWithNonAutonomous(t) defer cleanupTestRobots(t) t.Run("non-autonomous robot not in cache on startup", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() // Non-autonomous robot should NOT be in cache robot := m.Cache().Get("robot_test_manager_on_demand") assert.Nil(t, robot, "Non-autonomous robot should not be pre-loaded into cache") // Autonomous robot SHOULD be in cache autoRobot := m.Cache().Get("robot_test_manager_times") assert.NotNil(t, autoRobot, "Autonomous robot should be in cache") }) t.Run("TriggerManual lazy-loads non-autonomous robot", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Verify robot is NOT in cache before trigger assert.Nil(t, m.Cache().Get("robot_test_manager_on_demand")) // Trigger the non-autonomous robot manually execID, err := m.TriggerManual(ctx, "robot_test_manager_on_demand", types.TriggerHuman, nil) assert.NoError(t, err) assert.NotEmpty(t, execID) // Robot should now be in cache (lazy-loaded) robot := m.Cache().Get("robot_test_manager_on_demand") assert.NotNil(t, robot, "Robot should be lazy-loaded into cache") assert.Equal(t, "robot_test_manager_on_demand", robot.MemberID) assert.False(t, robot.AutonomousMode) }) t.Run("Intervene lazy-loads non-autonomous robot", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Verify robot is NOT in cache before trigger assert.Nil(t, m.Cache().Get("robot_test_manager_on_demand_intervene")) // Intervene on the non-autonomous robot req := &types.InterveneRequest{ TeamID: "team_test_manager", MemberID: "robot_test_manager_on_demand_intervene", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Test lazy load via intervene"}, }, } result, err := m.Intervene(ctx, req) assert.NoError(t, err) assert.NotEmpty(t, result.ExecutionID) // Robot should now be in cache (lazy-loaded) robot := m.Cache().Get("robot_test_manager_on_demand_intervene") assert.NotNil(t, robot, "Robot should be lazy-loaded into cache via Intervene") }) t.Run("HandleEvent lazy-loads non-autonomous robot", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Verify robot is NOT in cache before trigger assert.Nil(t, m.Cache().Get("robot_test_manager_on_demand_event")) // Send event to the non-autonomous robot req := &types.EventRequest{ MemberID: "robot_test_manager_on_demand_event", Source: "webhook", EventType: "data.updated", Data: map[string]interface{}{"test": true}, } result, err := m.HandleEvent(ctx, req) assert.NoError(t, err) assert.NotEmpty(t, result.ExecutionID) // Robot should now be in cache (lazy-loaded) robot := m.Cache().Get("robot_test_manager_on_demand_event") assert.NotNil(t, robot, "Robot should be lazy-loaded into cache via HandleEvent") }) t.Run("lazy-loaded robot is cleaned up after execution completes", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Trigger the non-autonomous robot _, err = m.TriggerManual(ctx, "robot_test_manager_on_demand", types.TriggerHuman, nil) assert.NoError(t, err) // Robot should be in cache immediately after trigger robot := m.Cache().Get("robot_test_manager_on_demand") assert.NotNil(t, robot, "Robot should be in cache after trigger") // Wait for execution to complete and cleanup to happen // The stub executor completes quickly, and cleanup runs every 5 seconds // We wait up to 10 seconds for the cleanup goroutine to remove the robot var removed bool for i := 0; i < 20; i++ { time.Sleep(500 * time.Millisecond) if m.Cache().Get("robot_test_manager_on_demand") == nil { removed = true break } } assert.True(t, removed, "Non-autonomous robot should be removed from cache after execution completes") }) t.Run("trigger non-existent robot returns error", func(t *testing.T) { m := manager.New() err := m.Start() assert.NoError(t, err) defer m.Stop() ctx := types.NewContext(context.Background(), nil) // Try to trigger a robot that doesn't exist in DB _, err = m.TriggerManual(ctx, "robot_nonexistent_xyz", types.TriggerHuman, nil) assert.Error(t, err) assert.Equal(t, types.ErrRobotNotFound, err) }) } // setupTestRobotsWithNonAutonomous creates test robots including non-autonomous ones func setupTestRobotsWithNonAutonomous(t *testing.T) { // First setup the autonomous robots setupTestRobotsWithClockConfig(t) // Add non-autonomous robots qb := capsule.Query() m := model.Select("__yao.member") tableName := m.MetaData.Table.Name // Non-autonomous robot 1: for TriggerManual test robotConfigOnDemand := map[string]interface{}{ "identity": map[string]interface{}{ "role": "On-Demand Robot", }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": false}, "intervene": map[string]interface{}{"enabled": true}, }, "quota": map[string]interface{}{ "max": 2, "queue": 5, }, } configOnDemandJSON, _ := json.Marshal(robotConfigOnDemand) err := qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_manager_on_demand", "team_id": "team_test_manager", "member_type": "robot", "display_name": "Test On-Demand Robot", "status": "active", "role_id": "member", "autonomous_mode": false, // Non-autonomous! "robot_status": "idle", "robot_config": string(configOnDemandJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_manager_on_demand: %v", err) } // Non-autonomous robot 2: for Intervene test robotConfigOnDemandIntervene := map[string]interface{}{ "identity": map[string]interface{}{ "role": "On-Demand Intervene Robot", }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": false}, "intervene": map[string]interface{}{"enabled": true}, }, "quota": map[string]interface{}{ "max": 2, }, } configOnDemandInterveneJSON, _ := json.Marshal(robotConfigOnDemandIntervene) err = qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_manager_on_demand_intervene", "team_id": "team_test_manager", "member_type": "robot", "display_name": "Test On-Demand Intervene Robot", "status": "active", "role_id": "member", "autonomous_mode": false, // Non-autonomous! "robot_status": "idle", "robot_config": string(configOnDemandInterveneJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_manager_on_demand_intervene: %v", err) } // Non-autonomous robot 3: for HandleEvent test robotConfigOnDemandEvent := map[string]interface{}{ "identity": map[string]interface{}{ "role": "On-Demand Event Robot", }, "triggers": map[string]interface{}{ "clock": map[string]interface{}{"enabled": false}, "event": map[string]interface{}{"enabled": true}, }, "quota": map[string]interface{}{ "max": 2, }, } configOnDemandEventJSON, _ := json.Marshal(robotConfigOnDemandEvent) err = qb.Table(tableName).Insert([]map[string]interface{}{ { "member_id": "robot_test_manager_on_demand_event", "team_id": "team_test_manager", "member_type": "robot", "display_name": "Test On-Demand Event Robot", "status": "active", "role_id": "member", "autonomous_mode": false, // Non-autonomous! "robot_status": "idle", "robot_config": string(configOnDemandEventJSON), }, }) if err != nil { t.Fatalf("Failed to insert robot_test_manager_on_demand_event: %v", err) } } // cleanupTestRobots removes all test robot records func cleanupTestRobots(t *testing.T) { qb := capsule.Query() m := model.Select("__yao.member") tableName := m.MetaData.Table.Name // List of test robot IDs to clean up testRobotIDs := []string{ "robot_test_sales_001", "robot_test_support_002", "robot_test_inactive_003", "robot_test_manager_times", "robot_test_manager_interval", "robot_test_manager_daemon", "robot_test_manager_paused", "robot_test_manager_disabled", "robot_test_manager_intervene", "robot_test_manager_intervene_disabled", "robot_test_manager_event", "robot_test_manager_event_disabled", // Non-autonomous robots "robot_test_manager_on_demand", "robot_test_manager_on_demand_intervene", "robot_test_manager_on_demand_event", } for _, id := range testRobotIDs { // Soft delete m.DeleteWhere(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "member_id", Value: id}, }, }) // Hard delete qb.Table(tableName).Where("member_id", id).Delete() } } ================================================ FILE: agent/robot/plan/plan.go ================================================ package plan import ( "time" "github.com/yaoapp/yao/agent/robot/types" ) // Plan manages planned tasks/goals for later execution // This is a stub implementation for Phase 2 type Plan struct{} // New creates a new plan instance func New() *Plan { return &Plan{} } // Add adds a task or goal to plan queue // Stub: returns nil (will be implemented in Phase 11) func (p *Plan) Add(ctx *types.Context, memberID string, item interface{}, executeAt time.Time) error { return nil } // Remove removes an item from plan queue // Stub: returns nil (will be implemented in Phase 11) func (p *Plan) Remove(ctx *types.Context, memberID string, itemID string) error { return nil } // List lists all planned items for a robot // Stub: returns empty slice (will be implemented in Phase 11) func (p *Plan) List(ctx *types.Context, memberID string) ([]interface{}, error) { return []interface{}{}, nil } // GetDue returns items that are due for execution // Stub: returns empty slice (will be implemented in Phase 11) func (p *Plan) GetDue(ctx *types.Context, now time.Time) ([]interface{}, error) { return []interface{}{}, nil } ================================================ FILE: agent/robot/pool/goroutine_test.go ================================================ package pool_test import ( "context" "runtime" "testing" "time" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/robot/executor" "github.com/yaoapp/yao/agent/robot/pool" "github.com/yaoapp/yao/agent/robot/types" ) // ==================== Goroutine Leak Detection Tests ==================== // getGoroutineCount returns current number of goroutines func getGoroutineCount() int { return runtime.NumGoroutine() } // waitForGoroutineCount waits for goroutine count to stabilize func waitForGoroutineCount(target int, timeout time.Duration) int { deadline := time.Now().Add(timeout) var count int for time.Now().Before(deadline) { count = getGoroutineCount() if count <= target { return count } runtime.Gosched() time.Sleep(10 * time.Millisecond) } return count } // TestPoolNoGoroutineLeak tests that pool doesn't leak goroutines after stop func TestPoolNoGoroutineLeak(t *testing.T) { // Get baseline goroutine count runtime.GC() time.Sleep(50 * time.Millisecond) baseline := getGoroutineCount() // Create and start pool exec := executor.NewDryRunWithDelay(10 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 5, QueueSize: 100, }) p.SetExecutor(exec) p.Start() // Verify workers are running afterStart := getGoroutineCount() assert.Greater(t, afterStart, baseline, "Should have more goroutines after start") // Submit some jobs ctx := types.NewContext(context.Background(), nil) robot := createTestRobot("robot_1", "team_1", 5, 10, 5) for i := 0; i < 10; i++ { p.Submit(ctx, robot, types.TriggerClock, nil) } // Wait for jobs to complete time.Sleep(300 * time.Millisecond) // Stop pool p.Stop() // Wait for goroutines to clean up finalCount := waitForGoroutineCount(baseline+2, 500*time.Millisecond) // Allow small variance (test framework goroutines) assert.LessOrEqual(t, finalCount, baseline+2, "Goroutine count should return to near baseline after stop (baseline=%d, final=%d)", baseline, finalCount) } // TestPoolMultipleStartStop tests no leak with multiple start/stop cycles func TestPoolMultipleStartStop(t *testing.T) { runtime.GC() time.Sleep(50 * time.Millisecond) baseline := getGoroutineCount() exec := executor.NewDryRunWithDelay(5 * time.Millisecond) for i := 0; i < 5; i++ { p := pool.NewWithConfig(&pool.Config{ WorkerSize: 3, QueueSize: 50, }) p.SetExecutor(exec) p.Start() // Submit a few jobs ctx := types.NewContext(context.Background(), nil) robot := createTestRobot("robot_1", "team_1", 5, 10, 5) for j := 0; j < 5; j++ { p.Submit(ctx, robot, types.TriggerClock, nil) } time.Sleep(100 * time.Millisecond) p.Stop() } // Wait for cleanup finalCount := waitForGoroutineCount(baseline+2, 500*time.Millisecond) assert.LessOrEqual(t, finalCount, baseline+2, "Goroutine count should return to near baseline after multiple cycles (baseline=%d, final=%d)", baseline, finalCount) } // TestPoolStopWithoutJobs tests no leak when stopping pool with no jobs submitted func TestPoolStopWithoutJobs(t *testing.T) { runtime.GC() time.Sleep(50 * time.Millisecond) baseline := getGoroutineCount() exec := executor.New() p := pool.NewWithConfig(&pool.Config{ WorkerSize: 10, QueueSize: 100, }) p.SetExecutor(exec) p.Start() // Immediately stop without submitting any jobs p.Stop() finalCount := waitForGoroutineCount(baseline+2, 500*time.Millisecond) assert.LessOrEqual(t, finalCount, baseline+2, "Goroutine count should return to near baseline (baseline=%d, final=%d)", baseline, finalCount) } // TestPoolStopWithPendingJobs tests no leak when stopping with jobs in queue func TestPoolStopWithPendingJobs(t *testing.T) { runtime.GC() time.Sleep(50 * time.Millisecond) baseline := getGoroutineCount() // Use slow executor so jobs stay in queue exec := executor.NewDryRunWithDelay(500 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 1, // only 1 worker QueueSize: 100, }) p.SetExecutor(exec) p.Start() // Submit many jobs (most will be queued) ctx := types.NewContext(context.Background(), nil) robot := createTestRobot("robot_1", "team_1", 5, 50, 5) for i := 0; i < 20; i++ { p.Submit(ctx, robot, types.TriggerClock, nil) } // Stop immediately (some jobs still in queue) time.Sleep(50 * time.Millisecond) p.Stop() finalCount := waitForGoroutineCount(baseline+2, 500*time.Millisecond) assert.LessOrEqual(t, finalCount, baseline+2, "Goroutine count should return to near baseline even with pending jobs (baseline=%d, final=%d)", baseline, finalCount) } // TestPoolConcurrentStartStop tests no leak with concurrent start/stop func TestPoolConcurrentStartStop(t *testing.T) { runtime.GC() time.Sleep(50 * time.Millisecond) baseline := getGoroutineCount() exec := executor.NewDryRunWithDelay(10 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 5, QueueSize: 100, }) p.SetExecutor(exec) // Start pool p.Start() // Concurrent operations done := make(chan bool, 3) // Goroutine 1: Submit jobs go func() { ctx := types.NewContext(context.Background(), nil) robot := createTestRobot("robot_1", "team_1", 5, 10, 5) for i := 0; i < 20; i++ { p.Submit(ctx, robot, types.TriggerClock, nil) time.Sleep(5 * time.Millisecond) } done <- true }() // Goroutine 2: Check status go func() { for i := 0; i < 20; i++ { _ = p.Running() _ = p.Queued() time.Sleep(5 * time.Millisecond) } done <- true }() // Wait for operations <-done <-done // Stop pool p.Stop() finalCount := waitForGoroutineCount(baseline+2, 500*time.Millisecond) assert.LessOrEqual(t, finalCount, baseline+2, "Goroutine count should return to near baseline after concurrent ops (baseline=%d, final=%d)", baseline, finalCount) } // TestWorkerGoroutinesCleanup tests that worker goroutines are properly cleaned up func TestWorkerGoroutinesCleanup(t *testing.T) { runtime.GC() time.Sleep(50 * time.Millisecond) baseline := getGoroutineCount() exec := executor.NewDryRunWithDelay(10 * time.Millisecond) // Create pool with many workers p := pool.NewWithConfig(&pool.Config{ WorkerSize: 20, QueueSize: 100, }) p.SetExecutor(exec) p.Start() // Should have baseline + 20 workers afterStart := getGoroutineCount() assert.GreaterOrEqual(t, afterStart, baseline+20, "Should have at least 20 worker goroutines") // Stop pool p.Stop() // All worker goroutines should be cleaned up finalCount := waitForGoroutineCount(baseline+2, 500*time.Millisecond) assert.LessOrEqual(t, finalCount, baseline+2, "All worker goroutines should be cleaned up (baseline=%d, final=%d)", baseline, finalCount) } // TestPoolLongRunningJobsNoLeak tests no leak with long-running jobs func TestPoolLongRunningJobsNoLeak(t *testing.T) { runtime.GC() time.Sleep(50 * time.Millisecond) baseline := getGoroutineCount() exec := executor.NewDryRunWithDelay(200 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 3, QueueSize: 100, }) p.SetExecutor(exec) p.Start() // Submit jobs ctx := types.NewContext(context.Background(), nil) robot := createTestRobot("robot_1", "team_1", 5, 10, 5) for i := 0; i < 5; i++ { p.Submit(ctx, robot, types.TriggerClock, nil) } // Wait for some jobs to complete time.Sleep(500 * time.Millisecond) // Stop pool p.Stop() finalCount := waitForGoroutineCount(baseline+2, 500*time.Millisecond) assert.LessOrEqual(t, finalCount, baseline+2, "No goroutine leak after long-running jobs (baseline=%d, final=%d)", baseline, finalCount) } // TestQueueNoGoroutineLeak tests that queue operations don't leak goroutines func TestQueueNoGoroutineLeak(t *testing.T) { runtime.GC() time.Sleep(50 * time.Millisecond) baseline := getGoroutineCount() // Create queue and perform many operations pq := pool.NewPriorityQueue(1000) // Enqueue many items for i := 0; i < 500; i++ { robot := createTestRobot("robot_"+string(rune('A'+i%26)), "team_1", 5, 100, 5) pq.Enqueue(&pool.QueueItem{ Robot: robot, Trigger: types.TriggerClock, }) } // Dequeue all items for pq.Size() > 0 { pq.Dequeue() } runtime.GC() time.Sleep(50 * time.Millisecond) finalCount := getGoroutineCount() assert.LessOrEqual(t, finalCount, baseline+2, "Queue operations should not leak goroutines (baseline=%d, final=%d)", baseline, finalCount) } ================================================ FILE: agent/robot/pool/pool.go ================================================ package pool import ( "fmt" "sync" "sync/atomic" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/robot/utils" ) // Default configuration values const ( DefaultWorkerSize = 10 // default number of workers DefaultQueueSize = 100 // default global queue size ) // Config holds pool configuration type Config struct { WorkerSize int // number of workers (default: 10) QueueSize int // global queue size (default: 100) } // DefaultConfig returns default pool configuration func DefaultConfig() *Config { return &Config{ WorkerSize: DefaultWorkerSize, QueueSize: DefaultQueueSize, } } // ExecutorFactory creates an executor based on the mode type ExecutorFactory func(mode types.ExecutorMode) types.Executor // OnCompleteCallback is called when an execution completes (success or failure) // Parameters: execID, memberID, status type OnCompleteCallback func(execID, memberID string, status types.ExecStatus) // Pool implements types.Pool interface // Manages a pool of workers that execute robot jobs from a priority queue type Pool struct { size int // number of workers queue *PriorityQueue // priority queue for pending jobs executor types.Executor // default executor for running jobs executorFactory ExecutorFactory // optional: factory for mode-specific executors onComplete OnCompleteCallback // optional: callback when execution completes workers []*Worker // worker goroutines running atomic.Int32 // number of currently running jobs wg sync.WaitGroup // wait group for graceful shutdown started bool // whether pool has been started mu sync.RWMutex // protects started flag } // New creates a new pool instance with default configuration func New() *Pool { return NewWithConfig(nil) } // NewWithConfig creates a new pool instance with custom configuration func NewWithConfig(config *Config) *Pool { if config == nil { config = DefaultConfig() } // Apply defaults for zero values workerSize := config.WorkerSize if workerSize <= 0 { workerSize = DefaultWorkerSize } queueSize := config.QueueSize if queueSize <= 0 { queueSize = DefaultQueueSize } return &Pool{ size: workerSize, queue: NewPriorityQueue(queueSize), } } // SetExecutor sets the default executor for the pool // Must be called before Start() func (p *Pool) SetExecutor(executor types.Executor) { p.executor = executor } // SetExecutorFactory sets the executor factory for mode-specific executors // If set, the factory is used to create executors based on ExecutorMode func (p *Pool) SetExecutorFactory(factory ExecutorFactory) { p.executorFactory = factory } // SetOnComplete sets the callback for execution completion // Called when an execution finishes (completed, failed, or cancelled) func (p *Pool) SetOnComplete(callback OnCompleteCallback) { p.onComplete = callback } // GetExecutor returns the appropriate executor for the given mode // If factory is set and mode is specified, uses factory; otherwise uses default func (p *Pool) GetExecutor(mode types.ExecutorMode) types.Executor { // If factory is set and mode is specified, use factory if p.executorFactory != nil && mode != "" { return p.executorFactory(mode) } // Otherwise use default executor return p.executor } // Start starts the worker pool func (p *Pool) Start() error { p.mu.Lock() defer p.mu.Unlock() if p.started { return fmt.Errorf("pool already started") } if p.executor == nil { return fmt.Errorf("executor not set, call SetExecutor() first") } // Create and start workers p.workers = make([]*Worker, p.size) for i := 0; i < p.size; i++ { worker := newWorker(i+1, p, &p.wg) p.workers[i] = worker worker.start() } p.started = true return nil } // Stop stops the worker pool gracefully // Waits for all running jobs to complete func (p *Pool) Stop() error { p.mu.Lock() if !p.started { p.mu.Unlock() return nil // already stopped or never started } p.started = false p.mu.Unlock() // Stop all workers for _, worker := range p.workers { worker.stop() } // Wait for all workers to finish p.wg.Wait() return nil } // Submit submits a robot execution to the pool // Returns execution ID if successfully queued, error otherwise func (p *Pool) Submit(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}) (string, error) { return p.SubmitWithMode(ctx, robot, trigger, data, "") } // GenerateExecID generates a new execution ID // Exported so Manager can pre-generate IDs for tracking func GenerateExecID() string { return utils.NewID() } // SubmitWithMode submits a robot execution with specified executor mode // executorMode: optional, overrides robot's config if provided // Returns execution ID if successfully queued, error otherwise // Note: This method does not support execution control (pause/resume) func (p *Pool) SubmitWithMode(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}, executorMode types.ExecutorMode) (string, error) { execID := GenerateExecID() return p.submitWithIDAndMode(ctx, robot, trigger, data, execID, executorMode, nil) } // SubmitWithID submits a robot execution with a pre-generated execution ID // This is used when the caller needs to track the execution before submission func (p *Pool) SubmitWithID(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}, execID string, control types.ExecutionControl) (string, error) { return p.submitWithIDAndMode(ctx, robot, trigger, data, execID, "", control) } // submitWithIDAndMode is the internal implementation that handles both cases func (p *Pool) submitWithIDAndMode(ctx *types.Context, robot *types.Robot, trigger types.TriggerType, data interface{}, execID string, executorMode types.ExecutorMode, control types.ExecutionControl) (string, error) { p.mu.RLock() if !p.started { p.mu.RUnlock() return "", fmt.Errorf("pool not started") } p.mu.RUnlock() if robot == nil { return "", fmt.Errorf("robot cannot be nil") } // Create queue item with the provided ID and control item := &QueueItem{ Robot: robot, Ctx: ctx, Trigger: trigger, Data: data, ExecutorMode: executorMode, ExecID: execID, Control: control, } // Try to add to queue if !p.queue.Enqueue(item) { return "", fmt.Errorf("queue full (max %d items)", p.queue.maxSize) } return execID, nil } // Running returns number of currently running jobs func (p *Pool) Running() int { return int(p.running.Load()) } // Queued returns number of queued jobs func (p *Pool) Queued() int { return p.queue.Size() } // incrementRunning increments the running counter func (p *Pool) incrementRunning() { p.running.Add(1) } // decrementRunning decrements the running counter func (p *Pool) decrementRunning() { p.running.Add(-1) } // Size returns the configured pool size func (p *Pool) Size() int { return p.size } // QueueSize returns the configured queue size func (p *Pool) QueueSize() int { return p.queue.maxSize } // IsStarted returns true if the pool has been started func (p *Pool) IsStarted() bool { p.mu.RLock() defer p.mu.RUnlock() return p.started } ================================================ FILE: agent/robot/pool/pool_test.go ================================================ package pool_test import ( "context" "testing" "time" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/robot/executor" "github.com/yaoapp/yao/agent/robot/pool" "github.com/yaoapp/yao/agent/robot/types" ) // createTestRobot creates a robot for testing with specified quota func createTestRobot(memberID, teamID string, maxConcurrent, queueSize, priority int) *types.Robot { return &types.Robot{ MemberID: memberID, TeamID: teamID, DisplayName: "Test Robot " + memberID, Status: types.RobotIdle, AutonomousMode: true, Config: &types.Config{ Identity: &types.Identity{Role: "Test"}, Quota: &types.Quota{ Max: maxConcurrent, Queue: queueSize, Priority: priority, }, }, } } // createTestContext creates a context for testing func createTestContext() *types.Context { return types.NewContext(context.Background(), nil) } // TestPoolStartStop tests pool start and stop lifecycle func TestPoolStartStop(t *testing.T) { p := pool.New() exec := executor.New() p.SetExecutor(exec) t.Run("start pool", func(t *testing.T) { err := p.Start() assert.NoError(t, err) assert.True(t, p.IsStarted()) }) t.Run("start already started pool", func(t *testing.T) { err := p.Start() assert.Error(t, err) assert.Contains(t, err.Error(), "already started") }) t.Run("stop pool", func(t *testing.T) { err := p.Stop() assert.NoError(t, err) assert.False(t, p.IsStarted()) }) t.Run("stop already stopped pool", func(t *testing.T) { err := p.Stop() assert.NoError(t, err) // should not error }) } // TestPoolSubmitWithoutStart tests submitting to unstarted pool func TestPoolSubmitWithoutStart(t *testing.T) { p := pool.New() exec := executor.New() p.SetExecutor(exec) robot := createTestRobot("robot_1", "team_1", 2, 10, 5) ctx := createTestContext() _, err := p.Submit(ctx, robot, types.TriggerClock, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "not started") } // TestPoolSubmitNilRobot tests submitting nil robot func TestPoolSubmitNilRobot(t *testing.T) { p := pool.New() exec := executor.New() p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() _, err := p.Submit(ctx, nil, types.TriggerClock, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "cannot be nil") } // TestPoolBasicExecution tests basic job execution func TestPoolBasicExecution(t *testing.T) { exec := executor.NewDryRunWithDelay(50 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 5, QueueSize: 100, }) p.SetExecutor(exec) p.Start() defer p.Stop() robot := createTestRobot("robot_1", "team_1", 2, 10, 5) ctx := createTestContext() // Submit a job execID, err := p.Submit(ctx, robot, types.TriggerClock, nil) assert.NoError(t, err) assert.NotEmpty(t, execID) // Wait for execution (worker polls every 100ms + 50ms exec + buffer) time.Sleep(300 * time.Millisecond) // Verify execution completed assert.Equal(t, 1, exec.ExecCount()) // Note: CurrentCount may briefly be non-zero during execution, use Eventually pattern assert.Eventually(t, func() bool { return exec.CurrentCount() == 0 }, 500*time.Millisecond, 50*time.Millisecond, "CurrentCount should be 0 after execution") } // TestPoolConcurrencyLimit tests global worker limit func TestPoolConcurrencyLimit(t *testing.T) { exec := executor.NewDryRunWithDelay(200 * time.Millisecond) // longer delay p := pool.NewWithConfig(&pool.Config{ WorkerSize: 3, // only 3 workers QueueSize: 100, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() // Create robots with high quota (won't be the bottleneck) robots := make([]*types.Robot, 10) for i := 0; i < 10; i++ { robots[i] = createTestRobot( "robot_"+string(rune('A'+i)), "team_1", 5, // max concurrent per robot 20, // queue size per robot 5, // priority ) } // Submit 10 jobs for i := 0; i < 10; i++ { _, err := p.Submit(ctx, robots[i], types.TriggerClock, nil) assert.NoError(t, err) } // Wait for workers to pick up jobs (worker polls every 100ms) time.Sleep(200 * time.Millisecond) // Should have at most 3 running (worker limit) running := p.Running() assert.LessOrEqual(t, running, 3, "Should not exceed worker limit") // Wait for all to complete (10 jobs / 3 workers * 200ms each = ~700ms + buffer) // Use Eventually to handle CI timing variations assert.Eventually(t, func() bool { return exec.ExecCount() >= 10 }, 2*time.Second, 100*time.Millisecond, "All 10 jobs should complete") } // TestRobotConcurrencyLimit tests per-robot concurrent execution limit func TestRobotConcurrencyLimit(t *testing.T) { exec := executor.NewDryRunWithDelay(100 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 10, // plenty of workers QueueSize: 100, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() // Create robot with Max=2 (can only run 2 at a time) robot := createTestRobot("robot_limited", "team_1", 2, 20, 5) // Submit 5 jobs for the same robot for i := 0; i < 5; i++ { _, err := p.Submit(ctx, robot, types.TriggerClock, nil) assert.NoError(t, err) } // Wait a bit for execution to start time.Sleep(150 * time.Millisecond) // Robot should have at most 2 running (Quota.Max=2) runningCount := robot.RunningCount() assert.LessOrEqual(t, runningCount, 2, "Robot should not exceed Quota.Max") // Wait for all to complete (with re-enqueue, need more time) // 5 jobs with Max=2: ~3 batches * 100ms exec + poll overhead time.Sleep(800 * time.Millisecond) // All 5 jobs should eventually execute assert.GreaterOrEqual(t, exec.ExecCount(), 5, "All jobs should eventually execute") } // TestRobotQueueLimit tests per-robot queue limit func TestRobotQueueLimit(t *testing.T) { exec := executor.NewDryRunWithDelay(200 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 2, QueueSize: 100, // global queue is large }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() // Create robot with small queue limit robot := createTestRobot("robot_small_queue", "team_1", 1, 3, 5) // Queue=3 // Submit jobs until queue limit is reached successCount := 0 for i := 0; i < 10; i++ { _, err := p.Submit(ctx, robot, types.TriggerClock, nil) if err == nil { successCount++ } } // Should only accept up to Queue limit (some may have started executing) // Max accepted = Queue(3) + Max(1) = 4 (1 running + 3 in queue) assert.LessOrEqual(t, successCount, 4, "Should respect robot queue limit") assert.GreaterOrEqual(t, successCount, 1, "Should accept at least 1 job") } // TestGlobalQueueLimit tests global queue limit func TestGlobalQueueLimit(t *testing.T) { exec := executor.NewDryRunWithDelay(500 * time.Millisecond) // slow execution p := pool.NewWithConfig(&pool.Config{ WorkerSize: 1, // only 1 worker QueueSize: 5, // small global queue }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() // Create multiple robots with large queue limits successCount := 0 for i := 0; i < 20; i++ { robot := createTestRobot( "robot_"+string(rune('A'+i%26)), "team_1", 5, // large max 20, // large per-robot queue 5, ) _, err := p.Submit(ctx, robot, types.TriggerClock, nil) if err == nil { successCount++ } } // Should only accept up to global queue limit + running // Max = QueueSize(5) + WorkerSize(1) = 6 assert.LessOrEqual(t, successCount, 6, "Should respect global queue limit") } // TestPriorityOrder tests that higher priority jobs execute first func TestPriorityOrder(t *testing.T) { exec := executor.NewDryRunWithDelay(50 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 1, // single worker to ensure order QueueSize: 100, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() // Create robots with different priorities robotLow := createTestRobot("robot_low", "team_1", 5, 20, 1) // priority 1 robotMed := createTestRobot("robot_med", "team_1", 5, 20, 5) // priority 5 robotHigh := createTestRobot("robot_high", "team_1", 5, 20, 10) // priority 10 // Submit in low-to-high order p.Submit(ctx, robotLow, types.TriggerClock, nil) p.Submit(ctx, robotMed, types.TriggerClock, nil) p.Submit(ctx, robotHigh, types.TriggerClock, nil) // Wait for all to complete (3 jobs * (100ms poll + 50ms exec) = ~450ms + buffer) // Use Eventually for CI timing variations assert.Eventually(t, func() bool { return exec.ExecCount() >= 3 }, 1*time.Second, 50*time.Millisecond, "All 3 jobs should complete") } // TestTriggerTypePriority tests that human triggers have higher priority than clock func TestTriggerTypePriority(t *testing.T) { exec := executor.NewDryRunWithDelay(50 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 1, // single worker QueueSize: 100, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() // Same robot, same priority, different trigger types robot := createTestRobot("robot_1", "team_1", 5, 20, 5) // Submit clock first, then human p.Submit(ctx, robot, types.TriggerClock, nil) p.Submit(ctx, robot, types.TriggerHuman, nil) // should execute first // Wait for all to complete (2 jobs * (100ms poll + 50ms exec) = ~300ms + buffer) assert.Eventually(t, func() bool { return exec.ExecCount() >= 2 }, 1*time.Second, 50*time.Millisecond, "Both jobs should complete") } // TestMultipleRobotsFairness tests that multiple robots get fair access func TestMultipleRobotsFairness(t *testing.T) { exec := executor.NewDryRunWithDelay(30 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 5, QueueSize: 100, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() // Create 3 robots with same priority robotA := createTestRobot("robot_A", "team_1", 2, 10, 5) robotB := createTestRobot("robot_B", "team_1", 2, 10, 5) robotC := createTestRobot("robot_C", "team_1", 2, 10, 5) // Submit jobs for each robot for i := 0; i < 6; i++ { p.Submit(ctx, robotA, types.TriggerClock, nil) p.Submit(ctx, robotB, types.TriggerClock, nil) p.Submit(ctx, robotC, types.TriggerClock, nil) } // Wait for all to complete // 18 jobs with Quota.Max=2 per robot, 5 workers, 30ms each // Jobs are batched by robot quota, use Eventually for CI timing assert.Eventually(t, func() bool { return exec.ExecCount() >= 18 }, 3*time.Second, 100*time.Millisecond, "All 18 jobs should complete") } // TestGracefulShutdown tests that pool waits for running jobs on shutdown func TestGracefulShutdown(t *testing.T) { exec := executor.NewDryRunWithDelay(200 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 2, QueueSize: 10, }) p.SetExecutor(exec) p.Start() ctx := createTestContext() robot := createTestRobot("robot_1", "team_1", 5, 20, 5) // Submit 2 jobs p.Submit(ctx, robot, types.TriggerClock, nil) p.Submit(ctx, robot, types.TriggerClock, nil) // Wait for workers to pick up jobs (poll every 100ms) time.Sleep(150 * time.Millisecond) // Verify jobs are running assert.GreaterOrEqual(t, p.Running(), 1, "Should have at least 1 running job") // Stop - workers will finish their current tick cycle p.Stop() // After stop, verify jobs completed assert.GreaterOrEqual(t, exec.ExecCount(), 1, "Should have executed at least 1 job") } // TestDefaultConfig tests default configuration values func TestDefaultConfig(t *testing.T) { config := pool.DefaultConfig() assert.Equal(t, pool.DefaultWorkerSize, config.WorkerSize) assert.Equal(t, pool.DefaultQueueSize, config.QueueSize) } // TestPoolWithNilConfig tests pool creation with nil config func TestPoolWithNilConfig(t *testing.T) { p := pool.NewWithConfig(nil) assert.Equal(t, pool.DefaultWorkerSize, p.Size()) assert.Equal(t, pool.DefaultQueueSize, p.QueueSize()) } // TestPoolWithZeroConfig tests pool creation with zero values func TestPoolWithZeroConfig(t *testing.T) { p := pool.NewWithConfig(&pool.Config{ WorkerSize: 0, QueueSize: 0, }) // Should use defaults for zero values assert.Equal(t, pool.DefaultWorkerSize, p.Size()) assert.Equal(t, pool.DefaultQueueSize, p.QueueSize()) } // TestPoolWithoutExecutor tests starting pool without executor func TestPoolWithoutExecutor(t *testing.T) { p := pool.New() // Don't set executor err := p.Start() assert.Error(t, err) assert.Contains(t, err.Error(), "executor not set") } ================================================ FILE: agent/robot/pool/queue.go ================================================ package pool import ( "container/heap" "sync" "time" "github.com/yaoapp/yao/agent/robot/types" ) // QueueItem represents a job waiting in the queue type QueueItem struct { Robot *types.Robot Ctx *types.Context Trigger types.TriggerType Data interface{} ExecutorMode types.ExecutorMode // optional: override robot's executor mode ExecID string // pre-generated execution ID for tracking Control types.ExecutionControl // execution control for pause/resume/stop EnqueueTime time.Time Priority int // calculated priority for sorting Index int // index in heap (managed by container/heap) } // PriorityQueue implements a priority queue for robot executions // Sorted by: robot priority > trigger type priority > wait time type PriorityQueue struct { items []*QueueItem mu sync.RWMutex maxSize int // global queue size limit robotCount map[string]int // per-robot queue count: memberID -> count } // NewPriorityQueue creates a new priority queue func NewPriorityQueue(maxSize int) *PriorityQueue { pq := &PriorityQueue{ items: make([]*QueueItem, 0), maxSize: maxSize, robotCount: make(map[string]int), } heap.Init(pq) return pq } // Enqueue adds an item to the queue // Returns false if: // - Global queue is full (maxSize) // - Robot's queue limit reached (Quota.Queue) func (pq *PriorityQueue) Enqueue(item *QueueItem) bool { pq.mu.Lock() defer pq.mu.Unlock() // Check 1: Global queue limit if pq.maxSize > 0 && len(pq.items) >= pq.maxSize { return false // global queue full } // Check 2: Per-robot queue limit (prevents single robot from hogging the queue) if item.Robot != nil { memberID := item.Robot.MemberID robotQueueLimit := 10 // default if item.Robot.Config != nil && item.Robot.Config.Quota != nil { robotQueueLimit = item.Robot.Config.Quota.GetQueue() } if pq.robotCount[memberID] >= robotQueueLimit { return false // robot's queue limit reached } // Increment robot's queue count pq.robotCount[memberID]++ } item.Priority = calculatePriority(item) item.EnqueueTime = time.Now() heap.Push(pq, item) return true } // Dequeue removes and returns the highest priority item // Returns nil if queue is empty func (pq *PriorityQueue) Dequeue() *QueueItem { pq.mu.Lock() defer pq.mu.Unlock() if len(pq.items) == 0 { return nil } item := heap.Pop(pq).(*QueueItem) // Decrement robot's queue count if item.Robot != nil { memberID := item.Robot.MemberID if pq.robotCount[memberID] > 0 { pq.robotCount[memberID]-- } // Clean up if count reaches zero if pq.robotCount[memberID] == 0 { delete(pq.robotCount, memberID) } } return item } // Size returns the number of items in the queue (thread-safe) func (pq *PriorityQueue) Size() int { pq.mu.RLock() defer pq.mu.RUnlock() return len(pq.items) } // IsFull returns true if queue has reached max capacity func (pq *PriorityQueue) IsFull() bool { pq.mu.RLock() defer pq.mu.RUnlock() return pq.maxSize > 0 && len(pq.items) >= pq.maxSize } // RobotQueuedCount returns the number of queued items for a specific robot func (pq *PriorityQueue) RobotQueuedCount(memberID string) int { pq.mu.RLock() defer pq.mu.RUnlock() return pq.robotCount[memberID] } // ==================== heap.Interface implementation ==================== // These methods are called internally by heap.Push/Pop with lock already held func (pq *PriorityQueue) Len() int { return len(pq.items) } func (pq *PriorityQueue) Less(i, j int) bool { // Higher priority value = higher priority (processed first) // If priority is equal, older items (earlier EnqueueTime) come first if pq.items[i].Priority == pq.items[j].Priority { return pq.items[i].EnqueueTime.Before(pq.items[j].EnqueueTime) } return pq.items[i].Priority > pq.items[j].Priority } func (pq *PriorityQueue) Swap(i, j int) { pq.items[i], pq.items[j] = pq.items[j], pq.items[i] pq.items[i].Index = i pq.items[j].Index = j } // Push is required by heap.Interface // Note: This is called by heap.Push(), not directly func (pq *PriorityQueue) Push(x interface{}) { item := x.(*QueueItem) item.Index = len(pq.items) pq.items = append(pq.items, item) } // Pop is required by heap.Interface // Note: This is called by heap.Pop(), not directly func (pq *PriorityQueue) Pop() interface{} { old := pq.items n := len(old) item := old[n-1] old[n-1] = nil // avoid memory leak item.Index = -1 // mark as removed pq.items = old[0 : n-1] return item } // ==================== Priority Calculation ==================== // calculatePriority calculates the priority score for a queue item // Priority = robot_priority * 1000 + trigger_priority * 100 // Higher score = higher priority func calculatePriority(item *QueueItem) int { priority := 0 // 1. Robot priority (from config, 1-10, default 5) if item.Robot != nil && item.Robot.Config != nil && item.Robot.Config.Quota != nil { robotPriority := item.Robot.Config.Quota.GetPriority() priority += robotPriority * 1000 } else { priority += 5000 // default robot priority } // 2. Trigger type priority // Human intervention > Event > Clock triggerPriority := getTriggerPriority(item.Trigger) priority += triggerPriority * 100 return priority } // getTriggerPriority returns priority value for trigger type func getTriggerPriority(trigger types.TriggerType) int { switch trigger { case types.TriggerHuman: return 10 // highest priority case types.TriggerEvent: return 5 // medium priority case types.TriggerClock: return 1 // lowest priority default: return 0 } } ================================================ FILE: agent/robot/pool/queue_test.go ================================================ package pool_test import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/robot/pool" "github.com/yaoapp/yao/agent/robot/types" ) // ==================== Priority Queue Basic Tests ==================== // TestQueueNewPriorityQueue tests queue creation func TestQueueNewPriorityQueue(t *testing.T) { t.Run("create with positive size", func(t *testing.T) { pq := pool.NewPriorityQueue(100) assert.NotNil(t, pq) assert.Equal(t, 0, pq.Size()) assert.False(t, pq.IsFull()) }) t.Run("create with zero size (unlimited)", func(t *testing.T) { pq := pool.NewPriorityQueue(0) assert.NotNil(t, pq) assert.False(t, pq.IsFull()) // never full when maxSize=0 }) } // TestQueueEnqueueDequeue tests basic enqueue and dequeue func TestQueueEnqueueDequeue(t *testing.T) { pq := pool.NewPriorityQueue(100) robot := createTestRobot("robot_1", "team_1", 5, 10, 5) ctx := createTestContext() t.Run("enqueue single item", func(t *testing.T) { item := &pool.QueueItem{ Robot: robot, Ctx: ctx, Trigger: types.TriggerClock, Data: "test_data", } ok := pq.Enqueue(item) assert.True(t, ok) assert.Equal(t, 1, pq.Size()) }) t.Run("dequeue single item", func(t *testing.T) { item := pq.Dequeue() assert.NotNil(t, item) assert.Equal(t, "robot_1", item.Robot.MemberID) assert.Equal(t, "test_data", item.Data) assert.Equal(t, 0, pq.Size()) }) t.Run("dequeue from empty queue", func(t *testing.T) { item := pq.Dequeue() assert.Nil(t, item) }) } // TestQueueSize tests Size method func TestQueueSize(t *testing.T) { pq := pool.NewPriorityQueue(100) robot := createTestRobot("robot_1", "team_1", 5, 10, 5) assert.Equal(t, 0, pq.Size()) // Add 5 items for i := 0; i < 5; i++ { pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock}) } assert.Equal(t, 5, pq.Size()) // Remove 2 items pq.Dequeue() pq.Dequeue() assert.Equal(t, 3, pq.Size()) } // ==================== Global Queue Limit Tests ==================== // TestQueueGlobalLimit tests global queue size limit func TestQueueGlobalLimit(t *testing.T) { pq := pool.NewPriorityQueue(5) // max 5 items // Create different robots to avoid per-robot limit for i := 0; i < 10; i++ { robot := createTestRobot("robot_"+string(rune('A'+i)), "team_1", 5, 10, 5) item := &pool.QueueItem{Robot: robot, Trigger: types.TriggerClock} ok := pq.Enqueue(item) if i < 5 { assert.True(t, ok, "Should accept item %d", i) } else { assert.False(t, ok, "Should reject item %d (queue full)", i) } } assert.Equal(t, 5, pq.Size()) assert.True(t, pq.IsFull()) } // TestQueueUnlimitedSize tests queue with no size limit (maxSize=0) func TestQueueUnlimitedSize(t *testing.T) { pq := pool.NewPriorityQueue(0) // unlimited // Add many items for i := 0; i < 100; i++ { robot := createTestRobot("robot_"+string(rune('A'+i%26)), "team_1", 5, 1000, 5) item := &pool.QueueItem{Robot: robot, Trigger: types.TriggerClock} ok := pq.Enqueue(item) assert.True(t, ok) } assert.Equal(t, 100, pq.Size()) assert.False(t, pq.IsFull()) // never full } // ==================== Per-Robot Queue Limit Tests ==================== // TestQueuePerRobotLimit tests per-robot queue limit (Quota.Queue) func TestQueuePerRobotLimit(t *testing.T) { pq := pool.NewPriorityQueue(100) // large global limit // Robot with Queue=3 robot := createTestRobot("robot_limited", "team_1", 5, 3, 5) // Try to add 10 items for same robot successCount := 0 for i := 0; i < 10; i++ { item := &pool.QueueItem{Robot: robot, Trigger: types.TriggerClock} if pq.Enqueue(item) { successCount++ } } // Should only accept Queue(3) items assert.Equal(t, 3, successCount) assert.Equal(t, 3, pq.Size()) assert.Equal(t, 3, pq.RobotQueuedCount("robot_limited")) } // TestQueueMultipleRobotsIndependentLimits tests that each robot has independent queue limit func TestQueueMultipleRobotsIndependentLimits(t *testing.T) { pq := pool.NewPriorityQueue(100) // Robot A: Queue=2 robotA := createTestRobot("robot_A", "team_1", 5, 2, 5) // Robot B: Queue=3 robotB := createTestRobot("robot_B", "team_1", 5, 3, 5) // Add items for Robot A for i := 0; i < 5; i++ { pq.Enqueue(&pool.QueueItem{Robot: robotA, Trigger: types.TriggerClock}) } assert.Equal(t, 2, pq.RobotQueuedCount("robot_A")) // Add items for Robot B for i := 0; i < 5; i++ { pq.Enqueue(&pool.QueueItem{Robot: robotB, Trigger: types.TriggerClock}) } assert.Equal(t, 3, pq.RobotQueuedCount("robot_B")) // Total in queue assert.Equal(t, 5, pq.Size()) } // TestQueueRobotCountAfterDequeue tests robot count decrements after dequeue func TestQueueRobotCountAfterDequeue(t *testing.T) { pq := pool.NewPriorityQueue(100) robot := createTestRobot("robot_1", "team_1", 5, 10, 5) // Add 3 items for i := 0; i < 3; i++ { pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock}) } assert.Equal(t, 3, pq.RobotQueuedCount("robot_1")) // Dequeue 2 pq.Dequeue() assert.Equal(t, 2, pq.RobotQueuedCount("robot_1")) pq.Dequeue() assert.Equal(t, 1, pq.RobotQueuedCount("robot_1")) // Dequeue last pq.Dequeue() assert.Equal(t, 0, pq.RobotQueuedCount("robot_1")) } // TestQueueNilRobot tests handling of nil robot func TestQueueNilRobot(t *testing.T) { pq := pool.NewPriorityQueue(100) // Item with nil robot should still be enqueued item := &pool.QueueItem{ Robot: nil, Trigger: types.TriggerClock, } ok := pq.Enqueue(item) assert.True(t, ok) assert.Equal(t, 1, pq.Size()) // Dequeue should work dequeued := pq.Dequeue() assert.NotNil(t, dequeued) assert.Nil(t, dequeued.Robot) } // TestQueueDefaultRobotQueueLimit tests default queue limit when Quota is nil func TestQueueDefaultRobotQueueLimit(t *testing.T) { pq := pool.NewPriorityQueue(100) // Robot without Config robot := &types.Robot{ MemberID: "robot_no_config", TeamID: "team_1", } // Should use default queue limit (10) successCount := 0 for i := 0; i < 15; i++ { item := &pool.QueueItem{Robot: robot, Trigger: types.TriggerClock} if pq.Enqueue(item) { successCount++ } } assert.Equal(t, 10, successCount) // default queue limit } // ==================== Priority Tests ==================== // TestQueuePriorityByRobotPriority tests sorting by robot priority func TestQueuePriorityByRobotPriority(t *testing.T) { pq := pool.NewPriorityQueue(100) // Add robots with different priorities (low to high) robotLow := createTestRobot("robot_low", "team_1", 5, 10, 1) robotMed := createTestRobot("robot_med", "team_1", 5, 10, 5) robotHigh := createTestRobot("robot_high", "team_1", 5, 10, 10) // Add in low-to-high order pq.Enqueue(&pool.QueueItem{Robot: robotLow, Trigger: types.TriggerClock}) pq.Enqueue(&pool.QueueItem{Robot: robotMed, Trigger: types.TriggerClock}) pq.Enqueue(&pool.QueueItem{Robot: robotHigh, Trigger: types.TriggerClock}) // Dequeue should return high priority first item1 := pq.Dequeue() assert.Equal(t, "robot_high", item1.Robot.MemberID) item2 := pq.Dequeue() assert.Equal(t, "robot_med", item2.Robot.MemberID) item3 := pq.Dequeue() assert.Equal(t, "robot_low", item3.Robot.MemberID) } // TestQueuePriorityByTriggerType tests sorting by trigger type func TestQueuePriorityByTriggerType(t *testing.T) { pq := pool.NewPriorityQueue(100) // Same robot, different trigger types robot := createTestRobot("robot_1", "team_1", 5, 10, 5) // Add in clock -> event -> human order pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock}) pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerEvent}) pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerHuman}) // Dequeue should return human first (highest trigger priority) item1 := pq.Dequeue() assert.Equal(t, types.TriggerHuman, item1.Trigger) item2 := pq.Dequeue() assert.Equal(t, types.TriggerEvent, item2.Trigger) item3 := pq.Dequeue() assert.Equal(t, types.TriggerClock, item3.Trigger) } // TestQueuePriorityRobotOverTrigger tests that robot priority > trigger priority func TestQueuePriorityRobotOverTrigger(t *testing.T) { pq := pool.NewPriorityQueue(100) // Low priority robot with human trigger robotLow := createTestRobot("robot_low", "team_1", 5, 10, 1) // High priority robot with clock trigger robotHigh := createTestRobot("robot_high", "team_1", 5, 10, 10) pq.Enqueue(&pool.QueueItem{Robot: robotLow, Trigger: types.TriggerHuman}) pq.Enqueue(&pool.QueueItem{Robot: robotHigh, Trigger: types.TriggerClock}) // Robot priority (10*1000=10000) > trigger priority (1*1000+10*100=2000) // So high priority robot should come first even with lower trigger type item1 := pq.Dequeue() assert.Equal(t, "robot_high", item1.Robot.MemberID) item2 := pq.Dequeue() assert.Equal(t, "robot_low", item2.Robot.MemberID) } // TestQueuePriorityByEnqueueTime tests FIFO for same priority func TestQueuePriorityByEnqueueTime(t *testing.T) { pq := pool.NewPriorityQueue(100) // Same robot, same trigger type (same priority) robot := createTestRobot("robot_1", "team_1", 5, 10, 5) // Add items with slight delay to ensure different EnqueueTime pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock, Data: "first"}) time.Sleep(1 * time.Millisecond) pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock, Data: "second"}) time.Sleep(1 * time.Millisecond) pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock, Data: "third"}) // Should dequeue in FIFO order (earlier EnqueueTime first) item1 := pq.Dequeue() assert.Equal(t, "first", item1.Data) item2 := pq.Dequeue() assert.Equal(t, "second", item2.Data) item3 := pq.Dequeue() assert.Equal(t, "third", item3.Data) } // ==================== Concurrency Tests ==================== // TestQueueConcurrentEnqueue tests concurrent enqueue operations func TestQueueConcurrentEnqueue(t *testing.T) { pq := pool.NewPriorityQueue(1000) // Concurrently add items from multiple goroutines done := make(chan bool) for i := 0; i < 10; i++ { go func(id int) { robot := createTestRobot("robot_"+string(rune('A'+id)), "team_1", 5, 100, 5) for j := 0; j < 50; j++ { pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock}) } done <- true }(i) } // Wait for all goroutines for i := 0; i < 10; i++ { <-done } // Should have 10 robots * 50 items = 500 items assert.Equal(t, 500, pq.Size()) } // TestQueueConcurrentDequeue tests concurrent dequeue operations func TestQueueConcurrentDequeue(t *testing.T) { pq := pool.NewPriorityQueue(1000) // Pre-fill queue for i := 0; i < 500; i++ { robot := createTestRobot("robot_"+string(rune('A'+i%10)), "team_1", 5, 100, 5) pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock}) } // Concurrently dequeue from multiple goroutines dequeued := make(chan *pool.QueueItem, 500) done := make(chan bool) for i := 0; i < 10; i++ { go func() { for { item := pq.Dequeue() if item == nil { break } dequeued <- item } done <- true }() } // Wait for all goroutines for i := 0; i < 10; i++ { <-done } close(dequeued) // Count dequeued items count := 0 for range dequeued { count++ } assert.Equal(t, 500, count) assert.Equal(t, 0, pq.Size()) } // TestQueueConcurrentEnqueueDequeue tests concurrent enqueue and dequeue func TestQueueConcurrentEnqueueDequeue(t *testing.T) { pq := pool.NewPriorityQueue(100) // Run for a short time with concurrent operations done := make(chan bool) // Enqueue goroutine go func() { for i := 0; i < 200; i++ { robot := createTestRobot("robot_"+string(rune('A'+i%10)), "team_1", 5, 50, 5) pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock}) time.Sleep(1 * time.Millisecond) } done <- true }() // Dequeue goroutine dequeueCount := 0 go func() { for i := 0; i < 200; i++ { if pq.Dequeue() != nil { dequeueCount++ } time.Sleep(1 * time.Millisecond) } done <- true }() // Wait for both <-done <-done // Should have processed some items (exact count depends on timing) assert.GreaterOrEqual(t, dequeueCount, 1) } // ==================== Edge Cases ==================== // TestQueueIsFull tests IsFull method func TestQueueIsFull(t *testing.T) { t.Run("not full initially", func(t *testing.T) { pq := pool.NewPriorityQueue(5) assert.False(t, pq.IsFull()) }) t.Run("full when at max", func(t *testing.T) { pq := pool.NewPriorityQueue(3) for i := 0; i < 3; i++ { robot := createTestRobot("robot_"+string(rune('A'+i)), "team_1", 5, 10, 5) pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock}) } assert.True(t, pq.IsFull()) }) t.Run("not full after dequeue", func(t *testing.T) { pq := pool.NewPriorityQueue(3) for i := 0; i < 3; i++ { robot := createTestRobot("robot_"+string(rune('A'+i)), "team_1", 5, 10, 5) pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock}) } pq.Dequeue() assert.False(t, pq.IsFull()) }) t.Run("never full when unlimited", func(t *testing.T) { pq := pool.NewPriorityQueue(0) for i := 0; i < 100; i++ { robot := createTestRobot("robot_"+string(rune('A'+i%26)), "team_1", 5, 1000, 5) pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock}) } assert.False(t, pq.IsFull()) }) } // TestQueueRobotQueuedCount tests RobotQueuedCount method func TestQueueRobotQueuedCount(t *testing.T) { pq := pool.NewPriorityQueue(100) t.Run("zero for unknown robot", func(t *testing.T) { assert.Equal(t, 0, pq.RobotQueuedCount("unknown_robot")) }) t.Run("correct count for robot", func(t *testing.T) { robot := createTestRobot("robot_1", "team_1", 5, 10, 5) pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock}) pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock}) assert.Equal(t, 2, pq.RobotQueuedCount("robot_1")) }) t.Run("zero after all dequeued", func(t *testing.T) { robot := createTestRobot("robot_2", "team_1", 5, 10, 5) pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock}) pq.Dequeue() pq.Dequeue() // dequeue robot_1's items too pq.Dequeue() assert.Equal(t, 0, pq.RobotQueuedCount("robot_2")) }) } // TestQueueEnqueueSetsEnqueueTime tests that EnqueueTime is set on enqueue func TestQueueEnqueueSetsEnqueueTime(t *testing.T) { pq := pool.NewPriorityQueue(100) robot := createTestRobot("robot_1", "team_1", 5, 10, 5) before := time.Now() pq.Enqueue(&pool.QueueItem{Robot: robot, Trigger: types.TriggerClock}) after := time.Now() item := pq.Dequeue() assert.True(t, item.EnqueueTime.After(before) || item.EnqueueTime.Equal(before)) assert.True(t, item.EnqueueTime.Before(after) || item.EnqueueTime.Equal(after)) } ================================================ FILE: agent/robot/pool/worker.go ================================================ package pool import ( "sync" "time" "github.com/yaoapp/yao/agent/robot/logger" "github.com/yaoapp/yao/agent/robot/types" ) var log = logger.New("pool") // Worker represents a worker goroutine that processes jobs type Worker struct { id int pool *Pool stopChan chan struct{} wg *sync.WaitGroup } // newWorker creates a new worker func newWorker(id int, pool *Pool, wg *sync.WaitGroup) *Worker { return &Worker{ id: id, pool: pool, stopChan: make(chan struct{}), wg: wg, } } // start starts the worker goroutine func (w *Worker) start() { w.wg.Add(1) go w.run() } // stop signals the worker to stop func (w *Worker) stop() { close(w.stopChan) } // run is the main worker loop func (w *Worker) run() { defer w.wg.Done() ticker := time.NewTicker(100 * time.Millisecond) // poll queue every 100ms defer ticker.Stop() for { select { case <-w.stopChan: return case <-ticker.C: // Try to get a job from the queue item := w.pool.queue.Dequeue() if item == nil { continue // queue empty, wait for next tick } // Execute the job w.execute(item) } } } // execute processes a single queue item func (w *Worker) execute(item *QueueItem) { // Pre-check if robot can run (non-atomic, just for early rejection) // The actual atomic check happens inside Executor.Execute() via TryAcquireSlot() if !item.Robot.CanRun() { // Robot likely at quota, re-enqueue for later w.requeue(item, "quota pre-check failed") return } // Mark as running (only when actually executing) w.pool.incrementRunning() defer w.pool.decrementRunning() // Get executor based on mode (uses factory if available, otherwise default) exec := w.pool.GetExecutor(item.ExecutorMode) // Execute via Executor interface with pre-generated ID and control // Note: Executor.ExecuteWithControl() does atomic quota check via TryAcquireSlot() // The control parameter allows executor to check pause state during execution execution, err := exec.ExecuteWithControl(item.Ctx, item.Robot, item.Trigger, item.Data, item.ExecID, item.Control) if err != nil { // Check if it's a quota error (race condition - another worker got the slot) if err == types.ErrQuotaExceeded { w.requeue(item, "quota exceeded (race)") return } // Suspended execution: state is persisted, worker slot released gracefully. // Do NOT call onComplete — the execution stays in robot.Executions and execController // so that Resume can find it later (§16.1). if err == types.ErrExecutionSuspended { if execution != nil { log.Info("Worker %d: Execution %s suspended for robot %s (waiting for input)", w.id, execution.ID, item.Robot.MemberID) } return } log.Error("Worker %d: Execution failed for robot %s: %v", w.id, item.Robot.MemberID, err) // Notify completion callback with appropriate status if w.pool.onComplete != nil { status := types.ExecFailed if err == types.ErrExecutionCancelled { status = types.ExecCancelled } w.pool.onComplete(item.ExecID, item.Robot.MemberID, status) } return } if execution != nil { log.Info("Worker %d: Execution %s completed for robot %s (status: %s)", w.id, execution.ID, item.Robot.MemberID, execution.Status) // Notify completion callback if w.pool.onComplete != nil { w.pool.onComplete(execution.ID, item.Robot.MemberID, execution.Status) } } } // requeue attempts to put the item back in the queue func (w *Worker) requeue(item *QueueItem, reason string) { // Queue length is our system load threshold: // - If queue has space: task waits for robot quota // - If queue is full: system is overloaded, drop task if !w.pool.queue.Enqueue(item) { log.Warn("Worker %d: Task for robot %s dropped (queue full, %s)", w.id, item.Robot.MemberID, reason) } } ================================================ FILE: agent/robot/pool/worker_test.go ================================================ package pool_test import ( "sync" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/robot/executor" "github.com/yaoapp/yao/agent/robot/pool" "github.com/yaoapp/yao/agent/robot/types" ) // ==================== Worker Basic Tests ==================== // TestWorkerExecutesJob tests that worker executes a job from queue func TestWorkerExecutesJob(t *testing.T) { exec := executor.NewDryRunWithDelay(10 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 1, QueueSize: 10, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() robot := createTestRobot("robot_1", "team_1", 5, 10, 5) // Submit job p.Submit(ctx, robot, types.TriggerClock, nil) // Wait for execution time.Sleep(200 * time.Millisecond) assert.Equal(t, 1, exec.ExecCount()) } // TestWorkerMultipleJobs tests worker processes multiple jobs sequentially func TestWorkerMultipleJobs(t *testing.T) { exec := executor.NewDryRunWithDelay(20 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 1, // single worker QueueSize: 10, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() robot := createTestRobot("robot_1", "team_1", 10, 10, 5) // Submit 3 jobs for i := 0; i < 3; i++ { p.Submit(ctx, robot, types.TriggerClock, nil) } // Wait for all executions (worker polls every 100ms, each job takes 20ms) // Use Eventually for CI timing variations assert.Eventually(t, func() bool { return exec.ExecCount() >= 3 }, 1*time.Second, 50*time.Millisecond, "All 3 jobs should complete") } // ==================== Worker Quota Check Tests ==================== // TestWorkerRespectsRobotQuota tests worker re-enqueues when robot quota is full func TestWorkerRespectsRobotQuota(t *testing.T) { // This test verifies that all jobs eventually complete even when robot quota limits concurrency exec := executor.NewDryRunWithDelay(100 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 5, // multiple workers QueueSize: 20, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() // Robot can only run 2 at a time robot := createTestRobot("robot_limited", "team_1", 2, 10, 5) // Submit 5 jobs for same robot for i := 0; i < 5; i++ { p.Submit(ctx, robot, types.TriggerClock, nil) } // Wait for all to complete // With Quota.Max=2, jobs execute in batches: 2+2+1 = 3 batches // Each batch: 100ms exec + 100ms poll = ~200ms, total ~600ms, add buffer time.Sleep(1000 * time.Millisecond) // All should eventually execute assert.GreaterOrEqual(t, exec.ExecCount(), 5, "All jobs should eventually execute") } // TestWorkerReenqueueOnQuotaFull tests that jobs are re-enqueued when quota is full func TestWorkerReenqueueOnQuotaFull(t *testing.T) { exec := executor.NewDryRunWithDelay(100 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 3, QueueSize: 100, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() // Robot can only run 1 at a time, but large queue robot := createTestRobot("robot_1", "team_1", 1, 50, 5) // Submit 5 jobs for i := 0; i < 5; i++ { p.Submit(ctx, robot, types.TriggerClock, nil) } // Wait for all to complete // With Quota.Max=1, jobs execute sequentially: 5 * (100ms exec + 100ms poll) = ~1000ms // Use Eventually for CI timing variations assert.Eventually(t, func() bool { return exec.ExecCount() >= 5 }, 2*time.Second, 100*time.Millisecond, "All 5 jobs should complete") } // ==================== Worker Concurrency Tests ==================== // TestWorkersConcurrentExecution tests multiple workers execute concurrently func TestWorkersConcurrentExecution(t *testing.T) { // Track max concurrent executions var maxConcurrent int32 var currentConcurrent int32 exec := executor.NewDryRunWithCallbacks(100*time.Millisecond, func() { current := atomic.AddInt32(¤tConcurrent, 1) // Update max if current is higher for { max := atomic.LoadInt32(&maxConcurrent) if current <= max || atomic.CompareAndSwapInt32(&maxConcurrent, max, current) { break } } }, func() { atomic.AddInt32(¤tConcurrent, -1) }) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 5, // 5 workers QueueSize: 100, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() // Submit 10 jobs for different robots for i := 0; i < 10; i++ { robot := createTestRobot("robot_"+string(rune('A'+i)), "team_1", 5, 10, 5) p.Submit(ctx, robot, types.TriggerClock, nil) } // Wait for execution time.Sleep(400 * time.Millisecond) // Should have had concurrent execution (max > 1) assert.GreaterOrEqual(t, atomic.LoadInt32(&maxConcurrent), int32(2), "Should have concurrent execution") } // TestWorkersDoNotExceedPoolSize tests workers don't exceed pool size func TestWorkersDoNotExceedPoolSize(t *testing.T) { var maxConcurrent int32 var currentConcurrent int32 var mu sync.Mutex exec := executor.NewDryRunWithCallbacks(50*time.Millisecond, func() { mu.Lock() currentConcurrent++ if currentConcurrent > maxConcurrent { maxConcurrent = currentConcurrent } mu.Unlock() }, func() { mu.Lock() currentConcurrent-- mu.Unlock() }) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 3, // only 3 workers QueueSize: 100, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() // Submit 20 jobs for i := 0; i < 20; i++ { robot := createTestRobot("robot_"+string(rune('A'+i)), "team_1", 5, 10, 5) p.Submit(ctx, robot, types.TriggerClock, nil) } // Wait for all to complete time.Sleep(500 * time.Millisecond) // Max concurrent should not exceed worker size assert.LessOrEqual(t, maxConcurrent, int32(3), "Should not exceed worker size") } // ==================== Worker Stop Tests ==================== // TestWorkerStopsGracefully tests worker stops when signaled func TestWorkerStopsGracefully(t *testing.T) { exec := executor.NewDryRunWithDelay(50 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 2, QueueSize: 10, }) p.SetExecutor(exec) p.Start() ctx := createTestContext() robot := createTestRobot("robot_1", "team_1", 5, 10, 5) // Submit jobs p.Submit(ctx, robot, types.TriggerClock, nil) p.Submit(ctx, robot, types.TriggerClock, nil) // Wait for jobs to start time.Sleep(150 * time.Millisecond) // Stop pool err := p.Stop() assert.NoError(t, err) // Pool should be stopped assert.False(t, p.IsStarted()) } // TestWorkerCompletesCurrentJobOnStop tests worker completes current job before stopping func TestWorkerCompletesCurrentJobOnStop(t *testing.T) { exec := executor.NewDryRunWithDelay(100 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 1, QueueSize: 10, }) p.SetExecutor(exec) p.Start() ctx := createTestContext() robot := createTestRobot("robot_1", "team_1", 5, 10, 5) // Submit job p.Submit(ctx, robot, types.TriggerClock, nil) // Wait for job to start time.Sleep(150 * time.Millisecond) // Stop pool - should wait for current job p.Stop() // Job should have completed assert.GreaterOrEqual(t, exec.ExecCount(), 1) } // ==================== Worker Error Handling Tests ==================== // TestWorkerHandlesExecutorError tests worker continues after executor error func TestWorkerHandlesExecutorError(t *testing.T) { exec := executor.NewDryRunWithDelay(10 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 1, QueueSize: 10, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() robot := createTestRobot("robot_1", "team_1", 5, 10, 5) // Submit job that will fail (using special data) p.Submit(ctx, robot, types.TriggerClock, "simulate_failure") // Submit another job that should succeed p.Submit(ctx, robot, types.TriggerClock, nil) // Wait for execution time.Sleep(300 * time.Millisecond) // Both should have been attempted assert.GreaterOrEqual(t, exec.ExecCount(), 2) } // ==================== Worker Running Counter Tests ==================== // TestWorkerRunningCounterAccurate tests running counter is accurate func TestWorkerRunningCounterAccurate(t *testing.T) { exec := executor.NewDryRunWithDelay(100 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 3, QueueSize: 10, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() // Submit jobs for different robots for i := 0; i < 3; i++ { robot := createTestRobot("robot_"+string(rune('A'+i)), "team_1", 5, 10, 5) p.Submit(ctx, robot, types.TriggerClock, nil) } // Wait for jobs to start time.Sleep(200 * time.Millisecond) // Running should be > 0 while jobs are executing // Note: On fast CI, jobs may already be done, so we just verify it doesn't panic // Wait for completion and verify running counter returns to 0 assert.Eventually(t, func() bool { return p.Running() == 0 }, 1*time.Second, 50*time.Millisecond, "Running should be 0 after all jobs complete") } // TestWorkerRunningCounterDecrementsOnError tests running counter decrements on error func TestWorkerRunningCounterDecrementsOnError(t *testing.T) { exec := executor.NewDryRunWithDelay(10 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 1, QueueSize: 10, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() robot := createTestRobot("robot_1", "team_1", 5, 10, 5) // Submit failing job p.Submit(ctx, robot, types.TriggerClock, "simulate_failure") // Wait for execution time.Sleep(200 * time.Millisecond) // Running should be 0 (decremented even on error) assert.Equal(t, 0, p.Running()) } // ==================== Worker with Different Trigger Types ==================== // TestWorkerProcessesDifferentTriggers tests worker handles all trigger types func TestWorkerProcessesDifferentTriggers(t *testing.T) { exec := executor.NewDryRunWithDelay(10 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 1, QueueSize: 10, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() robot := createTestRobot("robot_1", "team_1", 5, 10, 5) // Submit different trigger types p.Submit(ctx, robot, types.TriggerClock, nil) p.Submit(ctx, robot, types.TriggerHuman, nil) p.Submit(ctx, robot, types.TriggerEvent, nil) // Wait for execution (worker polls every 100ms, each job takes 10ms) // Use Eventually for CI timing variations assert.Eventually(t, func() bool { return exec.ExecCount() >= 3 }, 1*time.Second, 50*time.Millisecond, "All 3 trigger types should execute") } // ==================== Worker Polling Behavior Tests ==================== // TestWorkerPollsQueuePeriodically tests worker polls queue at regular intervals func TestWorkerPollsQueuePeriodically(t *testing.T) { exec := executor.NewDryRunWithDelay(10 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 1, QueueSize: 10, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() robot := createTestRobot("robot_1", "team_1", 5, 10, 5) // Submit job after pool started time.Sleep(50 * time.Millisecond) p.Submit(ctx, robot, types.TriggerClock, nil) // Worker should pick up job within poll interval (100ms) time.Sleep(200 * time.Millisecond) assert.Equal(t, 1, exec.ExecCount()) } // TestWorkerContinuesAfterEmptyQueue tests worker continues polling after empty queue func TestWorkerContinuesAfterEmptyQueue(t *testing.T) { exec := executor.NewDryRunWithDelay(10 * time.Millisecond) p := pool.NewWithConfig(&pool.Config{ WorkerSize: 1, QueueSize: 10, }) p.SetExecutor(exec) p.Start() defer p.Stop() ctx := createTestContext() robot := createTestRobot("robot_1", "team_1", 5, 10, 5) // Wait with empty queue time.Sleep(200 * time.Millisecond) // Submit job p.Submit(ctx, robot, types.TriggerClock, nil) // Worker should still be running and pick up job time.Sleep(200 * time.Millisecond) assert.Equal(t, 1, exec.ExecCount()) } ================================================ FILE: agent/robot/process.go ================================================ package robot import ( "context" "github.com/yaoapp/gou/process" "github.com/yaoapp/kun/exception" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/robot/api" "github.com/yaoapp/yao/agent/robot/types" ) func init() { process.RegisterGroup("robot", map[string]process.Handler{ "get": processGet, "list": processList, "status": processStatus, "executions": processExecutions, "execution": processExecution, "updateChatTitle": processUpdateChatTitle, }) } // processGet handles robot.Get(memberID). // args[0]: memberID string func processGet(p *process.Process) interface{} { p.ValidateArgNums(1) memberID := p.ArgsString(0) ctx := types.NewContext(context.Background(), nil) result, err := api.GetRobotResponse(ctx, memberID) if err != nil { exception.New(err.Error(), 500).Throw() } return result } // processList handles robot.List(filter?). // args[0]: optional filter map with page, pagesize, status, search (keywords) func processList(p *process.Process) interface{} { p.ValidateArgNums(0) ctx := types.NewContext(context.Background(), nil) filter := &api.ListQuery{} if p.NumOfArgs() > 0 { raw := p.ArgsMap(0) if v, ok := raw["page"]; ok { filter.Page = toInt(v) } if v, ok := raw["pagesize"]; ok { filter.PageSize = toInt(v) } if v, ok := raw["status"]; ok { filter.Status = types.RobotStatus(toString(v)) } if v, ok := raw["search"]; ok { filter.Keywords = toString(v) } if v, ok := raw["team_id"]; ok { filter.TeamID = toString(v) } } result, err := api.ListRobots(ctx, filter) if err != nil { exception.New(err.Error(), 500).Throw() } return result } // processStatus handles robot.Status(memberID). // args[0]: memberID string func processStatus(p *process.Process) interface{} { p.ValidateArgNums(1) memberID := p.ArgsString(0) ctx := types.NewContext(context.Background(), nil) result, err := api.GetRobotStatus(ctx, memberID) if err != nil { exception.New(err.Error(), 500).Throw() } return result } // processExecutions handles robot.Executions(memberID, filter?). // args[0]: memberID string; args[1]: optional filter map func processExecutions(p *process.Process) interface{} { p.ValidateArgNums(1) memberID := p.ArgsString(0) ctx := types.NewContext(context.Background(), nil) filter := &api.ExecutionQuery{} if p.NumOfArgs() > 1 { raw := p.ArgsMap(1) if v, ok := raw["page"]; ok { filter.Page = toInt(v) } if v, ok := raw["pagesize"]; ok { filter.PageSize = toInt(v) } if v, ok := raw["status"]; ok { filter.Status = types.ExecStatus(toString(v)) } if v, ok := raw["trigger"]; ok { filter.Trigger = types.TriggerType(toString(v)) } } result, err := api.ListExecutions(ctx, memberID, filter) if err != nil { exception.New(err.Error(), 500).Throw() } return result } // processExecution handles robot.Execution(memberID, executionID). // args[0]: memberID string; args[1]: executionID string func processExecution(p *process.Process) interface{} { p.ValidateArgNums(2) memberID := p.ArgsString(0) executionID := p.ArgsString(1) ctx := types.NewContext(context.Background(), nil) result, err := api.GetExecutionStatus(ctx, executionID) if err != nil { exception.New(err.Error(), 500).Throw() } _ = memberID // reserved for future permission scoping return result } // processUpdateChatTitle handles robot.UpdateChatTitle(chatID, title). // args[0]: chatID string; args[1]: title string func processUpdateChatTitle(p *process.Process) interface{} { p.ValidateArgNums(2) chatID := p.ArgsString(0) title := p.ArgsString(1) chatStore := assistant.GetChatStore() if chatStore == nil { exception.New("chat store not available", 500).Throw() } if err := chatStore.UpdateChat(chatID, map[string]interface{}{"title": title}); err != nil { exception.New(err.Error(), 500).Throw() } return nil } func toInt(v interface{}) int { switch n := v.(type) { case int: return n case int64: return int(n) case float64: return int(n) default: return 0 } } func toString(v interface{}) string { if s, ok := v.(string); ok { return s } return "" } ================================================ FILE: agent/robot/process_test.go ================================================ package robot_test import ( "context" "fmt" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/process" "github.com/yaoapp/yao/agent/assistant" storetypes "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/agent/testutils" // Register robot process handlers via init() _ "github.com/yaoapp/yao/agent/robot" ) // ============================================================================ // robot.UpdateChatTitle // ============================================================================ func TestProcessUpdateChatTitle(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) chatStore := assistant.GetChatStore() if chatStore == nil { t.Skip("Chat store not configured, skipping UpdateChatTitle tests") } t.Run("UpdatesTitle", func(t *testing.T) { chatID := fmt.Sprintf("robot_test_proc_%s", uuid.New().String()[:8]) // Create a chat record first err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: "robot.host", Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer chatStore.DeleteChat(chatID) title := "Create a mecha image, sci-fi style" p := process.New("robot.UpdateChatTitle", chatID, title) _, err = p.Exec() require.NoError(t, err) chat, err := chatStore.GetChat(chatID) require.NoError(t, err) assert.Equal(t, title, chat.Title, "Title should be updated to the confirmed goals") t.Logf("✓ robot.UpdateChatTitle: chat_id=%s, title=%q", chatID, chat.Title) }) t.Run("UpdatesLongGoalsTitle", func(t *testing.T) { chatID := fmt.Sprintf("robot_test_long_%s", uuid.New().String()[:8]) err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: "robot.host", Status: "active", Share: "private", CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer chatStore.DeleteChat(chatID) // Long goals string — title field should accommodate it title := "请帮我制作一张充满未来感的机甲图片,风格参考《攻壳机动队》,以赛博朋克城市为背景,色调偏冷,蓝紫配色" p := process.New("robot.UpdateChatTitle", chatID, title) _, err = p.Exec() require.NoError(t, err) chat, err := chatStore.GetChat(chatID) require.NoError(t, err) assert.Equal(t, title, chat.Title) t.Logf("✓ Long goals title persisted: %d chars", len(title)) }) t.Run("ErrorOnNonExistentChat", func(t *testing.T) { p := process.New("robot.UpdateChatTitle", "non_existent_chat_id", "some title") _, err := p.Exec() assert.Error(t, err, "Should error when chat does not exist") t.Logf("✓ Non-existent chat correctly returns error") }) } // ============================================================================ // robot.Get // ============================================================================ func TestProcessGet(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) t.Run("ErrorOnNotFound", func(t *testing.T) { p := process.New("robot.Get", "non_existent_robot_member_id") _, err := p.Exec() assert.Error(t, err, "Should error for non-existent robot") t.Logf("✓ robot.Get returns error for non-existent robot") }) } // ============================================================================ // robot.List // ============================================================================ func TestProcessList(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) t.Run("ReturnsListWithNoFilter", func(t *testing.T) { p := process.New("robot.List") result, err := p.Exec() require.NoError(t, err) // Result is a paginated list — just assert it's not nil assert.NotNil(t, result) t.Logf("✓ robot.List returned: %T", result) }) t.Run("ReturnsListWithPageFilter", func(t *testing.T) { p := process.New("robot.List", map[string]interface{}{ "page": 1, "pagesize": 5, }) result, err := p.Exec() require.NoError(t, err) assert.NotNil(t, result) t.Logf("✓ robot.List with page filter returned: %T", result) }) } // ============================================================================ // robot.Status // ============================================================================ func TestProcessStatus(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) t.Run("ErrorOnNotFound", func(t *testing.T) { p := process.New("robot.Status", "non_existent_robot_member_id") _, err := p.Exec() assert.Error(t, err, "Should error for non-existent robot") t.Logf("✓ robot.Status returns error for non-existent robot") }) } // ============================================================================ // robot.Executions // ============================================================================ func TestProcessExecutions(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) t.Run("ReturnsEmptyForUnknownRobot", func(t *testing.T) { memberID := fmt.Sprintf("proc_exec_test_%d", time.Now().UnixNano()) p := process.New("robot.Executions", memberID) result, err := p.Exec() // May error or return empty — both acceptable if err == nil { assert.NotNil(t, result) } t.Logf("✓ robot.Executions handled for unknown robot") }) t.Run("AcceptsFilterMap", func(t *testing.T) { memberID := fmt.Sprintf("proc_exec_filter_%d", time.Now().UnixNano()) p := process.New("robot.Executions", memberID, map[string]interface{}{ "page": 1, "pagesize": 10, "status": "completed", }) result, err := p.Exec() if err == nil { assert.NotNil(t, result) } t.Logf("✓ robot.Executions with filter handled") }) } // ============================================================================ // robot.Execution // ============================================================================ func TestProcessExecution(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) t.Run("ErrorOnNonExistentExecution", func(t *testing.T) { p := process.New("robot.Execution", "some_member_id", "non_existent_exec_id") _, err := p.Exec() assert.Error(t, err, "Should error for non-existent execution") t.Logf("✓ robot.Execution returns error for non-existent execution") }) } // ============================================================================ // Argument Validation // ============================================================================ func TestProcessArgumentValidation(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) t.Run("UpdateChatTitle_RequiresTwoArgs", func(t *testing.T) { // Missing title argument p := process.New("robot.UpdateChatTitle", "some_chat_id") _, err := p.Exec() assert.Error(t, err, "Should require 2 arguments") }) t.Run("Get_RequiresOneArg", func(t *testing.T) { p := process.New("robot.Get") _, err := p.Exec() assert.Error(t, err, "Should require 1 argument") }) t.Run("Status_RequiresOneArg", func(t *testing.T) { p := process.New("robot.Status") _, err := p.Exec() assert.Error(t, err, "Should require 1 argument") }) t.Run("Execution_RequiresTwoArgs", func(t *testing.T) { p := process.New("robot.Execution", "only_one_arg") _, err := p.Exec() assert.Error(t, err, "Should require 2 arguments") }) } // ============================================================================ // Integration: UpdateChatTitle flow (simulate Host Agent Next Hook) // ============================================================================ func TestProcessUpdateChatTitleIntegration(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) chatStore := assistant.GetChatStore() if chatStore == nil { t.Skip("Chat store not configured") } t.Run("SimulatesHostAgentNextHook", func(t *testing.T) { // Simulate the full flow: // 1. ChatDrawer creates a chat with robot_id in metadata // 2. Host Agent Next Hook calls robot.UpdateChatTitle with confirmed goals // 3. History dropdown shows the goals as the chat title memberID := "120004485525" chatID := fmt.Sprintf("robot_%s_%d", memberID, time.Now().UnixMilli()) confirmedGoals := "制作一张机甲图片,风格和设计由AI自主决定" // Step 1: Create chat (simulating AssignTaskDrawer) err := chatStore.CreateChat(&storetypes.Chat{ ChatID: chatID, AssistantID: "yao.robot-host", Status: "active", Share: "private", Metadata: map[string]interface{}{ "robot_id": memberID, }, CreatedAt: time.Now(), UpdatedAt: time.Now(), }) require.NoError(t, err) defer chatStore.DeleteChat(chatID) // Step 2: Next Hook calls robot.UpdateChatTitle p := process.New("robot.UpdateChatTitle", chatID, confirmedGoals) _, err = p.Exec() require.NoError(t, err) // Step 3: Verify title is set (history dropdown will display this) chat, err := chatStore.GetChat(chatID) require.NoError(t, err) assert.Equal(t, confirmedGoals, chat.Title, "History dropdown should show confirmed goals as title") require.NotNil(t, chat.Metadata) assert.Equal(t, memberID, chat.Metadata["robot_id"], "Metadata robot_id should be preserved after title update") t.Logf("✓ Full Host Agent flow: chat_id=%s, title=%q, robot_id=%v", chatID, chat.Title, chat.Metadata["robot_id"]) }) } // ensure context is used (avoid unused import) var _ = context.Background ================================================ FILE: agent/robot/robot.go ================================================ package robot import ( "context" "github.com/yaoapp/yao/agent/robot/cache" "github.com/yaoapp/yao/agent/robot/dedup" "github.com/yaoapp/yao/agent/robot/events/integrations" "github.com/yaoapp/yao/agent/robot/events/integrations/telegram" "github.com/yaoapp/yao/agent/robot/executor" "github.com/yaoapp/yao/agent/robot/logger" "github.com/yaoapp/yao/agent/robot/manager" "github.com/yaoapp/yao/agent/robot/plan" "github.com/yaoapp/yao/agent/robot/pool" "github.com/yaoapp/yao/agent/robot/store" robottypes "github.com/yaoapp/yao/agent/robot/types" ) var ( log = logger.New("robot") globalManager *manager.Manager globalCache *cache.Cache globalPool *pool.Pool globalDedup *dedup.Dedup globalStore *store.Store globalExecutor executor.Executor globalPlan *plan.Plan globalDispatcher *integrations.Dispatcher ) // Init initializes the robot agent system func Init() error { globalCache = cache.New() globalDedup = dedup.New() globalStore = store.New() globalPool = pool.New() globalExecutor = executor.New() globalManager = manager.New() globalPlan = plan.New() // Load robots into cache from database before starting dispatcher rCtx := robottypes.NewContext(context.Background(), nil) if err := globalCache.Load(rCtx); err != nil { log.Warn("robot.Init: cache load failed (will rely on config events): %v", err) } adapters := map[string]integrations.Adapter{ "telegram": telegram.NewAdapter(), } globalDispatcher = integrations.NewDispatcher(globalCache, adapters) if err := globalDispatcher.Start(context.Background()); err != nil { return err } return nil } // Shutdown gracefully shuts down the robot agent system func Shutdown() error { if globalDispatcher != nil { globalDispatcher.Stop() } return nil } // Manager returns the global manager instance func Manager() *manager.Manager { return globalManager } ================================================ FILE: agent/robot/store/execution.go ================================================ package store import ( "context" "encoding/json" "fmt" "time" "github.com/yaoapp/gou/model" "github.com/yaoapp/yao/agent/robot/types" ) // ExecutionRecord - persistent storage for robot execution history // Maps to __yao.agent_execution model type ExecutionRecord struct { ID int64 `json:"id,omitempty"` // Auto-increment primary key ExecutionID string `json:"execution_id"` // Unique execution identifier MemberID string `json:"member_id"` // Robot member ID (globally unique) TeamID string `json:"team_id"` // Team ID TriggerType types.TriggerType `json:"trigger_type"` // clock | human | event // Status tracking (synced with runtime Execution) Status types.ExecStatus `json:"status"` // pending | running | completed | failed | cancelled Phase types.Phase `json:"phase"` // Current phase Current *CurrentState `json:"current,omitempty"` Error string `json:"error,omitempty"` // UI display fields (updated by executor at each phase) Name string `json:"name,omitempty"` // Execution title CurrentTaskName string `json:"current_task_name,omitempty"` // Current task description // Trigger input Input *types.TriggerInput `json:"input,omitempty"` // Phase outputs (P0-P5) Inspiration *types.InspirationReport `json:"inspiration,omitempty"` Goals *types.Goals `json:"goals,omitempty"` Tasks []types.Task `json:"tasks,omitempty"` Results []types.TaskResult `json:"results,omitempty"` Delivery *types.DeliveryResult `json:"delivery,omitempty"` Learning []types.LearningEntry `json:"learning,omitempty"` // V2: Conversation and suspend-resume fields ChatID string `json:"chat_id,omitempty"` WaitingTaskID string `json:"waiting_task_id,omitempty"` WaitingQuestion string `json:"waiting_question,omitempty"` WaitingSince *time.Time `json:"waiting_since,omitempty"` ResumeContext *types.ResumeContext `json:"resume_context,omitempty"` // Timestamps StartTime *time.Time `json:"start_time,omitempty"` EndTime *time.Time `json:"end_time,omitempty"` CreatedAt *time.Time `json:"created_at,omitempty"` UpdatedAt *time.Time `json:"updated_at,omitempty"` } // CurrentState - current executing state (for JSON storage) type CurrentState struct { TaskIndex int `json:"task_index"` // index in Tasks slice Progress string `json:"progress,omitempty"` // human-readable progress (e.g., "2/5 tasks") } // ListOptions - options for listing execution records type ListOptions struct { MemberID string `json:"member_id,omitempty"` TeamID string `json:"team_id,omitempty"` Status types.ExecStatus `json:"status,omitempty"` ExcludeStatuses []types.ExecStatus `json:"exclude_statuses,omitempty"` TriggerType types.TriggerType `json:"trigger_type,omitempty"` Page int `json:"page,omitempty"` PageSize int `json:"pagesize,omitempty"` OrderBy string `json:"order_by,omitempty"` } // ListResult wraps paginated list results type ListResult struct { Data []*ExecutionRecord Total int Page int PageSize int } // ExecutionStore - persistent storage for robot execution records type ExecutionStore struct { modelID string } // NewExecutionStore creates a new execution store instance func NewExecutionStore() *ExecutionStore { return &ExecutionStore{ modelID: "__yao.agent.execution", } } // Save creates or updates an execution record func (s *ExecutionStore) Save(ctx context.Context, record *ExecutionRecord) error { mod := model.Select(s.modelID) if mod == nil { return fmt.Errorf("model %s not found", s.modelID) } data := s.recordToMap(record) // Check if record exists by execution_id existing, err := s.Get(ctx, record.ExecutionID) if err == nil && existing != nil { // Update existing record _, err = mod.UpdateWhere( model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "execution_id", Value: record.ExecutionID}, }, }, data, ) if err != nil { return fmt.Errorf("failed to update execution record: %w", err) } return nil } // Create new record _, err = mod.Create(data) if err != nil { return fmt.Errorf("failed to create execution record: %w", err) } return nil } // Get retrieves an execution record by execution_id func (s *ExecutionStore) Get(ctx context.Context, executionID string) (*ExecutionRecord, error) { mod := model.Select(s.modelID) if mod == nil { return nil, fmt.Errorf("model %s not found", s.modelID) } rows, err := mod.Get(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "execution_id", Value: executionID}, }, Limit: 1, }) if err != nil { return nil, fmt.Errorf("failed to get execution record: %w", err) } if len(rows) == 0 { return nil, nil } return s.mapToRecord(rows[0]) } // List retrieves execution records with pagination using mod.Paginate func (s *ExecutionStore) List(ctx context.Context, opts *ListOptions) (*ListResult, error) { mod := model.Select(s.modelID) if mod == nil { return nil, fmt.Errorf("model %s not found", s.modelID) } params := model.QueryParam{} var wheres []model.QueryWhere page := 1 pageSize := 20 if opts != nil { if opts.MemberID != "" { wheres = append(wheres, model.QueryWhere{Column: "member_id", Value: opts.MemberID}) } if opts.TeamID != "" { wheres = append(wheres, model.QueryWhere{Column: "team_id", Value: opts.TeamID}) } if opts.Status != "" { wheres = append(wheres, model.QueryWhere{Column: "status", Value: string(opts.Status)}) } for _, es := range opts.ExcludeStatuses { wheres = append(wheres, model.QueryWhere{Column: "status", Value: string(es), OP: "ne"}) } if opts.TriggerType != "" { wheres = append(wheres, model.QueryWhere{Column: "trigger_type", Value: string(opts.TriggerType)}) } if opts.Page > 0 { page = opts.Page } if opts.PageSize > 0 { pageSize = opts.PageSize if pageSize > 100 { pageSize = 100 } } if opts.OrderBy != "" { parts := splitOrderBy(opts.OrderBy) params.Orders = []model.QueryOrder{{Column: parts[0], Option: parts[1]}} } else { params.Orders = []model.QueryOrder{{Column: "start_time", Option: "desc"}} } } else { params.Orders = []model.QueryOrder{{Column: "start_time", Option: "desc"}} } params.Wheres = wheres res, err := mod.Paginate(params, page, pageSize) if err != nil { return nil, fmt.Errorf("failed to list execution records: %w", err) } total := 0 if v, ok := res["total"].(int64); ok { total = int(v) } else if v, ok := res["total"].(int); ok { total = v } records := make([]*ExecutionRecord, 0) for _, row := range toRows(res["data"]) { record, err := s.mapToRecord(row) if err != nil { continue } records = append(records, record) } return &ListResult{ Data: records, Total: total, Page: page, PageSize: pageSize, }, nil } // UpdatePhase updates the current phase and its data func (s *ExecutionStore) UpdatePhase(ctx context.Context, executionID string, phase types.Phase, data interface{}) error { mod := model.Select(s.modelID) if mod == nil { return fmt.Errorf("model %s not found", s.modelID) } updateData := map[string]interface{}{ "phase": string(phase), } // Set the appropriate phase output field switch phase { case types.PhaseInspiration: if data != nil { updateData["inspiration"] = data } case types.PhaseGoals: if data != nil { updateData["goals"] = data } case types.PhaseTasks: if data != nil { updateData["tasks"] = data } case types.PhaseRun: if data != nil { updateData["results"] = data } case types.PhaseDelivery: if data != nil { updateData["delivery"] = data } case types.PhaseLearning: if data != nil { updateData["learning"] = data } } _, err := mod.UpdateWhere( model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "execution_id", Value: executionID}, }, }, updateData, ) if err != nil { return fmt.Errorf("failed to update phase: %w", err) } return nil } // UpdateStatus updates the execution status func (s *ExecutionStore) UpdateStatus(ctx context.Context, executionID string, status types.ExecStatus, errorMsg string) error { mod := model.Select(s.modelID) if mod == nil { return fmt.Errorf("model %s not found", s.modelID) } updateData := map[string]interface{}{ "status": string(status), } if errorMsg != "" { updateData["error"] = errorMsg } // Set end_time for terminal states if status == types.ExecCompleted || status == types.ExecFailed || status == types.ExecCancelled { now := time.Now() updateData["end_time"] = now } _, err := mod.UpdateWhere( model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "execution_id", Value: executionID}, }, }, updateData, ) if err != nil { return fmt.Errorf("failed to update status: %w", err) } return nil } // UpdateCurrent updates the current executing state func (s *ExecutionStore) UpdateCurrent(ctx context.Context, executionID string, current *CurrentState) error { mod := model.Select(s.modelID) if mod == nil { return fmt.Errorf("model %s not found", s.modelID) } updateData := map[string]interface{}{ "current": current, } _, err := mod.UpdateWhere( model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "execution_id", Value: executionID}, }, }, updateData, ) if err != nil { return fmt.Errorf("failed to update current state: %w", err) } return nil } // UpdateTasks updates the tasks array with current status // This should be called after each task completes to persist status changes func (s *ExecutionStore) UpdateTasks(ctx context.Context, executionID string, tasks []types.Task, current *CurrentState) error { mod := model.Select(s.modelID) if mod == nil { return fmt.Errorf("model %s not found", s.modelID) } updateData := map[string]interface{}{ "tasks": tasks, "current": current, } _, err := mod.UpdateWhere( model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "execution_id", Value: executionID}, }, }, updateData, ) if err != nil { return fmt.Errorf("failed to update tasks: %w", err) } return nil } // UpdateUIFields updates the UI display fields (name and current_task_name) // These fields are updated by executor at each phase for frontend display func (s *ExecutionStore) UpdateUIFields(ctx context.Context, executionID string, name string, currentTaskName string) error { mod := model.Select(s.modelID) if mod == nil { return fmt.Errorf("model %s not found", s.modelID) } updateData := map[string]interface{}{} if name != "" { updateData["name"] = name } if currentTaskName != "" { updateData["current_task_name"] = currentTaskName } if len(updateData) == 0 { return nil // Nothing to update } _, err := mod.UpdateWhere( model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "execution_id", Value: executionID}, }, }, updateData, ) if err != nil { return fmt.Errorf("failed to update UI fields: %w", err) } return nil } // UpdateSuspendState atomically transitions an execution to waiting status // with all suspend-related fields in a single DB write. func (s *ExecutionStore) UpdateSuspendState(ctx context.Context, executionID string, waitingTaskID string, question string, resumeCtx *types.ResumeContext) error { mod := model.Select(s.modelID) if mod == nil { return fmt.Errorf("model %s not found", s.modelID) } now := time.Now() updateData := map[string]interface{}{ "status": string(types.ExecWaiting), "waiting_task_id": waitingTaskID, "waiting_question": question, "waiting_since": now, } if resumeCtx != nil { updateData["resume_context"] = resumeCtx } _, err := mod.UpdateWhere( model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "execution_id", Value: executionID}, }, }, updateData, ) if err != nil { return fmt.Errorf("failed to update suspend state: %w", err) } return nil } // UpdateResumeState clears waiting fields and transitions execution back to running. func (s *ExecutionStore) UpdateResumeState(ctx context.Context, executionID string) error { mod := model.Select(s.modelID) if mod == nil { return fmt.Errorf("model %s not found", s.modelID) } updateData := map[string]interface{}{ "status": string(types.ExecRunning), "waiting_task_id": "", "waiting_question": "", "waiting_since": nil, "resume_context": nil, } _, err := mod.UpdateWhere( model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "execution_id", Value: executionID}, }, }, updateData, ) if err != nil { return fmt.Errorf("failed to update resume state: %w", err) } return nil } // Delete removes an execution record by execution_id func (s *ExecutionStore) Delete(ctx context.Context, executionID string) error { mod := model.Select(s.modelID) if mod == nil { return fmt.Errorf("model %s not found", s.modelID) } _, err := mod.DeleteWhere(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "execution_id", Value: executionID}, }, }) if err != nil { return fmt.Errorf("failed to delete execution record: %w", err) } return nil } // recordToMap converts ExecutionRecord to map for model operations func (s *ExecutionStore) recordToMap(record *ExecutionRecord) map[string]interface{} { data := map[string]interface{}{ "execution_id": record.ExecutionID, "member_id": record.MemberID, "team_id": record.TeamID, "trigger_type": string(record.TriggerType), "status": string(record.Status), "phase": string(record.Phase), } if record.Error != "" { data["error"] = record.Error } if record.Name != "" { data["name"] = record.Name } if record.CurrentTaskName != "" { data["current_task_name"] = record.CurrentTaskName } if record.Current != nil { data["current"] = record.Current } if record.Input != nil { data["input"] = record.Input } if record.Inspiration != nil { data["inspiration"] = record.Inspiration } if record.Goals != nil { data["goals"] = record.Goals } if record.Tasks != nil { data["tasks"] = record.Tasks } if record.Results != nil { data["results"] = record.Results } if record.Delivery != nil { data["delivery"] = record.Delivery } if record.Learning != nil { data["learning"] = record.Learning } // V2 fields if record.ChatID != "" { data["chat_id"] = record.ChatID } if record.WaitingTaskID != "" { data["waiting_task_id"] = record.WaitingTaskID } if record.WaitingQuestion != "" { data["waiting_question"] = record.WaitingQuestion } if record.WaitingSince != nil { data["waiting_since"] = *record.WaitingSince } if record.ResumeContext != nil { data["resume_context"] = record.ResumeContext } if record.StartTime != nil { data["start_time"] = *record.StartTime } if record.EndTime != nil { data["end_time"] = *record.EndTime } return data } // mapToRecord converts a model row to ExecutionRecord func (s *ExecutionStore) mapToRecord(row map[string]interface{}) (*ExecutionRecord, error) { record := &ExecutionRecord{} // Basic fields if v, ok := row["id"]; ok { switch id := v.(type) { case float64: record.ID = int64(id) case int64: record.ID = id case int: record.ID = int64(id) } } if v, ok := row["execution_id"].(string); ok { record.ExecutionID = v } if v, ok := row["member_id"].(string); ok { record.MemberID = v } if v, ok := row["team_id"].(string); ok { record.TeamID = v } if v, ok := row["trigger_type"].(string); ok { record.TriggerType = types.TriggerType(v) } if v, ok := row["status"].(string); ok { record.Status = types.ExecStatus(v) } if v, ok := row["phase"].(string); ok { record.Phase = types.Phase(v) } if v, ok := row["error"].(string); ok { record.Error = v } if v, ok := row["name"].(string); ok { record.Name = v } if v, ok := row["current_task_name"].(string); ok { record.CurrentTaskName = v } // JSON fields - need to unmarshal if v := row["current"]; v != nil { record.Current = s.parseCurrentState(v) } if v := row["input"]; v != nil { record.Input = s.parseTriggerInput(v) } if v := row["inspiration"]; v != nil { record.Inspiration = s.parseInspirationReport(v) } if v := row["goals"]; v != nil { record.Goals = s.parseGoals(v) } if v := row["tasks"]; v != nil { record.Tasks = s.parseTasks(v) } if v := row["results"]; v != nil { record.Results = s.parseResults(v) } if v := row["delivery"]; v != nil { record.Delivery = s.parseDeliveryResult(v) } if v := row["learning"]; v != nil { record.Learning = s.parseLearningEntries(v) } // V2 fields if v, ok := row["chat_id"].(string); ok { record.ChatID = v } if v, ok := row["waiting_task_id"].(string); ok { record.WaitingTaskID = v } if v, ok := row["waiting_question"].(string); ok { record.WaitingQuestion = v } if v := row["waiting_since"]; v != nil { record.WaitingSince = s.parseTime(v) } if v := row["resume_context"]; v != nil { record.ResumeContext = s.parseResumeContext(v) } // Timestamps if v := row["start_time"]; v != nil { record.StartTime = s.parseTime(v) } if v := row["end_time"]; v != nil { record.EndTime = s.parseTime(v) } if v := row["created_at"]; v != nil { record.CreatedAt = s.parseTime(v) } if v := row["updated_at"]; v != nil { record.UpdatedAt = s.parseTime(v) } return record, nil } // Helper functions for parsing JSON fields func (s *ExecutionStore) parseCurrentState(v interface{}) *CurrentState { data, err := s.toJSON(v) if err != nil { return nil } var state CurrentState if err := json.Unmarshal(data, &state); err != nil { return nil } return &state } func (s *ExecutionStore) parseTriggerInput(v interface{}) *types.TriggerInput { data, err := s.toJSON(v) if err != nil { return nil } var input types.TriggerInput if err := json.Unmarshal(data, &input); err != nil { return nil } return &input } func (s *ExecutionStore) parseInspirationReport(v interface{}) *types.InspirationReport { data, err := s.toJSON(v) if err != nil { return nil } var report types.InspirationReport if err := json.Unmarshal(data, &report); err != nil { return nil } return &report } func (s *ExecutionStore) parseGoals(v interface{}) *types.Goals { data, err := s.toJSON(v) if err != nil { return nil } var goals types.Goals if err := json.Unmarshal(data, &goals); err != nil { return nil } return &goals } func (s *ExecutionStore) parseTasks(v interface{}) []types.Task { data, err := s.toJSON(v) if err != nil { return nil } var tasks []types.Task if err := json.Unmarshal(data, &tasks); err != nil { return nil } return tasks } func (s *ExecutionStore) parseResults(v interface{}) []types.TaskResult { data, err := s.toJSON(v) if err != nil { return nil } var results []types.TaskResult if err := json.Unmarshal(data, &results); err != nil { return nil } return results } func (s *ExecutionStore) parseDeliveryResult(v interface{}) *types.DeliveryResult { data, err := s.toJSON(v) if err != nil { return nil } var result types.DeliveryResult if err := json.Unmarshal(data, &result); err != nil { return nil } return &result } func (s *ExecutionStore) parseLearningEntries(v interface{}) []types.LearningEntry { data, err := s.toJSON(v) if err != nil { return nil } var entries []types.LearningEntry if err := json.Unmarshal(data, &entries); err != nil { return nil } return entries } func (s *ExecutionStore) parseResumeContext(v interface{}) *types.ResumeContext { data, err := s.toJSON(v) if err != nil { return nil } var ctx types.ResumeContext if err := json.Unmarshal(data, &ctx); err != nil { return nil } return &ctx } func (s *ExecutionStore) toJSON(v interface{}) ([]byte, error) { switch data := v.(type) { case []byte: return data, nil case string: return []byte(data), nil case map[string]interface{}, []interface{}: return json.Marshal(data) default: return json.Marshal(v) } } // splitOrderBy parses "column desc" or "column asc" or just "column" // Returns [column, option] where option defaults to "desc" // toRows converts Paginate result data to []map[string]interface{} // handles type aliases like maps.MapStrAny via JSON round-trip func toRows(data interface{}) []map[string]interface{} { if data == nil { return nil } raw, err := json.Marshal(data) if err != nil { return nil } var rows []map[string]interface{} if err := json.Unmarshal(raw, &rows); err != nil { return nil } return rows } func splitOrderBy(orderBy string) [2]string { parts := [2]string{"", "desc"} if orderBy == "" { return parts } // Split by space for i, c := range orderBy { if c == ' ' { parts[0] = orderBy[:i] rest := orderBy[i+1:] if rest == "asc" || rest == "ASC" { parts[1] = "asc" } else if rest == "desc" || rest == "DESC" { parts[1] = "desc" } return parts } } // No space found, just column name parts[0] = orderBy return parts } func (s *ExecutionStore) parseTime(v interface{}) *time.Time { switch t := v.(type) { case time.Time: return &t case *time.Time: return t case string: // Try parsing common time formats formats := []string{ time.RFC3339, time.RFC3339Nano, "2006-01-02 15:04:05", "2006-01-02T15:04:05Z", } for _, format := range formats { if parsed, err := time.Parse(format, t); err == nil { return &parsed } } } return nil } // ==================== Results & Activities ==================== // ResultListOptions - options for listing execution results (deliveries) type ResultListOptions struct { MemberID string `json:"member_id,omitempty"` TeamID string `json:"team_id,omitempty"` TriggerType types.TriggerType `json:"trigger_type,omitempty"` Keyword string `json:"keyword,omitempty"` Page int `json:"page,omitempty"` PageSize int `json:"pagesize,omitempty"` } // ResultListResponse - paginated result list response type ResultListResponse struct { Data []*ExecutionRecord `json:"data"` Total int `json:"total"` Page int `json:"page"` PageSize int `json:"pagesize"` } // ListResults retrieves completed executions with delivery content // Only returns executions where delivery.content is not null func (s *ExecutionStore) ListResults(ctx context.Context, opts *ResultListOptions) (*ResultListResponse, error) { mod := model.Select(s.modelID) if mod == nil { return nil, fmt.Errorf("model %s not found", s.modelID) } // Build where conditions var wheres []model.QueryWhere // Must have completed status and delivery content wheres = append(wheres, model.QueryWhere{Column: "status", Value: "completed"}) wheres = append(wheres, model.QueryWhere{Column: "delivery", OP: "notnull"}) if opts != nil { if opts.MemberID != "" { wheres = append(wheres, model.QueryWhere{Column: "member_id", Value: opts.MemberID}) } if opts.TeamID != "" { wheres = append(wheres, model.QueryWhere{Column: "team_id", Value: opts.TeamID}) } if opts.TriggerType != "" { wheres = append(wheres, model.QueryWhere{Column: "trigger_type", Value: string(opts.TriggerType)}) } // Keyword search in name field (delivery.content.summary is in JSON, harder to search) // For now search in the name field if opts.Keyword != "" { wheres = append(wheres, model.QueryWhere{Column: "name", OP: "like", Value: "%" + opts.Keyword + "%"}) } } page := 1 pageSize := 20 if opts != nil { if opts.Page > 0 { page = opts.Page } if opts.PageSize > 0 { pageSize = opts.PageSize if pageSize > 100 { pageSize = 100 } } } params := model.QueryParam{ Wheres: wheres, Orders: []model.QueryOrder{{Column: "end_time", Option: "desc"}}, } res, err := mod.Paginate(params, page, pageSize) if err != nil { return nil, fmt.Errorf("failed to list results: %w", err) } total := 0 if v, ok := res["total"].(int64); ok { total = int(v) } else if v, ok := res["total"].(int); ok { total = v } records := make([]*ExecutionRecord, 0) for _, row := range toRows(res["data"]) { record, err := s.mapToRecord(row) if err != nil { continue } if record.Delivery != nil && record.Delivery.Content != nil { records = append(records, record) } } return &ResultListResponse{ Data: records, Total: total, Page: page, PageSize: pageSize, }, nil } // CountResults counts total results matching criteria func (s *ExecutionStore) CountResults(ctx context.Context, opts *ResultListOptions) (int, error) { var wheres []model.QueryWhere // Must have completed status and delivery content wheres = append(wheres, model.QueryWhere{Column: "status", Value: "completed"}) wheres = append(wheres, model.QueryWhere{Column: "delivery", OP: "notnull"}) if opts != nil { if opts.MemberID != "" { wheres = append(wheres, model.QueryWhere{Column: "member_id", Value: opts.MemberID}) } if opts.TeamID != "" { wheres = append(wheres, model.QueryWhere{Column: "team_id", Value: opts.TeamID}) } if opts.TriggerType != "" { wheres = append(wheres, model.QueryWhere{Column: "trigger_type", Value: string(opts.TriggerType)}) } if opts.Keyword != "" { wheres = append(wheres, model.QueryWhere{Column: "name", OP: "like", Value: "%" + opts.Keyword + "%"}) } } return s.countWithWheres(wheres) } // countWithWheres counts records matching the given where conditions func (s *ExecutionStore) countWithWheres(wheres []model.QueryWhere) (int, error) { mod := model.Select(s.modelID) if mod == nil { return 0, fmt.Errorf("model %s not found", s.modelID) } // Use model Paginate to get total count params := model.QueryParam{ Wheres: wheres, Limit: 1, } result, err := mod.Paginate(params, 1, 1) if err != nil { return 0, fmt.Errorf("failed to count records: %w", err) } // Paginate returns map with total field if result == nil { return 0, nil } total := 0 if t, ok := result["total"]; ok { switch v := t.(type) { case float64: total = int(v) case int64: total = int(v) case int: total = v } } return total, nil } // ActivityType represents the type of activity type ActivityType string const ( ActivityExecutionStarted ActivityType = "execution.started" ActivityExecutionCompleted ActivityType = "execution.completed" ActivityExecutionFailed ActivityType = "execution.failed" ActivityExecutionCancelled ActivityType = "execution.cancelled" ) // Activity represents a robot activity entry type Activity struct { Type ActivityType `json:"type"` RobotID string `json:"robot_id"` RobotName string `json:"robot_name,omitempty"` // Will be populated by API layer ExecutionID string `json:"execution_id"` Message string `json:"message"` Timestamp time.Time `json:"timestamp"` } // ActivityListOptions - options for listing activities type ActivityListOptions struct { TeamID string `json:"team_id,omitempty"` // Filter by team ID Since *time.Time `json:"since,omitempty"` // Only activities after this time Limit int `json:"limit,omitempty"` Type ActivityType `json:"type,omitempty"` // Filter by activity type } // ListActivities derives activities from recent execution status changes func (s *ExecutionStore) ListActivities(ctx context.Context, opts *ActivityListOptions) ([]*Activity, error) { mod := model.Select(s.modelID) if mod == nil { return nil, fmt.Errorf("model %s not found", s.modelID) } // Build where conditions var wheres []model.QueryWhere // Filter by activity type if specified // Map activity types to execution statuses if opts != nil && opts.Type != "" { switch opts.Type { case ActivityExecutionStarted: wheres = append(wheres, model.QueryWhere{Column: "status", Value: "running"}) case ActivityExecutionCompleted: wheres = append(wheres, model.QueryWhere{Column: "status", Value: "completed"}) case ActivityExecutionFailed: wheres = append(wheres, model.QueryWhere{Column: "status", Value: "failed"}) case ActivityExecutionCancelled: wheres = append(wheres, model.QueryWhere{Column: "status", Value: "cancelled"}) default: // Unknown type, return empty return []*Activity{}, nil } } else { // Only completed, failed, or cancelled executions generate activities // For started activities, we'd need running status wheres = append(wheres, model.QueryWhere{ Column: "status", OP: "in", Value: []string{"completed", "failed", "cancelled", "running"}, }) } if opts != nil { if opts.TeamID != "" { wheres = append(wheres, model.QueryWhere{Column: "team_id", Value: opts.TeamID}) } if opts.Since != nil { // Get executions that ended or started after 'since' wheres = append(wheres, model.QueryWhere{Column: "updated_at", OP: ">=", Value: *opts.Since}) } } limit := 20 if opts != nil && opts.Limit > 0 { limit = opts.Limit if limit > 100 { limit = 100 } } params := model.QueryParam{ Wheres: wheres, Limit: limit, Orders: []model.QueryOrder{{Column: "updated_at", Option: "desc"}}, } rows, err := mod.Get(params) if err != nil { return nil, fmt.Errorf("failed to list activities: %w", err) } activities := make([]*Activity, 0, len(rows)) for _, row := range rows { record, err := s.mapToRecord(row) if err != nil { continue } activity := s.executionToActivity(record) if activity != nil { activities = append(activities, activity) } } return activities, nil } // executionToActivity converts an execution record to an activity func (s *ExecutionStore) executionToActivity(record *ExecutionRecord) *Activity { var actType ActivityType var message string var timestamp time.Time switch record.Status { case types.ExecRunning: actType = ActivityExecutionStarted message = "Started" if record.StartTime != nil { timestamp = *record.StartTime } else { timestamp = time.Now() } case types.ExecCompleted: actType = ActivityExecutionCompleted message = "Completed" if record.EndTime != nil { timestamp = *record.EndTime } else if record.UpdatedAt != nil { timestamp = *record.UpdatedAt } else { timestamp = time.Now() } case types.ExecFailed: actType = ActivityExecutionFailed message = "Failed" if record.Error != "" { message = "Failed: " + record.Error // Truncate long error messages if len(message) > 100 { message = message[:97] + "..." } } if record.EndTime != nil { timestamp = *record.EndTime } else if record.UpdatedAt != nil { timestamp = *record.UpdatedAt } else { timestamp = time.Now() } case types.ExecCancelled: actType = ActivityExecutionCancelled message = "Cancelled" if record.EndTime != nil { timestamp = *record.EndTime } else if record.UpdatedAt != nil { timestamp = *record.UpdatedAt } else { timestamp = time.Now() } default: return nil // Other statuses don't generate activities } // Add execution name to message if available if record.Name != "" { message = message + ": " + record.Name // Truncate long messages if len(message) > 150 { message = message[:147] + "..." } } return &Activity{ Type: actType, RobotID: record.MemberID, ExecutionID: record.ExecutionID, Message: message, Timestamp: timestamp, } } // FromExecution creates an ExecutionRecord from a runtime Execution func FromExecution(exec *types.Execution) *ExecutionRecord { record := &ExecutionRecord{ ExecutionID: exec.ID, MemberID: exec.MemberID, TeamID: exec.TeamID, TriggerType: exec.TriggerType, Status: exec.Status, Phase: exec.Phase, Error: exec.Error, Name: exec.Name, CurrentTaskName: exec.CurrentTaskName, Input: exec.Input, Inspiration: exec.Inspiration, Goals: exec.Goals, Tasks: exec.Tasks, Results: exec.Results, Delivery: exec.Delivery, Learning: exec.Learning, ChatID: exec.ChatID, WaitingTaskID: exec.WaitingTaskID, WaitingQuestion: exec.WaitingQuestion, WaitingSince: exec.WaitingSince, ResumeContext: exec.ResumeContext, } // Convert timestamps if !exec.StartTime.IsZero() { record.StartTime = &exec.StartTime } if exec.EndTime != nil { record.EndTime = exec.EndTime } // Convert CurrentState if exec.Current != nil { record.Current = &CurrentState{ TaskIndex: exec.Current.TaskIndex, Progress: exec.Current.Progress, } } return record } // ToExecution converts an ExecutionRecord to a runtime Execution func (r *ExecutionRecord) ToExecution() *types.Execution { exec := &types.Execution{ ID: r.ExecutionID, MemberID: r.MemberID, TeamID: r.TeamID, TriggerType: r.TriggerType, Status: r.Status, Phase: r.Phase, Error: r.Error, Name: r.Name, CurrentTaskName: r.CurrentTaskName, Input: r.Input, Inspiration: r.Inspiration, Goals: r.Goals, Tasks: r.Tasks, Results: r.Results, Delivery: r.Delivery, Learning: r.Learning, ChatID: r.ChatID, WaitingTaskID: r.WaitingTaskID, WaitingQuestion: r.WaitingQuestion, WaitingSince: r.WaitingSince, ResumeContext: r.ResumeContext, } // Convert timestamps if r.StartTime != nil { exec.StartTime = *r.StartTime } if r.EndTime != nil { exec.EndTime = r.EndTime } // Convert CurrentState if r.Current != nil { exec.Current = &types.CurrentState{ TaskIndex: r.Current.TaskIndex, Progress: r.Current.Progress, } } return exec } ================================================ FILE: agent/robot/store/execution_test.go ================================================ package store_test import ( "context" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/yao/agent/robot/store" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/testutils" ) // TestExecutionStoreSave tests creating and updating execution records func TestExecutionStoreSave(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Clean up any existing test data cleanupTestExecutions(t) defer cleanupTestExecutions(t) s := store.NewExecutionStore() ctx := context.Background() t.Run("creates_new_execution_record", func(t *testing.T) { startTime := time.Now() record := &store.ExecutionRecord{ ExecutionID: "exec_test_save_001", MemberID: "member_test_001", TeamID: "team_test_001", TriggerType: types.TriggerClock, Status: types.ExecPending, Phase: types.PhaseInspiration, StartTime: &startTime, } err := s.Save(ctx, record) require.NoError(t, err) // Verify it was created saved, err := s.Get(ctx, "exec_test_save_001") require.NoError(t, err) require.NotNil(t, saved) assert.Equal(t, "exec_test_save_001", saved.ExecutionID) assert.Equal(t, "member_test_001", saved.MemberID) assert.Equal(t, "team_test_001", saved.TeamID) assert.Equal(t, types.TriggerClock, saved.TriggerType) assert.Equal(t, types.ExecPending, saved.Status) assert.Equal(t, types.PhaseInspiration, saved.Phase) assert.NotNil(t, saved.StartTime) assert.NotNil(t, saved.CreatedAt) }) t.Run("updates_existing_execution_record", func(t *testing.T) { // First create a record startTime := time.Now() record := &store.ExecutionRecord{ ExecutionID: "exec_test_save_002", MemberID: "member_test_002", TeamID: "team_test_002", TriggerType: types.TriggerHuman, Status: types.ExecPending, Phase: types.PhaseInspiration, StartTime: &startTime, } err := s.Save(ctx, record) require.NoError(t, err) // Update the record record.Status = types.ExecRunning record.Phase = types.PhaseGoals record.Goals = &types.Goals{Content: "Test goals content"} err = s.Save(ctx, record) require.NoError(t, err) // Verify the update saved, err := s.Get(ctx, "exec_test_save_002") require.NoError(t, err) require.NotNil(t, saved) assert.Equal(t, types.ExecRunning, saved.Status) assert.Equal(t, types.PhaseGoals, saved.Phase) assert.NotNil(t, saved.Goals) assert.Equal(t, "Test goals content", saved.Goals.Content) }) } // TestExecutionStoreGet tests retrieving execution records func TestExecutionStoreGet(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestExecutions(t) defer cleanupTestExecutions(t) s := store.NewExecutionStore() ctx := context.Background() // Create a test record with all fields populated setupTestExecution(t, s, ctx) t.Run("returns_existing_record", func(t *testing.T) { record, err := s.Get(ctx, "exec_test_get_001") require.NoError(t, err) require.NotNil(t, record) assert.Equal(t, "exec_test_get_001", record.ExecutionID) assert.Equal(t, "member_test_get", record.MemberID) assert.Equal(t, "team_test_get", record.TeamID) assert.Equal(t, types.TriggerClock, record.TriggerType) assert.Equal(t, types.ExecCompleted, record.Status) assert.Equal(t, types.PhaseDelivery, record.Phase) // Verify phase outputs assert.NotNil(t, record.Inspiration) assert.Equal(t, "Test inspiration content", record.Inspiration.Content) assert.NotNil(t, record.Goals) assert.Equal(t, "Test goals content", record.Goals.Content) assert.Len(t, record.Tasks, 2) assert.Equal(t, "task_001", record.Tasks[0].ID) assert.Len(t, record.Results, 2) assert.True(t, record.Results[0].Success) }) t.Run("returns_nil_for_non_existent_record", func(t *testing.T) { record, err := s.Get(ctx, "exec_non_existent") require.NoError(t, err) assert.Nil(t, record) }) } // TestExecutionStoreList tests listing execution records with filters func TestExecutionStoreList(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestExecutions(t) defer cleanupTestExecutions(t) s := store.NewExecutionStore() ctx := context.Background() // Create multiple test records setupTestExecutionsForList(t, s, ctx) t.Run("lists_all_records_without_filters", func(t *testing.T) { result, err := s.List(ctx, nil) require.NoError(t, err) assert.GreaterOrEqual(t, len(result.Data), 4) }) t.Run("filters_by_member_id", func(t *testing.T) { result, err := s.List(ctx, &store.ListOptions{ MemberID: "member_list_001", }) require.NoError(t, err) assert.Equal(t, 2, len(result.Data)) for _, r := range result.Data { assert.Equal(t, "member_list_001", r.MemberID) } }) t.Run("filters_by_team_id", func(t *testing.T) { result, err := s.List(ctx, &store.ListOptions{ TeamID: "team_list_001", }) require.NoError(t, err) assert.Equal(t, 3, len(result.Data)) for _, r := range result.Data { assert.Equal(t, "team_list_001", r.TeamID) } }) t.Run("filters_by_status", func(t *testing.T) { result, err := s.List(ctx, &store.ListOptions{ Status: types.ExecCompleted, }) require.NoError(t, err) assert.GreaterOrEqual(t, len(result.Data), 2) for _, r := range result.Data { assert.Equal(t, types.ExecCompleted, r.Status) } }) t.Run("filters_by_trigger_type", func(t *testing.T) { result, err := s.List(ctx, &store.ListOptions{ TriggerType: types.TriggerHuman, }) require.NoError(t, err) assert.GreaterOrEqual(t, len(result.Data), 1) for _, r := range result.Data { assert.Equal(t, types.TriggerHuman, r.TriggerType) } }) t.Run("respects_pagesize", func(t *testing.T) { result, err := s.List(ctx, &store.ListOptions{ PageSize: 2, }) require.NoError(t, err) assert.Equal(t, 2, len(result.Data)) }) t.Run("combines_multiple_filters", func(t *testing.T) { result, err := s.List(ctx, &store.ListOptions{ TeamID: "team_list_001", Status: types.ExecCompleted, }) require.NoError(t, err) assert.Equal(t, 2, len(result.Data)) for _, r := range result.Data { assert.Equal(t, "team_list_001", r.TeamID) assert.Equal(t, types.ExecCompleted, r.Status) } }) } // TestExecutionStoreUpdatePhase tests updating phase and phase data func TestExecutionStoreUpdatePhase(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestExecutions(t) defer cleanupTestExecutions(t) s := store.NewExecutionStore() ctx := context.Background() // Create a base record startTime := time.Now() record := &store.ExecutionRecord{ ExecutionID: "exec_test_phase_001", MemberID: "member_phase_001", TeamID: "team_phase_001", TriggerType: types.TriggerClock, Status: types.ExecRunning, Phase: types.PhaseInspiration, StartTime: &startTime, } err := s.Save(ctx, record) require.NoError(t, err) t.Run("updates_inspiration_phase", func(t *testing.T) { inspiration := &types.InspirationReport{ Content: "Updated inspiration content", } err := s.UpdatePhase(ctx, "exec_test_phase_001", types.PhaseInspiration, inspiration) require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_phase_001") require.NoError(t, err) assert.Equal(t, types.PhaseInspiration, saved.Phase) assert.NotNil(t, saved.Inspiration) assert.Equal(t, "Updated inspiration content", saved.Inspiration.Content) }) t.Run("updates_goals_phase", func(t *testing.T) { goals := &types.Goals{ Content: "Updated goals content", } err := s.UpdatePhase(ctx, "exec_test_phase_001", types.PhaseGoals, goals) require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_phase_001") require.NoError(t, err) assert.Equal(t, types.PhaseGoals, saved.Phase) assert.NotNil(t, saved.Goals) assert.Equal(t, "Updated goals content", saved.Goals.Content) }) t.Run("updates_tasks_phase", func(t *testing.T) { tasks := []types.Task{ {ID: "task_phase_001", ExecutorType: types.ExecutorAssistant}, {ID: "task_phase_002", ExecutorType: types.ExecutorProcess}, } err := s.UpdatePhase(ctx, "exec_test_phase_001", types.PhaseTasks, tasks) require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_phase_001") require.NoError(t, err) assert.Equal(t, types.PhaseTasks, saved.Phase) assert.Len(t, saved.Tasks, 2) assert.Equal(t, "task_phase_001", saved.Tasks[0].ID) }) t.Run("updates_run_phase", func(t *testing.T) { results := []types.TaskResult{ {TaskID: "task_phase_001", Success: true, Output: "Result 1"}, {TaskID: "task_phase_002", Success: false, Error: "Failed"}, } err := s.UpdatePhase(ctx, "exec_test_phase_001", types.PhaseRun, results) require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_phase_001") require.NoError(t, err) assert.Equal(t, types.PhaseRun, saved.Phase) assert.Len(t, saved.Results, 2) assert.True(t, saved.Results[0].Success) assert.False(t, saved.Results[1].Success) }) t.Run("updates_delivery_phase", func(t *testing.T) { delivery := &types.DeliveryResult{ Success: true, } err := s.UpdatePhase(ctx, "exec_test_phase_001", types.PhaseDelivery, delivery) require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_phase_001") require.NoError(t, err) assert.Equal(t, types.PhaseDelivery, saved.Phase) assert.NotNil(t, saved.Delivery) assert.True(t, saved.Delivery.Success) }) t.Run("updates_learning_phase", func(t *testing.T) { learning := []types.LearningEntry{ {Type: types.LearnExecution, Content: "Learned something"}, } err := s.UpdatePhase(ctx, "exec_test_phase_001", types.PhaseLearning, learning) require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_phase_001") require.NoError(t, err) assert.Equal(t, types.PhaseLearning, saved.Phase) assert.Len(t, saved.Learning, 1) assert.Equal(t, "Learned something", saved.Learning[0].Content) }) } // TestExecutionStoreUpdateStatus tests updating execution status func TestExecutionStoreUpdateStatus(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestExecutions(t) defer cleanupTestExecutions(t) s := store.NewExecutionStore() ctx := context.Background() t.Run("updates_status_to_running", func(t *testing.T) { startTime := time.Now() record := &store.ExecutionRecord{ ExecutionID: "exec_test_status_001", MemberID: "member_status_001", TeamID: "team_status_001", TriggerType: types.TriggerClock, Status: types.ExecPending, Phase: types.PhaseInspiration, StartTime: &startTime, } err := s.Save(ctx, record) require.NoError(t, err) err = s.UpdateStatus(ctx, "exec_test_status_001", types.ExecRunning, "") require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_status_001") require.NoError(t, err) assert.Equal(t, types.ExecRunning, saved.Status) assert.Nil(t, saved.EndTime) // Should not set end_time for running }) t.Run("updates_status_to_completed_with_end_time", func(t *testing.T) { startTime := time.Now() record := &store.ExecutionRecord{ ExecutionID: "exec_test_status_002", MemberID: "member_status_002", TeamID: "team_status_002", TriggerType: types.TriggerHuman, Status: types.ExecRunning, Phase: types.PhaseDelivery, StartTime: &startTime, } err := s.Save(ctx, record) require.NoError(t, err) err = s.UpdateStatus(ctx, "exec_test_status_002", types.ExecCompleted, "") require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_status_002") require.NoError(t, err) assert.Equal(t, types.ExecCompleted, saved.Status) assert.NotNil(t, saved.EndTime) // Should set end_time for completed }) t.Run("updates_status_to_failed_with_error", func(t *testing.T) { startTime := time.Now() record := &store.ExecutionRecord{ ExecutionID: "exec_test_status_003", MemberID: "member_status_003", TeamID: "team_status_003", TriggerType: types.TriggerEvent, Status: types.ExecRunning, Phase: types.PhaseRun, StartTime: &startTime, } err := s.Save(ctx, record) require.NoError(t, err) err = s.UpdateStatus(ctx, "exec_test_status_003", types.ExecFailed, "Task execution failed: timeout") require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_status_003") require.NoError(t, err) assert.Equal(t, types.ExecFailed, saved.Status) assert.Equal(t, "Task execution failed: timeout", saved.Error) assert.NotNil(t, saved.EndTime) // Should set end_time for failed }) t.Run("updates_status_to_cancelled", func(t *testing.T) { startTime := time.Now() record := &store.ExecutionRecord{ ExecutionID: "exec_test_status_004", MemberID: "member_status_004", TeamID: "team_status_004", TriggerType: types.TriggerClock, Status: types.ExecRunning, Phase: types.PhaseTasks, StartTime: &startTime, } err := s.Save(ctx, record) require.NoError(t, err) err = s.UpdateStatus(ctx, "exec_test_status_004", types.ExecCancelled, "User cancelled") require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_status_004") require.NoError(t, err) assert.Equal(t, types.ExecCancelled, saved.Status) assert.Equal(t, "User cancelled", saved.Error) assert.NotNil(t, saved.EndTime) // Should set end_time for cancelled }) } // TestExecutionStoreUpdateCurrent tests updating current state func TestExecutionStoreUpdateCurrent(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestExecutions(t) defer cleanupTestExecutions(t) s := store.NewExecutionStore() ctx := context.Background() // Create a base record startTime := time.Now() record := &store.ExecutionRecord{ ExecutionID: "exec_test_current_001", MemberID: "member_current_001", TeamID: "team_current_001", TriggerType: types.TriggerClock, Status: types.ExecRunning, Phase: types.PhaseRun, StartTime: &startTime, } err := s.Save(ctx, record) require.NoError(t, err) t.Run("updates_current_state", func(t *testing.T) { current := &store.CurrentState{ TaskIndex: 2, Progress: "3/5 tasks completed", } err := s.UpdateCurrent(ctx, "exec_test_current_001", current) require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_current_001") require.NoError(t, err) assert.NotNil(t, saved.Current) assert.Equal(t, 2, saved.Current.TaskIndex) assert.Equal(t, "3/5 tasks completed", saved.Current.Progress) }) } // TestExecutionStoreUpdateUIFields tests updating UI display fields func TestExecutionStoreUpdateUIFields(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestExecutions(t) defer cleanupTestExecutions(t) s := store.NewExecutionStore() ctx := context.Background() // Create a base record startTime := time.Now() record := &store.ExecutionRecord{ ExecutionID: "exec_test_uifields_001", MemberID: "member_uifields_001", TeamID: "team_uifields_001", TriggerType: types.TriggerHuman, Status: types.ExecRunning, Phase: types.PhaseInspiration, StartTime: &startTime, } err := s.Save(ctx, record) require.NoError(t, err) t.Run("updates_name_only", func(t *testing.T) { err := s.UpdateUIFields(ctx, "exec_test_uifields_001", "Analyze sales data", "") require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_uifields_001") require.NoError(t, err) assert.Equal(t, "Analyze sales data", saved.Name) assert.Equal(t, "", saved.CurrentTaskName) }) t.Run("updates_current_task_name_only", func(t *testing.T) { err := s.UpdateUIFields(ctx, "exec_test_uifields_001", "", "Analyzing context...") require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_uifields_001") require.NoError(t, err) assert.Equal(t, "Analyze sales data", saved.Name) // Previous value retained assert.Equal(t, "Analyzing context...", saved.CurrentTaskName) }) t.Run("updates_both_fields", func(t *testing.T) { err := s.UpdateUIFields(ctx, "exec_test_uifields_001", "Generate monthly report", "Task 1/3: Collect data") require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_uifields_001") require.NoError(t, err) assert.Equal(t, "Generate monthly report", saved.Name) assert.Equal(t, "Task 1/3: Collect data", saved.CurrentTaskName) }) t.Run("does_nothing_when_both_empty", func(t *testing.T) { // Get current values before, err := s.Get(ctx, "exec_test_uifields_001") require.NoError(t, err) // Update with empty strings err = s.UpdateUIFields(ctx, "exec_test_uifields_001", "", "") require.NoError(t, err) // Values should remain unchanged after, err := s.Get(ctx, "exec_test_uifields_001") require.NoError(t, err) assert.Equal(t, before.Name, after.Name) assert.Equal(t, before.CurrentTaskName, after.CurrentTaskName) }) t.Run("handles_chinese_content", func(t *testing.T) { err := s.UpdateUIFields(ctx, "exec_test_uifields_001", "生成月度报告", "任务 2/3: 分析数据") require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_uifields_001") require.NoError(t, err) assert.Equal(t, "生成月度报告", saved.Name) assert.Equal(t, "任务 2/3: 分析数据", saved.CurrentTaskName) }) t.Run("handles_long_content", func(t *testing.T) { longName := "This is a very long execution name that might come from a detailed user instruction about what they want the robot to accomplish in this particular run cycle" longTask := "Task 1/5: Processing a complex multi-step operation with various sub-tasks that need to be completed..." err := s.UpdateUIFields(ctx, "exec_test_uifields_001", longName, longTask) require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_uifields_001") require.NoError(t, err) assert.Equal(t, longName, saved.Name) assert.Equal(t, longTask, saved.CurrentTaskName) }) } // TestExecutionStoreUpdateTasks tests updating tasks array with status func TestExecutionStoreUpdateTasks(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestExecutions(t) defer cleanupTestExecutions(t) s := store.NewExecutionStore() ctx := context.Background() // Create a base record with initial tasks startTime := time.Now() record := &store.ExecutionRecord{ ExecutionID: "exec_test_tasks_001", MemberID: "member_tasks_001", TeamID: "team_tasks_001", TriggerType: types.TriggerClock, Status: types.ExecRunning, Phase: types.PhaseRun, StartTime: &startTime, Tasks: []types.Task{ {ID: "task_001", ExecutorType: types.ExecutorAssistant, Status: types.TaskPending, Order: 0}, {ID: "task_002", ExecutorType: types.ExecutorProcess, Status: types.TaskPending, Order: 1}, {ID: "task_003", ExecutorType: types.ExecutorAssistant, Status: types.TaskPending, Order: 2}, }, } err := s.Save(ctx, record) require.NoError(t, err) t.Run("updates_task_status_to_running", func(t *testing.T) { // Update first task to running tasks := []types.Task{ {ID: "task_001", ExecutorType: types.ExecutorAssistant, Status: types.TaskRunning, Order: 0}, {ID: "task_002", ExecutorType: types.ExecutorProcess, Status: types.TaskPending, Order: 1}, {ID: "task_003", ExecutorType: types.ExecutorAssistant, Status: types.TaskPending, Order: 2}, } current := &store.CurrentState{TaskIndex: 0, Progress: "1/3 tasks"} err := s.UpdateTasks(ctx, "exec_test_tasks_001", tasks, current) require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_tasks_001") require.NoError(t, err) require.Len(t, saved.Tasks, 3) assert.Equal(t, types.TaskRunning, saved.Tasks[0].Status) assert.Equal(t, types.TaskPending, saved.Tasks[1].Status) assert.Equal(t, types.TaskPending, saved.Tasks[2].Status) assert.NotNil(t, saved.Current) assert.Equal(t, 0, saved.Current.TaskIndex) }) t.Run("updates_task_status_to_completed", func(t *testing.T) { // First task completed, second running tasks := []types.Task{ {ID: "task_001", ExecutorType: types.ExecutorAssistant, Status: types.TaskCompleted, Order: 0}, {ID: "task_002", ExecutorType: types.ExecutorProcess, Status: types.TaskRunning, Order: 1}, {ID: "task_003", ExecutorType: types.ExecutorAssistant, Status: types.TaskPending, Order: 2}, } current := &store.CurrentState{TaskIndex: 1, Progress: "2/3 tasks"} err := s.UpdateTasks(ctx, "exec_test_tasks_001", tasks, current) require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_tasks_001") require.NoError(t, err) assert.Equal(t, types.TaskCompleted, saved.Tasks[0].Status) assert.Equal(t, types.TaskRunning, saved.Tasks[1].Status) assert.Equal(t, types.TaskPending, saved.Tasks[2].Status) assert.Equal(t, 1, saved.Current.TaskIndex) }) t.Run("updates_task_status_to_failed_with_skipped", func(t *testing.T) { // Second task failed, third skipped tasks := []types.Task{ {ID: "task_001", ExecutorType: types.ExecutorAssistant, Status: types.TaskCompleted, Order: 0}, {ID: "task_002", ExecutorType: types.ExecutorProcess, Status: types.TaskFailed, Order: 1}, {ID: "task_003", ExecutorType: types.ExecutorAssistant, Status: types.TaskSkipped, Order: 2}, } current := &store.CurrentState{TaskIndex: 1, Progress: "Failed at 2/3"} err := s.UpdateTasks(ctx, "exec_test_tasks_001", tasks, current) require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_tasks_001") require.NoError(t, err) assert.Equal(t, types.TaskCompleted, saved.Tasks[0].Status) assert.Equal(t, types.TaskFailed, saved.Tasks[1].Status) assert.Equal(t, types.TaskSkipped, saved.Tasks[2].Status) }) t.Run("updates_with_nil_current", func(t *testing.T) { // All tasks completed, no current tasks := []types.Task{ {ID: "task_001", ExecutorType: types.ExecutorAssistant, Status: types.TaskCompleted, Order: 0}, {ID: "task_002", ExecutorType: types.ExecutorProcess, Status: types.TaskCompleted, Order: 1}, {ID: "task_003", ExecutorType: types.ExecutorAssistant, Status: types.TaskCompleted, Order: 2}, } err := s.UpdateTasks(ctx, "exec_test_tasks_001", tasks, nil) require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_tasks_001") require.NoError(t, err) assert.Equal(t, types.TaskCompleted, saved.Tasks[0].Status) assert.Equal(t, types.TaskCompleted, saved.Tasks[1].Status) assert.Equal(t, types.TaskCompleted, saved.Tasks[2].Status) }) t.Run("preserves_task_description", func(t *testing.T) { // Create a new record with descriptions record2 := &store.ExecutionRecord{ ExecutionID: "exec_test_tasks_002", MemberID: "member_tasks_002", TeamID: "team_tasks_002", TriggerType: types.TriggerHuman, Status: types.ExecRunning, Phase: types.PhaseRun, StartTime: &startTime, Tasks: []types.Task{ {ID: "task_d01", Description: "Analyze data", ExecutorType: types.ExecutorAssistant, Status: types.TaskPending, Order: 0}, {ID: "task_d02", Description: "Generate report", ExecutorType: types.ExecutorAssistant, Status: types.TaskPending, Order: 1}, }, } err := s.Save(ctx, record2) require.NoError(t, err) // Update status preserving description tasks := []types.Task{ {ID: "task_d01", Description: "Analyze data", ExecutorType: types.ExecutorAssistant, Status: types.TaskCompleted, Order: 0}, {ID: "task_d02", Description: "Generate report", ExecutorType: types.ExecutorAssistant, Status: types.TaskRunning, Order: 1}, } err = s.UpdateTasks(ctx, "exec_test_tasks_002", tasks, &store.CurrentState{TaskIndex: 1}) require.NoError(t, err) saved, err := s.Get(ctx, "exec_test_tasks_002") require.NoError(t, err) assert.Equal(t, "Analyze data", saved.Tasks[0].Description) assert.Equal(t, "Generate report", saved.Tasks[1].Description) assert.Equal(t, types.TaskCompleted, saved.Tasks[0].Status) assert.Equal(t, types.TaskRunning, saved.Tasks[1].Status) }) } // TestExecutionStoreDelete tests deleting execution records func TestExecutionStoreDelete(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestExecutions(t) defer cleanupTestExecutions(t) s := store.NewExecutionStore() ctx := context.Background() t.Run("deletes_existing_record", func(t *testing.T) { // Create a record startTime := time.Now() record := &store.ExecutionRecord{ ExecutionID: "exec_test_delete_001", MemberID: "member_delete_001", TeamID: "team_delete_001", TriggerType: types.TriggerClock, Status: types.ExecCompleted, Phase: types.PhaseDelivery, StartTime: &startTime, } err := s.Save(ctx, record) require.NoError(t, err) // Verify it exists saved, err := s.Get(ctx, "exec_test_delete_001") require.NoError(t, err) require.NotNil(t, saved) // Delete it err = s.Delete(ctx, "exec_test_delete_001") require.NoError(t, err) // Verify it's gone saved, err = s.Get(ctx, "exec_test_delete_001") require.NoError(t, err) assert.Nil(t, saved) }) t.Run("no_error_for_non_existent_record", func(t *testing.T) { err := s.Delete(ctx, "exec_non_existent") assert.NoError(t, err) }) } // TestExecutionRecordConversion tests conversion between ExecutionRecord and Execution func TestExecutionRecordConversion(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) t.Run("converts_from_execution", func(t *testing.T) { now := time.Now() endTime := now.Add(time.Hour) exec := &types.Execution{ ID: "exec_convert_001", MemberID: "member_convert_001", TeamID: "team_convert_001", TriggerType: types.TriggerHuman, Status: types.ExecCompleted, Phase: types.PhaseDelivery, StartTime: now, EndTime: &endTime, Error: "", Name: "Analyze sales data", CurrentTaskName: "Task 1/3: Processing", Inspiration: &types.InspirationReport{Content: "Test inspiration"}, Goals: &types.Goals{Content: "Test goals"}, Tasks: []types.Task{ {ID: "task_001", ExecutorType: types.ExecutorAssistant}, }, Results: []types.TaskResult{ {TaskID: "task_001", Success: true}, }, Current: &types.CurrentState{ TaskIndex: 1, Progress: "1/1 tasks", }, } record := store.FromExecution(exec) assert.Equal(t, "exec_convert_001", record.ExecutionID) assert.Equal(t, "member_convert_001", record.MemberID) assert.Equal(t, "team_convert_001", record.TeamID) assert.Equal(t, types.TriggerHuman, record.TriggerType) assert.Equal(t, types.ExecCompleted, record.Status) assert.Equal(t, types.PhaseDelivery, record.Phase) assert.NotNil(t, record.StartTime) assert.NotNil(t, record.EndTime) assert.NotNil(t, record.Inspiration) assert.NotNil(t, record.Goals) assert.Len(t, record.Tasks, 1) assert.Len(t, record.Results, 1) assert.NotNil(t, record.Current) assert.Equal(t, 1, record.Current.TaskIndex) // Verify UI fields conversion assert.Equal(t, "Analyze sales data", record.Name) assert.Equal(t, "Task 1/3: Processing", record.CurrentTaskName) }) t.Run("converts_to_execution", func(t *testing.T) { now := time.Now() endTime := now.Add(time.Hour) record := &store.ExecutionRecord{ ExecutionID: "exec_convert_002", MemberID: "member_convert_002", TeamID: "team_convert_002", TriggerType: types.TriggerClock, Status: types.ExecRunning, Phase: types.PhaseRun, StartTime: &now, EndTime: &endTime, Name: "定时执行", CurrentTaskName: "任务 1/2: 数据分析", Inspiration: &types.InspirationReport{Content: "Test inspiration"}, Goals: &types.Goals{Content: "Test goals"}, Tasks: []types.Task{ {ID: "task_002", ExecutorType: types.ExecutorProcess}, }, Results: []types.TaskResult{ {TaskID: "task_002", Success: false, Error: "Failed"}, }, Current: &store.CurrentState{ TaskIndex: 0, Progress: "0/1 tasks", }, } exec := record.ToExecution() assert.Equal(t, "exec_convert_002", exec.ID) assert.Equal(t, "member_convert_002", exec.MemberID) assert.Equal(t, "team_convert_002", exec.TeamID) assert.Equal(t, types.TriggerClock, exec.TriggerType) assert.Equal(t, types.ExecRunning, exec.Status) assert.Equal(t, types.PhaseRun, exec.Phase) assert.NotNil(t, exec.Inspiration) assert.NotNil(t, exec.Goals) assert.Len(t, exec.Tasks, 1) assert.Len(t, exec.Results, 1) assert.NotNil(t, exec.Current) assert.Equal(t, 0, exec.Current.TaskIndex) // Verify UI fields conversion assert.Equal(t, "定时执行", exec.Name) assert.Equal(t, "任务 1/2: 数据分析", exec.CurrentTaskName) }) } // Helper functions func cleanupTestExecutions(t *testing.T) { mod := model.Select("__yao.agent.execution") if mod == nil { return } // Delete all test execution records _, err := mod.DeleteWhere(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "execution_id", OP: "like", Value: "exec_test_%"}, }, }) if err != nil { t.Logf("Warning: failed to cleanup test executions: %v", err) } } func setupTestExecution(t *testing.T, s *store.ExecutionStore, ctx context.Context) { startTime := time.Now().Add(-time.Hour) endTime := time.Now() record := &store.ExecutionRecord{ ExecutionID: "exec_test_get_001", MemberID: "member_test_get", TeamID: "team_test_get", TriggerType: types.TriggerClock, Status: types.ExecCompleted, Phase: types.PhaseDelivery, StartTime: &startTime, EndTime: &endTime, Inspiration: &types.InspirationReport{ Content: "Test inspiration content", }, Goals: &types.Goals{ Content: "Test goals content", }, Tasks: []types.Task{ {ID: "task_001", ExecutorType: types.ExecutorAssistant, Status: types.TaskCompleted}, {ID: "task_002", ExecutorType: types.ExecutorProcess, Status: types.TaskCompleted}, }, Results: []types.TaskResult{ {TaskID: "task_001", Success: true, Output: "Result 1"}, {TaskID: "task_002", Success: true, Output: "Result 2"}, }, Delivery: &types.DeliveryResult{ Success: true, }, Learning: []types.LearningEntry{ {Type: types.LearnExecution, Content: "Test learning"}, }, } err := s.Save(ctx, record) require.NoError(t, err) } func setupTestExecutionsForList(t *testing.T, s *store.ExecutionStore, ctx context.Context) { startTime := time.Now() records := []*store.ExecutionRecord{ { ExecutionID: "exec_test_list_001", MemberID: "member_list_001", TeamID: "team_list_001", TriggerType: types.TriggerClock, Status: types.ExecCompleted, Phase: types.PhaseDelivery, StartTime: &startTime, }, { ExecutionID: "exec_test_list_002", MemberID: "member_list_001", TeamID: "team_list_001", TriggerType: types.TriggerClock, Status: types.ExecCompleted, Phase: types.PhaseDelivery, StartTime: &startTime, }, { ExecutionID: "exec_test_list_003", MemberID: "member_list_002", TeamID: "team_list_001", TriggerType: types.TriggerHuman, Status: types.ExecRunning, Phase: types.PhaseRun, StartTime: &startTime, }, { ExecutionID: "exec_test_list_004", MemberID: "member_list_002", TeamID: "team_list_002", TriggerType: types.TriggerEvent, Status: types.ExecFailed, Phase: types.PhaseRun, StartTime: &startTime, Error: "Test error", }, } for _, record := range records { err := s.Save(ctx, record) require.NoError(t, err) } } // ==================== Results & Activities Tests ==================== // TestExecutionStoreListResults tests listing execution results (deliveries) func TestExecutionStoreListResults(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestExecutions(t) defer cleanupTestExecutions(t) s := store.NewExecutionStore() ctx := context.Background() // Setup test data with delivery content setupTestResultsData(t, s, ctx) t.Run("lists_results_without_filters", func(t *testing.T) { result, err := s.ListResults(ctx, &store.ResultListOptions{ MemberID: "member_result_001", }) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, 2, result.Total) assert.Len(t, result.Data, 2) // Should be ordered by end_time desc for _, r := range result.Data { assert.NotNil(t, r.Delivery) assert.NotNil(t, r.Delivery.Content) } }) t.Run("filters_by_trigger_type", func(t *testing.T) { result, err := s.ListResults(ctx, &store.ResultListOptions{ MemberID: "member_result_001", TriggerType: types.TriggerClock, }) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, 1, result.Total) assert.Len(t, result.Data, 1) assert.Equal(t, types.TriggerClock, result.Data[0].TriggerType) }) t.Run("filters_by_keyword", func(t *testing.T) { result, err := s.ListResults(ctx, &store.ResultListOptions{ MemberID: "member_result_001", Keyword: "Weekly", }) require.NoError(t, err) require.NotNil(t, result) // Should match "Weekly Sales Report" assert.GreaterOrEqual(t, result.Total, 1) }) t.Run("respects_pagination", func(t *testing.T) { result, err := s.ListResults(ctx, &store.ResultListOptions{ MemberID: "member_result_001", PageSize: 1, Page: 1, }) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, 1, len(result.Data)) assert.Equal(t, 2, result.Total) assert.Equal(t, 1, result.Page) }) t.Run("excludes_executions_without_delivery", func(t *testing.T) { result, err := s.ListResults(ctx, &store.ResultListOptions{ MemberID: "member_result_002", // Has no delivery content }) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, 0, result.Total) assert.Empty(t, result.Data) }) } // TestExecutionStoreCountResults tests counting results func TestExecutionStoreCountResults(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestExecutions(t) defer cleanupTestExecutions(t) s := store.NewExecutionStore() ctx := context.Background() // Setup test data with delivery content setupTestResultsData(t, s, ctx) t.Run("counts_all_results_for_member", func(t *testing.T) { count, err := s.CountResults(ctx, &store.ResultListOptions{ MemberID: "member_result_001", }) require.NoError(t, err) assert.Equal(t, 2, count) }) t.Run("counts_filtered_results", func(t *testing.T) { count, err := s.CountResults(ctx, &store.ResultListOptions{ MemberID: "member_result_001", TriggerType: types.TriggerHuman, }) require.NoError(t, err) assert.Equal(t, 1, count) }) t.Run("returns_zero_for_no_results", func(t *testing.T) { count, err := s.CountResults(ctx, &store.ResultListOptions{ MemberID: "member_result_002", }) require.NoError(t, err) assert.Equal(t, 0, count) }) } // TestExecutionStoreListActivities tests listing activities func TestExecutionStoreListActivities(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestExecutions(t) defer cleanupTestExecutions(t) s := store.NewExecutionStore() ctx := context.Background() // Setup test data setupTestActivitiesData(t, s, ctx) t.Run("lists_activities_for_team", func(t *testing.T) { activities, err := s.ListActivities(ctx, &store.ActivityListOptions{ TeamID: "team_activity_001", }) require.NoError(t, err) assert.GreaterOrEqual(t, len(activities), 3) }) t.Run("respects_limit", func(t *testing.T) { activities, err := s.ListActivities(ctx, &store.ActivityListOptions{ TeamID: "team_activity_001", Limit: 2, }) require.NoError(t, err) assert.LessOrEqual(t, len(activities), 2) }) t.Run("filters_by_since", func(t *testing.T) { // Without since, should get all activities activitiesAll, err := s.ListActivities(ctx, &store.ActivityListOptions{ TeamID: "team_activity_001", }) require.NoError(t, err) allCount := len(activitiesAll) assert.GreaterOrEqual(t, allCount, 3, "should have at least 3 activities without filter") // Use a time in the future to ensure we get no results future := time.Now().Add(24 * time.Hour) activitiesFuture, err := s.ListActivities(ctx, &store.ActivityListOptions{ TeamID: "team_activity_001", Since: &future, }) require.NoError(t, err) assert.Equal(t, 0, len(activitiesFuture), "should get no results with future since time") }) t.Run("generates_correct_activity_types", func(t *testing.T) { activities, err := s.ListActivities(ctx, &store.ActivityListOptions{ TeamID: "team_activity_001", }) require.NoError(t, err) // Should have activities of different types typeCount := make(map[store.ActivityType]int) for _, a := range activities { typeCount[a.Type]++ } // We should have at least completed and failed types assert.Greater(t, typeCount[store.ActivityExecutionCompleted], 0, "should have completed activities") assert.Greater(t, typeCount[store.ActivityExecutionFailed], 0, "should have failed activities") }) t.Run("filters_by_type_completed", func(t *testing.T) { activities, err := s.ListActivities(ctx, &store.ActivityListOptions{ TeamID: "team_activity_001", Type: store.ActivityExecutionCompleted, }) require.NoError(t, err) // All returned activities should be of type completed for _, a := range activities { assert.Equal(t, store.ActivityExecutionCompleted, a.Type, "all activities should be completed type") } assert.Greater(t, len(activities), 0, "should have at least one completed activity") }) t.Run("filters_by_type_failed", func(t *testing.T) { activities, err := s.ListActivities(ctx, &store.ActivityListOptions{ TeamID: "team_activity_001", Type: store.ActivityExecutionFailed, }) require.NoError(t, err) // All returned activities should be of type failed for _, a := range activities { assert.Equal(t, store.ActivityExecutionFailed, a.Type, "all activities should be failed type") } assert.Greater(t, len(activities), 0, "should have at least one failed activity") }) t.Run("filters_by_type_invalid_returns_empty", func(t *testing.T) { activities, err := s.ListActivities(ctx, &store.ActivityListOptions{ TeamID: "team_activity_001", Type: store.ActivityType("invalid.type"), }) require.NoError(t, err) // Invalid type should return empty result assert.Equal(t, 0, len(activities), "invalid type should return empty result") }) t.Run("includes_execution_name_in_message", func(t *testing.T) { activities, err := s.ListActivities(ctx, &store.ActivityListOptions{ TeamID: "team_activity_001", }) require.NoError(t, err) // Find a completed activity var completedActivity *store.Activity for _, a := range activities { if a.Type == store.ActivityExecutionCompleted && a.Message != "" { completedActivity = a break } } require.NotNil(t, completedActivity, "should find a completed activity") assert.Contains(t, completedActivity.Message, "Completed") }) } // Helper function to setup test results data func setupTestResultsData(t *testing.T, s *store.ExecutionStore, ctx context.Context) { startTime := time.Now().Add(-2 * time.Hour) endTime := time.Now().Add(-1 * time.Hour) endTime2 := time.Now().Add(-30 * time.Minute) records := []*store.ExecutionRecord{ { ExecutionID: "exec_test_result_001", MemberID: "member_result_001", TeamID: "team_result_001", TriggerType: types.TriggerClock, Status: types.ExecCompleted, Phase: types.PhaseDelivery, Name: "Weekly Sales Report", StartTime: &startTime, EndTime: &endTime, Delivery: &types.DeliveryResult{ Success: true, Content: &types.DeliveryContent{ Summary: "Weekly sales report generated successfully", Body: "## Weekly Sales Report\n\nTotal sales: $50,000", }, }, }, { ExecutionID: "exec_test_result_002", MemberID: "member_result_001", TeamID: "team_result_001", TriggerType: types.TriggerHuman, Status: types.ExecCompleted, Phase: types.PhaseDelivery, Name: "Custom Analysis", StartTime: &startTime, EndTime: &endTime2, Delivery: &types.DeliveryResult{ Success: true, Content: &types.DeliveryContent{ Summary: "Custom analysis completed", Body: "## Analysis Results\n\nFindings...", Attachments: []types.DeliveryAttachment{ {Title: "Report.pdf", File: "__attachment://file_001"}, }, }, }, }, { // Completed but no delivery content - should be excluded ExecutionID: "exec_test_result_003", MemberID: "member_result_002", TeamID: "team_result_001", TriggerType: types.TriggerClock, Status: types.ExecCompleted, Phase: types.PhaseDelivery, Name: "No Delivery Content", StartTime: &startTime, EndTime: &endTime, // No Delivery field }, { // Running - should be excluded from results ExecutionID: "exec_test_result_004", MemberID: "member_result_001", TeamID: "team_result_001", TriggerType: types.TriggerClock, Status: types.ExecRunning, Phase: types.PhaseRun, Name: "Running Task", StartTime: &startTime, }, } for _, record := range records { err := s.Save(ctx, record) require.NoError(t, err) } } // Helper function to setup test activities data func setupTestActivitiesData(t *testing.T, s *store.ExecutionStore, ctx context.Context) { startTime := time.Now().Add(-2 * time.Hour) endTime := time.Now().Add(-1 * time.Hour) endTimeFailed := time.Now().Add(-45 * time.Minute) records := []*store.ExecutionRecord{ { ExecutionID: "exec_test_activity_001", MemberID: "member_activity_001", TeamID: "team_activity_001", TriggerType: types.TriggerClock, Status: types.ExecCompleted, Phase: types.PhaseDelivery, Name: "Daily Report", StartTime: &startTime, EndTime: &endTime, }, { ExecutionID: "exec_test_activity_002", MemberID: "member_activity_001", TeamID: "team_activity_001", TriggerType: types.TriggerHuman, Status: types.ExecFailed, Phase: types.PhaseRun, Name: "Custom Task", StartTime: &startTime, EndTime: &endTimeFailed, Error: "Task timeout", }, { ExecutionID: "exec_test_activity_003", MemberID: "member_activity_002", TeamID: "team_activity_001", TriggerType: types.TriggerEvent, Status: types.ExecCancelled, Phase: types.PhaseTasks, Name: "Lead Processing", StartTime: &startTime, EndTime: &endTime, }, { ExecutionID: "exec_test_activity_004", MemberID: "member_activity_002", TeamID: "team_activity_001", TriggerType: types.TriggerClock, Status: types.ExecRunning, Phase: types.PhaseRun, Name: "Data Analysis", StartTime: &startTime, }, } for _, record := range records { err := s.Save(ctx, record) require.NoError(t, err) } } ================================================ FILE: agent/robot/store/robot.go ================================================ package store import ( "context" "fmt" "time" "github.com/yaoapp/gou/model" "github.com/yaoapp/kun/maps" "github.com/yaoapp/yao/agent/robot/types" "github.com/yaoapp/yao/agent/robot/utils" ) // RobotRecord - persistent storage for robot member // Maps to __yao.member model type RobotRecord struct { ID int64 `json:"id,omitempty"` // Auto-increment primary key MemberID string `json:"member_id"` // Unique robot identifier TeamID string `json:"team_id"` // Team ID MemberType string `json:"member_type"` // Always "robot" for robots Status string `json:"status"` // Member status: active | inactive | pending | suspended RobotStatus string `json:"robot_status"` // Robot status: idle | working | paused | error | maintenance AutonomousMode bool `json:"autonomous_mode"` // Whether autonomous mode is enabled // Profile DisplayName string `json:"display_name"` // Display name Bio string `json:"bio,omitempty"` // Robot description Avatar string `json:"avatar,omitempty"` // Identity & Role SystemPrompt string `json:"system_prompt"` // System prompt RoleID string `json:"role_id"` // Role within team ManagerID string `json:"manager_id"` // Direct manager user_id (who manages this robot) // Communication RobotEmail string `json:"robot_email"` // Robot email address AuthorizedSenders interface{} `json:"authorized_senders,omitempty"` // Email whitelist (JSON array) EmailFilterRules interface{} `json:"email_filter_rules,omitempty"` // Email filter rules (JSON array) // Capabilities RobotConfig interface{} `json:"robot_config"` // Robot config JSON Agents interface{} `json:"agents,omitempty"` // Accessible agents (JSON array) MCPServers interface{} `json:"mcp_servers,omitempty"` // MCP servers (JSON array) LanguageModel string `json:"language_model,omitempty"` // Language model name // Limits CostLimit float64 `json:"cost_limit,omitempty"` // Monthly cost limit USD // Ownership & Audit InvitedBy string `json:"invited_by,omitempty"` // Who created/added this robot JoinedAt *time.Time `json:"joined_at,omitempty"` // When robot was created // Timestamps CreatedAt *time.Time `json:"created_at,omitempty"` UpdatedAt *time.Time `json:"updated_at,omitempty"` // Yao Permission Fields (automatically handled by Yao model when permission:true) // These fields are passed through to the model layer for permission control YaoCreatedBy string `json:"__yao_created_by,omitempty"` // Creator user_id (set on create) YaoUpdatedBy string `json:"__yao_updated_by,omitempty"` // Updater user_id (set on update) YaoTeamID string `json:"__yao_team_id,omitempty"` // Permission team scope YaoTenantID string `json:"__yao_tenant_id,omitempty"` // Permission tenant scope } // RobotListOptions - options for listing robot records type RobotListOptions struct { TeamID string `json:"team_id,omitempty"` Status types.RobotStatus `json:"status,omitempty"` Keywords string `json:"keywords,omitempty"` // Search in display_name Limit int `json:"limit,omitempty"` Offset int `json:"offset,omitempty"` Page int `json:"page,omitempty"` PageSize int `json:"pagesize,omitempty"` OrderBy string `json:"order_by,omitempty"` } // RobotStore - persistent storage for robot members type RobotStore struct { modelID string } // NewRobotStore creates a new robot store instance func NewRobotStore() *RobotStore { return &RobotStore{ modelID: "__yao.member", } } // robotFields are the fields to select when loading robots var robotFields = []interface{}{ // Basic "id", "member_id", "team_id", "member_type", "status", "robot_status", "autonomous_mode", // Profile "display_name", "bio", "avatar", // Identity & Role "system_prompt", "role_id", "manager_id", // Communication "robot_email", "authorized_senders", "email_filter_rules", // Capabilities "robot_config", "agents", "mcp_servers", "language_model", // Limits "cost_limit", // Ownership & Audit "invited_by", "joined_at", // Timestamps "created_at", "updated_at", // Yao Permission Fields (for access control) "__yao_created_by", "__yao_updated_by", "__yao_team_id", "__yao_tenant_id", } // Save creates or updates a robot member record func (s *RobotStore) Save(ctx context.Context, record *RobotRecord) error { mod := model.Select(s.modelID) if mod == nil { return fmt.Errorf("model %s not found", s.modelID) } // Ensure member_type is robot record.MemberType = "robot" data := s.recordToMap(record) // Check if record exists by member_id existing, err := s.Get(ctx, record.MemberID) if err == nil && existing != nil { // Update existing record _, err = mod.UpdateWhere( model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "member_id", Value: record.MemberID}, }, }, data, ) if err != nil { return fmt.Errorf("failed to update robot record: %w", err) } return nil } // Create new record _, err = mod.Create(data) if err != nil { return fmt.Errorf("failed to create robot record: %w", err) } return nil } // Get retrieves a robot record by member_id func (s *RobotStore) Get(ctx context.Context, memberID string) (*RobotRecord, error) { mod := model.Select(s.modelID) if mod == nil { return nil, fmt.Errorf("model %s not found", s.modelID) } rows, err := mod.Get(model.QueryParam{ Select: robotFields, Wheres: []model.QueryWhere{ {Column: "member_id", Value: memberID}, {Column: "member_type", Value: "robot"}, }, Limit: 1, }) if err != nil { return nil, fmt.Errorf("failed to get robot record: %w", err) } if len(rows) == 0 { return nil, nil } return s.mapToRecord(rows[0]) } // List retrieves robot records with filters func (s *RobotStore) List(ctx context.Context, opts *RobotListOptions) ([]*RobotRecord, int, error) { mod := model.Select(s.modelID) if mod == nil { return nil, 0, fmt.Errorf("model %s not found", s.modelID) } // Build where conditions - only require member_type=robot wheres := []model.QueryWhere{ {Column: "member_type", Value: "robot"}, } if opts != nil { if opts.TeamID != "" { wheres = append(wheres, model.QueryWhere{Column: "team_id", Value: opts.TeamID}) } if opts.Status != "" { wheres = append(wheres, model.QueryWhere{Column: "robot_status", Value: string(opts.Status)}) } if opts.Keywords != "" { wheres = append(wheres, model.QueryWhere{ Column: "display_name", OP: "like", Value: "%" + opts.Keywords + "%", }) } } // Build order orders := []model.QueryOrder{} if opts != nil && opts.OrderBy != "" { orders = append(orders, model.QueryOrder{Column: opts.OrderBy}) } else { orders = append(orders, model.QueryOrder{Column: "created_at", Option: "desc"}) } // Determine pagination page := 1 pageSize := 100 if opts != nil { if opts.Page > 0 { page = opts.Page } if opts.PageSize > 0 { pageSize = opts.PageSize } // Limit overrides PageSize for simple limit queries if opts.Limit > 0 { pageSize = opts.Limit } } // Execute paginated query result, err := mod.Paginate(model.QueryParam{ Select: robotFields, Wheres: wheres, Orders: orders, }, page, pageSize) if err != nil { return nil, 0, fmt.Errorf("failed to list robots: %w", err) } // Get total count total := 0 if t, ok := result.Get("total").(int); ok { total = t } // Parse records records := []*RobotRecord{} data := result.Get("data") switch rows := data.(type) { case []maps.MapStr: for _, row := range rows { record, err := s.mapToRecord(map[string]interface{}(row)) if err != nil { continue // skip invalid records } records = append(records, record) } case []map[string]interface{}: for _, row := range rows { record, err := s.mapToRecord(row) if err != nil { continue // skip invalid records } records = append(records, record) } } return records, total, nil } // Delete removes a robot member by member_id func (s *RobotStore) Delete(ctx context.Context, memberID string) error { mod := model.Select(s.modelID) if mod == nil { return fmt.Errorf("model %s not found", s.modelID) } _, err := mod.DeleteWhere(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "member_id", Value: memberID}, {Column: "member_type", Value: "robot"}, }, }) if err != nil { return fmt.Errorf("failed to delete robot record: %w", err) } return nil } // UpdateConfig updates only the robot_config field func (s *RobotStore) UpdateConfig(ctx context.Context, memberID string, config interface{}) error { mod := model.Select(s.modelID) if mod == nil { return fmt.Errorf("model %s not found", s.modelID) } data := map[string]interface{}{ "robot_config": config, } _, err := mod.UpdateWhere( model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "member_id", Value: memberID}, {Column: "member_type", Value: "robot"}, }, }, data, ) if err != nil { return fmt.Errorf("failed to update robot config: %w", err) } return nil } // UpdateStatus updates the robot_status field func (s *RobotStore) UpdateStatus(ctx context.Context, memberID string, status types.RobotStatus) error { mod := model.Select(s.modelID) if mod == nil { return fmt.Errorf("model %s not found", s.modelID) } data := map[string]interface{}{ "robot_status": string(status), } _, err := mod.UpdateWhere( model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "member_id", Value: memberID}, {Column: "member_type", Value: "robot"}, }, }, data, ) if err != nil { return fmt.Errorf("failed to update robot status: %w", err) } return nil } // recordToMap converts RobotRecord to map for model operations func (s *RobotStore) recordToMap(record *RobotRecord) map[string]interface{} { data := map[string]interface{}{ // Required fields "member_id": record.MemberID, "team_id": record.TeamID, "member_type": "robot", "autonomous_mode": record.AutonomousMode, } // Status if record.Status != "" { data["status"] = record.Status } else { data["status"] = "active" } if record.RobotStatus != "" { data["robot_status"] = record.RobotStatus } else { data["robot_status"] = "idle" } // Profile if record.DisplayName != "" { data["display_name"] = record.DisplayName } if record.Bio != "" { data["bio"] = record.Bio } if record.Avatar != "" { data["avatar"] = record.Avatar } // Identity & Role if record.SystemPrompt != "" { data["system_prompt"] = record.SystemPrompt } if record.RoleID != "" { data["role_id"] = record.RoleID } if record.ManagerID != "" { data["manager_id"] = record.ManagerID } // Communication if record.RobotEmail != "" { data["robot_email"] = record.RobotEmail } if record.AuthorizedSenders != nil { data["authorized_senders"] = record.AuthorizedSenders } if record.EmailFilterRules != nil { data["email_filter_rules"] = record.EmailFilterRules } // Capabilities if record.RobotConfig != nil { data["robot_config"] = record.RobotConfig } if record.Agents != nil { data["agents"] = record.Agents } if record.MCPServers != nil { data["mcp_servers"] = record.MCPServers } if record.LanguageModel != "" { data["language_model"] = record.LanguageModel } // Limits if record.CostLimit > 0 { data["cost_limit"] = record.CostLimit } // Ownership & Audit if record.InvitedBy != "" { data["invited_by"] = record.InvitedBy } if record.JoinedAt != nil { // Format time for Gou model (expects string format) data["joined_at"] = record.JoinedAt.Format("2006-01-02 15:04:05") } // Yao Permission Fields - pass through for model layer if record.YaoCreatedBy != "" { data["__yao_created_by"] = record.YaoCreatedBy } if record.YaoUpdatedBy != "" { data["__yao_updated_by"] = record.YaoUpdatedBy } if record.YaoTeamID != "" { data["__yao_team_id"] = record.YaoTeamID } if record.YaoTenantID != "" { data["__yao_tenant_id"] = record.YaoTenantID } return data } // mapToRecord converts a model row to RobotRecord func (s *RobotStore) mapToRecord(row map[string]interface{}) (*RobotRecord, error) { record := &RobotRecord{} // Basic fields if v, ok := row["id"]; ok { switch id := v.(type) { case float64: record.ID = int64(id) case int64: record.ID = id case int: record.ID = int64(id) } } if v, ok := row["member_id"].(string); ok { record.MemberID = v } if v, ok := row["team_id"].(string); ok { record.TeamID = v } if v, ok := row["member_type"].(string); ok { record.MemberType = v } if v, ok := row["status"].(string); ok { record.Status = v } if v, ok := row["robot_status"].(string); ok { record.RobotStatus = v } if v, ok := row["autonomous_mode"]; ok { record.AutonomousMode = utils.ToBool(v) } // Profile if v, ok := row["display_name"].(string); ok { record.DisplayName = v } if v, ok := row["bio"].(string); ok { record.Bio = v } if v, ok := row["avatar"].(string); ok { record.Avatar = v } // Identity & Role if v, ok := row["system_prompt"].(string); ok { record.SystemPrompt = v } if v, ok := row["role_id"].(string); ok { record.RoleID = v } if v, ok := row["manager_id"].(string); ok { record.ManagerID = v } // Communication if v, ok := row["robot_email"].(string); ok { record.RobotEmail = v } if v := row["authorized_senders"]; v != nil { record.AuthorizedSenders = utils.ToJSONValue(v) } if v := row["email_filter_rules"]; v != nil { record.EmailFilterRules = utils.ToJSONValue(v) } // Capabilities if v := row["robot_config"]; v != nil { record.RobotConfig = utils.ToJSONValue(v) } if v := row["agents"]; v != nil { record.Agents = utils.ToJSONValue(v) } if v := row["mcp_servers"]; v != nil { record.MCPServers = utils.ToJSONValue(v) } if v, ok := row["language_model"].(string); ok { record.LanguageModel = v } // Limits if v := row["cost_limit"]; v != nil { record.CostLimit = utils.ToFloat64(v) } // Ownership & Audit if v, ok := row["invited_by"].(string); ok { record.InvitedBy = v } if v := row["joined_at"]; v != nil { record.JoinedAt = utils.ToTimestamp(v) } // Timestamps if v := row["created_at"]; v != nil { record.CreatedAt = utils.ToTimestamp(v) } if v := row["updated_at"]; v != nil { record.UpdatedAt = utils.ToTimestamp(v) } // Yao Permission Fields if v, ok := row["__yao_created_by"].(string); ok { record.YaoCreatedBy = v } if v, ok := row["__yao_updated_by"].(string); ok { record.YaoUpdatedBy = v } if v, ok := row["__yao_team_id"].(string); ok { record.YaoTeamID = v } if v, ok := row["__yao_tenant_id"].(string); ok { record.YaoTenantID = v } return record, nil } // ToRobot converts a RobotRecord to types.Robot func (r *RobotRecord) ToRobot() (*types.Robot, error) { robot := &types.Robot{ MemberID: r.MemberID, TeamID: r.TeamID, DisplayName: r.DisplayName, Bio: r.Bio, SystemPrompt: r.SystemPrompt, AutonomousMode: r.AutonomousMode, RobotEmail: r.RobotEmail, } // Parse robot_status if r.RobotStatus != "" { robot.Status = types.RobotStatus(r.RobotStatus) } else { robot.Status = types.RobotIdle } // Parse robot_config if r.RobotConfig != nil { config, err := types.ParseConfig(r.RobotConfig) if err != nil { return nil, fmt.Errorf("failed to parse robot_config: %w", err) } robot.Config = config } // Ensure Config exists for merging agents/mcp_servers if robot.Config == nil { robot.Config = &types.Config{} } if robot.Config.Resources == nil { robot.Config.Resources = &types.Resources{} } // Merge agents from member table into Config.Resources.Agents if r.Agents != nil { agents := parseStringSlice(r.Agents) if len(agents) > 0 { robot.Config.Resources.Agents = agents } } // Merge mcp_servers from member table into Config.Resources.MCP if r.MCPServers != nil { mcpServers := parseStringSlice(r.MCPServers) if len(mcpServers) > 0 { // Convert string slice to MCPConfig slice (each server ID becomes an MCPConfig) for _, serverID := range mcpServers { robot.Config.Resources.MCP = append(robot.Config.Resources.MCP, types.MCPConfig{ ID: serverID, // Tools empty means all tools available }) } } } return robot, nil } // parseStringSlice converts interface{} to []string func parseStringSlice(v interface{}) []string { if v == nil { return nil } switch val := v.(type) { case []string: return val case []interface{}: result := make([]string, 0, len(val)) for _, item := range val { if s, ok := item.(string); ok { result = append(result, s) } } return result } return nil } // FromRobot creates a RobotRecord from types.Robot func FromRobot(robot *types.Robot) *RobotRecord { record := &RobotRecord{ MemberID: robot.MemberID, TeamID: robot.TeamID, DisplayName: robot.DisplayName, Bio: robot.Bio, SystemPrompt: robot.SystemPrompt, RobotStatus: string(robot.Status), AutonomousMode: robot.AutonomousMode, RobotEmail: robot.RobotEmail, MemberType: "robot", Status: "active", } if robot.Config != nil { record.RobotConfig = robot.Config } return record } ================================================ FILE: agent/robot/store/robot_test.go ================================================ package store_test import ( "context" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/gou/model" "github.com/yaoapp/yao/agent/robot/store" "github.com/yaoapp/yao/agent/testutils" ) // TestRobotStoreSave tests creating and updating robot records func TestRobotStoreSave(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) defer cleanupTestRobots(t) s := store.NewRobotStore() ctx := context.Background() t.Run("creates_new_robot_record", func(t *testing.T) { now := time.Now() record := &store.RobotRecord{ MemberID: "robot_test_save_001", TeamID: "team_test_001", DisplayName: "Test Robot 001", Bio: "A test robot for save operations", SystemPrompt: "You are a helpful assistant", Status: "active", RobotStatus: "idle", AutonomousMode: true, RobotEmail: "robot001@test.com", JoinedAt: &now, } err := s.Save(ctx, record) require.NoError(t, err) // Verify it was created saved, err := s.Get(ctx, "robot_test_save_001") require.NoError(t, err) require.NotNil(t, saved) assert.Equal(t, "robot_test_save_001", saved.MemberID) assert.Equal(t, "team_test_001", saved.TeamID) assert.Equal(t, "Test Robot 001", saved.DisplayName) assert.Equal(t, "A test robot for save operations", saved.Bio) assert.Equal(t, "You are a helpful assistant", saved.SystemPrompt) assert.Equal(t, "active", saved.Status) assert.Equal(t, "idle", saved.RobotStatus) assert.True(t, saved.AutonomousMode) assert.Equal(t, "robot001@test.com", saved.RobotEmail) assert.Equal(t, "robot", saved.MemberType) assert.NotNil(t, saved.JoinedAt) }) t.Run("updates_existing_robot_record", func(t *testing.T) { // First create a record record := &store.RobotRecord{ MemberID: "robot_test_save_002", TeamID: "team_test_002", DisplayName: "Original Name", Status: "active", RobotStatus: "idle", } err := s.Save(ctx, record) require.NoError(t, err) // Update the record record.DisplayName = "Updated Name" record.Bio = "Updated bio" record.RobotStatus = "working" err = s.Save(ctx, record) require.NoError(t, err) // Verify the update saved, err := s.Get(ctx, "robot_test_save_002") require.NoError(t, err) require.NotNil(t, saved) assert.Equal(t, "Updated Name", saved.DisplayName) assert.Equal(t, "Updated bio", saved.Bio) assert.Equal(t, "working", saved.RobotStatus) }) t.Run("saves_robot_with_config", func(t *testing.T) { record := &store.RobotRecord{ MemberID: "robot_test_save_003", TeamID: "team_test_003", DisplayName: "Robot with Config", Status: "active", RobotStatus: "idle", RobotConfig: map[string]interface{}{ "clock_mode": "on", "max_concurrent": 3, "timeout_seconds": 300, }, } err := s.Save(ctx, record) require.NoError(t, err) saved, err := s.Get(ctx, "robot_test_save_003") require.NoError(t, err) require.NotNil(t, saved) assert.NotNil(t, saved.RobotConfig) }) t.Run("saves_robot_with_permission_fields", func(t *testing.T) { record := &store.RobotRecord{ MemberID: "robot_test_save_004", TeamID: "team_test_004", DisplayName: "Robot with Perms", Status: "active", RobotStatus: "idle", YaoCreatedBy: "user_001", YaoTeamID: "team_001", YaoTenantID: "tenant_001", } err := s.Save(ctx, record) require.NoError(t, err) // Yao permission fields are handled by the model layer saved, err := s.Get(ctx, "robot_test_save_004") require.NoError(t, err) require.NotNil(t, saved) }) } // TestRobotStoreGet tests retrieving robot records func TestRobotStoreGet(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) defer cleanupTestRobots(t) s := store.NewRobotStore() ctx := context.Background() // Create a test record setupTestRobot(t, s, ctx) t.Run("returns_existing_record", func(t *testing.T) { record, err := s.Get(ctx, "robot_test_get_001") require.NoError(t, err) require.NotNil(t, record) assert.Equal(t, "robot_test_get_001", record.MemberID) assert.Equal(t, "team_test_get", record.TeamID) assert.Equal(t, "Test Robot Get", record.DisplayName) assert.Equal(t, "Test robot description", record.Bio) assert.Equal(t, "robot", record.MemberType) assert.Equal(t, "active", record.Status) assert.Equal(t, "idle", record.RobotStatus) }) t.Run("returns_nil_for_non_existent_record", func(t *testing.T) { record, err := s.Get(ctx, "robot_non_existent") require.NoError(t, err) assert.Nil(t, record) }) t.Run("ignores_non_robot_members", func(t *testing.T) { // Get should only return member_type="robot" records record, err := s.Get(ctx, "robot_test_get_001") require.NoError(t, err) require.NotNil(t, record) assert.Equal(t, "robot", record.MemberType) }) } // TestRobotStoreList tests listing robot records with filters func TestRobotStoreList(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) defer cleanupTestRobots(t) s := store.NewRobotStore() ctx := context.Background() // Create multiple test records setupTestRobotsForList(t, s, ctx) t.Run("lists_all_robot_records", func(t *testing.T) { // List with keywords filter to only get our test records // Test robots have display names like "Robot Alpha", "Robot Beta", etc. records, total, err := s.List(ctx, &store.RobotListOptions{ Keywords: "Robot", }) require.NoError(t, err) // Should find at least our 4 test robots assert.GreaterOrEqual(t, len(records), 4) assert.GreaterOrEqual(t, total, 4) }) t.Run("filters_by_team_id", func(t *testing.T) { records, total, err := s.List(ctx, &store.RobotListOptions{ TeamID: "team_list_001", }) require.NoError(t, err) assert.Equal(t, 2, len(records)) assert.Equal(t, 2, total) for _, r := range records { assert.Equal(t, "team_list_001", r.TeamID) } }) t.Run("filters_by_robot_status", func(t *testing.T) { records, _, err := s.List(ctx, &store.RobotListOptions{ Status: "working", }) require.NoError(t, err) assert.GreaterOrEqual(t, len(records), 1) for _, r := range records { assert.Equal(t, "working", r.RobotStatus) } }) t.Run("filters_by_keywords", func(t *testing.T) { records, _, err := s.List(ctx, &store.RobotListOptions{ Keywords: "Alpha", }) require.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Contains(t, records[0].DisplayName, "Alpha") }) t.Run("respects_pagination", func(t *testing.T) { records, total, err := s.List(ctx, &store.RobotListOptions{ Page: 1, PageSize: 2, }) require.NoError(t, err) assert.Equal(t, 2, len(records)) assert.GreaterOrEqual(t, total, 4) // total count should be full count }) t.Run("respects_limit", func(t *testing.T) { records, _, err := s.List(ctx, &store.RobotListOptions{ Limit: 2, }) require.NoError(t, err) assert.Equal(t, 2, len(records)) }) t.Run("combines_multiple_filters", func(t *testing.T) { records, total, err := s.List(ctx, &store.RobotListOptions{ TeamID: "team_list_001", Status: "idle", }) require.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 1, total) assert.Equal(t, "team_list_001", records[0].TeamID) assert.Equal(t, "idle", records[0].RobotStatus) }) } // TestRobotStoreDelete tests deleting robot records func TestRobotStoreDelete(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) defer cleanupTestRobots(t) s := store.NewRobotStore() ctx := context.Background() t.Run("deletes_existing_record", func(t *testing.T) { // Create a record record := &store.RobotRecord{ MemberID: "robot_test_delete_001", TeamID: "team_delete_001", DisplayName: "Robot to Delete", Status: "active", RobotStatus: "idle", } err := s.Save(ctx, record) require.NoError(t, err) // Verify it exists saved, err := s.Get(ctx, "robot_test_delete_001") require.NoError(t, err) require.NotNil(t, saved) // Delete it err = s.Delete(ctx, "robot_test_delete_001") require.NoError(t, err) // Verify it's gone saved, err = s.Get(ctx, "robot_test_delete_001") require.NoError(t, err) assert.Nil(t, saved) }) t.Run("no_error_for_non_existent_record", func(t *testing.T) { err := s.Delete(ctx, "robot_non_existent") assert.NoError(t, err) }) } // TestRobotStoreUpdateConfig tests updating robot config func TestRobotStoreUpdateConfig(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) defer cleanupTestRobots(t) s := store.NewRobotStore() ctx := context.Background() // Create a base record record := &store.RobotRecord{ MemberID: "robot_test_config_001", TeamID: "team_config_001", DisplayName: "Config Test Robot", Status: "active", RobotStatus: "idle", RobotConfig: map[string]interface{}{ "clock_mode": "off", }, } err := s.Save(ctx, record) require.NoError(t, err) t.Run("updates_config_only", func(t *testing.T) { newConfig := map[string]interface{}{ "clock_mode": "on", "max_concurrent": 5, "timeout_seconds": 600, } err := s.UpdateConfig(ctx, "robot_test_config_001", newConfig) require.NoError(t, err) saved, err := s.Get(ctx, "robot_test_config_001") require.NoError(t, err) require.NotNil(t, saved) assert.NotNil(t, saved.RobotConfig) // Display name should be unchanged assert.Equal(t, "Config Test Robot", saved.DisplayName) }) } // TestRobotStoreUpdateStatus tests updating robot status func TestRobotStoreUpdateStatus(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) cleanupTestRobots(t) defer cleanupTestRobots(t) s := store.NewRobotStore() ctx := context.Background() // Create a base record record := &store.RobotRecord{ MemberID: "robot_test_status_001", TeamID: "team_status_001", DisplayName: "Status Test Robot", Status: "active", RobotStatus: "idle", } err := s.Save(ctx, record) require.NoError(t, err) t.Run("updates_robot_status", func(t *testing.T) { err := s.UpdateStatus(ctx, "robot_test_status_001", "working") require.NoError(t, err) saved, err := s.Get(ctx, "robot_test_status_001") require.NoError(t, err) require.NotNil(t, saved) assert.Equal(t, "working", saved.RobotStatus) // Display name should be unchanged assert.Equal(t, "Status Test Robot", saved.DisplayName) }) t.Run("updates_to_paused", func(t *testing.T) { err := s.UpdateStatus(ctx, "robot_test_status_001", "paused") require.NoError(t, err) saved, err := s.Get(ctx, "robot_test_status_001") require.NoError(t, err) assert.Equal(t, "paused", saved.RobotStatus) }) t.Run("updates_to_error", func(t *testing.T) { err := s.UpdateStatus(ctx, "robot_test_status_001", "error") require.NoError(t, err) saved, err := s.Get(ctx, "robot_test_status_001") require.NoError(t, err) assert.Equal(t, "error", saved.RobotStatus) }) } // TestRobotRecordConversion tests conversion between RobotRecord and Robot types func TestRobotRecordConversion(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) t.Run("converts_record_to_robot", func(t *testing.T) { now := time.Now() record := &store.RobotRecord{ MemberID: "robot_convert_001", TeamID: "team_convert_001", DisplayName: "Conversion Test Robot", Bio: "Test description", SystemPrompt: "You are helpful", Status: "active", RobotStatus: "idle", AutonomousMode: true, RobotEmail: "convert@test.com", JoinedAt: &now, RobotConfig: map[string]interface{}{ "clock_mode": "on", }, } robot, err := record.ToRobot() require.NoError(t, err) require.NotNil(t, robot) assert.Equal(t, "robot_convert_001", robot.MemberID) assert.Equal(t, "team_convert_001", robot.TeamID) assert.Equal(t, "Conversion Test Robot", robot.DisplayName) assert.Equal(t, "Test description", robot.Bio) assert.Equal(t, "You are helpful", robot.SystemPrompt) assert.True(t, robot.AutonomousMode) assert.Equal(t, "convert@test.com", robot.RobotEmail) }) t.Run("converts_robot_to_record", func(t *testing.T) { robot := &store.RobotRecord{ MemberID: "robot_from_001", TeamID: "team_from_001", DisplayName: "From Robot Test", Bio: "From robot description", SystemPrompt: "System prompt", RobotStatus: "working", AutonomousMode: false, RobotEmail: "from@test.com", } // ToRobot and verify converted, err := robot.ToRobot() require.NoError(t, err) assert.Equal(t, "robot_from_001", converted.MemberID) assert.Equal(t, "team_from_001", converted.TeamID) assert.Equal(t, "From Robot Test", converted.DisplayName) }) } // Helper functions func cleanupTestRobots(t *testing.T) { mod := model.Select("__yao.member") if mod == nil { return } // Delete all test robot records _, err := mod.DeleteWhere(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "member_id", OP: "like", Value: "robot_test_%"}, {Column: "member_type", Value: "robot"}, }, }) if err != nil { t.Logf("Warning: failed to cleanup test robots: %v", err) } } func setupTestRobot(t *testing.T, s *store.RobotStore, ctx context.Context) { now := time.Now() record := &store.RobotRecord{ MemberID: "robot_test_get_001", TeamID: "team_test_get", DisplayName: "Test Robot Get", Bio: "Test robot description", SystemPrompt: "You are a test assistant", Status: "active", RobotStatus: "idle", AutonomousMode: false, RobotEmail: "test@robot.com", JoinedAt: &now, } err := s.Save(ctx, record) require.NoError(t, err) } func setupTestRobotsForList(t *testing.T, s *store.RobotStore, ctx context.Context) { now := time.Now() records := []*store.RobotRecord{ { MemberID: "robot_test_list_001", TeamID: "team_list_001", DisplayName: "Robot Alpha", Status: "active", RobotStatus: "idle", JoinedAt: &now, }, { MemberID: "robot_test_list_002", TeamID: "team_list_001", DisplayName: "Robot Beta", Status: "active", RobotStatus: "working", JoinedAt: &now, }, { MemberID: "robot_test_list_003", TeamID: "team_list_002", DisplayName: "Robot Gamma", Status: "active", RobotStatus: "idle", JoinedAt: &now, }, { MemberID: "robot_test_list_004", TeamID: "team_list_002", DisplayName: "Robot Delta", Status: "inactive", RobotStatus: "paused", JoinedAt: &now, }, } for _, record := range records { err := s.Save(ctx, record) require.NoError(t, err) } } ================================================ FILE: agent/robot/store/store.go ================================================ package store import "github.com/yaoapp/yao/agent/robot/types" // Store implements types.Store interface // This is a stub implementation for Phase 2 type Store struct{} // New creates a new store instance func New() *Store { return &Store{} } // SaveLearning saves learning entries to private KB // Stub: returns nil (will be implemented in Phase 9) func (s *Store) SaveLearning(ctx *types.Context, memberID string, entries []types.LearningEntry) error { return nil } // GetHistory retrieves learning history from private KB // Stub: returns empty slice (will be implemented in Phase 9) func (s *Store) GetHistory(ctx *types.Context, memberID string, limit int) ([]types.LearningEntry, error) { return []types.LearningEntry{}, nil } // SearchKB searches knowledge base collections // Stub: returns empty slice (will be implemented in Phase 4+) func (s *Store) SearchKB(ctx *types.Context, collections []string, query string) ([]interface{}, error) { return []interface{}{}, nil } // QueryDB queries database models // Stub: returns empty slice (will be implemented in Phase 4+) func (s *Store) QueryDB(ctx *types.Context, models []string, query interface{}) ([]interface{}, error) { return []interface{}{}, nil } ================================================ FILE: agent/robot/trigger/clock.go ================================================ package trigger import ( "time" "github.com/yaoapp/yao/agent/robot/types" ) // ClockMatcher provides clock trigger matching logic // This is extracted from Manager for reuse and testing type ClockMatcher struct{} // NewClockMatcher creates a new clock matcher func NewClockMatcher() *ClockMatcher { return &ClockMatcher{} } // ShouldTrigger checks if a robot should be triggered based on its clock config func (cm *ClockMatcher) ShouldTrigger(robot *types.Robot, now time.Time) bool { if robot == nil || robot.Config == nil || robot.Config.Clock == nil { return false } clock := robot.Config.Clock // Get time in robot's timezone loc := clock.GetLocation() localNow := now.In(loc) switch clock.Mode { case types.ClockTimes: return cm.shouldTriggerTimes(robot, clock, localNow) case types.ClockInterval: return cm.shouldTriggerInterval(robot, clock, localNow) case types.ClockDaemon: return cm.shouldTriggerDaemon(robot, clock, localNow) default: return false } } // shouldTriggerTimes checks if current time matches any configured times // times mode: run at specific times (e.g., ["09:00", "14:00", "17:00"]) func (cm *ClockMatcher) shouldTriggerTimes(robot *types.Robot, clock *types.Clock, now time.Time) bool { // Check day of week first if !cm.matchesDay(clock, now) { return false } // Check if current time matches any configured time currentTime := now.Format("15:04") for _, t := range clock.Times { if t == currentTime { // Check if already triggered in this minute if !robot.LastRun.IsZero() { lastRunInLoc := robot.LastRun.In(now.Location()) if lastRunInLoc.Format("15:04") == currentTime && lastRunInLoc.Day() == now.Day() { return false // Already triggered this minute today } } return true } } return false } // shouldTriggerInterval checks if enough time has passed since last run // interval mode: run every X duration (e.g., "30m", "2h") func (cm *ClockMatcher) shouldTriggerInterval(robot *types.Robot, clock *types.Clock, now time.Time) bool { interval, err := time.ParseDuration(clock.Every) if err != nil { return false } // First run if never executed if robot.LastRun.IsZero() { return true } // Check if interval has passed return now.Sub(robot.LastRun) >= interval } // shouldTriggerDaemon checks if robot can restart immediately after last run // daemon mode: restart immediately after each run completes func (cm *ClockMatcher) shouldTriggerDaemon(robot *types.Robot, clock *types.Clock, now time.Time) bool { // Daemon mode: trigger if not currently running // CanRun() checks if robot has available execution slots return robot.CanRun() } // matchesDay checks if current day matches the configured days func (cm *ClockMatcher) matchesDay(clock *types.Clock, now time.Time) bool { // Empty days or ["*"] means all days if len(clock.Days) == 0 { return true } for _, day := range clock.Days { if day == "*" { return true } // Match day name (Mon, Tue, Wed, Thu, Fri, Sat, Sun) // or full name (Monday, Tuesday, etc.) weekday := now.Weekday().String() shortDay := weekday[:3] // Mon, Tue, etc. if day == weekday || day == shortDay { return true } } return false } // ParseTime parses a time string in "HH:MM" format func ParseTime(timeStr string) (hour, minute int, err error) { t, err := time.Parse("15:04", timeStr) if err != nil { return 0, 0, err } return t.Hour(), t.Minute(), nil } // FormatTime formats hour and minute to "HH:MM" string func FormatTime(hour, minute int) string { return time.Date(0, 1, 1, hour, minute, 0, 0, time.UTC).Format("15:04") } ================================================ FILE: agent/robot/trigger/clock_test.go ================================================ package trigger_test import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/robot/trigger" "github.com/yaoapp/yao/agent/robot/types" ) // ==================== ClockMatcher Tests ==================== func TestClockMatcherShouldTrigger(t *testing.T) { cm := trigger.NewClockMatcher() t.Run("nil robot returns false", func(t *testing.T) { result := cm.ShouldTrigger(nil, time.Now()) assert.False(t, result) }) t.Run("nil config returns false", func(t *testing.T) { robot := &types.Robot{ MemberID: "robot_001", Config: nil, } result := cm.ShouldTrigger(robot, time.Now()) assert.False(t, result) }) t.Run("nil clock config returns false", func(t *testing.T) { robot := &types.Robot{ MemberID: "robot_001", Config: &types.Config{Clock: nil}, } result := cm.ShouldTrigger(robot, time.Now()) assert.False(t, result) }) } // ==================== Times Mode Tests ==================== func TestClockMatcherTimesMode(t *testing.T) { cm := trigger.NewClockMatcher() t.Run("matches configured time", func(t *testing.T) { robot := &types.Robot{ MemberID: "robot_001", Config: &types.Config{ Clock: &types.Clock{ Mode: types.ClockTimes, Times: []string{"09:00", "14:00", "17:00"}, Days: []string{"*"}, TZ: "UTC", }, }, } // Create time at 09:00 UTC now := time.Date(2025, 1, 15, 9, 0, 0, 0, time.UTC) result := cm.ShouldTrigger(robot, now) assert.True(t, result) }) t.Run("does not match non-configured time", func(t *testing.T) { robot := &types.Robot{ MemberID: "robot_001", Config: &types.Config{ Clock: &types.Clock{ Mode: types.ClockTimes, Times: []string{"09:00", "14:00", "17:00"}, Days: []string{"*"}, TZ: "UTC", }, }, } // Create time at 10:00 UTC (not in configured times) now := time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC) result := cm.ShouldTrigger(robot, now) assert.False(t, result) }) t.Run("respects day filter - weekday", func(t *testing.T) { robot := &types.Robot{ MemberID: "robot_001", Config: &types.Config{ Clock: &types.Clock{ Mode: types.ClockTimes, Times: []string{"09:00"}, Days: []string{"Mon", "Tue", "Wed", "Thu", "Fri"}, TZ: "UTC", }, }, } // Wednesday 09:00 - should trigger wed := time.Date(2025, 1, 15, 9, 0, 0, 0, time.UTC) // Wednesday assert.True(t, cm.ShouldTrigger(robot, wed)) // Saturday 09:00 - should NOT trigger sat := time.Date(2025, 1, 18, 9, 0, 0, 0, time.UTC) // Saturday assert.False(t, cm.ShouldTrigger(robot, sat)) }) t.Run("dedup - same minute same day should not trigger twice", func(t *testing.T) { robot := &types.Robot{ MemberID: "robot_001", Config: &types.Config{ Clock: &types.Clock{ Mode: types.ClockTimes, Times: []string{"09:00"}, Days: []string{"*"}, TZ: "UTC", }, }, } now := time.Date(2025, 1, 15, 9, 0, 0, 0, time.UTC) // First trigger - should succeed assert.True(t, cm.ShouldTrigger(robot, now)) // Simulate LastRun was set robot.LastRun = now // Second trigger same minute - should fail now2 := time.Date(2025, 1, 15, 9, 0, 30, 0, time.UTC) assert.False(t, cm.ShouldTrigger(robot, now2)) }) t.Run("different day should trigger again", func(t *testing.T) { robot := &types.Robot{ MemberID: "robot_001", Config: &types.Config{ Clock: &types.Clock{ Mode: types.ClockTimes, Times: []string{"09:00"}, Days: []string{"*"}, TZ: "UTC", }, }, } // First day day1 := time.Date(2025, 1, 15, 9, 0, 0, 0, time.UTC) robot.LastRun = day1 // Next day same time - should trigger day2 := time.Date(2025, 1, 16, 9, 0, 0, 0, time.UTC) assert.True(t, cm.ShouldTrigger(robot, day2)) }) } // ==================== Interval Mode Tests ==================== func TestClockMatcherIntervalMode(t *testing.T) { cm := trigger.NewClockMatcher() t.Run("first run triggers immediately", func(t *testing.T) { robot := &types.Robot{ MemberID: "robot_001", Config: &types.Config{ Clock: &types.Clock{ Mode: types.ClockInterval, Every: "30m", TZ: "UTC", }, }, } // LastRun is zero - should trigger now := time.Now() result := cm.ShouldTrigger(robot, now) assert.True(t, result) }) t.Run("triggers after interval passed", func(t *testing.T) { robot := &types.Robot{ MemberID: "robot_001", Config: &types.Config{ Clock: &types.Clock{ Mode: types.ClockInterval, Every: "30m", TZ: "UTC", }, }, } now := time.Now() robot.LastRun = now.Add(-31 * time.Minute) // 31 minutes ago result := cm.ShouldTrigger(robot, now) assert.True(t, result) }) t.Run("does not trigger before interval", func(t *testing.T) { robot := &types.Robot{ MemberID: "robot_001", Config: &types.Config{ Clock: &types.Clock{ Mode: types.ClockInterval, Every: "30m", TZ: "UTC", }, }, } now := time.Now() robot.LastRun = now.Add(-15 * time.Minute) // Only 15 minutes ago result := cm.ShouldTrigger(robot, now) assert.False(t, result) }) t.Run("invalid interval format returns false", func(t *testing.T) { robot := &types.Robot{ MemberID: "robot_001", Config: &types.Config{ Clock: &types.Clock{ Mode: types.ClockInterval, Every: "invalid", TZ: "UTC", }, }, } result := cm.ShouldTrigger(robot, time.Now()) assert.False(t, result) }) t.Run("various interval formats", func(t *testing.T) { intervals := []struct { every string lastAgo time.Duration expected bool }{ {"1h", 61 * time.Minute, true}, {"1h", 30 * time.Minute, false}, {"2h", 121 * time.Minute, true}, {"2h", 60 * time.Minute, false}, {"10s", 11 * time.Second, true}, {"10s", 5 * time.Second, false}, } for _, tt := range intervals { t.Run(tt.every, func(t *testing.T) { robot := &types.Robot{ MemberID: "robot_001", Config: &types.Config{ Clock: &types.Clock{ Mode: types.ClockInterval, Every: tt.every, TZ: "UTC", }, }, } now := time.Now() robot.LastRun = now.Add(-tt.lastAgo) result := cm.ShouldTrigger(robot, now) assert.Equal(t, tt.expected, result) }) } }) } // ==================== Daemon Mode Tests ==================== func TestClockMatcherDaemonMode(t *testing.T) { cm := trigger.NewClockMatcher() t.Run("triggers when robot can run", func(t *testing.T) { robot := &types.Robot{ MemberID: "robot_001", Config: &types.Config{ Clock: &types.Clock{ Mode: types.ClockDaemon, TZ: "UTC", }, Quota: &types.Quota{Max: 2}, }, } // No running executions - should trigger result := cm.ShouldTrigger(robot, time.Now()) assert.True(t, result) }) t.Run("does not trigger when at quota", func(t *testing.T) { robot := &types.Robot{ MemberID: "robot_001", Config: &types.Config{ Clock: &types.Clock{ Mode: types.ClockDaemon, TZ: "UTC", }, Quota: &types.Quota{Max: 1}, }, } // Add one execution to fill quota exec := &types.Execution{ID: "exec_001"} robot.AddExecution(exec) result := cm.ShouldTrigger(robot, time.Now()) assert.False(t, result) // Remove execution robot.RemoveExecution("exec_001") // Now should trigger result = cm.ShouldTrigger(robot, time.Now()) assert.True(t, result) }) } // ==================== Timezone Tests ==================== func TestClockMatcherTimezone(t *testing.T) { cm := trigger.NewClockMatcher() t.Run("respects timezone for times mode", func(t *testing.T) { // Robot configured for Asia/Shanghai (UTC+8) robot := &types.Robot{ MemberID: "robot_001", Config: &types.Config{ Clock: &types.Clock{ Mode: types.ClockTimes, Times: []string{"09:00"}, Days: []string{"*"}, TZ: "Asia/Shanghai", }, }, } // 01:00 UTC = 09:00 Shanghai - should trigger utc0100 := time.Date(2025, 1, 15, 1, 0, 0, 0, time.UTC) assert.True(t, cm.ShouldTrigger(robot, utc0100)) // 09:00 UTC = 17:00 Shanghai - should NOT trigger utc0900 := time.Date(2025, 1, 15, 9, 0, 0, 0, time.UTC) assert.False(t, cm.ShouldTrigger(robot, utc0900)) }) t.Run("invalid timezone falls back to local", func(t *testing.T) { robot := &types.Robot{ MemberID: "robot_001", Config: &types.Config{ Clock: &types.Clock{ Mode: types.ClockTimes, Times: []string{"09:00"}, Days: []string{"*"}, TZ: "Invalid/Timezone", }, }, } // Should still work with local time local0900 := time.Date(2025, 1, 15, 9, 0, 0, 0, time.Local) result := cm.ShouldTrigger(robot, local0900) // Result depends on local timezone, just verify no panic assert.IsType(t, true, result) }) } // ==================== ParseTime/FormatTime Tests ==================== func TestParseTime(t *testing.T) { t.Run("parses valid time", func(t *testing.T) { hour, minute, err := trigger.ParseTime("09:30") assert.NoError(t, err) assert.Equal(t, 9, hour) assert.Equal(t, 30, minute) }) t.Run("parses midnight", func(t *testing.T) { hour, minute, err := trigger.ParseTime("00:00") assert.NoError(t, err) assert.Equal(t, 0, hour) assert.Equal(t, 0, minute) }) t.Run("parses 23:59", func(t *testing.T) { hour, minute, err := trigger.ParseTime("23:59") assert.NoError(t, err) assert.Equal(t, 23, hour) assert.Equal(t, 59, minute) }) t.Run("invalid format returns error", func(t *testing.T) { // Note: time.Parse("15:04", "9:30") actually succeeds // Only truly invalid formats fail _, _, err := trigger.ParseTime("09:30:00") assert.Error(t, err) _, _, err = trigger.ParseTime("invalid") assert.Error(t, err) _, _, err = trigger.ParseTime("") assert.Error(t, err) }) } func TestFormatTime(t *testing.T) { tests := []struct { hour int minute int expected string }{ {9, 0, "09:00"}, {9, 30, "09:30"}, {0, 0, "00:00"}, {23, 59, "23:59"}, {14, 5, "14:05"}, } for _, tt := range tests { t.Run(tt.expected, func(t *testing.T) { result := trigger.FormatTime(tt.hour, tt.minute) assert.Equal(t, tt.expected, result) }) } } ================================================ FILE: agent/robot/trigger/control.go ================================================ package trigger import ( "context" "fmt" "sync" "time" "github.com/yaoapp/yao/agent/robot/types" ) // ExecutionController manages execution lifecycle (pause/resume/stop) type ExecutionController struct { executions map[string]*ControlledExecution mu sync.RWMutex } // ControlledExecution represents an execution that can be controlled type ControlledExecution struct { ID string MemberID string TeamID string Status types.ExecStatus Phase types.Phase StartTime time.Time PausedAt *time.Time // Control channels ctx context.Context cancel context.CancelFunc paused bool pauseMu sync.Mutex resumeCh chan struct{} // signaled (closed) when resumed } // NewExecutionController creates a new execution controller func NewExecutionController() *ExecutionController { return &ExecutionController{ executions: make(map[string]*ControlledExecution), } } // Track starts tracking an execution func (c *ExecutionController) Track(execID, memberID, teamID string) *ControlledExecution { c.mu.Lock() defer c.mu.Unlock() ctx, cancel := context.WithCancel(context.Background()) exec := &ControlledExecution{ ID: execID, MemberID: memberID, TeamID: teamID, Status: types.ExecRunning, Phase: types.PhaseInspiration, StartTime: time.Now(), ctx: ctx, cancel: cancel, paused: false, resumeCh: nil, // nil when not paused, created on pause } c.executions[execID] = exec return exec } // Untrack stops tracking an execution func (c *ExecutionController) Untrack(execID string) { c.mu.Lock() defer c.mu.Unlock() delete(c.executions, execID) } // Get returns a tracked execution func (c *ExecutionController) Get(execID string) *ControlledExecution { c.mu.RLock() defer c.mu.RUnlock() return c.executions[execID] } // List returns all tracked executions func (c *ExecutionController) List() []*ControlledExecution { c.mu.RLock() defer c.mu.RUnlock() result := make([]*ControlledExecution, 0, len(c.executions)) for _, exec := range c.executions { result = append(result, exec) } return result } // ListByMember returns all executions for a specific member func (c *ExecutionController) ListByMember(memberID string) []*ControlledExecution { c.mu.RLock() defer c.mu.RUnlock() var result []*ControlledExecution for _, exec := range c.executions { if exec.MemberID == memberID { result = append(result, exec) } } return result } // Pause pauses an execution func (c *ExecutionController) Pause(execID string) error { exec := c.Get(execID) if exec == nil { return fmt.Errorf("execution not found: %s", execID) } exec.pauseMu.Lock() defer exec.pauseMu.Unlock() if exec.paused { return fmt.Errorf("execution already paused: %s", execID) } exec.paused = true now := time.Now() exec.PausedAt = &now // Create a new resume channel that will be closed on resume exec.resumeCh = make(chan struct{}) return nil } // Resume resumes a paused execution func (c *ExecutionController) Resume(execID string) error { exec := c.Get(execID) if exec == nil { return fmt.Errorf("execution not found: %s", execID) } exec.pauseMu.Lock() defer exec.pauseMu.Unlock() if !exec.paused { return fmt.Errorf("execution not paused: %s", execID) } exec.paused = false exec.PausedAt = nil // Close the resume channel to signal resume to waiting goroutines if exec.resumeCh != nil { close(exec.resumeCh) exec.resumeCh = nil } return nil } // Stop stops an execution func (c *ExecutionController) Stop(execID string) error { c.mu.Lock() defer c.mu.Unlock() exec, ok := c.executions[execID] if !ok { return fmt.Errorf("execution not found: %s", execID) } // Cancel the context to signal stop if exec.cancel != nil { exec.cancel() } exec.Status = types.ExecCancelled // Remove from tracking delete(c.executions, execID) return nil } // ==================== ControlledExecution methods ==================== // IsPaused returns true if the execution is paused func (e *ControlledExecution) IsPaused() bool { e.pauseMu.Lock() defer e.pauseMu.Unlock() return e.paused } // IsCancelled returns true if the execution is cancelled func (e *ControlledExecution) IsCancelled() bool { select { case <-e.ctx.Done(): return true default: return false } } // Context returns the execution's context func (e *ControlledExecution) Context() context.Context { return e.ctx } // WaitIfPaused blocks until the execution is resumed or cancelled // Returns error if cancelled func (e *ControlledExecution) WaitIfPaused() error { e.pauseMu.Lock() paused := e.paused resumeCh := e.resumeCh e.pauseMu.Unlock() if !paused { return nil } // Safety check: if paused but resumeCh is nil (shouldn't happen in normal flow), // treat as not paused to avoid blocking forever on nil channel if resumeCh == nil { return nil } // resumeCh is created when paused and closed when resumed // Wait for resume signal or cancellation select { case <-e.ctx.Done(): return types.ErrExecutionCancelled case <-resumeCh: // Resume signal received, execution can continue return nil } } // CheckCancelled checks if the execution is cancelled and returns error if so func (e *ControlledExecution) CheckCancelled() error { if e.IsCancelled() { return types.ErrExecutionCancelled } return nil } // UpdatePhase updates the current phase func (e *ControlledExecution) UpdatePhase(phase types.Phase) { e.Phase = phase } // UpdateStatus updates the execution status func (e *ControlledExecution) UpdateStatus(status types.ExecStatus) { e.Status = status } ================================================ FILE: agent/robot/trigger/control_test.go ================================================ package trigger_test import ( "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/robot/trigger" "github.com/yaoapp/yao/agent/robot/types" ) // ==================== ExecutionController Tests ==================== func TestExecutionControllerTrack(t *testing.T) { t.Run("tracks new execution", func(t *testing.T) { ctrl := trigger.NewExecutionController() exec := ctrl.Track("exec_001", "robot_001", "team_001") assert.NotNil(t, exec) assert.Equal(t, "exec_001", exec.ID) assert.Equal(t, "robot_001", exec.MemberID) assert.Equal(t, "team_001", exec.TeamID) assert.Equal(t, types.ExecRunning, exec.Status) assert.False(t, exec.IsPaused()) assert.False(t, exec.IsCancelled()) }) t.Run("get tracked execution", func(t *testing.T) { ctrl := trigger.NewExecutionController() ctrl.Track("exec_001", "robot_001", "team_001") exec := ctrl.Get("exec_001") assert.NotNil(t, exec) assert.Equal(t, "exec_001", exec.ID) }) t.Run("get non-existent execution returns nil", func(t *testing.T) { ctrl := trigger.NewExecutionController() exec := ctrl.Get("non_existent") assert.Nil(t, exec) }) } func TestExecutionControllerList(t *testing.T) { t.Run("list all executions", func(t *testing.T) { ctrl := trigger.NewExecutionController() ctrl.Track("exec_001", "robot_001", "team_001") ctrl.Track("exec_002", "robot_002", "team_001") ctrl.Track("exec_003", "robot_001", "team_002") list := ctrl.List() assert.Len(t, list, 3) }) t.Run("list by member", func(t *testing.T) { ctrl := trigger.NewExecutionController() ctrl.Track("exec_001", "robot_001", "team_001") ctrl.Track("exec_002", "robot_002", "team_001") ctrl.Track("exec_003", "robot_001", "team_002") list := ctrl.ListByMember("robot_001") assert.Len(t, list, 2) list = ctrl.ListByMember("robot_002") assert.Len(t, list, 1) list = ctrl.ListByMember("robot_003") assert.Len(t, list, 0) }) } func TestExecutionControllerUntrack(t *testing.T) { t.Run("untrack removes execution", func(t *testing.T) { ctrl := trigger.NewExecutionController() ctrl.Track("exec_001", "robot_001", "team_001") assert.NotNil(t, ctrl.Get("exec_001")) ctrl.Untrack("exec_001") assert.Nil(t, ctrl.Get("exec_001")) }) t.Run("untrack non-existent does not panic", func(t *testing.T) { ctrl := trigger.NewExecutionController() assert.NotPanics(t, func() { ctrl.Untrack("non_existent") }) }) } // ==================== Pause/Resume Tests ==================== func TestExecutionControllerPause(t *testing.T) { t.Run("pause execution", func(t *testing.T) { ctrl := trigger.NewExecutionController() exec := ctrl.Track("exec_001", "robot_001", "team_001") err := ctrl.Pause("exec_001") assert.NoError(t, err) assert.True(t, exec.IsPaused()) assert.NotNil(t, exec.PausedAt) }) t.Run("pause non-existent returns error", func(t *testing.T) { ctrl := trigger.NewExecutionController() err := ctrl.Pause("non_existent") assert.Error(t, err) assert.Contains(t, err.Error(), "not found") }) t.Run("pause already paused returns error", func(t *testing.T) { ctrl := trigger.NewExecutionController() ctrl.Track("exec_001", "robot_001", "team_001") err := ctrl.Pause("exec_001") assert.NoError(t, err) err = ctrl.Pause("exec_001") assert.Error(t, err) assert.Contains(t, err.Error(), "already paused") }) } func TestExecutionControllerResume(t *testing.T) { t.Run("resume paused execution", func(t *testing.T) { ctrl := trigger.NewExecutionController() exec := ctrl.Track("exec_001", "robot_001", "team_001") ctrl.Pause("exec_001") assert.True(t, exec.IsPaused()) err := ctrl.Resume("exec_001") assert.NoError(t, err) assert.False(t, exec.IsPaused()) assert.Nil(t, exec.PausedAt) }) t.Run("resume non-existent returns error", func(t *testing.T) { ctrl := trigger.NewExecutionController() err := ctrl.Resume("non_existent") assert.Error(t, err) assert.Contains(t, err.Error(), "not found") }) t.Run("resume not paused returns error", func(t *testing.T) { ctrl := trigger.NewExecutionController() ctrl.Track("exec_001", "robot_001", "team_001") err := ctrl.Resume("exec_001") assert.Error(t, err) assert.Contains(t, err.Error(), "not paused") }) } // ==================== Stop Tests ==================== func TestExecutionControllerStop(t *testing.T) { t.Run("stop execution", func(t *testing.T) { ctrl := trigger.NewExecutionController() exec := ctrl.Track("exec_001", "robot_001", "team_001") err := ctrl.Stop("exec_001") assert.NoError(t, err) assert.True(t, exec.IsCancelled()) assert.Equal(t, types.ExecCancelled, exec.Status) // Should be removed from tracking assert.Nil(t, ctrl.Get("exec_001")) }) t.Run("stop non-existent returns error", func(t *testing.T) { ctrl := trigger.NewExecutionController() err := ctrl.Stop("non_existent") assert.Error(t, err) assert.Contains(t, err.Error(), "not found") }) } // ==================== ControlledExecution Methods Tests ==================== func TestControlledExecutionContext(t *testing.T) { t.Run("context is valid", func(t *testing.T) { ctrl := trigger.NewExecutionController() exec := ctrl.Track("exec_001", "robot_001", "team_001") ctx := exec.Context() assert.NotNil(t, ctx) // Context should not be done yet select { case <-ctx.Done(): t.Fatal("context should not be done") default: // OK } }) t.Run("context done after stop", func(t *testing.T) { ctrl := trigger.NewExecutionController() exec := ctrl.Track("exec_001", "robot_001", "team_001") ctx := exec.Context() ctrl.Stop("exec_001") select { case <-ctx.Done(): // OK case <-time.After(100 * time.Millisecond): t.Fatal("context should be done after stop") } }) } func TestControlledExecutionCheckCancelled(t *testing.T) { t.Run("not cancelled returns nil", func(t *testing.T) { ctrl := trigger.NewExecutionController() exec := ctrl.Track("exec_001", "robot_001", "team_001") err := exec.CheckCancelled() assert.NoError(t, err) }) t.Run("cancelled returns error", func(t *testing.T) { ctrl := trigger.NewExecutionController() exec := ctrl.Track("exec_001", "robot_001", "team_001") ctrl.Stop("exec_001") err := exec.CheckCancelled() assert.Error(t, err) assert.Equal(t, types.ErrExecutionCancelled, err) }) } func TestControlledExecutionUpdatePhase(t *testing.T) { ctrl := trigger.NewExecutionController() exec := ctrl.Track("exec_001", "robot_001", "team_001") assert.Equal(t, types.PhaseInspiration, exec.Phase) exec.UpdatePhase(types.PhaseGoals) assert.Equal(t, types.PhaseGoals, exec.Phase) exec.UpdatePhase(types.PhaseTasks) assert.Equal(t, types.PhaseTasks, exec.Phase) } func TestControlledExecutionUpdateStatus(t *testing.T) { ctrl := trigger.NewExecutionController() exec := ctrl.Track("exec_001", "robot_001", "team_001") assert.Equal(t, types.ExecRunning, exec.Status) exec.UpdateStatus(types.ExecCompleted) assert.Equal(t, types.ExecCompleted, exec.Status) exec.UpdateStatus(types.ExecFailed) assert.Equal(t, types.ExecFailed, exec.Status) } // ==================== WaitIfPaused Tests ==================== func TestControlledExecutionWaitIfPaused(t *testing.T) { t.Run("returns immediately if not paused", func(t *testing.T) { ctrl := trigger.NewExecutionController() exec := ctrl.Track("exec_001", "robot_001", "team_001") done := make(chan error) go func() { done <- exec.WaitIfPaused() }() select { case err := <-done: assert.NoError(t, err) case <-time.After(100 * time.Millisecond): t.Fatal("WaitIfPaused should return immediately when not paused") } }) t.Run("does not infinite loop when paused without resume", func(t *testing.T) { // This test verifies the fix for the infinite loop bug // where WaitIfPaused would spin if pauseCh was closed but paused remained true ctrl := trigger.NewExecutionController() exec := ctrl.Track("exec_001", "robot_001", "team_001") ctrl.Pause("exec_001") // Start WaitIfPaused in a goroutine done := make(chan error) go func() { done <- exec.WaitIfPaused() }() // Wait a bit - if there's an infinite loop, CPU would spike // The goroutine should be blocked, not spinning time.Sleep(100 * time.Millisecond) // Now stop the execution - this should unblock WaitIfPaused ctrl.Stop("exec_001") select { case err := <-done: // Should get cancellation error assert.Error(t, err) case <-time.After(200 * time.Millisecond): t.Fatal("WaitIfPaused should unblock after stop") } }) t.Run("rapid pause-resume-pause does not cause issues", func(t *testing.T) { // Test TOCTOU race condition handling ctrl := trigger.NewExecutionController() exec := ctrl.Track("exec_001", "robot_001", "team_001") // Pause first ctrl.Pause("exec_001") done := make(chan error) go func() { done <- exec.WaitIfPaused() }() // Rapid resume then pause again time.Sleep(10 * time.Millisecond) ctrl.Resume("exec_001") // WaitIfPaused should return (the original resumeCh was closed) select { case err := <-done: assert.NoError(t, err) case <-time.After(200 * time.Millisecond): t.Fatal("WaitIfPaused should return after resume") } }) t.Run("blocks when paused, resumes after resume", func(t *testing.T) { ctrl := trigger.NewExecutionController() exec := ctrl.Track("exec_001", "robot_001", "team_001") ctrl.Pause("exec_001") done := make(chan error) go func() { done <- exec.WaitIfPaused() }() // Should be blocked select { case <-done: t.Fatal("WaitIfPaused should block when paused") case <-time.After(50 * time.Millisecond): // OK, still blocked } // Resume ctrl.Resume("exec_001") // Should unblock select { case err := <-done: assert.NoError(t, err) case <-time.After(100 * time.Millisecond): t.Fatal("WaitIfPaused should unblock after resume") } }) t.Run("returns error when cancelled while paused", func(t *testing.T) { ctrl := trigger.NewExecutionController() exec := ctrl.Track("exec_001", "robot_001", "team_001") ctrl.Pause("exec_001") done := make(chan error) go func() { done <- exec.WaitIfPaused() }() // Should be blocked select { case <-done: t.Fatal("WaitIfPaused should block when paused") case <-time.After(50 * time.Millisecond): // OK, still blocked } // Stop instead of resume ctrl.Stop("exec_001") // Should unblock with error select { case err := <-done: assert.Error(t, err) assert.Equal(t, types.ErrExecutionCancelled, err) case <-time.After(100 * time.Millisecond): t.Fatal("WaitIfPaused should unblock after stop") } }) } // ==================== Concurrent Access Tests ==================== func TestExecutionControllerConcurrency(t *testing.T) { t.Run("concurrent track and list", func(t *testing.T) { ctrl := trigger.NewExecutionController() var wg sync.WaitGroup // Concurrent tracking for i := 0; i < 100; i++ { wg.Add(1) go func(id int) { defer wg.Done() ctrl.Track( "exec_"+string(rune('0'+id%10)), "robot_"+string(rune('0'+id%5)), "team_001", ) }(i) } // Concurrent listing for i := 0; i < 50; i++ { wg.Add(1) go func() { defer wg.Done() _ = ctrl.List() }() } wg.Wait() // No race conditions or panics }) t.Run("concurrent pause/resume", func(t *testing.T) { ctrl := trigger.NewExecutionController() ctrl.Track("exec_001", "robot_001", "team_001") var wg sync.WaitGroup // Concurrent pause/resume attempts for i := 0; i < 50; i++ { wg.Add(2) go func() { defer wg.Done() _ = ctrl.Pause("exec_001") }() go func() { defer wg.Done() _ = ctrl.Resume("exec_001") }() } wg.Wait() // No race conditions or panics }) } ================================================ FILE: agent/robot/trigger/trigger.go ================================================ // Package trigger provides trigger-related utilities and execution control // The main trigger logic is in the manager package. // This package provides: // - Validation functions for intervention and event requests // - Builder helpers for TriggerInput // - ExecutionController for pause/resume/stop // - ClockMatcher for clock trigger matching (reusable) package trigger import ( "fmt" "github.com/yaoapp/yao/agent/robot/types" ) // ValidateIntervention validates a human intervention request func ValidateIntervention(req *types.InterveneRequest) error { if req == nil { return fmt.Errorf("request is nil") } if req.MemberID == "" { return fmt.Errorf("member_id is required") } if !isValidAction(req.Action) { return fmt.Errorf("invalid action: %s", req.Action) } // Validate action-specific requirements switch req.Action { case types.ActionTaskAdd, types.ActionGoalAdd, types.ActionInstruct: // These actions require messages if len(req.Messages) == 0 { return fmt.Errorf("messages required for action: %s", req.Action) } case types.ActionPlanAdd: // Plan add requires plan_time if req.PlanTime == nil { return fmt.Errorf("plan_time required for action: plan.add") } } return nil } // ValidateEvent validates an event trigger request func ValidateEvent(req *types.EventRequest) error { if req == nil { return fmt.Errorf("request is nil") } if req.MemberID == "" { return fmt.Errorf("member_id is required") } if req.Source == "" { return fmt.Errorf("source is required") } if req.EventType == "" { return fmt.Errorf("event_type is required") } return nil } // BuildEventInput creates a TriggerInput from an event request func BuildEventInput(req *types.EventRequest) *types.TriggerInput { return &types.TriggerInput{ Source: types.EventSource(req.Source), EventType: req.EventType, Data: req.Data, } } // isValidAction checks if the intervention action is valid func isValidAction(action types.InterventionAction) bool { switch action { case types.ActionTaskAdd, types.ActionTaskCancel, types.ActionTaskUpdate, types.ActionGoalAdjust, types.ActionGoalAdd, types.ActionGoalComplete, types.ActionGoalCancel, types.ActionPlanAdd, types.ActionPlanRemove, types.ActionPlanUpdate, types.ActionInstruct: return true default: return false } } // GetActionCategory returns the category of an intervention action func GetActionCategory(action types.InterventionAction) string { switch action { case types.ActionTaskAdd, types.ActionTaskCancel, types.ActionTaskUpdate: return "task" case types.ActionGoalAdjust, types.ActionGoalAdd, types.ActionGoalComplete, types.ActionGoalCancel: return "goal" case types.ActionPlanAdd, types.ActionPlanRemove, types.ActionPlanUpdate: return "plan" case types.ActionInstruct: return "instruct" default: return "unknown" } } // GetActionDescription returns a human-readable description of an action func GetActionDescription(action types.InterventionAction) string { switch action { case types.ActionTaskAdd: return "Add a new task" case types.ActionTaskCancel: return "Cancel a task" case types.ActionTaskUpdate: return "Update task details" case types.ActionGoalAdjust: return "Adjust current goal" case types.ActionGoalAdd: return "Add a new goal" case types.ActionGoalComplete: return "Mark goal as complete" case types.ActionGoalCancel: return "Cancel a goal" case types.ActionPlanAdd: return "Add to plan queue" case types.ActionPlanRemove: return "Remove from plan queue" case types.ActionPlanUpdate: return "Update planned item" case types.ActionInstruct: return "Direct instruction to robot" default: return "Unknown action" } } ================================================ FILE: agent/robot/trigger/trigger_test.go ================================================ package trigger_test import ( "testing" "time" "github.com/stretchr/testify/assert" agentcontext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/robot/trigger" "github.com/yaoapp/yao/agent/robot/types" ) // ==================== ValidateIntervention Tests ==================== func TestValidateIntervention(t *testing.T) { t.Run("nil request returns error", func(t *testing.T) { err := trigger.ValidateIntervention(nil) assert.Error(t, err) assert.Contains(t, err.Error(), "request is nil") }) t.Run("empty member_id returns error", func(t *testing.T) { req := &types.InterveneRequest{ MemberID: "", Action: types.ActionTaskAdd, } err := trigger.ValidateIntervention(req) assert.Error(t, err) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("invalid action returns error", func(t *testing.T) { req := &types.InterveneRequest{ MemberID: "robot_001", Action: types.InterventionAction("invalid.action"), } err := trigger.ValidateIntervention(req) assert.Error(t, err) assert.Contains(t, err.Error(), "invalid action") }) t.Run("task.add without messages returns error", func(t *testing.T) { req := &types.InterveneRequest{ MemberID: "robot_001", Action: types.ActionTaskAdd, Messages: nil, } err := trigger.ValidateIntervention(req) assert.Error(t, err) assert.Contains(t, err.Error(), "messages required") }) t.Run("goal.add without messages returns error", func(t *testing.T) { req := &types.InterveneRequest{ MemberID: "robot_001", Action: types.ActionGoalAdd, Messages: nil, } err := trigger.ValidateIntervention(req) assert.Error(t, err) assert.Contains(t, err.Error(), "messages required") }) t.Run("instruct without messages returns error", func(t *testing.T) { req := &types.InterveneRequest{ MemberID: "robot_001", Action: types.ActionInstruct, Messages: nil, } err := trigger.ValidateIntervention(req) assert.Error(t, err) assert.Contains(t, err.Error(), "messages required") }) t.Run("plan.add without plan_time returns error", func(t *testing.T) { req := &types.InterveneRequest{ MemberID: "robot_001", Action: types.ActionPlanAdd, PlanTime: nil, } err := trigger.ValidateIntervention(req) assert.Error(t, err) assert.Contains(t, err.Error(), "plan_time required") }) t.Run("valid task.add request passes", func(t *testing.T) { req := &types.InterveneRequest{ MemberID: "robot_001", Action: types.ActionTaskAdd, Messages: []agentcontext.Message{ {Role: agentcontext.RoleUser, Content: "Add a new task"}, }, } err := trigger.ValidateIntervention(req) assert.NoError(t, err) }) t.Run("valid plan.add request passes", func(t *testing.T) { planTime := time.Now().Add(time.Hour) req := &types.InterveneRequest{ MemberID: "robot_001", Action: types.ActionPlanAdd, PlanTime: &planTime, } err := trigger.ValidateIntervention(req) assert.NoError(t, err) }) t.Run("task.cancel without messages passes", func(t *testing.T) { req := &types.InterveneRequest{ MemberID: "robot_001", Action: types.ActionTaskCancel, } err := trigger.ValidateIntervention(req) assert.NoError(t, err) }) t.Run("goal.adjust without messages passes", func(t *testing.T) { req := &types.InterveneRequest{ MemberID: "robot_001", Action: types.ActionGoalAdjust, } err := trigger.ValidateIntervention(req) assert.NoError(t, err) }) } // ==================== ValidateEvent Tests ==================== func TestValidateEvent(t *testing.T) { t.Run("nil request returns error", func(t *testing.T) { err := trigger.ValidateEvent(nil) assert.Error(t, err) assert.Contains(t, err.Error(), "request is nil") }) t.Run("empty member_id returns error", func(t *testing.T) { req := &types.EventRequest{ MemberID: "", Source: "webhook", EventType: "lead.created", } err := trigger.ValidateEvent(req) assert.Error(t, err) assert.Contains(t, err.Error(), "member_id is required") }) t.Run("empty source returns error", func(t *testing.T) { req := &types.EventRequest{ MemberID: "robot_001", Source: "", EventType: "lead.created", } err := trigger.ValidateEvent(req) assert.Error(t, err) assert.Contains(t, err.Error(), "source is required") }) t.Run("empty event_type returns error", func(t *testing.T) { req := &types.EventRequest{ MemberID: "robot_001", Source: "webhook", EventType: "", } err := trigger.ValidateEvent(req) assert.Error(t, err) assert.Contains(t, err.Error(), "event_type is required") }) t.Run("valid request passes", func(t *testing.T) { req := &types.EventRequest{ MemberID: "robot_001", Source: "webhook", EventType: "lead.created", Data: map[string]interface{}{"name": "John"}, } err := trigger.ValidateEvent(req) assert.NoError(t, err) }) } // ==================== BuildEventInput Tests ==================== func TestBuildEventInput(t *testing.T) { t.Run("builds correct TriggerInput", func(t *testing.T) { req := &types.EventRequest{ MemberID: "robot_001", Source: "webhook", EventType: "lead.created", Data: map[string]interface{}{"name": "John", "email": "john@example.com"}, } input := trigger.BuildEventInput(req) assert.NotNil(t, input) assert.Equal(t, types.EventSource("webhook"), input.Source) assert.Equal(t, "lead.created", input.EventType) assert.Equal(t, "John", input.Data["name"]) assert.Equal(t, "john@example.com", input.Data["email"]) }) t.Run("handles nil data", func(t *testing.T) { req := &types.EventRequest{ MemberID: "robot_001", Source: "database", EventType: "order.paid", Data: nil, } input := trigger.BuildEventInput(req) assert.NotNil(t, input) assert.Equal(t, types.EventSource("database"), input.Source) assert.Equal(t, "order.paid", input.EventType) assert.Nil(t, input.Data) }) } // ==================== GetActionCategory Tests ==================== func TestGetActionCategory(t *testing.T) { tests := []struct { action types.InterventionAction expected string }{ {types.ActionTaskAdd, "task"}, {types.ActionTaskCancel, "task"}, {types.ActionTaskUpdate, "task"}, {types.ActionGoalAdjust, "goal"}, {types.ActionGoalAdd, "goal"}, {types.ActionGoalComplete, "goal"}, {types.ActionGoalCancel, "goal"}, {types.ActionPlanAdd, "plan"}, {types.ActionPlanRemove, "plan"}, {types.ActionPlanUpdate, "plan"}, {types.ActionInstruct, "instruct"}, {types.InterventionAction("unknown"), "unknown"}, } for _, tt := range tests { t.Run(string(tt.action), func(t *testing.T) { result := trigger.GetActionCategory(tt.action) assert.Equal(t, tt.expected, result) }) } } // ==================== GetActionDescription Tests ==================== func TestGetActionDescription(t *testing.T) { tests := []struct { action types.InterventionAction contains string }{ {types.ActionTaskAdd, "Add"}, {types.ActionTaskCancel, "Cancel"}, {types.ActionTaskUpdate, "Update"}, {types.ActionGoalAdjust, "Adjust"}, {types.ActionGoalAdd, "Add"}, {types.ActionGoalComplete, "complete"}, {types.ActionGoalCancel, "Cancel"}, {types.ActionPlanAdd, "plan"}, {types.ActionPlanRemove, "Remove"}, {types.ActionPlanUpdate, "Update"}, {types.ActionInstruct, "instruction"}, {types.InterventionAction("unknown"), "Unknown"}, } for _, tt := range tests { t.Run(string(tt.action), func(t *testing.T) { result := trigger.GetActionDescription(tt.action) assert.NotEmpty(t, result) assert.Contains(t, result, tt.contains) }) } } ================================================ FILE: agent/robot/types/clock.go ================================================ package types import "time" // ClockContext - time context for P0 inspiration type ClockContext struct { Now time.Time `json:"now"` Hour int `json:"hour"` // 0-23 DayOfWeek string `json:"day_of_week"` // Monday, Tuesday... DayOfMonth int `json:"day_of_month"` // 1-31 WeekOfYear int `json:"week_of_year"` // 1-52 Month int `json:"month"` // 1-12 Year int `json:"year"` IsWeekend bool `json:"is_weekend"` IsMonthStart bool `json:"is_month_start"` // 1st-3rd IsMonthEnd bool `json:"is_month_end"` // last 3 days IsQuarterEnd bool `json:"is_quarter_end"` IsYearEnd bool `json:"is_year_end"` TZ string `json:"tz"` } // NewClockContext creates clock context from time func NewClockContext(t time.Time, tz string) *ClockContext { loc := time.Local if tz != "" { if l, err := time.LoadLocation(tz); err == nil { loc = l } } t = t.In(loc) _, week := t.ISOWeek() dayOfMonth := t.Day() lastDay := time.Date(t.Year(), t.Month()+1, 0, 0, 0, 0, 0, loc).Day() return &ClockContext{ Now: t, Hour: t.Hour(), DayOfWeek: t.Weekday().String(), DayOfMonth: dayOfMonth, WeekOfYear: week, Month: int(t.Month()), Year: t.Year(), IsWeekend: t.Weekday() == time.Saturday || t.Weekday() == time.Sunday, IsMonthStart: dayOfMonth <= 3, IsMonthEnd: dayOfMonth >= lastDay-2, IsQuarterEnd: (t.Month()%3 == 0) && dayOfMonth >= lastDay-2, IsYearEnd: t.Month() == 12 && dayOfMonth >= 29, TZ: loc.String(), } } ================================================ FILE: agent/robot/types/clock_test.go ================================================ package types_test import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/robot/types" ) func TestNewClockContext(t *testing.T) { t.Run("basic clock context", func(t *testing.T) { // Test with a known date: 2024-01-15 14:30:00 (Monday) testTime := time.Date(2024, 1, 15, 14, 30, 0, 0, time.UTC) ctx := types.NewClockContext(testTime, "UTC") assert.Equal(t, 14, ctx.Hour) assert.Equal(t, "Monday", ctx.DayOfWeek) assert.Equal(t, 15, ctx.DayOfMonth) assert.Equal(t, 1, ctx.Month) assert.Equal(t, 2024, ctx.Year) assert.False(t, ctx.IsWeekend) assert.False(t, ctx.IsMonthStart) assert.False(t, ctx.IsMonthEnd) assert.False(t, ctx.IsQuarterEnd) assert.False(t, ctx.IsYearEnd) }) t.Run("weekend detection", func(t *testing.T) { // Saturday saturday := time.Date(2024, 1, 13, 10, 0, 0, 0, time.UTC) ctx := types.NewClockContext(saturday, "") assert.True(t, ctx.IsWeekend) // Sunday sunday := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC) ctx = types.NewClockContext(sunday, "") assert.True(t, ctx.IsWeekend) }) t.Run("month start detection", func(t *testing.T) { // 1st day day1 := time.Date(2024, 1, 1, 10, 0, 0, 0, time.UTC) ctx := types.NewClockContext(day1, "") assert.True(t, ctx.IsMonthStart) // 3rd day day3 := time.Date(2024, 1, 3, 10, 0, 0, 0, time.UTC) ctx = types.NewClockContext(day3, "") assert.True(t, ctx.IsMonthStart) // 4th day - not month start day4 := time.Date(2024, 1, 4, 10, 0, 0, 0, time.UTC) ctx = types.NewClockContext(day4, "") assert.False(t, ctx.IsMonthStart) }) t.Run("month end detection", func(t *testing.T) { // Last day of January (31st) lastDay := time.Date(2024, 1, 31, 10, 0, 0, 0, time.UTC) ctx := types.NewClockContext(lastDay, "") assert.True(t, ctx.IsMonthEnd) // 29th day of January (31 days total) day29 := time.Date(2024, 1, 29, 10, 0, 0, 0, time.UTC) ctx = types.NewClockContext(day29, "") assert.True(t, ctx.IsMonthEnd) // 28th day of January - not month end day28 := time.Date(2024, 1, 28, 10, 0, 0, 0, time.UTC) ctx = types.NewClockContext(day28, "") assert.False(t, ctx.IsMonthEnd) }) t.Run("quarter end detection", func(t *testing.T) { // March 31 - Q1 end q1End := time.Date(2024, 3, 31, 10, 0, 0, 0, time.UTC) ctx := types.NewClockContext(q1End, "") assert.True(t, ctx.IsQuarterEnd) // June 30 - Q2 end q2End := time.Date(2024, 6, 30, 10, 0, 0, 0, time.UTC) ctx = types.NewClockContext(q2End, "") assert.True(t, ctx.IsQuarterEnd) // September 30 - Q3 end q3End := time.Date(2024, 9, 30, 10, 0, 0, 0, time.UTC) ctx = types.NewClockContext(q3End, "") assert.True(t, ctx.IsQuarterEnd) // December 31 - Q4 end q4End := time.Date(2024, 12, 31, 10, 0, 0, 0, time.UTC) ctx = types.NewClockContext(q4End, "") assert.True(t, ctx.IsQuarterEnd) // Not quarter end notQEnd := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC) ctx = types.NewClockContext(notQEnd, "") assert.False(t, ctx.IsQuarterEnd) }) t.Run("year end detection", func(t *testing.T) { // December 29 dec29 := time.Date(2024, 12, 29, 10, 0, 0, 0, time.UTC) ctx := types.NewClockContext(dec29, "") assert.True(t, ctx.IsYearEnd) // December 31 dec31 := time.Date(2024, 12, 31, 10, 0, 0, 0, time.UTC) ctx = types.NewClockContext(dec31, "") assert.True(t, ctx.IsYearEnd) // December 28 - not year end dec28 := time.Date(2024, 12, 28, 10, 0, 0, 0, time.UTC) ctx = types.NewClockContext(dec28, "") assert.False(t, ctx.IsYearEnd) // January - not year end jan1 := time.Date(2024, 1, 1, 10, 0, 0, 0, time.UTC) ctx = types.NewClockContext(jan1, "") assert.False(t, ctx.IsYearEnd) }) t.Run("timezone handling", func(t *testing.T) { testTime := time.Date(2024, 1, 15, 14, 30, 0, 0, time.UTC) // With Asia/Shanghai timezone ctx := types.NewClockContext(testTime, "Asia/Shanghai") assert.Equal(t, "Asia/Shanghai", ctx.TZ) // Time should be converted to Shanghai timezone assert.NotEqual(t, testTime, ctx.Now) assert.Equal(t, 22, ctx.Hour) // UTC 14:00 = Shanghai 22:00 (UTC+8) // With invalid timezone - should fall back to local ctx = types.NewClockContext(testTime, "Invalid/Timezone") assert.NotEmpty(t, ctx.TZ) }) t.Run("week of year", func(t *testing.T) { // First week of 2024 jan1 := time.Date(2024, 1, 1, 10, 0, 0, 0, time.UTC) ctx := types.NewClockContext(jan1, "") assert.Equal(t, 1, ctx.WeekOfYear) // Mid year july15 := time.Date(2024, 7, 15, 10, 0, 0, 0, time.UTC) ctx = types.NewClockContext(july15, "") assert.Greater(t, ctx.WeekOfYear, 20) assert.Less(t, ctx.WeekOfYear, 35) }) } func TestClockContextFields(t *testing.T) { // Test all fields are populated correctly testTime := time.Date(2024, 12, 30, 23, 45, 30, 0, time.UTC) ctx := types.NewClockContext(testTime, "UTC") assert.NotZero(t, ctx.Now) assert.Equal(t, 23, ctx.Hour) assert.Equal(t, "Monday", ctx.DayOfWeek) assert.Equal(t, 30, ctx.DayOfMonth) assert.Equal(t, 1, ctx.WeekOfYear) // Dec 30, 2024 is week 1 of 2025 assert.Equal(t, 12, ctx.Month) assert.Equal(t, 2024, ctx.Year) assert.False(t, ctx.IsWeekend) // Monday assert.False(t, ctx.IsMonthStart) assert.True(t, ctx.IsMonthEnd) assert.True(t, ctx.IsQuarterEnd) assert.True(t, ctx.IsYearEnd) assert.Equal(t, "UTC", ctx.TZ) } ================================================ FILE: agent/robot/types/config.go ================================================ package types import ( "encoding/json" "time" ) // Config - robot_config in __yao.member type Config struct { Triggers *Triggers `json:"triggers,omitempty"` Clock *Clock `json:"clock,omitempty"` Identity *Identity `json:"identity"` Quota *Quota `json:"quota,omitempty"` KB *KB `json:"kb,omitempty"` // shared knowledge base (same as assistant) DB *DB `json:"db,omitempty"` // shared database (same as assistant) Learn *Learn `json:"learn,omitempty"` // learning config for private KB Resources *Resources `json:"resources,omitempty"` Delivery *DeliveryPreferences `json:"delivery,omitempty"` // delivery preferences (see robot.go) Events []Event `json:"events,omitempty"` Executor *ExecutorConfig `json:"executor,omitempty"` // executor mode settings DefaultLocale string `json:"default_locale,omitempty"` // default language for clock/event triggers ("en", "zh") Integrations *Integrations `json:"integrations,omitempty"` // external channel integrations (telegram, etc.) } // Integrations holds configuration for external platform integrations. type Integrations struct { Telegram *TelegramConfig `json:"telegram,omitempty"` Feishu *FeishuConfig `json:"feishu,omitempty"` DingTalk *DingTalkConfig `json:"dingtalk,omitempty"` Discord *DiscordConfig `json:"discord,omitempty"` } // TelegramConfig holds Telegram Bot integration settings. type TelegramConfig struct { Enabled bool `json:"enabled"` BotToken string `json:"bot_token"` Host string `json:"host,omitempty"` // custom Bot API server, defaults to https://api.telegram.org AppID string `json:"app_id,omitempty"` // auto-generated, used for webhook URL routing ChatID string `json:"chat_id,omitempty"` // default reply chat WebhookSecret string `json:"webhook_secret,omitempty"` // sent with SetWebhook, verified on incoming webhooks } // FeishuConfig holds Feishu (Lark) Bot integration settings. type FeishuConfig struct { Enabled bool `json:"enabled"` AppID string `json:"app_id"` AppSecret string `json:"app_secret"` } // DingTalkConfig holds DingTalk Bot integration settings. type DingTalkConfig struct { Enabled bool `json:"enabled"` ClientID string `json:"client_id"` ClientSecret string `json:"client_secret"` } // DiscordConfig holds Discord Bot integration settings. type DiscordConfig struct { Enabled bool `json:"enabled"` BotToken string `json:"bot_token"` AppID string `json:"app_id,omitempty"` } // ExecutorConfig - executor settings type ExecutorConfig struct { Mode ExecutorMode `json:"mode,omitempty"` // standard | dryrun | sandbox MaxDuration string `json:"max_duration,omitempty"` // max execution time (e.g., "30m") } // GetMode returns the executor mode (default: standard) func (e *ExecutorConfig) GetMode() ExecutorMode { if e == nil || e.Mode == "" { return ExecutorStandard } return e.Mode } // GetMaxDuration returns the max duration (default: 30m) func (e *ExecutorConfig) GetMaxDuration() time.Duration { if e == nil || e.MaxDuration == "" { return 30 * time.Minute } d, err := time.ParseDuration(e.MaxDuration) if err != nil { return 30 * time.Minute } return d } // Validate validates the config func (c *Config) Validate() error { if c.Identity == nil || c.Identity.Role == "" { return ErrMissingIdentity } if c.Clock != nil { if err := c.Clock.Validate(); err != nil { return err } } return nil } // GetDefaultLocale returns the default locale (default: "en") func (c *Config) GetDefaultLocale() string { if c == nil || c.DefaultLocale == "" { return "en" } return c.DefaultLocale } // Triggers - trigger enable/disable type Triggers struct { Clock *TriggerSwitch `json:"clock,omitempty"` Intervene *TriggerSwitch `json:"intervene,omitempty"` Event *TriggerSwitch `json:"event,omitempty"` } // TriggerSwitch - trigger enable/disable switch type TriggerSwitch struct { Enabled bool `json:"enabled"` Actions []string `json:"actions,omitempty"` // for intervene } // IsEnabled checks if trigger is enabled (default: true) func (t *Triggers) IsEnabled(typ TriggerType) bool { if t == nil { return true } switch typ { case TriggerClock: return t.Clock == nil || t.Clock.Enabled case TriggerHuman: return t.Intervene == nil || t.Intervene.Enabled case TriggerEvent: return t.Event == nil || t.Event.Enabled } return false } // Clock - when to wake up type Clock struct { Mode ClockMode `json:"mode"` // times | interval | daemon Times []string `json:"times,omitempty"` // ["09:00", "14:00"] Days []string `json:"days,omitempty"` // ["Mon", "Tue"] or ["*"] Every string `json:"every,omitempty"` // "30m", "1h" TZ string `json:"tz,omitempty"` // "Asia/Shanghai" Timeout string `json:"timeout,omitempty"` // "30m" } // Validate validates clock config func (c *Clock) Validate() error { switch c.Mode { case ClockTimes: if len(c.Times) == 0 { return ErrClockTimesEmpty } case ClockInterval: if c.Every == "" { return ErrClockIntervalEmpty } case ClockDaemon: // no extra validation default: return ErrClockModeInvalid } return nil } // GetTimeout returns parsed timeout duration func (c *Clock) GetTimeout() time.Duration { if c.Timeout == "" { return 30 * time.Minute // default } d, err := time.ParseDuration(c.Timeout) if err != nil { return 30 * time.Minute } return d } // GetLocation returns timezone location func (c *Clock) GetLocation() *time.Location { if c.TZ == "" { return time.Local } loc, err := time.LoadLocation(c.TZ) if err != nil { return time.Local } return loc } // Identity - who is this robot type Identity struct { Role string `json:"role"` Duties []string `json:"duties,omitempty"` Rules []string `json:"rules,omitempty"` } // Quota - concurrency limits type Quota struct { Max int `json:"max"` // max running (default: 2) Queue int `json:"queue"` // queue size (default: 10) Priority int `json:"priority"` // 1-10 (default: 5) } // GetMax returns max with default func (q *Quota) GetMax() int { if q == nil || q.Max <= 0 { return 2 } return q.Max } // GetQueue returns queue size with default func (q *Quota) GetQueue() int { if q == nil || q.Queue <= 0 { return 10 } return q.Queue } // GetPriority returns priority with default func (q *Quota) GetPriority() int { if q == nil || q.Priority <= 0 { return 5 } return q.Priority } // KB - knowledge base config (same as assistant, from store/types) // Shared KB collections accessible by this robot type KB struct { Collections []string `json:"collections,omitempty"` // KB collection IDs Options map[string]interface{} `json:"options,omitempty"` } // DB - database config (same as assistant, from store/types) // Shared database models accessible by this robot type DB struct { Models []string `json:"models,omitempty"` // database model names Options map[string]interface{} `json:"options,omitempty"` } // Learn - learning config for robot's private KB // Private KB is auto-created: robot_{team_id}_{member_id}_kb type Learn struct { On bool `json:"on"` Types []string `json:"types,omitempty"` // execution, feedback, insight Keep int `json:"keep,omitempty"` // days, 0 = forever } // Resources - available agents and tools type Resources struct { Phases map[Phase]string `json:"phases,omitempty"` // phase -> agent ID Agents []string `json:"agents,omitempty"` MCP []MCPConfig `json:"mcp,omitempty"` } // GetPhaseAgent returns agent ID for phase (default: __yao.{phase}) func (r *Resources) GetPhaseAgent(phase Phase) string { if r != nil && r.Phases != nil { if id, ok := r.Phases[phase]; ok && id != "" { return id } } return "__yao." + string(phase) } // MCPConfig - MCP server configuration type MCPConfig struct { ID string `json:"id"` Tools []string `json:"tools,omitempty"` // empty = all } // Event - event trigger config type Event struct { Type EventSource `json:"type"` // webhook | database Source string `json:"source"` // webhook path or table name Filter map[string]interface{} `json:"filter,omitempty"` } // ParseConfig parses robot_config from various formats (string, []byte, map) func ParseConfig(data interface{}) (*Config, error) { if data == nil { return nil, nil } var configBytes []byte switch v := data.(type) { case string: if v == "" { return nil, nil } configBytes = []byte(v) case []byte: if len(v) == 0 { return nil, nil } configBytes = v case map[string]interface{}: var err error configBytes, err = json.Marshal(v) if err != nil { return nil, err } default: var err error configBytes, err = json.Marshal(v) if err != nil { return nil, err } } var config Config if err := json.Unmarshal(configBytes, &config); err != nil { return nil, err } return &config, nil } ================================================ FILE: agent/robot/types/config_global.go ================================================ package types import "sync" // Global configuration for robot agent // These values can be set during agent initialization var ( // defaultEmailChannel - default messenger channel name for sending emails // Can be configured via SetDefaultEmailChannel() // Default: "default" (maps to messengers/channels.yao configuration) defaultEmailChannel = "default" // configMu protects global configuration configMu sync.RWMutex ) // DefaultEmailChannel returns the default email channel name func DefaultEmailChannel() string { configMu.RLock() defer configMu.RUnlock() return defaultEmailChannel } // SetDefaultEmailChannel sets the default messenger channel for email delivery // This should be called during agent initialization // The channel name must match a channel defined in messengers/channels.yao func SetDefaultEmailChannel(channel string) { if channel == "" { return } configMu.Lock() defer configMu.Unlock() defaultEmailChannel = channel } ================================================ FILE: agent/robot/types/config_test.go ================================================ package types_test import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/robot/types" ) func TestConfigValidate(t *testing.T) { t.Run("valid config", func(t *testing.T) { config := &types.Config{ Identity: &types.Identity{ Role: "Sales Manager", }, } err := config.Validate() assert.NoError(t, err) }) t.Run("missing identity", func(t *testing.T) { config := &types.Config{} err := config.Validate() assert.Error(t, err) assert.Equal(t, types.ErrMissingIdentity, err) }) t.Run("missing identity role", func(t *testing.T) { config := &types.Config{ Identity: &types.Identity{}, } err := config.Validate() assert.Error(t, err) assert.Equal(t, types.ErrMissingIdentity, err) }) t.Run("invalid clock config", func(t *testing.T) { config := &types.Config{ Identity: &types.Identity{Role: "Test"}, Clock: &types.Clock{ Mode: types.ClockTimes, // Times is empty - should fail }, } err := config.Validate() assert.Error(t, err) assert.Equal(t, types.ErrClockTimesEmpty, err) }) } func TestClockValidate(t *testing.T) { t.Run("valid times mode", func(t *testing.T) { clock := &types.Clock{ Mode: types.ClockTimes, Times: []string{"09:00", "14:00"}, } err := clock.Validate() assert.NoError(t, err) }) t.Run("times mode without times", func(t *testing.T) { clock := &types.Clock{ Mode: types.ClockTimes, } err := clock.Validate() assert.Error(t, err) assert.Equal(t, types.ErrClockTimesEmpty, err) }) t.Run("valid interval mode", func(t *testing.T) { clock := &types.Clock{ Mode: types.ClockInterval, Every: "30m", } err := clock.Validate() assert.NoError(t, err) }) t.Run("interval mode without every", func(t *testing.T) { clock := &types.Clock{ Mode: types.ClockInterval, } err := clock.Validate() assert.Error(t, err) assert.Equal(t, types.ErrClockIntervalEmpty, err) }) t.Run("valid daemon mode", func(t *testing.T) { clock := &types.Clock{ Mode: types.ClockDaemon, } err := clock.Validate() assert.NoError(t, err) }) t.Run("invalid mode", func(t *testing.T) { clock := &types.Clock{ Mode: types.ClockMode("invalid"), } err := clock.Validate() assert.Error(t, err) assert.Equal(t, types.ErrClockModeInvalid, err) }) } func TestClockGetTimeout(t *testing.T) { t.Run("default timeout", func(t *testing.T) { clock := &types.Clock{} timeout := clock.GetTimeout() assert.Equal(t, 30*time.Minute, timeout) }) t.Run("custom timeout", func(t *testing.T) { clock := &types.Clock{ Timeout: "10m", } timeout := clock.GetTimeout() assert.Equal(t, 10*time.Minute, timeout) }) t.Run("invalid timeout returns default", func(t *testing.T) { clock := &types.Clock{ Timeout: "invalid", } timeout := clock.GetTimeout() assert.Equal(t, 30*time.Minute, timeout) }) } func TestClockGetLocation(t *testing.T) { t.Run("default location", func(t *testing.T) { clock := &types.Clock{} loc := clock.GetLocation() assert.Equal(t, time.Local, loc) }) t.Run("valid timezone", func(t *testing.T) { clock := &types.Clock{ TZ: "Asia/Shanghai", } loc := clock.GetLocation() assert.NotNil(t, loc) assert.Equal(t, "Asia/Shanghai", loc.String()) }) t.Run("invalid timezone returns local", func(t *testing.T) { clock := &types.Clock{ TZ: "Invalid/Timezone", } loc := clock.GetLocation() assert.Equal(t, time.Local, loc) }) } func TestTriggersIsEnabled(t *testing.T) { t.Run("nil triggers - all enabled by default", func(t *testing.T) { var triggers *types.Triggers assert.True(t, triggers.IsEnabled(types.TriggerClock)) assert.True(t, triggers.IsEnabled(types.TriggerHuman)) assert.True(t, triggers.IsEnabled(types.TriggerEvent)) }) t.Run("clock enabled", func(t *testing.T) { triggers := &types.Triggers{ Clock: &types.TriggerSwitch{Enabled: true}, } assert.True(t, triggers.IsEnabled(types.TriggerClock)) }) t.Run("clock disabled", func(t *testing.T) { triggers := &types.Triggers{ Clock: &types.TriggerSwitch{Enabled: false}, } assert.False(t, triggers.IsEnabled(types.TriggerClock)) }) t.Run("intervene enabled by default", func(t *testing.T) { triggers := &types.Triggers{} assert.True(t, triggers.IsEnabled(types.TriggerHuman)) }) t.Run("event disabled", func(t *testing.T) { triggers := &types.Triggers{ Event: &types.TriggerSwitch{Enabled: false}, } assert.False(t, triggers.IsEnabled(types.TriggerEvent)) }) } func TestQuotaDefaults(t *testing.T) { t.Run("nil quota", func(t *testing.T) { var quota *types.Quota assert.Equal(t, 2, quota.GetMax()) assert.Equal(t, 10, quota.GetQueue()) assert.Equal(t, 5, quota.GetPriority()) }) t.Run("zero values", func(t *testing.T) { quota := &types.Quota{} assert.Equal(t, 2, quota.GetMax()) assert.Equal(t, 10, quota.GetQueue()) assert.Equal(t, 5, quota.GetPriority()) }) t.Run("custom values", func(t *testing.T) { quota := &types.Quota{ Max: 5, Queue: 20, Priority: 8, } assert.Equal(t, 5, quota.GetMax()) assert.Equal(t, 20, quota.GetQueue()) assert.Equal(t, 8, quota.GetPriority()) }) } func TestResourcesGetPhaseAgent(t *testing.T) { t.Run("nil resources - returns default", func(t *testing.T) { var resources *types.Resources agent := resources.GetPhaseAgent(types.PhaseGoals) assert.Equal(t, "__yao.goals", agent) }) t.Run("phase not configured - returns default", func(t *testing.T) { resources := &types.Resources{ Phases: map[types.Phase]string{}, } agent := resources.GetPhaseAgent(types.PhaseGoals) assert.Equal(t, "__yao.goals", agent) }) t.Run("custom phase agent", func(t *testing.T) { resources := &types.Resources{ Phases: map[types.Phase]string{ types.PhaseGoals: "custom.goals.agent", }, } agent := resources.GetPhaseAgent(types.PhaseGoals) assert.Equal(t, "custom.goals.agent", agent) }) t.Run("all phases default names", func(t *testing.T) { resources := &types.Resources{} assert.Equal(t, "__yao.inspiration", resources.GetPhaseAgent(types.PhaseInspiration)) assert.Equal(t, "__yao.goals", resources.GetPhaseAgent(types.PhaseGoals)) assert.Equal(t, "__yao.tasks", resources.GetPhaseAgent(types.PhaseTasks)) assert.Equal(t, "__yao.run", resources.GetPhaseAgent(types.PhaseRun)) assert.Equal(t, "__yao.delivery", resources.GetPhaseAgent(types.PhaseDelivery)) assert.Equal(t, "__yao.learning", resources.GetPhaseAgent(types.PhaseLearning)) }) } func TestExecutorConfigGetMode(t *testing.T) { t.Run("nil config - returns default", func(t *testing.T) { var config *types.ExecutorConfig assert.Equal(t, types.ExecutorStandard, config.GetMode()) }) t.Run("empty mode - returns default", func(t *testing.T) { config := &types.ExecutorConfig{} assert.Equal(t, types.ExecutorStandard, config.GetMode()) }) t.Run("standard mode", func(t *testing.T) { config := &types.ExecutorConfig{Mode: types.ExecutorStandard} assert.Equal(t, types.ExecutorStandard, config.GetMode()) }) t.Run("dryrun mode", func(t *testing.T) { config := &types.ExecutorConfig{Mode: types.ExecutorDryRun} assert.Equal(t, types.ExecutorDryRun, config.GetMode()) }) t.Run("sandbox mode", func(t *testing.T) { config := &types.ExecutorConfig{Mode: types.ExecutorSandbox} assert.Equal(t, types.ExecutorSandbox, config.GetMode()) }) } func TestExecutorConfigGetMaxDuration(t *testing.T) { t.Run("nil config - returns default 30m", func(t *testing.T) { var config *types.ExecutorConfig assert.Equal(t, 30*time.Minute, config.GetMaxDuration()) }) t.Run("empty duration - returns default 30m", func(t *testing.T) { config := &types.ExecutorConfig{} assert.Equal(t, 30*time.Minute, config.GetMaxDuration()) }) t.Run("custom duration", func(t *testing.T) { config := &types.ExecutorConfig{MaxDuration: "10m"} assert.Equal(t, 10*time.Minute, config.GetMaxDuration()) }) t.Run("invalid duration - returns default", func(t *testing.T) { config := &types.ExecutorConfig{MaxDuration: "invalid"} assert.Equal(t, 30*time.Minute, config.GetMaxDuration()) }) t.Run("various valid durations", func(t *testing.T) { tests := []struct { input string expected time.Duration }{ {"1h", time.Hour}, {"30s", 30 * time.Second}, {"2h30m", 2*time.Hour + 30*time.Minute}, } for _, tt := range tests { config := &types.ExecutorConfig{MaxDuration: tt.input} assert.Equal(t, tt.expected, config.GetMaxDuration(), "for input %s", tt.input) } }) } ================================================ FILE: agent/robot/types/context.go ================================================ package types import ( "context" "github.com/yaoapp/yao/openapi/oauth/types" ) // Context - robot execution context (lightweight) type Context struct { context.Context // embed standard context Auth *types.AuthorizedInfo `json:"auth,omitempty"` // reuse oauth AuthorizedInfo MemberID string `json:"member_id,omitempty"` // current robot member ID RequestID string `json:"request_id,omitempty"` // request trace ID Locale string `json:"locale,omitempty"` // locale (e.g., "en-US") } // NewContext creates a new robot context func NewContext(parent context.Context, auth *types.AuthorizedInfo) *Context { if parent == nil { parent = context.Background() } return &Context{ Context: parent, Auth: auth, } } // UserID returns user ID from auth func (c *Context) UserID() string { if c.Auth == nil { return "" } return c.Auth.UserID } // TeamID returns team ID from auth func (c *Context) TeamID() string { if c.Auth == nil { return "" } return c.Auth.TeamID } ================================================ FILE: agent/robot/types/enums.go ================================================ package types // Phase - execution phase type Phase string // Phase constants define the execution phases for robot agent const ( PhaseInspiration Phase = "inspiration" // P0: Clock only PhaseGoals Phase = "goals" // P1 PhaseTasks Phase = "tasks" // P2 PhaseRun Phase = "run" // P3 PhaseDelivery Phase = "delivery" // P4 PhaseLearning Phase = "learning" // P5 PhaseHost Phase = "host" // V2: Host Agent (human interaction) ) // AllPhases lists phases in execution order (PhaseHost is excluded — it is a cross-phase service role, not a pipeline stage) var AllPhases = []Phase{ PhaseInspiration, PhaseGoals, PhaseTasks, PhaseRun, PhaseDelivery, PhaseLearning, } // AllConfigurablePhases lists phases that can be bound to custom agents var AllConfigurablePhases = []Phase{ PhaseInspiration, PhaseGoals, PhaseTasks, PhaseRun, PhaseDelivery, PhaseLearning, PhaseHost, } // ClockMode - clock trigger mode type ClockMode string // ClockMode constants define the clock trigger modes const ( ClockTimes ClockMode = "times" // run at specific times ClockInterval ClockMode = "interval" // run every X duration ClockDaemon ClockMode = "daemon" // run continuously ) // TriggerType - trigger source type TriggerType string // TriggerType constants define the trigger sources const ( TriggerClock TriggerType = "clock" TriggerHuman TriggerType = "human" TriggerEvent TriggerType = "event" ) // ExecStatus - execution status type ExecStatus string // ExecStatus constants define the execution status values const ( ExecPending ExecStatus = "pending" ExecRunning ExecStatus = "running" ExecPaused ExecStatus = "paused" ExecCompleted ExecStatus = "completed" ExecFailed ExecStatus = "failed" ExecCancelled ExecStatus = "cancelled" ExecConfirming ExecStatus = "confirming" // V2: awaiting human confirmation before running ExecWaiting ExecStatus = "waiting" // V2: suspended, waiting for human input ) // RobotStatus - matches __yao.member.robot_status type RobotStatus string // RobotStatus constants define the robot status values const ( RobotIdle RobotStatus = "idle" RobotWorking RobotStatus = "working" RobotPaused RobotStatus = "paused" RobotError RobotStatus = "error" RobotMaintenance RobotStatus = "maintenance" ) // InterventionAction - human intervention action // Format: category.action (e.g., "task.add", "goal.adjust") type InterventionAction string // InterventionAction constants define the human intervention actions const ( // ActionTaskAdd adds a new task ActionTaskAdd InterventionAction = "task.add" // ActionTaskCancel cancels a task ActionTaskCancel InterventionAction = "task.cancel" // ActionTaskUpdate updates task details ActionTaskUpdate InterventionAction = "task.update" // ActionGoalAdjust modifies current goal ActionGoalAdjust InterventionAction = "goal.adjust" // ActionGoalAdd adds a new goal ActionGoalAdd InterventionAction = "goal.add" // ActionGoalComplete marks goal as complete ActionGoalComplete InterventionAction = "goal.complete" // ActionGoalCancel cancels a goal ActionGoalCancel InterventionAction = "goal.cancel" // ActionPlanAdd adds to plan queue ActionPlanAdd InterventionAction = "plan.add" // ActionPlanRemove removes from plan queue ActionPlanRemove InterventionAction = "plan.remove" // ActionPlanUpdate updates planned item ActionPlanUpdate InterventionAction = "plan.update" // ActionInstruct is a direct instruction to robot ActionInstruct InterventionAction = "instruct" ) // Priority - task/goal priority type Priority string // Priority constants define the priority levels const ( PriorityHigh Priority = "high" PriorityNormal Priority = "normal" PriorityLow Priority = "low" ) // DeliveryType - output delivery type type DeliveryType string // DeliveryType constants define the output delivery types const ( DeliveryEmail DeliveryType = "email" // Send via yao/messenger DeliveryWebhook DeliveryType = "webhook" // POST to external URL DeliveryProcess DeliveryType = "process" // Call Yao Process DeliveryNotify DeliveryType = "notify" // In-app notification (future, auto by subscriptions) ) // DedupResult - deduplication result type DedupResult string // DedupResult constants define the deduplication results const ( DedupSkip DedupResult = "skip" // skip execution DedupMerge DedupResult = "merge" // merge with existing DedupProceed DedupResult = "proceed" // proceed normally ) // EventSource - event trigger source type EventSource string // EventSource constants define the event trigger sources const ( EventWebhook EventSource = "webhook" // HTTP webhook EventDatabase EventSource = "database" // DB change trigger ) // LearningType - learning entry type type LearningType string // LearningType constants define the learning entry types const ( LearnExecution LearningType = "execution" // execution record LearnFeedback LearningType = "feedback" // error/fix feedback LearnInsight LearningType = "insight" // pattern/tip insight ) // TaskSource - how task was created type TaskSource string // TaskSource constants define how a task was created const ( TaskSourceAuto TaskSource = "auto" // generated by P2 (task planning) TaskSourceHuman TaskSource = "human" // added via human intervention TaskSourceEvent TaskSource = "event" // added via event trigger ) // ExecutorType - task executor type type ExecutorType string // ExecutorType constants define the task executor types const ( ExecutorAssistant ExecutorType = "assistant" ExecutorMCP ExecutorType = "mcp" ExecutorProcess ExecutorType = "process" ) // TaskStatus - task execution status type TaskStatus string // TaskStatus constants define the task execution status values const ( TaskPending TaskStatus = "pending" TaskRunning TaskStatus = "running" TaskCompleted TaskStatus = "completed" TaskFailed TaskStatus = "failed" TaskSkipped TaskStatus = "skipped" TaskCancelled TaskStatus = "cancelled" TaskWaitingInput TaskStatus = "waiting_input" // V2: task suspended, waiting for human input ) // InsertPosition - where to insert task in queue type InsertPosition string // InsertPosition constants define where to insert task in queue const ( InsertFirst InsertPosition = "first" // insert at beginning (highest priority) InsertLast InsertPosition = "last" // append at end (default) InsertNext InsertPosition = "next" // insert after current task InsertAt InsertPosition = "at" // insert at specific index (use AtIndex) ) // ExecutorMode - executor mode for robot execution type ExecutorMode string // ExecutorMode constants define the executor modes const ( // ExecutorStandard uses real Agent calls (production mode) ExecutorStandard ExecutorMode = "standard" // ExecutorDryRun simulates execution without LLM calls (testing/demo) ExecutorDryRun ExecutorMode = "dryrun" // ExecutorSandbox runs in container-isolated environment (NOT IMPLEMENTED) // Requires Docker/gVisor/Firecracker infrastructure ExecutorSandbox ExecutorMode = "sandbox" ) // HostAction defines structured instructions from Host Agent to Manager type HostAction string // HostAction constants const ( HostActionConfirm HostAction = "confirm" // Confirm execution plan HostActionAdjust HostAction = "adjust" // Adjust goals/tasks HostActionAddTask HostAction = "add_task" // Inject a new task HostActionSkip HostAction = "skip" // Skip waiting task HostActionInjectCtx HostAction = "inject_context" // Add context to waiting task HostActionCancel HostAction = "cancel" // Cancel execution ) // InteractSource defines the source of an interact request type InteractSource string // InteractSource constants const ( InteractSourceUI InteractSource = "ui" // User via Mission Control UI InteractSourceEmail InteractSource = "email" // Incoming email InteractSourceWebhook InteractSource = "webhook" // External webhook InteractSourceA2A InteractSource = "a2a" // Agent-to-agent InteractSourceCron InteractSource = "cron" // Scheduled cron ) // IsValid checks if the executor mode is valid func (m ExecutorMode) IsValid() bool { switch m { case ExecutorStandard, ExecutorDryRun, ExecutorSandbox, "": return true } return false } // GetDefault returns the default executor mode if empty func (m ExecutorMode) GetDefault() ExecutorMode { if m == "" { return ExecutorStandard } return m } ================================================ FILE: agent/robot/types/enums_test.go ================================================ package types_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/robot/types" ) func TestPhaseEnum(t *testing.T) { assert.Equal(t, types.Phase("inspiration"), types.PhaseInspiration) assert.Equal(t, types.Phase("goals"), types.PhaseGoals) assert.Equal(t, types.Phase("tasks"), types.PhaseTasks) assert.Equal(t, types.Phase("run"), types.PhaseRun) assert.Equal(t, types.Phase("delivery"), types.PhaseDelivery) assert.Equal(t, types.Phase("learning"), types.PhaseLearning) } func TestAllPhases(t *testing.T) { // AllPhases is the execution pipeline — PhaseHost is excluded (it is a cross-phase service role) assert.Len(t, types.AllPhases, 6) assert.Equal(t, types.PhaseInspiration, types.AllPhases[0]) assert.Equal(t, types.PhaseGoals, types.AllPhases[1]) assert.Equal(t, types.PhaseTasks, types.AllPhases[2]) assert.Equal(t, types.PhaseRun, types.AllPhases[3]) assert.Equal(t, types.PhaseDelivery, types.AllPhases[4]) assert.Equal(t, types.PhaseLearning, types.AllPhases[5]) // AllConfigurablePhases includes PhaseHost for configuration validation assert.Len(t, types.AllConfigurablePhases, 7) assert.Contains(t, types.AllConfigurablePhases, types.PhaseHost) } func TestClockModeEnum(t *testing.T) { assert.Equal(t, types.ClockMode("times"), types.ClockTimes) assert.Equal(t, types.ClockMode("interval"), types.ClockInterval) assert.Equal(t, types.ClockMode("daemon"), types.ClockDaemon) } func TestTriggerTypeEnum(t *testing.T) { assert.Equal(t, types.TriggerType("clock"), types.TriggerClock) assert.Equal(t, types.TriggerType("human"), types.TriggerHuman) assert.Equal(t, types.TriggerType("event"), types.TriggerEvent) } func TestExecStatusEnum(t *testing.T) { assert.Equal(t, types.ExecStatus("pending"), types.ExecPending) assert.Equal(t, types.ExecStatus("running"), types.ExecRunning) assert.Equal(t, types.ExecStatus("paused"), types.ExecPaused) assert.Equal(t, types.ExecStatus("completed"), types.ExecCompleted) assert.Equal(t, types.ExecStatus("failed"), types.ExecFailed) assert.Equal(t, types.ExecStatus("cancelled"), types.ExecCancelled) } func TestRobotStatusEnum(t *testing.T) { assert.Equal(t, types.RobotStatus("idle"), types.RobotIdle) assert.Equal(t, types.RobotStatus("working"), types.RobotWorking) assert.Equal(t, types.RobotStatus("paused"), types.RobotPaused) assert.Equal(t, types.RobotStatus("error"), types.RobotError) assert.Equal(t, types.RobotStatus("maintenance"), types.RobotMaintenance) } func TestInterventionActionEnum(t *testing.T) { // Task operations assert.Equal(t, types.InterventionAction("task.add"), types.ActionTaskAdd) assert.Equal(t, types.InterventionAction("task.cancel"), types.ActionTaskCancel) assert.Equal(t, types.InterventionAction("task.update"), types.ActionTaskUpdate) // Goal operations assert.Equal(t, types.InterventionAction("goal.adjust"), types.ActionGoalAdjust) assert.Equal(t, types.InterventionAction("goal.add"), types.ActionGoalAdd) assert.Equal(t, types.InterventionAction("goal.complete"), types.ActionGoalComplete) assert.Equal(t, types.InterventionAction("goal.cancel"), types.ActionGoalCancel) // Plan operations assert.Equal(t, types.InterventionAction("plan.add"), types.ActionPlanAdd) assert.Equal(t, types.InterventionAction("plan.remove"), types.ActionPlanRemove) assert.Equal(t, types.InterventionAction("plan.update"), types.ActionPlanUpdate) // Instruction assert.Equal(t, types.InterventionAction("instruct"), types.ActionInstruct) } func TestPriorityEnum(t *testing.T) { assert.Equal(t, types.Priority("high"), types.PriorityHigh) assert.Equal(t, types.Priority("normal"), types.PriorityNormal) assert.Equal(t, types.Priority("low"), types.PriorityLow) } func TestDeliveryTypeEnum(t *testing.T) { assert.Equal(t, types.DeliveryType("email"), types.DeliveryEmail) assert.Equal(t, types.DeliveryType("webhook"), types.DeliveryWebhook) assert.Equal(t, types.DeliveryType("process"), types.DeliveryProcess) assert.Equal(t, types.DeliveryType("notify"), types.DeliveryNotify) } func TestDedupResultEnum(t *testing.T) { assert.Equal(t, types.DedupResult("skip"), types.DedupSkip) assert.Equal(t, types.DedupResult("merge"), types.DedupMerge) assert.Equal(t, types.DedupResult("proceed"), types.DedupProceed) } func TestEventSourceEnum(t *testing.T) { assert.Equal(t, types.EventSource("webhook"), types.EventWebhook) assert.Equal(t, types.EventSource("database"), types.EventDatabase) } func TestLearningTypeEnum(t *testing.T) { assert.Equal(t, types.LearningType("execution"), types.LearnExecution) assert.Equal(t, types.LearningType("feedback"), types.LearnFeedback) assert.Equal(t, types.LearningType("insight"), types.LearnInsight) } func TestTaskSourceEnum(t *testing.T) { assert.Equal(t, types.TaskSource("auto"), types.TaskSourceAuto) assert.Equal(t, types.TaskSource("human"), types.TaskSourceHuman) assert.Equal(t, types.TaskSource("event"), types.TaskSourceEvent) } func TestExecutorTypeEnum(t *testing.T) { assert.Equal(t, types.ExecutorType("assistant"), types.ExecutorAssistant) assert.Equal(t, types.ExecutorType("mcp"), types.ExecutorMCP) assert.Equal(t, types.ExecutorType("process"), types.ExecutorProcess) } func TestTaskStatusEnum(t *testing.T) { assert.Equal(t, types.TaskStatus("pending"), types.TaskPending) assert.Equal(t, types.TaskStatus("running"), types.TaskRunning) assert.Equal(t, types.TaskStatus("completed"), types.TaskCompleted) assert.Equal(t, types.TaskStatus("failed"), types.TaskFailed) assert.Equal(t, types.TaskStatus("skipped"), types.TaskSkipped) assert.Equal(t, types.TaskStatus("cancelled"), types.TaskCancelled) } func TestInsertPositionEnum(t *testing.T) { assert.Equal(t, types.InsertPosition("first"), types.InsertFirst) assert.Equal(t, types.InsertPosition("last"), types.InsertLast) assert.Equal(t, types.InsertPosition("next"), types.InsertNext) assert.Equal(t, types.InsertPosition("at"), types.InsertAt) } func TestExecutorModeEnum(t *testing.T) { assert.Equal(t, types.ExecutorMode("standard"), types.ExecutorStandard) assert.Equal(t, types.ExecutorMode("dryrun"), types.ExecutorDryRun) assert.Equal(t, types.ExecutorMode("sandbox"), types.ExecutorSandbox) } func TestExecutorModeIsValid(t *testing.T) { tests := []struct { mode types.ExecutorMode valid bool }{ {types.ExecutorStandard, true}, {types.ExecutorDryRun, true}, {types.ExecutorSandbox, true}, {"", true}, // empty is valid (defaults to standard) {types.ExecutorMode("invalid"), false}, {types.ExecutorMode("unknown"), false}, } for _, tt := range tests { t.Run(string(tt.mode), func(t *testing.T) { assert.Equal(t, tt.valid, tt.mode.IsValid()) }) } } func TestExecutorModeGetDefault(t *testing.T) { tests := []struct { mode types.ExecutorMode expected types.ExecutorMode }{ {"", types.ExecutorStandard}, {types.ExecutorStandard, types.ExecutorStandard}, {types.ExecutorDryRun, types.ExecutorDryRun}, {types.ExecutorSandbox, types.ExecutorSandbox}, } for _, tt := range tests { t.Run(string(tt.mode), func(t *testing.T) { assert.Equal(t, tt.expected, tt.mode.GetDefault()) }) } } ================================================ FILE: agent/robot/types/errors.go ================================================ package types import "errors" // ErrMissingIdentity indicates identity.role is required var ErrMissingIdentity = errors.New("identity.role is required") // ErrClockTimesEmpty indicates clock.times is required for times mode var ErrClockTimesEmpty = errors.New("clock.times is required for times mode") // ErrClockIntervalEmpty indicates clock.every is required for interval mode var ErrClockIntervalEmpty = errors.New("clock.every is required for interval mode") // ErrClockModeInvalid indicates clock.mode must be times, interval, or daemon var ErrClockModeInvalid = errors.New("clock.mode must be times, interval, or daemon") // ErrRobotNotFound indicates robot not found var ErrRobotNotFound = errors.New("robot not found") // ErrRobotPaused indicates robot is paused var ErrRobotPaused = errors.New("robot is paused") // ErrRobotBusy indicates robot has reached max concurrent executions var ErrRobotBusy = errors.New("robot has reached max concurrent executions") // ErrQuotaExceeded indicates robot quota was exceeded (atomic check failed) var ErrQuotaExceeded = errors.New("robot quota exceeded") // ErrTriggerDisabled indicates trigger type is disabled for this robot var ErrTriggerDisabled = errors.New("trigger type is disabled for this robot") // ErrExecutionCancelled indicates execution was cancelled var ErrExecutionCancelled = errors.New("execution was cancelled") // ErrExecutionTimeout indicates execution timed out var ErrExecutionTimeout = errors.New("execution timed out") // ErrPhaseAgentNotFound indicates phase agent not found var ErrPhaseAgentNotFound = errors.New("phase agent not found") // ErrGoalGenFailed indicates goal generation failed var ErrGoalGenFailed = errors.New("goal generation failed") // ErrTaskPlanFailed indicates task planning failed var ErrTaskPlanFailed = errors.New("task planning failed") // ErrDeliveryFailed indicates delivery failed var ErrDeliveryFailed = errors.New("delivery failed") // ErrExecutionSuspended is a sentinel error signaling that execution has been // suspended to wait for human input. The executor should persist state and // release its worker goroutine. NOT a failure — resumable via Resume(). var ErrExecutionSuspended = errors.New("execution suspended: waiting for human input") ================================================ FILE: agent/robot/types/host.go ================================================ package types import agentcontext "github.com/yaoapp/yao/agent/context" // HostInput is the unified input format for Host Agent (§5.7) type HostInput struct { Scenario string `json:"scenario"` // "assign" | "guide" | "clarify" Messages []agentcontext.Message `json:"messages"` // Messages from the human Context *HostContext `json:"context"` // Current execution context } // HostContext provides execution context to Host Agent. // Note: Goals is *Goals (struct with Content field), serialized as {"content":"..."}. // Host Agent prompts must expect this struct format rather than a plain string. type HostContext struct { RobotStatus *RobotStatusSnapshot `json:"robot_status,omitempty"` Goals *Goals `json:"goals,omitempty"` Tasks []Task `json:"tasks,omitempty"` CurrentTask *Task `json:"current_task,omitempty"` AgentReply string `json:"agent_reply,omitempty"` History []agentcontext.Message `json:"history,omitempty"` } // HostOutput is the structured output from Host Agent type HostOutput struct { Reply string `json:"reply"` Action HostAction `json:"action,omitempty"` ActionData interface{} `json:"action_data,omitempty"` WaitForMore bool `json:"wait_for_more,omitempty"` } ================================================ FILE: agent/robot/types/host_test.go ================================================ package types import ( "encoding/json" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestHostInputJSON(t *testing.T) { input := &HostInput{ Scenario: "assign", Context: &HostContext{ RobotStatus: &RobotStatusSnapshot{ ActiveCount: 1, MaxQuota: 5, }, Goals: &Goals{Content: "test goals"}, }, } data, err := json.Marshal(input) require.NoError(t, err) var parsed HostInput err = json.Unmarshal(data, &parsed) require.NoError(t, err) assert.Equal(t, "assign", parsed.Scenario) assert.NotNil(t, parsed.Context) assert.Equal(t, 1, parsed.Context.RobotStatus.ActiveCount) } func TestHostOutputJSON(t *testing.T) { output := &HostOutput{ Reply: "Task confirmed", Action: HostActionConfirm, WaitForMore: false, } data, err := json.Marshal(output) require.NoError(t, err) var parsed HostOutput err = json.Unmarshal(data, &parsed) require.NoError(t, err) assert.Equal(t, "Task confirmed", parsed.Reply) assert.Equal(t, HostActionConfirm, parsed.Action) assert.False(t, parsed.WaitForMore) } func TestHostOutputWithActionData(t *testing.T) { output := &HostOutput{ Reply: "I'll adjust the plan", Action: HostActionAdjust, ActionData: map[string]interface{}{"goals": "adjusted goals"}, } data, err := json.Marshal(output) require.NoError(t, err) var parsed HostOutput err = json.Unmarshal(data, &parsed) require.NoError(t, err) assert.Equal(t, HostActionAdjust, parsed.Action) assert.NotNil(t, parsed.ActionData) } ================================================ FILE: agent/robot/types/inspiration.go ================================================ package types // InspirationReport - P0 output (simple markdown for LLM) type InspirationReport struct { Clock *ClockContext `json:"clock"` // time context Content string `json:"content"` // markdown text for LLM } // Content is markdown like: // ## Summary // ... // ## Highlights // - [High] Sales up 50% // - [Medium] New lead from BigCorp // ## Opportunities // ... // ## Risks // ... // ## World News // ... // ## Pending // ... ================================================ FILE: agent/robot/types/interfaces.go ================================================ package types import "time" // ==================== Internal Interfaces ==================== // These are internal implementation interfaces, not exposed via API. // External API is defined in api/api.go // All interfaces use *Context (not context.Context) for consistency. // ExecutionControl provides pause/resume/stop control for running executions // This interface is implemented by trigger.ControlledExecution type ExecutionControl interface { // IsPaused returns true if execution is paused IsPaused() bool // IsCancelled returns true if execution is cancelled IsCancelled() bool // WaitIfPaused blocks until resumed or cancelled, returns error if cancelled WaitIfPaused() error // CheckCancelled returns ErrExecutionCancelled if cancelled CheckCancelled() error } // Manager - robot lifecycle and clock trigger management type Manager interface { Start() error Stop() error Tick(ctx *Context, now time.Time) error } // Executor - executes robot phases type Executor interface { // ExecuteWithControl runs execution with pre-generated ID and execution control (used by pool) // control: optional, allows pause/resume functionality ExecuteWithControl(ctx *Context, robot *Robot, trigger TriggerType, data interface{}, execID string, control ExecutionControl) (*Execution, error) // ExecuteWithID runs execution with a pre-generated ID but no control (for backward compatibility) ExecuteWithID(ctx *Context, robot *Robot, trigger TriggerType, data interface{}, execID string) (*Execution, error) // Execute runs execution with auto-generated ID (for direct calls) Execute(ctx *Context, robot *Robot, trigger TriggerType, data interface{}) (*Execution, error) // Resume resumes a suspended execution with human-provided input. // Returns ErrExecutionSuspended if the execution suspends again during resume. Resume(ctx *Context, execID string, reply string) error // Metrics and control (for monitoring and testing) ExecCount() int // total execution count CurrentCount() int // currently running count Reset() // reset counters } // Pool - worker pool for concurrent execution type Pool interface { Start() error Stop() error Submit(ctx *Context, robot *Robot, trigger TriggerType, data interface{}) (string, error) Running() int Queued() int } // Cache - in-memory robot cache type Cache interface { Load(ctx *Context) error Get(memberID string) *Robot List(teamID string) []*Robot Refresh(ctx *Context, memberID string) error Add(robot *Robot) Remove(memberID string) } // Dedup - deduplication check type Dedup interface { Check(ctx *Context, memberID string, trigger TriggerType) (DedupResult, error) Mark(memberID string, trigger TriggerType, window time.Duration) } // Store - data storage operations (KB, DB) type Store interface { SaveLearning(ctx *Context, memberID string, entries []LearningEntry) error GetHistory(ctx *Context, memberID string, limit int) ([]LearningEntry, error) SearchKB(ctx *Context, collections []string, query string) ([]interface{}, error) QueryDB(ctx *Context, models []string, query interface{}) ([]interface{}, error) } ================================================ FILE: agent/robot/types/request.go ================================================ package types import ( "time" agentcontext "github.com/yaoapp/yao/agent/context" ) // InterveneRequest - human intervention request type InterveneRequest struct { TeamID string `json:"team_id"` MemberID string `json:"member_id"` Action InterventionAction `json:"action"` Messages []agentcontext.Message `json:"messages"` // user input (text, images, files) PlanTime *time.Time `json:"plan_time,omitempty"` // for action=plan ExecutorMode ExecutorMode `json:"executor_mode,omitempty"` // optional: override robot config } // EventRequest - event trigger request type EventRequest struct { MemberID string `json:"member_id"` Source string `json:"source"` // webhook path or table name EventType string `json:"event_type"` // lead.created, etc. Data map[string]interface{} `json:"data"` ExecutorMode ExecutorMode `json:"executor_mode,omitempty"` // optional: override robot config } // ExecutionResult - trigger result type ExecutionResult struct { ExecutionID string `json:"execution_id"` Status ExecStatus `json:"status"` Message string `json:"message,omitempty"` } // RobotState - robot status query result type RobotState struct { MemberID string `json:"member_id"` TeamID string `json:"team_id"` DisplayName string `json:"display_name"` Status RobotStatus `json:"status"` Running int `json:"running"` // current running execution count MaxRunning int `json:"max_running"` // max concurrent allowed LastRun *time.Time `json:"last_run,omitempty"` NextRun *time.Time `json:"next_run,omitempty"` RunningIDs []string `json:"running_ids,omitempty"` // list of running execution IDs } ================================================ FILE: agent/robot/types/robot.go ================================================ package types import ( "context" "fmt" "sync" "time" agentcontext "github.com/yaoapp/yao/agent/context" ) // Robot - runtime representation of an autonomous robot (from __yao.member) // Relationship: 1 Robot : N Executions (concurrent) // Each trigger creates a new Execution (stored in __yao.agent_execution) type Robot struct { // From __yao.member MemberID string `json:"member_id"` TeamID string `json:"team_id"` DisplayName string `json:"display_name"` Bio string `json:"bio"` // Robot's description (from __yao.member.bio) SystemPrompt string `json:"system_prompt"` Status RobotStatus `json:"robot_status"` AutonomousMode bool `json:"autonomous_mode"` RobotEmail string `json:"robot_email"` // Robot's email address for sending emails LanguageModel string `json:"language_model"` // LLM connector override (from __yao.member.language_model) // Manager info (from __yao.member) ManagerID string `json:"manager_id"` // Direct manager user_id (who manages this robot) ManagerEmail string `json:"manager_email"` // Manager's email address (for default delivery) // Parsed config (from robot_config JSON field) Config *Config `json:"-"` // Runtime state LastRun time.Time `json:"-"` // last execution start time NextRun time.Time `json:"-"` // next scheduled execution (for clock trigger) // Concurrency control // Each Robot can run multiple Executions concurrently (up to Quota.Max) executions map[string]*Execution // execID -> Execution execMu sync.RWMutex } // CanRun checks if robot can accept new execution // Note: This is a read-only check. For atomic check-and-acquire, use TryAcquireSlot() func (r *Robot) CanRun() bool { r.execMu.RLock() defer r.execMu.RUnlock() if r.Config == nil { return len(r.executions) < 2 // default max } return len(r.executions) < r.Config.Quota.GetMax() } // TryAcquireSlot atomically checks if robot can run and reserves a slot // Returns true if slot was acquired, false if quota is full // This prevents race conditions between CanRun() check and AddExecution() func (r *Robot) TryAcquireSlot(exec *Execution) bool { r.execMu.Lock() defer r.execMu.Unlock() // Get max quota maxQuota := 2 // default if r.Config != nil { maxQuota = r.Config.Quota.GetMax() } // Check if we can add if len(r.executions) >= maxQuota { return false // quota full } // Reserve slot by adding execution if r.executions == nil { r.executions = make(map[string]*Execution) } r.executions[exec.ID] = exec return true } // RunningCount returns current running execution count func (r *Robot) RunningCount() int { r.execMu.RLock() defer r.execMu.RUnlock() return len(r.executions) } // AddExecution adds an execution to tracking // Note: Prefer TryAcquireSlot() for atomic check-and-add func (r *Robot) AddExecution(exec *Execution) { r.execMu.Lock() defer r.execMu.Unlock() if r.executions == nil { r.executions = make(map[string]*Execution) } r.executions[exec.ID] = exec } // RemoveExecution removes an execution from tracking func (r *Robot) RemoveExecution(execID string) { r.execMu.Lock() defer r.execMu.Unlock() delete(r.executions, execID) } // GetExecution returns an execution by ID func (r *Robot) GetExecution(execID string) *Execution { r.execMu.RLock() defer r.execMu.RUnlock() return r.executions[execID] } // GetExecutions returns all tracked executions func (r *Robot) GetExecutions() []*Execution { r.execMu.RLock() defer r.execMu.RUnlock() execs := make([]*Execution, 0, len(r.executions)) for _, exec := range r.executions { execs = append(execs, exec) } return execs } // ActiveCount returns the number of actively running executions func (r *Robot) ActiveCount() int { r.execMu.RLock() defer r.execMu.RUnlock() count := 0 for _, exec := range r.executions { if exec.Status == ExecRunning { count++ } } return count } // WaitingCount returns the number of executions waiting for human input func (r *Robot) WaitingCount() int { r.execMu.RLock() defer r.execMu.RUnlock() count := 0 for _, exec := range r.executions { if exec.Status == ExecWaiting { count++ } } return count } // ListExecutionBriefs returns brief summaries of all tracked executions func (r *Robot) ListExecutionBriefs() []ExecBrief { r.execMu.RLock() defer r.execMu.RUnlock() briefs := make([]ExecBrief, 0, len(r.executions)) for _, exec := range r.executions { brief := ExecBrief{ ID: exec.ID, Status: exec.Status, Phase: exec.Phase, Name: exec.Name, StartTime: exec.StartTime, TaskCount: len(exec.Tasks), } for _, result := range exec.Results { if result.Success { brief.DoneCount++ } else { brief.FailedCount++ } } briefs = append(briefs, brief) } return briefs } // MaxQuota returns the maximum concurrent execution quota func (r *Robot) MaxQuota() int { if r.Config == nil { return 2 } return r.Config.Quota.GetMax() } // Execution - single execution instance // Each trigger creates a new Execution, stored in ExecutionStore type Execution struct { ID string `json:"id"` // unique execution ID MemberID string `json:"member_id"` // robot member ID TeamID string `json:"team_id"` TriggerType TriggerType `json:"trigger_type"` // clock | human | event StartTime time.Time `json:"start_time"` EndTime *time.Time `json:"end_time,omitempty"` Status ExecStatus `json:"status"` Phase Phase `json:"phase"` Error string `json:"error,omitempty"` // UI display fields (updated by executor at each phase) Name string `json:"name,omitempty"` // Execution title (updated when goals complete) CurrentTaskName string `json:"current_task_name,omitempty"` // Current task description (updated during run phase) // Trigger input (stored for traceability) Input *TriggerInput `json:"input,omitempty"` // original trigger input // Phase outputs Inspiration *InspirationReport `json:"inspiration,omitempty"` // P0: markdown Goals *Goals `json:"goals,omitempty"` // P1: markdown Tasks []Task `json:"tasks,omitempty"` // P2: structured tasks Current *CurrentState `json:"current,omitempty"` // current executing state Results []TaskResult `json:"results,omitempty"` // P3: task results Delivery *DeliveryResult `json:"delivery,omitempty"` Learning []LearningEntry `json:"learning,omitempty"` // V2: Conversation and suspend-resume fields ChatID string `json:"chat_id,omitempty"` // Unique conversation ID for Host Agent WaitingTaskID string `json:"waiting_task_id,omitempty"` // Task ID that is waiting for input WaitingQuestion string `json:"waiting_question,omitempty"` // Question posed to human WaitingSince *time.Time `json:"waiting_since,omitempty"` // When execution was suspended ResumeContext *ResumeContext `json:"resume_context,omitempty"` // State for resuming suspended execution // Runtime (internal, not serialized) ctx context.Context `json:"-"` cancel context.CancelFunc `json:"-"` robot *Robot `json:"-"` } // ResumeContext holds the state needed to resume a suspended execution type ResumeContext struct { TaskIndex int `json:"task_index"` // Index of the task to resume from PreviousResults []TaskResult `json:"previous_results"` // Results from tasks completed before suspend } // ExecBrief is a lightweight summary of an execution for status snapshots type ExecBrief struct { ID string `json:"id"` Status ExecStatus `json:"status"` Phase Phase `json:"phase"` Name string `json:"name,omitempty"` StartTime time.Time `json:"start_time"` TaskCount int `json:"task_count"` DoneCount int `json:"done_count"` FailedCount int `json:"failed_count"` } // RobotStatusSnapshot provides real-time robot status for the Host Agent type RobotStatusSnapshot struct { MemberID string `json:"member_id,omitempty"` // Robot member ID Status RobotStatus `json:"status,omitempty"` // Current robot status (idle/working) ActiveCount int `json:"active_count"` // Currently running executions WaitingCount int `json:"waiting_count"` // Executions waiting for input QueuedCount int `json:"queued_count"` // Executions in queue (not yet started) MaxQuota int `json:"max_quota"` // Maximum concurrent executions ActiveExecs []ExecBrief `json:"active_execs,omitempty"` // Currently running execution summaries RecentExecs []ExecBrief `json:"recent_execs,omitempty"` // Recently completed execution summaries } // GetRobot returns the robot associated with this execution func (e *Execution) GetRobot() *Robot { return e.robot } // SetRobot sets the robot associated with this execution func (e *Execution) SetRobot(robot *Robot) { e.robot = robot } // TriggerInput - stored trigger input for traceability type TriggerInput struct { // For human intervention Action InterventionAction `json:"action,omitempty"` // task.add, goal.adjust, etc. Messages []agentcontext.Message `json:"messages,omitempty"` // user's input (text, images, files) UserID string `json:"user_id,omitempty"` // who triggered Locale string `json:"locale,omitempty"` // language for UI display (e.g., "en-US", "zh-CN") // For event trigger Source EventSource `json:"source,omitempty"` // webhook | database EventType string `json:"event_type,omitempty"` // lead.created, etc. Data map[string]interface{} `json:"data,omitempty"` // event payload // For clock trigger Clock *ClockContext `json:"clock,omitempty"` // time context when triggered } // CurrentState - current executing goal and task type CurrentState struct { Task *Task `json:"task,omitempty"` // current task being executed TaskIndex int `json:"task_index"` // index in Tasks slice Progress string `json:"progress,omitempty"` // human-readable progress (e.g., "2/5 tasks") } // Goals - P1 output (markdown for LLM + structured metadata) // P1 Agent reads InspirationReport and generates goals as markdown // Example: // ## Goals // 1. [High] Analyze sales data and identify trends // - Reason: Sales up 50%, need to understand why // // 2. [Normal] Prepare weekly report for manager // - Reason: Friday 5pm, weekly report due // // 3. [Low] Update CRM with new leads // - Reason: 3 pending leads from yesterday type Goals struct { Content string `json:"content"` // markdown text // Delivery for P4 (where to send results) Delivery *DeliveryTarget `json:"delivery,omitempty"` } // DeliveryTarget - where to deliver results (defined in P1, used in P4) type DeliveryTarget struct { Type DeliveryType `json:"type"` // email | webhook | report | notification Recipients []string `json:"recipients,omitempty"` // email addresses, webhook URLs, user IDs Format string `json:"format,omitempty"` // markdown | html | json | text Template string `json:"template,omitempty"` // template name or inline template Options map[string]interface{} `json:"options,omitempty"` // channel-specific options } // Task - planned task (structured, for execution) type Task struct { ID string `json:"id"` Description string `json:"description,omitempty"` // human-readable task description (for UI display) Messages []agentcontext.Message `json:"messages"` // original input (text, images, files) GoalRef string `json:"goal_ref,omitempty"` // reference to goal (e.g., "Goal 1") Source TaskSource `json:"source"` // auto | human | event // Executor ExecutorType ExecutorType `json:"executor_type"` ExecutorID string `json:"executor_id"` // unified ID: agent/assistant/process ID, or "mcp_server.mcp_tool" for MCP Args []any `json:"args,omitempty"` // MCP-specific fields (required when executor_type is "mcp") MCPServer string `json:"mcp_server,omitempty"` // MCP server/client ID (e.g., "ark.image.text2img") MCPTool string `json:"mcp_tool,omitempty"` // MCP tool name (e.g., "generate") // Validation (defined in P2, used in P3) // ExpectedOutput describes what the task should produce (for LLM semantic validation) ExpectedOutput string `json:"expected_output,omitempty"` // e.g., "JSON with sales_total, growth_rate fields" // ValidationRules are specific checks to perform (can be semantic or structural) ValidationRules []string `json:"validation_rules,omitempty"` // e.g., ["output must be valid JSON", "sales_total > 0"] // Runtime Status TaskStatus `json:"status"` Order int `json:"order"` // execution order (0-based) StartTime *time.Time `json:"start_time,omitempty"` EndTime *time.Time `json:"end_time,omitempty"` } // TaskResult - task execution result type TaskResult struct { TaskID string `json:"task_id"` Success bool `json:"success"` Output interface{} `json:"output,omitempty"` Error string `json:"error,omitempty"` Duration int64 `json:"duration_ms"` // Validation result (populated by Delivery Agent in P4, not by runner in V2) Validation *ValidationResult `json:"validation,omitempty"` // V2: Need-input signal from assistant (detected via Next Hook protocol) NeedInput bool `json:"need_input,omitempty"` // Assistant requests human input InputQuestion string `json:"input_question,omitempty"` // Question for the human } // ValidationResult - P3 semantic validation result type ValidationResult struct { // Basic validation result Passed bool `json:"passed"` // overall validation passed Score float64 `json:"score,omitempty"` // 0-1 confidence score Issues []string `json:"issues,omitempty"` // what failed Suggestions []string `json:"suggestions,omitempty"` // how to improve Details string `json:"details,omitempty"` // detailed validation report (markdown) // Execution state (for multi-turn conversation control) Complete bool `json:"complete"` // whether expected result is obtained NeedReply bool `json:"need_reply,omitempty"` // whether to continue conversation ReplyContent string `json:"reply_content,omitempty"` // content for next turn (if NeedReply) } // DeliveryResult - P4 delivery output (new architecture) type DeliveryResult struct { RequestID string `json:"request_id"` // Delivery request ID Content *DeliveryContent `json:"content"` // Agent-generated content Results []ChannelResult `json:"results,omitempty"` // Results per channel Success bool `json:"success"` // Overall success Error string `json:"error,omitempty"` // Error if failed SentAt *time.Time `json:"sent_at,omitempty"` // When delivery completed } // DeliveryContent - Content generated by Delivery Agent (only content, no channels) type DeliveryContent struct { Summary string `json:"summary"` // Brief 1-2 sentence summary Body string `json:"body"` // Full markdown report Attachments []DeliveryAttachment `json:"attachments,omitempty"` // Output artifacts from P3 } // DeliveryAttachment - Task output attachment with metadata type DeliveryAttachment struct { Title string `json:"title"` // Human-readable title Description string `json:"description,omitempty"` // What this artifact is TaskID string `json:"task_id,omitempty"` // Which task produced this File string `json:"file"` // Wrapper: __:// } // DeliveryRequest - pushed to Delivery Center (no channels - center decides based on preferences) type DeliveryRequest struct { Content *DeliveryContent `json:"content"` // Agent-generated content Context *DeliveryContext `json:"context"` // Tracking info } // DeliveryContext - tracking and audit info type DeliveryContext struct { MemberID string `json:"member_id"` // Robot member ID (globally unique) ExecutionID string `json:"execution_id"` // Execution ID TriggerType TriggerType `json:"trigger_type"` // clock | human | event TeamID string `json:"team_id"` // Team ID } // DeliveryPreferences - Robot/User delivery preferences (from Config) type DeliveryPreferences struct { Email *EmailPreference `json:"email,omitempty"` // Email delivery settings Webhook *WebhookPreference `json:"webhook,omitempty"` // Webhook delivery settings Process *ProcessPreference `json:"process,omitempty"` // Process delivery settings } // EmailPreference - Email delivery configuration type EmailPreference struct { Enabled bool `json:"enabled"` // Whether email delivery is enabled Targets []EmailTarget `json:"targets,omitempty"` // Multiple email targets } // EmailTarget - Single email target type EmailTarget struct { To []string `json:"to"` // Recipient addresses Template string `json:"template,omitempty"` // Email template ID Subject string `json:"subject,omitempty"` // Subject template } // WebhookPreference - Webhook delivery configuration type WebhookPreference struct { Enabled bool `json:"enabled"` // Whether webhook delivery is enabled Targets []WebhookTarget `json:"targets,omitempty"` // Multiple webhook targets } // WebhookTarget - Single webhook target type WebhookTarget struct { URL string `json:"url"` // Webhook URL Method string `json:"method,omitempty"` // HTTP method (default: POST) Headers map[string]string `json:"headers,omitempty"` // Custom headers Secret string `json:"secret,omitempty"` // Signing secret } // ProcessPreference - Process delivery configuration type ProcessPreference struct { Enabled bool `json:"enabled"` // Whether process delivery is enabled Targets []ProcessTarget `json:"targets,omitempty"` // Multiple process targets } // ProcessTarget - Single process target type ProcessTarget struct { Process string `json:"process"` // Yao Process name Args []any `json:"args,omitempty"` // Process arguments } // ChannelResult - Result of delivery to a single channel target type ChannelResult struct { Type DeliveryType `json:"type"` // email | webhook | process Target string `json:"target"` // Target identifier (email, URL, process name) Success bool `json:"success"` // Whether delivery succeeded Recipients []string `json:"recipients,omitempty"` // Who received (for email) Details interface{} `json:"details,omitempty"` // Channel-specific response Error string `json:"error,omitempty"` // Error message if failed SentAt *time.Time `json:"sent_at,omitempty"` // When this target was delivered } // LearningEntry - knowledge to save type LearningEntry struct { Type LearningType `json:"type"` // execution | feedback | insight Content string `json:"content"` Tags []string `json:"tags,omitempty"` Meta interface{} `json:"meta,omitempty"` } // NewRobotFromMap creates a Robot from a map (typically from DB record) func NewRobotFromMap(m map[string]interface{}) (*Robot, error) { memberID := getString(m, "member_id") teamID := getString(m, "team_id") // Validate required fields if memberID == "" || teamID == "" { return nil, fmt.Errorf("missing required fields: member_id or team_id") } robot := &Robot{ MemberID: memberID, TeamID: teamID, DisplayName: getString(m, "display_name"), Bio: getString(m, "bio"), SystemPrompt: getString(m, "system_prompt"), AutonomousMode: getBool(m, "autonomous_mode"), RobotEmail: getString(m, "robot_email"), ManagerID: getString(m, "manager_id"), ManagerEmail: getString(m, "manager_email"), LanguageModel: getString(m, "language_model"), } // Parse robot_status if status := getString(m, "robot_status"); status != "" { robot.Status = RobotStatus(status) } else { robot.Status = RobotIdle } // Parse robot_config JSON if configData, ok := m["robot_config"]; ok && configData != nil { config, err := ParseConfig(configData) if err != nil { return nil, fmt.Errorf("failed to parse robot_config: %w", err) } robot.Config = config } // Ensure Config exists for merging agents/mcp_servers if robot.Config == nil { robot.Config = &Config{} } if robot.Config.Resources == nil { robot.Config.Resources = &Resources{} } // Merge agents from member table into Config.Resources.Agents if agentsData, ok := m["agents"]; ok && agentsData != nil { agents := getStringSlice(agentsData) if len(agents) > 0 { robot.Config.Resources.Agents = agents } } // Merge mcp_servers from member table into Config.Resources.MCP if mcpData, ok := m["mcp_servers"]; ok && mcpData != nil { mcpServers := getStringSlice(mcpData) if len(mcpServers) > 0 { for _, serverID := range mcpServers { robot.Config.Resources.MCP = append(robot.Config.Resources.MCP, MCPConfig{ ID: serverID, }) } } } return robot, nil } // getStringSlice converts interface{} to []string func getStringSlice(v interface{}) []string { if v == nil { return nil } switch val := v.(type) { case []string: return val case []interface{}: result := make([]string, 0, len(val)) for _, item := range val { if s, ok := item.(string); ok { result = append(result, s) } } return result } return nil } // getString safely gets a string value from map func getString(m map[string]interface{}, key string) string { if m == nil { return "" } if v, ok := m[key]; ok && v != nil { if s, ok := v.(string); ok { return s } return fmt.Sprintf("%v", v) } return "" } // getBool safely gets a bool value from map func getBool(m map[string]interface{}, key string) bool { if m == nil { return false } if v, ok := m[key]; ok && v != nil { switch b := v.(type) { case bool: return b case int: return b != 0 case int64: return b != 0 case float64: return b != 0 case string: return b == "true" || b == "1" } } return false } ================================================ FILE: agent/robot/types/robot_test.go ================================================ package types_test import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/robot/types" ) func TestRobotCanRun(t *testing.T) { t.Run("can run when under quota", func(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Quota: &types.Quota{Max: 2}, }, } assert.True(t, robot.CanRun()) }) t.Run("can run with nil config (uses default quota)", func(t *testing.T) { robot := &types.Robot{ Config: nil, // nil config should not panic } // Should not panic and use default max (2) assert.True(t, robot.CanRun()) }) t.Run("can run with nil quota (uses default)", func(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Quota: nil, // nil quota should use default }, } assert.True(t, robot.CanRun()) }) t.Run("cannot run when at quota", func(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Quota: &types.Quota{Max: 2}, }, } // Add 2 executions to reach quota exec1 := &types.Execution{ID: "exec1"} exec2 := &types.Execution{ID: "exec2"} robot.AddExecution(exec1) robot.AddExecution(exec2) assert.False(t, robot.CanRun()) }) t.Run("can run after removing execution", func(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Quota: &types.Quota{Max: 2}, }, } exec1 := &types.Execution{ID: "exec1"} exec2 := &types.Execution{ID: "exec2"} robot.AddExecution(exec1) robot.AddExecution(exec2) assert.False(t, robot.CanRun()) robot.RemoveExecution("exec1") assert.True(t, robot.CanRun()) }) } func TestRobotRunningCount(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Quota: &types.Quota{Max: 5}, }, } assert.Equal(t, 0, robot.RunningCount()) exec1 := &types.Execution{ID: "exec1"} robot.AddExecution(exec1) assert.Equal(t, 1, robot.RunningCount()) exec2 := &types.Execution{ID: "exec2"} robot.AddExecution(exec2) assert.Equal(t, 2, robot.RunningCount()) robot.RemoveExecution("exec1") assert.Equal(t, 1, robot.RunningCount()) robot.RemoveExecution("exec2") assert.Equal(t, 0, robot.RunningCount()) } func TestRobotAddExecution(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Quota: &types.Quota{Max: 2}, }, } exec := &types.Execution{ ID: "exec1", MemberID: "member1", } robot.AddExecution(exec) assert.Equal(t, 1, robot.RunningCount()) retrieved := robot.GetExecution("exec1") assert.NotNil(t, retrieved) assert.Equal(t, "exec1", retrieved.ID) assert.Equal(t, "member1", retrieved.MemberID) } func TestRobotRemoveExecution(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Quota: &types.Quota{Max: 2}, }, } exec := &types.Execution{ID: "exec1"} robot.AddExecution(exec) assert.Equal(t, 1, robot.RunningCount()) robot.RemoveExecution("exec1") assert.Equal(t, 0, robot.RunningCount()) retrieved := robot.GetExecution("exec1") assert.Nil(t, retrieved) } func TestRobotGetExecution(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Quota: &types.Quota{Max: 2}, }, } t.Run("get existing execution", func(t *testing.T) { exec := &types.Execution{ ID: "exec1", MemberID: "member1", } robot.AddExecution(exec) retrieved := robot.GetExecution("exec1") assert.NotNil(t, retrieved) assert.Equal(t, "exec1", retrieved.ID) }) t.Run("get non-existing execution", func(t *testing.T) { retrieved := robot.GetExecution("non-existing") assert.Nil(t, retrieved) }) } func TestRobotGetExecutions(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Quota: &types.Quota{Max: 5}, }, } t.Run("empty executions", func(t *testing.T) { execs := robot.GetExecutions() assert.Empty(t, execs) }) t.Run("multiple executions", func(t *testing.T) { exec1 := &types.Execution{ID: "exec1"} exec2 := &types.Execution{ID: "exec2"} exec3 := &types.Execution{ID: "exec3"} robot.AddExecution(exec1) robot.AddExecution(exec2) robot.AddExecution(exec3) execs := robot.GetExecutions() assert.Len(t, execs, 3) // Check all executions are present ids := make(map[string]bool) for _, exec := range execs { ids[exec.ID] = true } assert.True(t, ids["exec1"]) assert.True(t, ids["exec2"]) assert.True(t, ids["exec3"]) }) } func TestRobotConcurrentAccess(t *testing.T) { // Test thread-safe execution management robot := &types.Robot{ Config: &types.Config{ Quota: &types.Quota{Max: 10}, }, } // Add executions concurrently done := make(chan bool) for i := 0; i < 5; i++ { go func(id int) { exec := &types.Execution{ID: string(rune('0' + id))} robot.AddExecution(exec) done <- true }(i) } // Wait for all goroutines for i := 0; i < 5; i++ { <-done } // Verify count count := robot.RunningCount() assert.Equal(t, 5, count) // Remove executions concurrently for i := 0; i < 5; i++ { go func(id int) { robot.RemoveExecution(string(rune('0' + id))) done <- true }(i) } // Wait for all goroutines for i := 0; i < 5; i++ { <-done } // Verify count count = robot.RunningCount() assert.Equal(t, 0, count) } func TestRobotTryAcquireSlot(t *testing.T) { t.Run("acquire slot when under quota", func(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Quota: &types.Quota{Max: 2}, }, } exec := &types.Execution{ID: "exec1"} acquired := robot.TryAcquireSlot(exec) assert.True(t, acquired) assert.Equal(t, 1, robot.RunningCount()) assert.NotNil(t, robot.GetExecution("exec1")) }) t.Run("fail to acquire when at quota", func(t *testing.T) { robot := &types.Robot{ Config: &types.Config{ Quota: &types.Quota{Max: 2}, }, } // Fill quota robot.TryAcquireSlot(&types.Execution{ID: "exec1"}) robot.TryAcquireSlot(&types.Execution{ID: "exec2"}) // Try to acquire one more exec3 := &types.Execution{ID: "exec3"} acquired := robot.TryAcquireSlot(exec3) assert.False(t, acquired) assert.Equal(t, 2, robot.RunningCount()) assert.Nil(t, robot.GetExecution("exec3")) }) t.Run("acquire with nil config uses default quota", func(t *testing.T) { robot := &types.Robot{ Config: nil, // default quota is 2 } exec1 := &types.Execution{ID: "exec1"} exec2 := &types.Execution{ID: "exec2"} exec3 := &types.Execution{ID: "exec3"} assert.True(t, robot.TryAcquireSlot(exec1)) assert.True(t, robot.TryAcquireSlot(exec2)) assert.False(t, robot.TryAcquireSlot(exec3)) // should fail at default max=2 }) } func TestRobotTryAcquireSlotConcurrent(t *testing.T) { // Test that TryAcquireSlot is atomic and prevents exceeding quota robot := &types.Robot{ Config: &types.Config{ Quota: &types.Quota{Max: 5}, }, } // Launch 20 goroutines trying to acquire slots successCount := make(chan bool, 20) for i := 0; i < 20; i++ { go func(id int) { exec := &types.Execution{ID: string(rune('A' + id))} success := robot.TryAcquireSlot(exec) successCount <- success }(i) } // Count successes acquired := 0 for i := 0; i < 20; i++ { if <-successCount { acquired++ } } // Should have exactly 5 successful acquisitions (quota max) assert.Equal(t, 5, acquired, "Should acquire exactly quota max slots") assert.Equal(t, 5, robot.RunningCount(), "Running count should match quota max") } func TestRobotTryAcquireSlotRaceCondition(t *testing.T) { // Stress test to verify no race condition in TryAcquireSlot for iteration := 0; iteration < 100; iteration++ { robot := &types.Robot{ Config: &types.Config{ Quota: &types.Quota{Max: 3}, }, } // Launch many goroutines simultaneously successCount := make(chan bool, 50) for i := 0; i < 50; i++ { go func(id int) { exec := &types.Execution{ID: string(rune('A'+id%26)) + string(rune('0'+id/26))} success := robot.TryAcquireSlot(exec) successCount <- success }(i) } // Count successes acquired := 0 for i := 0; i < 50; i++ { if <-successCount { acquired++ } } // Should never exceed quota assert.Equal(t, 3, acquired, "Iteration %d: Should acquire exactly quota max slots", iteration) assert.Equal(t, 3, robot.RunningCount(), "Iteration %d: Running count should match quota max", iteration) } } func TestExecutionStructure(t *testing.T) { t.Run("execution with all fields", func(t *testing.T) { exec := &types.Execution{ ID: "exec1", MemberID: "member1", TeamID: "team1", TriggerType: types.TriggerClock, Status: types.ExecRunning, Phase: types.PhaseGoals, } assert.Equal(t, "exec1", exec.ID) assert.Equal(t, "member1", exec.MemberID) assert.Equal(t, "team1", exec.TeamID) assert.Equal(t, types.TriggerClock, exec.TriggerType) assert.Equal(t, types.ExecRunning, exec.Status) assert.Equal(t, types.PhaseGoals, exec.Phase) }) t.Run("execution with trigger input", func(t *testing.T) { exec := &types.Execution{ ID: "exec1", Input: &types.TriggerInput{ Action: types.ActionTaskAdd, UserID: "user1", }, } assert.NotNil(t, exec.Input) assert.Equal(t, types.ActionTaskAdd, exec.Input.Action) assert.Equal(t, "user1", exec.Input.UserID) }) } func TestTaskStructure(t *testing.T) { task := &types.Task{ ID: "task1", GoalRef: "Goal 1", Source: types.TaskSourceAuto, ExecutorType: types.ExecutorAssistant, ExecutorID: "assistant1", Status: types.TaskPending, Order: 0, // P3 validation fields ExpectedOutput: "JSON with sales_total and growth_rate fields", ValidationRules: []string{ // Natural language rules (matched by validator) "output must be valid JSON", "must contain 'sales_total'", // Structured rule: check field type `{"type": "type", "path": "growth_rate", "value": "number"}`, }, } assert.Equal(t, "task1", task.ID) assert.Equal(t, "Goal 1", task.GoalRef) assert.Equal(t, types.TaskSourceAuto, task.Source) assert.Equal(t, types.ExecutorAssistant, task.ExecutorType) assert.Equal(t, "assistant1", task.ExecutorID) assert.Equal(t, types.TaskPending, task.Status) assert.Equal(t, 0, task.Order) // Validation fields assert.Contains(t, task.ExpectedOutput, "sales_total") assert.Len(t, task.ValidationRules, 3) } func TestGoalsStructure(t *testing.T) { goals := &types.Goals{ Content: "## Goals\n1. [High] Complete project\n2. [Normal] Review code", Delivery: &types.DeliveryTarget{ Type: types.DeliveryEmail, Recipients: []string{"team@example.com"}, Format: "markdown", }, } assert.Contains(t, goals.Content, "Goals") assert.Contains(t, goals.Content, "Complete project") assert.NotNil(t, goals.Delivery) assert.Equal(t, types.DeliveryEmail, goals.Delivery.Type) } func TestTaskResultStructure(t *testing.T) { result := &types.TaskResult{ TaskID: "task1", Success: true, Output: "Task completed successfully", Duration: 1500, Validation: &types.ValidationResult{ Passed: true, Score: 0.98, }, } assert.Equal(t, "task1", result.TaskID) assert.True(t, result.Success) assert.Equal(t, "Task completed successfully", result.Output) assert.Equal(t, int64(1500), result.Duration) assert.NotNil(t, result.Validation) assert.True(t, result.Validation.Passed) assert.Equal(t, 0.98, result.Validation.Score) } func TestValidationResultStructure(t *testing.T) { validation := &types.ValidationResult{ Passed: false, Score: 0.45, Issues: []string{"Missing required field: sales_total", "Growth rate is negative"}, Suggestions: []string{"Add sales_total calculation", "Verify data source"}, Details: "Detailed validation report...", } assert.False(t, validation.Passed) assert.Equal(t, 0.45, validation.Score) assert.Len(t, validation.Issues, 2) assert.Contains(t, validation.Issues[0], "sales_total") assert.Len(t, validation.Suggestions, 2) } func TestValidationResultMultiTurnFields(t *testing.T) { // Test new multi-turn conversation control fields t.Run("complete and passed", func(t *testing.T) { validation := &types.ValidationResult{ Passed: true, Score: 0.95, Complete: true, } assert.True(t, validation.Passed) assert.True(t, validation.Complete) assert.False(t, validation.NeedReply) assert.Empty(t, validation.ReplyContent) }) t.Run("passed but not complete - needs reply", func(t *testing.T) { validation := &types.ValidationResult{ Passed: true, Score: 0.7, Complete: false, NeedReply: true, ReplyContent: "Please continue and provide the complete result.", } assert.True(t, validation.Passed) assert.False(t, validation.Complete) assert.True(t, validation.NeedReply) assert.NotEmpty(t, validation.ReplyContent) }) t.Run("failed with suggestions - needs reply", func(t *testing.T) { validation := &types.ValidationResult{ Passed: false, Score: 0.3, Complete: false, Issues: []string{"Missing required field"}, Suggestions: []string{"Add the field"}, NeedReply: true, ReplyContent: "## Validation Feedback\n\nPlease fix: Missing required field", } assert.False(t, validation.Passed) assert.False(t, validation.Complete) assert.True(t, validation.NeedReply) assert.Contains(t, validation.ReplyContent, "Validation Feedback") }) t.Run("failed without suggestions - no reply", func(t *testing.T) { validation := &types.ValidationResult{ Passed: false, Score: 0.0, Complete: false, Issues: []string{"Critical error: invalid format"}, NeedReply: false, } assert.False(t, validation.Passed) assert.False(t, validation.Complete) assert.False(t, validation.NeedReply) assert.Empty(t, validation.ReplyContent) }) } func TestDeliveryResultStructure(t *testing.T) { sentAt := time.Now() delivery := &types.DeliveryResult{ RequestID: "req-12345", Content: &types.DeliveryContent{ Summary: "Weekly sales report completed", Body: "# Weekly Report\n\nSales increased by 20%...", Attachments: []types.DeliveryAttachment{ { Title: "Sales Report", Description: "Detailed sales analysis", TaskID: "task-1", File: "__s3://report-12345.pdf", }, }, }, Results: []types.ChannelResult{ { Type: types.DeliveryEmail, Target: "user@example.com", Success: true, Details: map[string]interface{}{ "message_id": "msg-12345", }, }, { Type: types.DeliveryWebhook, Target: "https://webhook.example.com/notify", Success: true, }, }, Success: true, SentAt: &sentAt, } assert.Equal(t, "req-12345", delivery.RequestID) assert.True(t, delivery.Success) assert.NotNil(t, delivery.Content) assert.Equal(t, "Weekly sales report completed", delivery.Content.Summary) assert.Contains(t, delivery.Content.Body, "Weekly Report") assert.Len(t, delivery.Content.Attachments, 1) assert.Equal(t, "__s3://report-12345.pdf", delivery.Content.Attachments[0].File) assert.Len(t, delivery.Results, 2) assert.Equal(t, types.DeliveryEmail, delivery.Results[0].Type) assert.NotNil(t, delivery.SentAt) } func TestDeliveryContentStructure(t *testing.T) { content := &types.DeliveryContent{ Summary: "Task execution completed successfully", Body: "# Execution Report\n\n## Summary\n- 3 tasks completed\n- 1 task failed", Attachments: []types.DeliveryAttachment{ { Title: "Analysis Results", Description: "JSON data from analysis task", TaskID: "task-analysis", File: "__local://files/analysis-result.json", }, { Title: "Generated Chart", TaskID: "task-chart", File: "__s3://charts/sales-chart.png", }, }, } assert.NotEmpty(t, content.Summary) assert.Contains(t, content.Body, "Execution Report") assert.Len(t, content.Attachments, 2) assert.Equal(t, "Analysis Results", content.Attachments[0].Title) assert.Equal(t, "task-analysis", content.Attachments[0].TaskID) } func TestDeliveryAttachmentStructure(t *testing.T) { attachment := &types.DeliveryAttachment{ Title: "Sales Report PDF", Description: "Monthly sales analysis report", TaskID: "task-report", File: "__s3://reports/sales-2024-01.pdf", } assert.Equal(t, "Sales Report PDF", attachment.Title) assert.Equal(t, "Monthly sales analysis report", attachment.Description) assert.Equal(t, "task-report", attachment.TaskID) assert.Contains(t, attachment.File, "__s3://") } func TestDeliveryRequestStructure(t *testing.T) { request := &types.DeliveryRequest{ Content: &types.DeliveryContent{ Summary: "Report ready", Body: "# Report\n\nDetails...", }, Context: &types.DeliveryContext{ MemberID: "member-123", ExecutionID: "exec-456", TriggerType: types.TriggerClock, TeamID: "team-789", }, } assert.NotNil(t, request.Content) assert.NotNil(t, request.Context) assert.Equal(t, "member-123", request.Context.MemberID) assert.Equal(t, "exec-456", request.Context.ExecutionID) assert.Equal(t, types.TriggerClock, request.Context.TriggerType) } func TestDeliveryPreferencesStructure(t *testing.T) { prefs := &types.DeliveryPreferences{ Email: &types.EmailPreference{ Enabled: true, Targets: []types.EmailTarget{ { To: []string{"team@example.com"}, Template: "weekly-report", Subject: "Weekly Report - {{.Date}}", }, { To: []string{"backup@example.com"}, }, }, }, Webhook: &types.WebhookPreference{ Enabled: true, Targets: []types.WebhookTarget{ { URL: "https://api.example.com/webhook", Method: "POST", Headers: map[string]string{ "X-API-Key": "secret-key", }, Secret: "signing-secret", }, }, }, Process: &types.ProcessPreference{ Enabled: true, Targets: []types.ProcessTarget{ { Process: "scripts.notify.slack", Args: []any{"#general", "Report ready"}, }, }, }, } // Email assert.True(t, prefs.Email.Enabled) assert.Len(t, prefs.Email.Targets, 2) assert.Equal(t, "weekly-report", prefs.Email.Targets[0].Template) assert.Len(t, prefs.Email.Targets[0].To, 1) // Webhook assert.True(t, prefs.Webhook.Enabled) assert.Len(t, prefs.Webhook.Targets, 1) assert.Equal(t, "https://api.example.com/webhook", prefs.Webhook.Targets[0].URL) assert.Equal(t, "POST", prefs.Webhook.Targets[0].Method) // Process assert.True(t, prefs.Process.Enabled) assert.Len(t, prefs.Process.Targets, 1) assert.Equal(t, "scripts.notify.slack", prefs.Process.Targets[0].Process) assert.Len(t, prefs.Process.Targets[0].Args, 2) } func TestChannelResultStructure(t *testing.T) { t.Run("email result with recipients", func(t *testing.T) { sentAt := time.Now() result := &types.ChannelResult{ Type: types.DeliveryEmail, Target: "user@example.com", Success: true, Recipients: []string{"user@example.com", "manager@example.com"}, Details: map[string]interface{}{ "message_id": "msg-123", }, SentAt: &sentAt, } assert.Equal(t, types.DeliveryEmail, result.Type) assert.Equal(t, "user@example.com", result.Target) assert.True(t, result.Success) assert.Len(t, result.Recipients, 2) assert.NotNil(t, result.SentAt) }) t.Run("webhook result", func(t *testing.T) { sentAt := time.Now() result := &types.ChannelResult{ Type: types.DeliveryWebhook, Target: "https://api.example.com/webhook", Success: true, Details: map[string]interface{}{ "status_code": 200, "response": "OK", }, SentAt: &sentAt, } assert.Equal(t, types.DeliveryWebhook, result.Type) assert.True(t, result.Success) assert.NotNil(t, result.SentAt) }) t.Run("process result", func(t *testing.T) { result := &types.ChannelResult{ Type: types.DeliveryProcess, Target: "scripts.notify.slack", Success: true, Details: map[string]interface{}{ "output": "Message sent", }, } assert.Equal(t, types.DeliveryProcess, result.Type) assert.Equal(t, "scripts.notify.slack", result.Target) }) t.Run("failed result", func(t *testing.T) { result := &types.ChannelResult{ Type: types.DeliveryWebhook, Target: "https://api.example.com/webhook", Success: false, Error: "Connection refused", } assert.False(t, result.Success) assert.Equal(t, "Connection refused", result.Error) }) } func TestDeliveryTargetStructure(t *testing.T) { delivery := &types.DeliveryTarget{ Type: types.DeliveryEmail, Recipients: []string{"team@example.com"}, Format: "markdown", Template: "weekly-report", Options: map[string]interface{}{ "cc": []string{"manager@example.com"}, }, } assert.Equal(t, types.DeliveryEmail, delivery.Type) assert.Len(t, delivery.Recipients, 1) assert.Equal(t, "markdown", delivery.Format) assert.Equal(t, "weekly-report", delivery.Template) } func TestLearningEntryStructure(t *testing.T) { entry := &types.LearningEntry{ Type: types.LearnExecution, Content: "Successfully completed task using assistant", Tags: []string{"success", "assistant"}, Meta: map[string]interface{}{ "duration": 1500, "phase": "run", }, } assert.Equal(t, types.LearnExecution, entry.Type) assert.Equal(t, "Successfully completed task using assistant", entry.Content) assert.Len(t, entry.Tags, 2) assert.NotNil(t, entry.Meta) } ================================================ FILE: agent/robot/utils/convert.go ================================================ package utils import ( "encoding/json" "fmt" "time" ) // ==================== To Functions ==================== // Convert any value to specified type (safe, returns zero value on failure) // ToString converts any value to string func ToString(v interface{}) string { if v == nil { return "" } switch val := v.(type) { case string: return val case []byte: return string(val) case int: return fmt.Sprintf("%d", val) case int8: return fmt.Sprintf("%d", val) case int16: return fmt.Sprintf("%d", val) case int32: return fmt.Sprintf("%d", val) case int64: return fmt.Sprintf("%d", val) case uint: return fmt.Sprintf("%d", val) case uint8: return fmt.Sprintf("%d", val) case uint16: return fmt.Sprintf("%d", val) case uint32: return fmt.Sprintf("%d", val) case uint64: return fmt.Sprintf("%d", val) case float32: return fmt.Sprintf("%g", val) case float64: return fmt.Sprintf("%g", val) case bool: if val { return "true" } return "false" default: if str, err := json.Marshal(v); err == nil { return string(str) } return fmt.Sprintf("%v", v) } } // ToBool converts any value to bool func ToBool(v interface{}) bool { if v == nil { return false } switch b := v.(type) { case bool: return b case int: return b != 0 case int8: return b != 0 case int16: return b != 0 case int32: return b != 0 case int64: return b != 0 case uint: return b != 0 case uint8: return b != 0 case uint16: return b != 0 case uint32: return b != 0 case uint64: return b != 0 case float32: return b != 0 case float64: return b != 0 case string: return b == "true" || b == "1" || b == "yes" || b == "on" } return false } // ToInt converts any value to int func ToInt(v interface{}) int { if v == nil { return 0 } switch n := v.(type) { case int: return n case int8: return int(n) case int16: return int(n) case int32: return int(n) case int64: return int(n) case uint: return int(n) case uint8: return int(n) case uint16: return int(n) case uint32: return int(n) case uint64: return int(n) case float32: return int(n) case float64: return int(n) case string: var i int fmt.Sscanf(n, "%d", &i) return i case bool: if n { return 1 } return 0 } return 0 } // ToInt64 converts any value to int64 func ToInt64(v interface{}) int64 { if v == nil { return 0 } switch n := v.(type) { case int64: return n case int: return int64(n) case int8: return int64(n) case int16: return int64(n) case int32: return int64(n) case uint: return int64(n) case uint8: return int64(n) case uint16: return int64(n) case uint32: return int64(n) case uint64: return int64(n) case float32: return int64(n) case float64: return int64(n) case string: var i int64 fmt.Sscanf(n, "%d", &i) return i case bool: if n { return 1 } return 0 } return 0 } // ToFloat64 converts any value to float64 func ToFloat64(v interface{}) float64 { if v == nil { return 0 } switch f := v.(type) { case float64: return f case float32: return float64(f) case int: return float64(f) case int8: return float64(f) case int16: return float64(f) case int32: return float64(f) case int64: return float64(f) case uint: return float64(f) case uint8: return float64(f) case uint16: return float64(f) case uint32: return float64(f) case uint64: return float64(f) case string: var result float64 fmt.Sscanf(f, "%f", &result) return result case bool: if f { return 1 } return 0 } return 0 } // ToTimestamp converts any value to *time.Time // Handles: time.Time, *time.Time, string (various formats), int64/float64 (unix timestamp) func ToTimestamp(v interface{}) *time.Time { if v == nil { return nil } switch t := v.(type) { case time.Time: return &t case *time.Time: return t case string: if t == "" { return nil } // Try common time formats formats := []string{ time.RFC3339, time.RFC3339Nano, "2006-01-02 15:04:05", "2006-01-02T15:04:05Z", "2006-01-02T15:04:05", "2006-01-02", } for _, format := range formats { if parsed, err := time.Parse(format, t); err == nil { return &parsed } } case int64: // Unix timestamp (seconds) parsed := time.Unix(t, 0) return &parsed case int: parsed := time.Unix(int64(t), 0) return &parsed case float64: // Unix timestamp (seconds as float) parsed := time.Unix(int64(t), 0) return &parsed } return nil } // ToJSONValue parses JSON from string/[]byte or returns already-parsed value func ToJSONValue(v interface{}) interface{} { if v == nil { return nil } switch data := v.(type) { case string: if data == "" { return nil } var result interface{} if err := json.Unmarshal([]byte(data), &result); err != nil { return nil } return result case []byte: if len(data) == 0 { return nil } var result interface{} if err := json.Unmarshal(data, &result); err != nil { return nil } return result case map[string]interface{}, []interface{}: // Already parsed return data default: return v } } // ==================== Get Functions ==================== // Safely get typed value from map[string]interface{} // GetString safely gets a string value from map func GetString(m map[string]interface{}, key string) string { if m == nil { return "" } if v, ok := m[key]; ok { return ToString(v) } return "" } // GetBool safely gets a bool value from map func GetBool(m map[string]interface{}, key string) bool { if m == nil { return false } if v, ok := m[key]; ok { return ToBool(v) } return false } // GetInt safely gets an int value from map func GetInt(m map[string]interface{}, key string) int { if m == nil { return 0 } if v, ok := m[key]; ok { return ToInt(v) } return 0 } // GetInt64 safely gets an int64 value from map func GetInt64(m map[string]interface{}, key string) int64 { if m == nil { return 0 } if v, ok := m[key]; ok { return ToInt64(v) } return 0 } // GetFloat64 safely gets a float64 value from map func GetFloat64(m map[string]interface{}, key string) float64 { if m == nil { return 0 } if v, ok := m[key]; ok { return ToFloat64(v) } return 0 } // GetTimestamp safely gets a *time.Time value from map func GetTimestamp(m map[string]interface{}, key string) *time.Time { if m == nil { return nil } if v, ok := m[key]; ok { return ToTimestamp(v) } return nil } // GetJSONValue safely gets a parsed JSON value from map func GetJSONValue(m map[string]interface{}, key string) interface{} { if m == nil { return nil } if v, ok := m[key]; ok { return ToJSONValue(v) } return nil } // ==================== JSON/Map Conversion ==================== // ToJSON converts any value to JSON string func ToJSON(v interface{}) (string, error) { data, err := json.Marshal(v) if err != nil { return "", err } return string(data), nil } // FromJSON parses JSON string to target struct func FromJSON(jsonStr string, target interface{}) error { return json.Unmarshal([]byte(jsonStr), target) } // ToMap converts struct to map[string]interface{} func ToMap(v interface{}) (map[string]interface{}, error) { data, err := json.Marshal(v) if err != nil { return nil, err } var result map[string]interface{} if err := json.Unmarshal(data, &result); err != nil { return nil, err } return result, nil } // FromMap converts map to struct func FromMap(m map[string]interface{}, target interface{}) error { data, err := json.Marshal(m) if err != nil { return err } return json.Unmarshal(data, target) } // ==================== Map Utilities ==================== // MergeMap merges source map into target map (shallow copy) func MergeMap(target, source map[string]interface{}) map[string]interface{} { if target == nil { target = make(map[string]interface{}) } for k, v := range source { target[k] = v } return target } // CloneMap creates a shallow copy of a map func CloneMap(m map[string]interface{}) map[string]interface{} { if m == nil { return nil } result := make(map[string]interface{}, len(m)) for k, v := range m { result[k] = v } return result } ================================================ FILE: agent/robot/utils/convert_test.go ================================================ package utils_test import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/robot/utils" ) // ==================== To Tests ==================== func TestToBool(t *testing.T) { t.Run("from_bool", func(t *testing.T) { assert.True(t, utils.ToBool(true)) assert.False(t, utils.ToBool(false)) }) t.Run("from_int", func(t *testing.T) { assert.True(t, utils.ToBool(1)) assert.True(t, utils.ToBool(42)) assert.False(t, utils.ToBool(0)) }) t.Run("from_int64", func(t *testing.T) { assert.True(t, utils.ToBool(int64(1))) assert.False(t, utils.ToBool(int64(0))) }) t.Run("from_float64", func(t *testing.T) { assert.True(t, utils.ToBool(1.0)) assert.True(t, utils.ToBool(0.1)) assert.False(t, utils.ToBool(0.0)) }) t.Run("from_string", func(t *testing.T) { assert.True(t, utils.ToBool("true")) assert.True(t, utils.ToBool("1")) assert.True(t, utils.ToBool("yes")) assert.True(t, utils.ToBool("on")) assert.False(t, utils.ToBool("false")) assert.False(t, utils.ToBool("0")) assert.False(t, utils.ToBool("")) }) t.Run("from_nil", func(t *testing.T) { assert.False(t, utils.ToBool(nil)) }) t.Run("from_unsupported_type", func(t *testing.T) { assert.False(t, utils.ToBool([]int{1, 2, 3})) }) } func TestToInt(t *testing.T) { t.Run("from_int", func(t *testing.T) { assert.Equal(t, 42, utils.ToInt(42)) assert.Equal(t, -10, utils.ToInt(-10)) }) t.Run("from_int64", func(t *testing.T) { assert.Equal(t, 100, utils.ToInt(int64(100))) }) t.Run("from_float64", func(t *testing.T) { assert.Equal(t, 42, utils.ToInt(42.9)) // truncates assert.Equal(t, -5, utils.ToInt(-5.7)) }) t.Run("from_string", func(t *testing.T) { assert.Equal(t, 123, utils.ToInt("123")) assert.Equal(t, -456, utils.ToInt("-456")) assert.Equal(t, 0, utils.ToInt("invalid")) }) t.Run("from_bool", func(t *testing.T) { assert.Equal(t, 1, utils.ToInt(true)) assert.Equal(t, 0, utils.ToInt(false)) }) t.Run("from_nil", func(t *testing.T) { assert.Equal(t, 0, utils.ToInt(nil)) }) } func TestToInt64(t *testing.T) { t.Run("from_int64", func(t *testing.T) { assert.Equal(t, int64(9223372036854775807), utils.ToInt64(int64(9223372036854775807))) }) t.Run("from_int", func(t *testing.T) { assert.Equal(t, int64(42), utils.ToInt64(42)) }) t.Run("from_float64", func(t *testing.T) { assert.Equal(t, int64(42), utils.ToInt64(42.9)) }) t.Run("from_string", func(t *testing.T) { assert.Equal(t, int64(123456789), utils.ToInt64("123456789")) }) t.Run("from_nil", func(t *testing.T) { assert.Equal(t, int64(0), utils.ToInt64(nil)) }) } func TestToFloat64(t *testing.T) { t.Run("from_float64", func(t *testing.T) { assert.Equal(t, 3.14159, utils.ToFloat64(3.14159)) }) t.Run("from_float32", func(t *testing.T) { assert.InDelta(t, 3.14, utils.ToFloat64(float32(3.14)), 0.001) }) t.Run("from_int", func(t *testing.T) { assert.Equal(t, 42.0, utils.ToFloat64(42)) }) t.Run("from_int64", func(t *testing.T) { assert.Equal(t, 100.0, utils.ToFloat64(int64(100))) }) t.Run("from_string", func(t *testing.T) { assert.InDelta(t, 3.14, utils.ToFloat64("3.14"), 0.001) assert.Equal(t, 0.0, utils.ToFloat64("invalid")) }) t.Run("from_bool", func(t *testing.T) { assert.Equal(t, 1.0, utils.ToFloat64(true)) assert.Equal(t, 0.0, utils.ToFloat64(false)) }) t.Run("from_nil", func(t *testing.T) { assert.Equal(t, 0.0, utils.ToFloat64(nil)) }) } func TestToTimestamp(t *testing.T) { t.Run("from_time_Time", func(t *testing.T) { now := time.Now() result := utils.ToTimestamp(now) assert.NotNil(t, result) assert.Equal(t, now.Unix(), result.Unix()) }) t.Run("from_time_Time_pointer", func(t *testing.T) { now := time.Now() result := utils.ToTimestamp(&now) assert.NotNil(t, result) assert.Equal(t, now.Unix(), result.Unix()) }) t.Run("from_RFC3339_string", func(t *testing.T) { result := utils.ToTimestamp("2024-01-15T14:30:00Z") assert.NotNil(t, result) assert.Equal(t, 2024, result.Year()) assert.Equal(t, time.January, result.Month()) assert.Equal(t, 15, result.Day()) assert.Equal(t, 14, result.Hour()) assert.Equal(t, 30, result.Minute()) }) t.Run("from_datetime_string", func(t *testing.T) { result := utils.ToTimestamp("2024-01-15 14:30:00") assert.NotNil(t, result) assert.Equal(t, 2024, result.Year()) }) t.Run("from_date_string", func(t *testing.T) { result := utils.ToTimestamp("2024-01-15") assert.NotNil(t, result) assert.Equal(t, 2024, result.Year()) assert.Equal(t, 15, result.Day()) }) t.Run("from_unix_timestamp_int64", func(t *testing.T) { // 2024-01-15 00:00:00 UTC result := utils.ToTimestamp(int64(1705276800)) assert.NotNil(t, result) assert.Equal(t, 2024, result.Year()) }) t.Run("from_unix_timestamp_float64", func(t *testing.T) { result := utils.ToTimestamp(float64(1705276800)) assert.NotNil(t, result) assert.Equal(t, 2024, result.Year()) }) t.Run("from_empty_string", func(t *testing.T) { result := utils.ToTimestamp("") assert.Nil(t, result) }) t.Run("from_invalid_string", func(t *testing.T) { result := utils.ToTimestamp("not a date") assert.Nil(t, result) }) t.Run("from_nil", func(t *testing.T) { result := utils.ToTimestamp(nil) assert.Nil(t, result) }) } func TestToJSONValue(t *testing.T) { t.Run("from_json_string_object", func(t *testing.T) { result := utils.ToJSONValue(`{"name":"test","age":30}`) assert.NotNil(t, result) m, ok := result.(map[string]interface{}) assert.True(t, ok) assert.Equal(t, "test", m["name"]) assert.Equal(t, float64(30), m["age"]) }) t.Run("from_json_string_array", func(t *testing.T) { result := utils.ToJSONValue(`["a","b","c"]`) assert.NotNil(t, result) arr, ok := result.([]interface{}) assert.True(t, ok) assert.Len(t, arr, 3) assert.Equal(t, "a", arr[0]) }) t.Run("from_bytes", func(t *testing.T) { result := utils.ToJSONValue([]byte(`{"key":"value"}`)) assert.NotNil(t, result) m, ok := result.(map[string]interface{}) assert.True(t, ok) assert.Equal(t, "value", m["key"]) }) t.Run("from_already_parsed_map", func(t *testing.T) { input := map[string]interface{}{"foo": "bar"} result := utils.ToJSONValue(input) assert.Equal(t, input, result) }) t.Run("from_already_parsed_array", func(t *testing.T) { input := []interface{}{"a", "b"} result := utils.ToJSONValue(input) assert.Equal(t, input, result) }) t.Run("from_empty_string", func(t *testing.T) { result := utils.ToJSONValue("") assert.Nil(t, result) }) t.Run("from_empty_bytes", func(t *testing.T) { result := utils.ToJSONValue([]byte{}) assert.Nil(t, result) }) t.Run("from_invalid_json", func(t *testing.T) { result := utils.ToJSONValue("not json") assert.Nil(t, result) }) t.Run("from_nil", func(t *testing.T) { result := utils.ToJSONValue(nil) assert.Nil(t, result) }) t.Run("from_other_type_passthrough", func(t *testing.T) { // Non-string, non-[]byte types are passed through result := utils.ToJSONValue(42) assert.Equal(t, 42, result) }) } // ==================== Get Tests ==================== func TestGetString(t *testing.T) { m := map[string]interface{}{ "name": "test", "number": 42, "bool": true, "nil": nil, } t.Run("existing_string_key", func(t *testing.T) { assert.Equal(t, "test", utils.GetString(m, "name")) }) t.Run("converts_number_to_string", func(t *testing.T) { assert.Equal(t, "42", utils.GetString(m, "number")) }) t.Run("converts_bool_to_string", func(t *testing.T) { assert.Equal(t, "true", utils.GetString(m, "bool")) }) t.Run("non_existent_key", func(t *testing.T) { assert.Equal(t, "", utils.GetString(m, "missing")) }) t.Run("nil_map", func(t *testing.T) { assert.Equal(t, "", utils.GetString(nil, "key")) }) t.Run("nil_value", func(t *testing.T) { assert.Equal(t, "", utils.GetString(m, "nil")) }) } func TestGetBool(t *testing.T) { m := map[string]interface{}{ "bool_true": true, "bool_false": false, "int_one": 1, "int_zero": 0, "string_true": "true", } t.Run("bool_true", func(t *testing.T) { assert.True(t, utils.GetBool(m, "bool_true")) }) t.Run("bool_false", func(t *testing.T) { assert.False(t, utils.GetBool(m, "bool_false")) }) t.Run("int_one", func(t *testing.T) { assert.True(t, utils.GetBool(m, "int_one")) }) t.Run("int_zero", func(t *testing.T) { assert.False(t, utils.GetBool(m, "int_zero")) }) t.Run("string_true", func(t *testing.T) { assert.True(t, utils.GetBool(m, "string_true")) }) t.Run("non_existent_key", func(t *testing.T) { assert.False(t, utils.GetBool(m, "missing")) }) t.Run("nil_map", func(t *testing.T) { assert.False(t, utils.GetBool(nil, "key")) }) } func TestGetInt(t *testing.T) { m := map[string]interface{}{ "int": 42, "int64": int64(100), "float64": 3.14, "string": "123", } t.Run("int", func(t *testing.T) { assert.Equal(t, 42, utils.GetInt(m, "int")) }) t.Run("int64", func(t *testing.T) { assert.Equal(t, 100, utils.GetInt(m, "int64")) }) t.Run("float64", func(t *testing.T) { assert.Equal(t, 3, utils.GetInt(m, "float64")) }) t.Run("string", func(t *testing.T) { assert.Equal(t, 123, utils.GetInt(m, "string")) }) t.Run("non_existent_key", func(t *testing.T) { assert.Equal(t, 0, utils.GetInt(m, "missing")) }) t.Run("nil_map", func(t *testing.T) { assert.Equal(t, 0, utils.GetInt(nil, "key")) }) } func TestGetInt64(t *testing.T) { m := map[string]interface{}{ "int64": int64(9223372036854775807), "int": 42, "string": "123456789", } t.Run("int64", func(t *testing.T) { assert.Equal(t, int64(9223372036854775807), utils.GetInt64(m, "int64")) }) t.Run("int", func(t *testing.T) { assert.Equal(t, int64(42), utils.GetInt64(m, "int")) }) t.Run("string", func(t *testing.T) { assert.Equal(t, int64(123456789), utils.GetInt64(m, "string")) }) t.Run("nil_map", func(t *testing.T) { assert.Equal(t, int64(0), utils.GetInt64(nil, "key")) }) } func TestGetFloat64(t *testing.T) { m := map[string]interface{}{ "float64": 3.14159, "int": 42, "string": "2.718", } t.Run("float64", func(t *testing.T) { assert.Equal(t, 3.14159, utils.GetFloat64(m, "float64")) }) t.Run("int", func(t *testing.T) { assert.Equal(t, 42.0, utils.GetFloat64(m, "int")) }) t.Run("string", func(t *testing.T) { assert.InDelta(t, 2.718, utils.GetFloat64(m, "string"), 0.001) }) t.Run("nil_map", func(t *testing.T) { assert.Equal(t, 0.0, utils.GetFloat64(nil, "key")) }) } func TestGetTimestamp(t *testing.T) { now := time.Now() m := map[string]interface{}{ "time": now, "time_ptr": &now, "rfc3339": "2024-01-15T14:30:00Z", "unix": int64(1705276800), "empty": "", "nil_value": nil, } t.Run("time_value", func(t *testing.T) { result := utils.GetTimestamp(m, "time") assert.NotNil(t, result) assert.Equal(t, now.Unix(), result.Unix()) }) t.Run("time_ptr", func(t *testing.T) { result := utils.GetTimestamp(m, "time_ptr") assert.NotNil(t, result) }) t.Run("rfc3339_string", func(t *testing.T) { result := utils.GetTimestamp(m, "rfc3339") assert.NotNil(t, result) assert.Equal(t, 2024, result.Year()) }) t.Run("unix_timestamp", func(t *testing.T) { result := utils.GetTimestamp(m, "unix") assert.NotNil(t, result) }) t.Run("empty_string", func(t *testing.T) { result := utils.GetTimestamp(m, "empty") assert.Nil(t, result) }) t.Run("nil_value", func(t *testing.T) { result := utils.GetTimestamp(m, "nil_value") assert.Nil(t, result) }) t.Run("non_existent_key", func(t *testing.T) { result := utils.GetTimestamp(m, "missing") assert.Nil(t, result) }) t.Run("nil_map", func(t *testing.T) { result := utils.GetTimestamp(nil, "key") assert.Nil(t, result) }) } func TestGetJSONValue(t *testing.T) { m := map[string]interface{}{ "json_string": `{"nested":"value"}`, "json_array": `[1,2,3]`, "parsed_map": map[string]interface{}{"foo": "bar"}, "empty": "", "invalid": "not json", } t.Run("json_string", func(t *testing.T) { result := utils.GetJSONValue(m, "json_string") assert.NotNil(t, result) nested, ok := result.(map[string]interface{}) assert.True(t, ok) assert.Equal(t, "value", nested["nested"]) }) t.Run("json_array", func(t *testing.T) { result := utils.GetJSONValue(m, "json_array") assert.NotNil(t, result) arr, ok := result.([]interface{}) assert.True(t, ok) assert.Len(t, arr, 3) }) t.Run("parsed_map", func(t *testing.T) { result := utils.GetJSONValue(m, "parsed_map") assert.NotNil(t, result) parsed, ok := result.(map[string]interface{}) assert.True(t, ok) assert.Equal(t, "bar", parsed["foo"]) }) t.Run("empty_string", func(t *testing.T) { result := utils.GetJSONValue(m, "empty") assert.Nil(t, result) }) t.Run("invalid_json", func(t *testing.T) { result := utils.GetJSONValue(m, "invalid") assert.Nil(t, result) }) t.Run("nil_map", func(t *testing.T) { result := utils.GetJSONValue(nil, "key") assert.Nil(t, result) }) } // ==================== ToString Extended Tests ==================== func TestToStringExtended(t *testing.T) { t.Run("from_nil", func(t *testing.T) { assert.Equal(t, "", utils.ToString(nil)) }) t.Run("from_bytes", func(t *testing.T) { assert.Equal(t, "hello", utils.ToString([]byte("hello"))) }) t.Run("from_int_types", func(t *testing.T) { assert.Equal(t, "8", utils.ToString(int8(8))) assert.Equal(t, "16", utils.ToString(int16(16))) assert.Equal(t, "32", utils.ToString(int32(32))) assert.Equal(t, "64", utils.ToString(int64(64))) }) t.Run("from_uint_types", func(t *testing.T) { assert.Equal(t, "8", utils.ToString(uint8(8))) assert.Equal(t, "16", utils.ToString(uint16(16))) assert.Equal(t, "32", utils.ToString(uint32(32))) assert.Equal(t, "64", utils.ToString(uint64(64))) }) t.Run("from_float_formats_nicely", func(t *testing.T) { assert.Equal(t, "3.14", utils.ToString(3.14)) assert.Equal(t, "1000", utils.ToString(1000.0)) // no trailing zeros }) t.Run("from_struct_to_json", func(t *testing.T) { type TestStruct struct { Name string `json:"name"` } result := utils.ToString(TestStruct{Name: "test"}) assert.Contains(t, result, "test") }) } ================================================ FILE: agent/robot/utils/id.go ================================================ package utils import ( gonanoid "github.com/matoous/go-nanoid/v2" ) // NewID generates a new unique ID using nanoid func NewID() string { id, err := gonanoid.New() if err != nil { // Fallback to nanoid with default alphabet if error occurs return gonanoid.Must() } return id } // NewIDWithPrefix generates a new ID with a prefix func NewIDWithPrefix(prefix string) string { return prefix + NewID() } ================================================ FILE: agent/robot/utils/time.go ================================================ package utils import ( "fmt" "strconv" "strings" "time" ) // ParseTime parses a time string in HH:MM format func ParseTime(timeStr string) (hour, minute int, err error) { parts := strings.Split(timeStr, ":") if len(parts) != 2 { return 0, 0, fmt.Errorf("invalid time format: %s (expected HH:MM)", timeStr) } hour, err = strconv.Atoi(parts[0]) if err != nil || hour < 0 || hour > 23 { return 0, 0, fmt.Errorf("invalid hour: %s", parts[0]) } minute, err = strconv.Atoi(parts[1]) if err != nil || minute < 0 || minute > 59 { return 0, 0, fmt.Errorf("invalid minute: %s", parts[1]) } return hour, minute, nil } // FormatTime formats hour and minute into HH:MM format func FormatTime(hour, minute int) string { return fmt.Sprintf("%02d:%02d", hour, minute) } // LoadLocation loads a timezone location, returns Local if empty or invalid func LoadLocation(tz string) *time.Location { if tz == "" { return time.Local } loc, err := time.LoadLocation(tz) if err != nil { return time.Local } return loc } // ParseDuration parses a duration string with fallback default func ParseDuration(durStr string, defaultDur time.Duration) time.Duration { if durStr == "" { return defaultDur } d, err := time.ParseDuration(durStr) if err != nil { return defaultDur } return d } // IsTimeMatch checks if current time matches the specified time (HH:MM) func IsTimeMatch(now time.Time, timeStr string, loc *time.Location) bool { hour, minute, err := ParseTime(timeStr) if err != nil { return false } nowInLoc := now.In(loc) return nowInLoc.Hour() == hour && nowInLoc.Minute() == minute } // IsDayMatch checks if current day matches the specified day // days can be: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun", or "*" for any day func IsDayMatch(now time.Time, days []string) bool { if len(days) == 0 { return true } dayName := now.Weekday().String()[:3] // "Monday" -> "Mon" for _, day := range days { if day == "*" || day == dayName { return true } } return false } // NextScheduledTime calculates the next time a scheduled time will occur func NextScheduledTime(now time.Time, timeStr string, days []string, loc *time.Location) (time.Time, error) { hour, minute, err := ParseTime(timeStr) if err != nil { return time.Time{}, err } nowInLoc := now.In(loc) // Start from today at the specified time next := time.Date(nowInLoc.Year(), nowInLoc.Month(), nowInLoc.Day(), hour, minute, 0, 0, loc) // If the time has passed today, start from tomorrow if next.Before(nowInLoc) || next.Equal(nowInLoc) { next = next.Add(24 * time.Hour) } // Find the next matching day (within 7 days) for i := 0; i < 7; i++ { if IsDayMatch(next, days) { return next, nil } next = next.Add(24 * time.Hour) } // If no matching day found (should not happen with valid days), return the calculated time return next, nil } ================================================ FILE: agent/robot/utils/utils_test.go ================================================ package utils_test import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/robot/utils" ) // ID tests func TestNewID(t *testing.T) { id1 := utils.NewID() id2 := utils.NewID() assert.NotEmpty(t, id1) assert.NotEmpty(t, id2) assert.NotEqual(t, id1, id2, "IDs should be unique") } func TestNewIDWithPrefix(t *testing.T) { id := utils.NewIDWithPrefix("exec_") assert.NotEmpty(t, id) assert.Contains(t, id, "exec_") } // Time tests func TestParseTime(t *testing.T) { t.Run("valid time", func(t *testing.T) { hour, minute, err := utils.ParseTime("14:30") assert.NoError(t, err) assert.Equal(t, 14, hour) assert.Equal(t, 30, minute) }) t.Run("invalid format", func(t *testing.T) { _, _, err := utils.ParseTime("14-30") assert.Error(t, err) }) t.Run("invalid hour", func(t *testing.T) { _, _, err := utils.ParseTime("25:30") assert.Error(t, err) }) t.Run("invalid minute", func(t *testing.T) { _, _, err := utils.ParseTime("14:65") assert.Error(t, err) }) } func TestFormatTime(t *testing.T) { result := utils.FormatTime(9, 5) assert.Equal(t, "09:05", result) result = utils.FormatTime(14, 30) assert.Equal(t, "14:30", result) } func TestLoadLocation(t *testing.T) { t.Run("valid timezone", func(t *testing.T) { loc := utils.LoadLocation("Asia/Shanghai") assert.NotNil(t, loc) assert.Equal(t, "Asia/Shanghai", loc.String()) }) t.Run("empty timezone returns Local", func(t *testing.T) { loc := utils.LoadLocation("") assert.Equal(t, time.Local, loc) }) t.Run("invalid timezone returns Local", func(t *testing.T) { loc := utils.LoadLocation("Invalid/Timezone") assert.Equal(t, time.Local, loc) }) } func TestParseDuration(t *testing.T) { t.Run("valid duration", func(t *testing.T) { dur := utils.ParseDuration("30m", 10*time.Minute) assert.Equal(t, 30*time.Minute, dur) }) t.Run("empty returns default", func(t *testing.T) { dur := utils.ParseDuration("", 10*time.Minute) assert.Equal(t, 10*time.Minute, dur) }) t.Run("invalid returns default", func(t *testing.T) { dur := utils.ParseDuration("invalid", 10*time.Minute) assert.Equal(t, 10*time.Minute, dur) }) } func TestIsTimeMatch(t *testing.T) { loc := time.UTC testTime := time.Date(2024, 1, 15, 14, 30, 0, 0, loc) t.Run("exact match", func(t *testing.T) { assert.True(t, utils.IsTimeMatch(testTime, "14:30", loc)) }) t.Run("no match", func(t *testing.T) { assert.False(t, utils.IsTimeMatch(testTime, "14:31", loc)) assert.False(t, utils.IsTimeMatch(testTime, "15:30", loc)) }) t.Run("invalid time format", func(t *testing.T) { assert.False(t, utils.IsTimeMatch(testTime, "invalid", loc)) }) } func TestIsDayMatch(t *testing.T) { monday := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC) // Monday t.Run("match specific day", func(t *testing.T) { assert.True(t, utils.IsDayMatch(monday, []string{"Mon"})) }) t.Run("match wildcard", func(t *testing.T) { assert.True(t, utils.IsDayMatch(monday, []string{"*"})) }) t.Run("no match", func(t *testing.T) { assert.False(t, utils.IsDayMatch(monday, []string{"Tue", "Wed"})) }) t.Run("empty days returns true", func(t *testing.T) { assert.True(t, utils.IsDayMatch(monday, []string{})) }) } // Convert tests func TestToJSON(t *testing.T) { data := map[string]interface{}{ "name": "test", "age": 30, } json, err := utils.ToJSON(data) assert.NoError(t, err) assert.Contains(t, json, "test") assert.Contains(t, json, "30") } func TestFromJSON(t *testing.T) { jsonStr := `{"name":"test","age":30}` var result map[string]interface{} err := utils.FromJSON(jsonStr, &result) assert.NoError(t, err) assert.Equal(t, "test", result["name"]) assert.Equal(t, float64(30), result["age"]) // JSON numbers are float64 } func TestToMap(t *testing.T) { type TestStruct struct { Name string `json:"name"` Age int `json:"age"` } s := TestStruct{Name: "test", Age: 30} m, err := utils.ToMap(s) assert.NoError(t, err) assert.Equal(t, "test", m["name"]) assert.Equal(t, float64(30), m["age"]) // JSON conversion makes it float64 } func TestFromMap(t *testing.T) { type TestStruct struct { Name string `json:"name"` Age int `json:"age"` } m := map[string]interface{}{ "name": "test", "age": 30, } var result TestStruct err := utils.FromMap(m, &result) assert.NoError(t, err) assert.Equal(t, "test", result.Name) assert.Equal(t, 30, result.Age) } func TestToString(t *testing.T) { assert.Equal(t, "test", utils.ToString("test")) assert.Equal(t, "42", utils.ToString(42)) assert.Equal(t, "true", utils.ToString(true)) } func TestMergeMap(t *testing.T) { target := map[string]interface{}{ "a": 1, "b": 2, } source := map[string]interface{}{ "b": 3, "c": 4, } result := utils.MergeMap(target, source) assert.Equal(t, 1, result["a"]) assert.Equal(t, 3, result["b"]) // overwritten assert.Equal(t, 4, result["c"]) } func TestCloneMap(t *testing.T) { original := map[string]interface{}{ "a": 1, "b": 2, } cloned := utils.CloneMap(original) cloned["a"] = 999 assert.Equal(t, 1, original["a"]) // original unchanged assert.Equal(t, 999, cloned["a"]) } // Validate tests func TestIsEmpty(t *testing.T) { assert.True(t, utils.IsEmpty("")) assert.False(t, utils.IsEmpty("test")) } func TestIsValidEmail(t *testing.T) { assert.True(t, utils.IsValidEmail("test@example.com")) assert.True(t, utils.IsValidEmail("user+tag@domain.co.uk")) assert.False(t, utils.IsValidEmail("invalid")) assert.False(t, utils.IsValidEmail("@example.com")) assert.False(t, utils.IsValidEmail("test@")) } func TestIsValidTime(t *testing.T) { assert.True(t, utils.IsValidTime("09:00")) assert.True(t, utils.IsValidTime("14:30")) assert.True(t, utils.IsValidTime("23:59")) assert.False(t, utils.IsValidTime("25:00")) assert.False(t, utils.IsValidTime("14:65")) assert.False(t, utils.IsValidTime("14-30")) } func TestValidateRequired(t *testing.T) { t.Run("nil value", func(t *testing.T) { err := utils.ValidateRequired("field", nil) assert.Error(t, err) }) t.Run("empty string", func(t *testing.T) { err := utils.ValidateRequired("field", "") assert.Error(t, err) }) t.Run("valid string", func(t *testing.T) { err := utils.ValidateRequired("field", "value") assert.NoError(t, err) }) t.Run("empty slice", func(t *testing.T) { err := utils.ValidateRequired("field", []string{}) assert.Error(t, err) }) } func TestValidateRange(t *testing.T) { t.Run("within range", func(t *testing.T) { err := utils.ValidateRange("field", 5, 1, 10) assert.NoError(t, err) }) t.Run("below range", func(t *testing.T) { err := utils.ValidateRange("field", 0, 1, 10) assert.Error(t, err) }) t.Run("above range", func(t *testing.T) { err := utils.ValidateRange("field", 11, 1, 10) assert.Error(t, err) }) } func TestValidateOneOf(t *testing.T) { allowed := []string{"apple", "banana", "cherry"} t.Run("valid value", func(t *testing.T) { err := utils.ValidateOneOf("field", "banana", allowed) assert.NoError(t, err) }) t.Run("invalid value", func(t *testing.T) { err := utils.ValidateOneOf("field", "orange", allowed) assert.Error(t, err) }) } ================================================ FILE: agent/robot/utils/validate.go ================================================ package utils import ( "fmt" "regexp" ) var ( // Email regex pattern emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`) // Time pattern (HH:MM) timeRegex = regexp.MustCompile(`^([01]?[0-9]|2[0-3]):[0-5][0-9]$`) ) // IsEmpty checks if a string is empty or whitespace only func IsEmpty(s string) bool { return len(s) == 0 } // IsValidEmail validates email format func IsValidEmail(email string) bool { return emailRegex.MatchString(email) } // IsValidTime validates time format (HH:MM) func IsValidTime(timeStr string) bool { return timeRegex.MatchString(timeStr) } // ValidateRequired checks if required fields are present func ValidateRequired(fieldName string, value interface{}) error { if value == nil { return fmt.Errorf("%s is required", fieldName) } switch v := value.(type) { case string: if IsEmpty(v) { return fmt.Errorf("%s is required", fieldName) } case []string: if len(v) == 0 { return fmt.Errorf("%s is required", fieldName) } case map[string]interface{}: if len(v) == 0 { return fmt.Errorf("%s is required", fieldName) } } return nil } // ValidateRange checks if a number is within range func ValidateRange(fieldName string, value, min, max int) error { if value < min || value > max { return fmt.Errorf("%s must be between %d and %d", fieldName, min, max) } return nil } // ValidateOneOf checks if value is one of allowed values func ValidateOneOf(fieldName string, value string, allowed []string) error { for _, a := range allowed { if value == a { return nil } } return fmt.Errorf("%s must be one of: %v", fieldName, allowed) } // ValidateEmail validates email and returns error if invalid func ValidateEmail(fieldName string, email string) error { if !IsValidEmail(email) { return fmt.Errorf("%s is not a valid email", fieldName) } return nil } // ValidateTimeFormat validates time format (HH:MM) func ValidateTimeFormat(fieldName string, timeStr string) error { if !IsValidTime(timeStr) { return fmt.Errorf("%s must be in HH:MM format", fieldName) } return nil } ================================================ FILE: agent/sandbox/DESIGN.md ================================================ # Agent Sandbox Design Inject Coding Agent capabilities (Claude CLI, Cursor CLI) into Yao's LLM request pipeline via Docker-based sandbox execution. ## 1. Architecture Overview ``` ┌─────────────────────────────────────────────────────────────────┐ │ Yao LLM Pipeline │ │ │ │ ┌───────────────────────────────────────────────────────────┐ │ │ │ Before Hooks / Auth / Logging │ │ │ └─────────────────────────┬─────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌───────────────────────────────────────────────────────────┐ │ │ │ LLM Request Handler │ │ │ │ │ │ │ │ sandbox: nil → Direct LLM API call (default) │ │ │ │ sandbox: config → Sandbox + Claude CLI │ │ ← Inject here │ │ sandbox: config → Sandbox + Cursor CLI (future) │ │ │ │ │ │ │ └─────────────────────────┬─────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌───────────────────────────────────────────────────────────┐ │ │ │ After Hooks / Billing / Audit │ │ │ └───────────────────────────────────────────────────────────┘ │ └─────────────────────────────────────────────────────────────────┘ ``` ## 2. Design Benefits | Aspect | Standalone Executor | Pipeline Injection | | ------------- | ------------------------ | ----------------------- | | **Hooks** | Need re-implementation | ✅ Fully reused | | **Auth** | Need re-implementation | ✅ Fully reused | | **Logging** | Need re-implementation | ✅ Fully reused | | **Billing** | Need re-implementation | ✅ Fully reused | | **Errors** | Need re-implementation | ✅ Fully reused | | **Code size** | Large (full executor) | Small (one branch) | | **Extend** | One executor per agent | Just add agent type | ## 3. Configuration ### 3.1 Assistant Configuration Sandbox is configured at the **Assistant level** (not Connector), because: - Same Connector can be shared by multiple Assistants - Some Assistants need Sandbox (Coding), others don't (Q&A) - Assistant defines "behavior", Connector defines "connection" ```jsonc // assistants/coder/package.yao { "name": "Coder Assistant", "connector": "deepseek.v3", // Sandbox configuration "sandbox": { "command": "claude", // claude | cursor (future) "image": "yaoapp/sandbox-claude:latest", // Optional, auto-selected by agent "max_memory": "4g", // Optional "max_cpu": 2.0, // Optional "timeout": "10m", // Optional, execution timeout "arguments": { // Command-specific arguments (passed to Claude CLI) "max_turns": 20, "permission_mode": "acceptEdits" // Different commands may have different arguments } }, // MCP servers "mcp": { "servers": [ { "server_id": "filesystem", "tools": ["read_file", "write_file", "list_directory"] } ] } } ``` ### 3.2 Directory Structure Skills are auto-discovered from `skills/` directory following the [Agent Skills](https://agentskills.io) open standard: ``` assistants/coder/ ├── package.yao # Assistant config ├── prompts.yml # Prompts ├── mcps/ # MCP server definitions └── skills/ # Skills directory (auto-discovered) ├── code-review/ │ ├── SKILL.md # Required: instructions + metadata │ ├── scripts/ # Optional: executable code (Python, Bash, JS) │ ├── references/ # Optional: additional documentation │ └── assets/ # Optional: templates, images, data files └── deploy/ ├── SKILL.md └── scripts/ └── deploy.sh ``` ### 3.3 SKILL.md Format Each skill must have a `SKILL.md` file with YAML frontmatter: ```yaml --- name: code-review # Required: must match parent directory name description: > # Required: when to use this skill Review code for bugs, security issues, and best practices. Use when the user asks to review, audit, or analyze code quality. license: Apache-2.0 # Optional compatibility: Requires git # Optional: environment requirements metadata: # Optional: arbitrary key-value pairs author: yao-team version: "1.0" allowed-tools: Bash(git:*) Read # Optional: pre-approved tools (experimental) --- # Code Review ## When to use this skill Use this skill when the user asks to review code... ## Steps 1. Check for security vulnerabilities 2. Review code style and best practices 3. Identify potential bugs ... ``` | Field | Required | Description | |-------|----------|-------------| | `name` | Yes | 1-64 chars, lowercase + hyphens, must match directory name | | `description` | Yes | 1-1024 chars, describes what skill does and when to use it | | `license` | No | License name or reference to bundled LICENSE file | | `compatibility` | No | Environment requirements (tools, network access, etc.) | | `metadata` | No | Arbitrary key-value pairs for additional properties | | `allowed-tools` | No | Space-delimited list of pre-approved tools (experimental) | ### 3.4 Skills Progressive Disclosure Skills use progressive disclosure to manage context efficiently: 1. **Discovery**: At startup, agent loads only `name` and `description` of each skill 2. **Activation**: When task matches a skill's description, agent reads full `SKILL.md` 3. **Execution**: Agent follows instructions, loading `scripts/`, `references/`, `assets/` as needed ### 3.5 No Sandbox (Default) If `sandbox` is not configured, Assistant uses direct LLM API calls: ```jsonc // assistants/qa/package.yao { "name": "QA Assistant", "connector": "deepseek.v3" // No sandbox config → direct API call } ``` ## 4. Implementation ### 4.1 Package Structure ``` yao/ ├── sandbox/ # Low-level sandbox infrastructure │ ├── manager.go # Container management (✅ Done) │ ├── config.go # Configuration (✅ Done) │ ├── ipc/ # IPC communication (✅ Done) │ │ ├── manager.go │ │ ├── session.go │ │ └── types.go │ ├── bridge/ # yao-bridge binary (✅ Done) │ │ └── main.go │ └── docker/ # Docker images (✅ Done) │ ├── base/ │ ├── claude/ │ └── build.sh │ └── agent/ └── sandbox/ # Agent-level sandbox integration (NEW) ├── DESIGN.md # This document ├── types.go # Common types and interfaces ├── executor.go # Factory function and registry ├── claude/ # Claude CLI agent │ ├── types.go # Claude-specific types (StreamMessage, ToolCall, etc.) │ ├── executor.go # Executor implementation │ ├── command.go # CLI command builder │ ├── stream.go # Stream output parser │ ├── environment.go # Container environment setup │ └── executor_test.go ├── cursor/ # Cursor CLI agent (future) │ └── README.md # Placeholder for future implementation └── sandbox_test.go # Integration tests ``` ### 4.2 Types and Interfaces ```go // agent/sandbox/types.go package sandbox import ( "time" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" ) // Executor executes LLM requests in sandbox type Executor interface { // Execute runs the request and returns response Execute(ctx *context.Context, messages []context.Message, opts *Options) (*context.CompletionResponse, error) // Stream runs the request with streaming output Stream(ctx *context.Context, messages []context.Message, opts *Options, handler message.StreamFunc) (*context.CompletionResponse, error) // Filesystem operations (for Hooks) ReadFile(ctx context.Context, path string) ([]byte, error) WriteFile(ctx context.Context, path string, content []byte) error ListDir(ctx context.Context, path string) ([]os.FileInfo, error) // Command execution (for Hooks) Exec(ctx context.Context, cmd []string) (string, error) // Close releases container resources Close() error } // Options for sandbox execution type Options struct { // Command type (claude, cursor) Command string `json:"command"` // Docker image (optional, auto-selected by agent) Image string `json:"image,omitempty"` // Resource limits MaxMemory string `json:"max_memory,omitempty"` MaxCPU float64 `json:"max_cpu,omitempty"` // Execution timeout Timeout time.Duration `json:"timeout,omitempty"` // Command-specific arguments (passed to CLI) Arguments map[string]interface{} `json:"arguments,omitempty"` // ======================================== // Internal fields (auto-resolved by Yao) // Do NOT set these in package.yao config // ======================================== // MCP configuration - auto-loaded from assistants/{name}/mcps/ MCPConfig []byte `json:"-"` // Skills directory - auto-resolved to assistants/{name}/skills/ SkillsDir string `json:"-"` // Connector settings - auto-resolved from connector config file // e.g., connectors/deepseek/v3.conn.yao → host, key, model ConnectorHost string `json:"-"` ConnectorKey string `json:"-"` Model string `json:"-"` } ``` ### 4.3 Executor Factory ```go // agent/sandbox/executor.go package sandbox import ( "fmt" "github.com/yaoapp/yao/agent/sandbox/claude" "github.com/yaoapp/yao/agent/sandbox/cursor" "github.com/yaoapp/yao/sandbox" ) // New creates an executor based on agent type func New(manager *sandbox.Manager, opts *Options) (Executor, error) { if opts == nil { return nil, fmt.Errorf("options cannot be nil") } // Set default image if not specified if opts.Image == "" { opts.Image = DefaultImage(opts.Command) } switch opts.Command { case "claude": return claude.New(manager, opts) case "cursor": return cursor.New(manager, opts) default: return nil, fmt.Errorf("unknown command type: %s", opts.Command) } } // DefaultImage returns the default Docker image for a command type func DefaultImage(command string) string { switch command { case "claude": return "yaoapp/sandbox-claude:latest" case "cursor": return "yaoapp/sandbox-cursor:latest" default: return "" } } ``` ### 4.4 Claude Agent Implementation The Claude agent is split into multiple files for maintainability: #### 4.4.1 Executor (claude/executor.go) ```go // agent/sandbox/claude/executor.go package claude import ( "fmt" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" sandbox "github.com/yaoapp/yao/agent/sandbox" infra "github.com/yaoapp/yao/sandbox" ) // Executor executes requests via Claude CLI in sandbox type Executor struct { manager *infra.Manager opts *sandbox.Options } // New creates a new Claude executor func New(manager *infra.Manager, opts *sandbox.Options) (*Executor, error) { return &Executor{ manager: manager, opts: opts, }, nil } // Stream executes with streaming output func (e *Executor) Stream( ctx *context.Context, messages []context.Message, opts *sandbox.Options, handler message.StreamFunc, ) (*context.CompletionResponse, error) { // 1. Get or create container container, err := e.manager.GetOrCreate(ctx.Context(), ctx.Authorized.UserID, ctx.ChatID) if err != nil { return nil, fmt.Errorf("failed to get container: %w", err) } // 2. Prepare environment workDir := fmt.Sprintf("/workspace/%s/chat-%s", ctx.Authorized.UserID, ctx.ChatID) if err := e.prepareEnvironment(ctx, container, workDir); err != nil { return nil, fmt.Errorf("failed to prepare environment: %w", err) } // 3. Build CLI command cmd, env := BuildCommand(messages, opts, workDir) // 4. Execute in container with streaming reader, err := e.manager.Stream(ctx.Context(), container.Name, cmd, &infra.ExecOptions{ Env: env, WorkingDir: workDir, }) if err != nil { return nil, fmt.Errorf("failed to execute: %w", err) } defer reader.Close() // 5. Parse streaming output return ParseStream(reader, handler) } // Execute runs without streaming (wraps Stream) func (e *Executor) Execute(ctx *context.Context, messages []context.Message, opts *sandbox.Options) (*context.CompletionResponse, error) { var result *context.CompletionResponse _, err := e.Stream(ctx, messages, opts, func(msg *message.Message) error { // Collect final result return nil }) return result, err } // Close releases resources func (e *Executor) Close() error { return nil } ``` #### 4.4.2 Environment Setup (claude/environment.go) ```go // agent/sandbox/claude/environment.go package claude import ( "path/filepath" "github.com/yaoapp/yao/agent/context" sandbox "github.com/yaoapp/yao/agent/sandbox" infra "github.com/yaoapp/yao/sandbox" ) // prepareEnvironment sets up the container environment func (e *Executor) prepareEnvironment(ctx *context.Context, container *infra.Container, workDir string) error { // Create work directory if err := e.manager.MkDir(ctx.Context(), container.Name, workDir); err != nil { return err } // Write MCP config if len(e.opts.MCPConfig) > 0 { mcpPath := filepath.Join(workDir, ".mcp.json") if err := e.manager.WriteFile(ctx.Context(), container.Name, mcpPath, e.opts.MCPConfig); err != nil { return err } } // Copy skills if configured if e.opts.SkillsDir != "" { claudeDir := filepath.Join(workDir, ".claude") if err := e.manager.MkDir(ctx.Context(), container.Name, claudeDir); err != nil { return err } targetSkillsDir := filepath.Join(claudeDir, "skills") if err := e.manager.CopyToContainer(ctx.Context(), container.Name, e.opts.SkillsDir, targetSkillsDir); err != nil { return err } } return nil } ``` #### 4.4.3 Command Builder (claude/command.go) ```go // agent/sandbox/claude/command.go package claude import ( "fmt" "strconv" "strings" "github.com/yaoapp/yao/agent/context" sandbox "github.com/yaoapp/yao/agent/sandbox" ) // BuildCommand constructs the Claude CLI command func BuildCommand(messages []context.Message, opts *sandbox.Options, workDir string) ([]string, map[string]string) { cmd := []string{ "claude", "--print", "--output-format", "stream-json", } // Model if opts.Model != "" { cmd = append(cmd, "--model", opts.Model) } // Agent-specific options from sandbox.options if opts.Arguments != nil { if maxTurns, ok := opts.Arguments["max_turns"].(int); ok && maxTurns > 0 { cmd = append(cmd, "--max-turns", strconv.Itoa(maxTurns)) } if permMode, ok := opts.Arguments["permission_mode"].(string); ok && permMode != "" { cmd = append(cmd, "--permission-mode", permMode) } } // MCP config cmd = append(cmd, "--mcp-config", ".mcp.json") // System prompt (with history) systemPrompt := buildSystemPrompt(messages) if systemPrompt != "" { cmd = append(cmd, "--system-prompt", systemPrompt) } // User prompt (last user message) prompt := extractUserPrompt(messages) cmd = append(cmd, prompt) // Environment variables env := map[string]string{} if opts.ConnectorHost != "" { env["ANTHROPIC_BASE_URL"] = opts.ConnectorHost } if opts.ConnectorKey != "" { env["ANTHROPIC_API_KEY"] = opts.ConnectorKey } return cmd, env } // buildSystemPrompt builds system prompt with conversation history func buildSystemPrompt(messages []context.Message) string { var systemParts []string var history []string for i, msg := range messages { if msg.Role == "system" { if content, ok := msg.Content.(string); ok { systemParts = append(systemParts, content) } continue } // Skip last user message (it becomes the prompt) if i == len(messages)-1 && msg.Role == "user" { continue } // Add to history if content, ok := msg.Content.(string); ok { history = append(history, fmt.Sprintf("[%s]: %s", msg.Role, content)) } } if len(history) > 0 { systemParts = append(systemParts, "\n## Conversation History:\n"+strings.Join(history, "\n\n")) } return strings.Join(systemParts, "\n") } // extractUserPrompt gets the last user message func extractUserPrompt(messages []context.Message) string { for i := len(messages) - 1; i >= 0; i-- { if messages[i].Role == "user" { if content, ok := messages[i].Content.(string); ok { return content } } } return "" } ``` #### 4.4.4 Types (claude/types.go) ```go // agent/sandbox/claude/types.go package claude // StreamMessage represents a Claude CLI stream-json message type StreamMessage struct { Type string `json:"type"` Message struct { Content []struct { Type string `json:"type"` Text string `json:"text"` } `json:"content"` } `json:"message,omitempty"` Result string `json:"result,omitempty"` TotalCostUSD float64 `json:"total_cost_usd,omitempty"` NumTurns int `json:"num_turns,omitempty"` DurationMs int64 `json:"duration_ms,omitempty"` Usage *struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` } `json:"usage,omitempty"` } // ToolCall represents a tool invocation from Claude CLI type ToolCall struct { ID string `json:"id"` Name string `json:"name"` Arguments map[string]interface{} `json:"arguments"` } // ToolResult represents a tool execution result type ToolResult struct { ID string `json:"id"` Content string `json:"content"` IsError bool `json:"is_error,omitempty"` } ``` #### 4.4.5 Stream Parser (claude/stream.go) ```go // agent/sandbox/claude/stream.go package claude import ( "bufio" "encoding/json" "io" "strings" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" ) // ParseStream parses Claude CLI stream-json output func ParseStream(reader io.ReadCloser, handler message.StreamFunc) (*context.CompletionResponse, error) { resp := &context.CompletionResponse{} var fullContent strings.Builder scanner := bufio.NewScanner(reader) for scanner.Scan() { line := scanner.Text() if line == "" { continue } var msg StreamMessage if err := json.Unmarshal([]byte(line), &msg); err != nil { continue } switch msg.Type { case "assistant": // Extract text content for _, content := range msg.Message.Content { if content.Type == "text" && content.Text != "" { fullContent.WriteString(content.Text) // Send to stream handler if handler != nil { handler(&message.Message{ Type: "text", Data: content.Text, }) } } } case "result": // Final result if msg.Result != "" { resp.Content = msg.Result } if msg.Usage != nil { resp.Usage = &context.Usage{ PromptTokens: msg.Usage.InputTokens, CompletionTokens: msg.Usage.OutputTokens, TotalTokens: msg.Usage.InputTokens + msg.Usage.OutputTokens, } } resp.Extra = map[string]interface{}{ "total_cost_usd": msg.TotalCostUSD, "num_turns": msg.NumTurns, "duration_ms": msg.DurationMs, } } } if err := scanner.Err(); err != nil { return nil, err } if resp.Content == "" { resp.Content = fullContent.String() } return resp, nil } ``` ### 4.5 Sandbox Lifecycle #### 4.5.1 Design Principles The sandbox follows a **stateless container + persistent workspace** model: | Component | Lifecycle | Storage | |-----------|-----------|---------| | **Container** | Per-request, disposable | None (stateless) | | **Workspace** | Persistent across requests | `{YAO_DATA_ROOT}/sandbox/workspace/{user}/chat-{chat_id}/` | | **Message History** | Managed by Yao | Yao's session store | This means: - Container can be recreated anytime without losing state - All files are preserved in the mounted workspace - Conversation history is passed in each request (not stored in container) #### 4.5.2 Container Lifecycle ``` Request Start │ ▼ ┌─────────────────────────────────────────────────────────────┐ │ 1. Create Executor (docker run --rm) │ │ - Mount workspace directory │ │ - Set resource limits │ │ - ~500ms-1s cold start │ └─────────────────────────┬───────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────┐ │ 2. Before Hook (can access Executor) │ │ - Write config files │ │ - Check/prepare environment │ │ - Can reject request (container auto-cleaned) │ └─────────────────────────┬───────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────┐ │ 3. LLM Execute (Claude CLI in container) │ │ - Run with full message history │ │ - Stream output to handler │ └─────────────────────────┬───────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────┐ │ 4. After Hook (can access Executor) │ │ - Read generated files │ │ - Execute post-commands (git commit, etc.) │ │ - Cleanup temp files │ └─────────────────────────┬───────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────┐ │ 5. Close Executor (defer) │ │ - Container removed (--rm) │ │ - Workspace persists │ └─────────────────────────────────────────────────────────────┘ ``` #### 4.5.3 Workspace Cleanup Workspace cleanup is separate from container lifecycle: ```go // Global cleanup configuration type CleanupConfig struct { // Workspace retention period (default: 7 days) WorkspaceRetention time.Duration `json:"workspace_retention"` // Run cleanup on schedule (default: daily at 3am) CleanupSchedule string `json:"cleanup_schedule"` } // Cleanup worker func (m *Manager) cleanupStaleWorkspaces() { // Find workspaces older than retention period // Delete directories not accessed within the period } ``` ### 4.6 Hook Integration (via JSAPI) Hooks interact with the sandbox via **Context JSAPI**, not Go code. The Executor methods are exposed to JavaScript through `ctx.sandbox`: #### 4.6.1 Context JSAPI Extension ```typescript // Extension to Context interface (see agent/context/JSAPI.md) interface Context { // ... existing properties ... // Sandbox operations (only available when sandbox is configured) sandbox?: { // Filesystem operations ReadFile(path: string): string; // Read file content WriteFile(path: string, content: string): void; // Write file ListDir(path: string): FileInfo[]; // List directory // Command execution Exec(cmd: string[]): string; // Execute command, return output // Workspace info workdir: string; // Container workspace path }; } interface FileInfo { name: string; size: number; is_dir: boolean; mod_time: string; } ``` #### 4.6.2 Create Hook Examples (JavaScript) ```javascript // assistants/coder/src/index.ts /** * Create Hook - runs after container created, before LLM execution */ function Create(ctx, messages, options) { // Check if sandbox is available if (!ctx.sandbox) { return { messages }; } // Send loading message to user const loadingId = ctx.SendStream({ type: "loading", props: { message: "Preparing sandbox environment..." } }); // Write project configuration ctx.sandbox.WriteFile(".env", "DEBUG=true\nNODE_ENV=development"); // Check environment prerequisites try { const nodeVersion = ctx.sandbox.Exec(["node", "--version"]); log.Info(`Node.js version: ${nodeVersion}`); } catch (e) { log.Error("Node.js not available"); ctx.End(loadingId); throw new Error("Node.js is required"); } // List existing files const files = ctx.sandbox.ListDir("."); log.Debug(`Workspace files: ${files.map(f => f.name).join(", ")}`); // Update loading message and end it ctx.Replace(loadingId, { type: "text", props: { content: "Sandbox ready" } }); ctx.End(loadingId); return { messages }; } ``` #### 4.6.3 Next Hook Examples (JavaScript) ```javascript /** * Next Hook - runs after LLM execution, before container cleanup */ function Next(ctx, payload, options) { const { completion, error } = payload; if (error || !ctx.sandbox) { return null; } // Read generated files try { const result = ctx.sandbox.ReadFile("output/result.json"); log.Info(`Generated output: ${result}`); ctx.memory.context.Set("generated_result", JSON.parse(result)); } catch (e) { log.Debug("No result file generated"); } // List generated source files const srcFiles = ctx.sandbox.ListDir("src/"); for (const file of srcFiles) { log.Debug(`Generated: ${file.name} (${file.size} bytes)`); } // Execute git commit if files changed try { ctx.sandbox.Exec(["git", "add", "."]); ctx.sandbox.Exec(["git", "commit", "-m", "auto-commit by sandbox"]); log.Info("Changes committed"); } catch (e) { log.Debug(`git commit skipped: ${e.message}`); } // Cleanup temporary files try { ctx.sandbox.Exec(["rm", "-rf", "tmp/"]); } catch (e) { // Ignore cleanup errors } return null; // Use default response } ``` #### 4.6.4 Go Implementation (Internal) The Go layer only handles exposing the Executor to JSAPI: ```go // In agent/context/sandbox.go - expose Executor to JS runtime // SandboxJSAPI wraps Executor for JavaScript access type SandboxJSAPI struct { executor sandbox.Executor workdir string } // Methods are called from JavaScript via v8go bindings func (s *SandboxJSAPI) ReadFile(path string) (string, error) { content, err := s.executor.ReadFile(context.Background(), path) return string(content), err } func (s *SandboxJSAPI) WriteFile(path string, content string) error { return s.executor.WriteFile(context.Background(), path, []byte(content)) } func (s *SandboxJSAPI) ListDir(path string) ([]FileInfo, error) { return s.executor.ListDir(context.Background(), path) } func (s *SandboxJSAPI) Exec(cmd []string) (string, error) { return s.executor.Exec(context.Background(), cmd) } ``` ### 4.7 Integration with Assistant The sandbox integration is separated into two files: - `agent/assistant/sandbox.go` - Sandbox handler (new file, contains all sandbox logic) - `agent/assistant/llm.go` - Add sandbox detection (minimal change) #### 4.7.1 Sandbox Detection (`agent/assistant/llm.go`) ```go // In agent/assistant/llm.go - add sandbox detection func (ast *Assistant) executeLLMStream(...) (*context.CompletionResponse, error) { // Check if sandbox is configured if ast.Sandbox != nil && ast.Sandbox.Command != "" { return ast.executeSandboxStream(ctx, completionMessages, completionOptions, agentNode, streamHandler, opts) } // Default: direct LLM API call // ... existing code ... } ``` #### 4.7.2 Sandbox Handler (`agent/assistant/sandbox.go`) ```go // In agent/assistant/sandbox.go - sandbox execution logic func (ast *Assistant) executeSandboxStream(...) (*context.CompletionResponse, error) { // Get sandbox manager (singleton) manager := sandbox.GetManager() if manager == nil { return nil, fmt.Errorf("sandbox manager not initialized") } // Build executor options execOpts := &agentsandbox.Options{ Command: ast.Sandbox.Command, Image: ast.Sandbox.Image, MaxMemory: ast.Sandbox.MaxMemory, MaxCPU: ast.Sandbox.MaxCPU, Timeout: ast.Sandbox.Timeout, Arguments: ast.Sandbox.Arguments, SkillsDir: filepath.Join(ast.Path, "skills"), } // Get connector settings (resolved from connector config) conn, _, err := ast.GetConnector(ctx, opts) if err != nil { return nil, err } setting := conn.Setting() if host, ok := setting["host"].(string); ok { execOpts.ConnectorHost = host } if key, ok := setting["key"].(string); ok { execOpts.ConnectorKey = key } if model, ok := setting["model"].(string); ok { execOpts.Model = model } // Build MCP config execOpts.MCPConfig = ast.buildMCPConfig(ctx) // 1. Create executor (container starts here) ctx.Trace().Info("Creating sandbox container...") executor, err := agentsandbox.New(manager, execOpts) if err != nil { ctx.Trace().Error(fmt.Sprintf("Sandbox creation failed: %v", err)) return nil, err } ctx.Trace().Info("Sandbox container ready") defer executor.Close() // 5. Container cleanup on exit // 2. Before Hook - can access executor if err := ast.runBeforeHook(ctx, executor); err != nil { return nil, err // Container cleaned by defer } // 3. LLM Execute resp, err := executor.Stream(ctx, completionMessages, execOpts, streamHandler) if err != nil { return nil, err } // 4. After Hook - can access executor if err := ast.runAfterHook(ctx, executor, resp); err != nil { log.Printf("after hook error: %v", err) // Log but don't fail } return resp, nil } ``` ## 5. IPC Communication (Yao ↔ Sandbox) The sandbox can call Yao processes via MCP: ``` ┌─────────────────────────────────────────────────────────────────┐ │ Sandbox Container │ │ ┌─────────────────┐ ┌──────────────────┐ │ │ │ Claude CLI │───▶│ yao-bridge │ │ │ │ │ │ │ │ │ │ MCP: tools/call│ │ Unix Socket │ │ │ │ "yao.process" │ │ /tmp/yao.sock │ │ │ └─────────────────┘ └────────┬─────────┘ │ └──────────────────────────────────┼─────────────────────────────┘ │ JSON-RPC ▼ ┌──────────────────────────────────────────────────────────────────┐ │ Yao Host │ │ ┌─────────────────────────────────────────────────────────────┐ │ │ │ IPC Manager │ │ │ │ │ │ │ │ Session: user123-chat456 │ │ │ │ Socket: /data/sandbox/ipc/user123-chat456.sock │ │ │ │ │ │ │ │ Handlers: │ │ │ │ tools/list → Return authorized tool list │ │ │ │ tools/call → Execute Yao Process │ │ │ └─────────────────────────────────────────────────────────────┘ │ └──────────────────────────────────────────────────────────────────┘ ``` ## 6. Skills Support Skills follow [Agent Skills](https://agentskills.io) standard: ```markdown --- name: code-review description: "Review code for security, performance, and best practices" allowed-tools: - Read - Grep - Glob --- # Code Review When reviewing code: ## Security - Check for SQL injection, XSS, CSRF - Verify input validation ## Performance - Look for N+1 queries - Check for unnecessary re-renders ``` Skills are auto-copied to container at `{workDir}/.claude/skills/`. ## 7. Docker Images ### Available Images | Image | Contents | Size | |-------|----------|------| | `yaoapp/sandbox-base` | Ubuntu 24.04, git, vim, network tools, yao-bridge | ~200MB | | `yaoapp/sandbox-claude` | base + Node.js 22, Python 3.12, Claude CLI, CCR | ~900MB | | `yaoapp/sandbox-claude-full` | claude + Go 1.23 | ~1.2GB | ### Included Tools - **Editors**: vim, less, tree - **Network**: ping, netstat, nslookup, nc, telnet - **Compression**: zip, unzip, tar, gzip - **System**: htop, ps, sed, awk, grep - **Development**: git, Node.js, Python, Claude CLI, CCR ## 8. Implementation Status See [PLAN.md](./PLAN.md) for detailed implementation tasks and testing requirements. ### Completed (sandbox package) - [x] Container management (`sandbox/manager.go`) - [x] Configuration (`sandbox/config.go`) - [x] IPC communication (`sandbox/ipc/`) - [x] yao-bridge binary (`sandbox/bridge/`) - [x] Docker images (`sandbox/docker/`) ### To Implement (agent/sandbox package) - [ ] Types and interfaces - [ ] Executor factory - [ ] Claude executor (claude/) - [ ] Context JSAPI bindings - [ ] Assistant integration - [ ] Workspace cleanup ## 9. Testing See [PLAN.md](./PLAN.md) for complete testing strategy. **Quick test:** ```bash # Test Docker image docker run --rm yaoapp/sandbox-claude:latest bash -c " node --version && python3 --version && claude --version " # Run integration tests source /Users/max/Yao/yao/env.local.sh go test -v ./agent/sandbox/... ``` ## 10. Usage Example ```bash # Test script at yao-dev-app/agents/claude/run.sh ./run.sh "Hello, what is 1+1?" # Default: deepseek/v3 ./run.sh -c deepseek/r1 "Hello" # Use R1 model ./run.sh -c claude/sonnet-4_0 "Hello" # Use Claude ``` ================================================ FILE: agent/sandbox/PLAN.md ================================================ # Agent Sandbox Implementation Plan ## Overview This plan covers the implementation of the agent sandbox integration layer (`agent/sandbox/`), which enables coding agents (Claude CLI, Cursor CLI) to run in isolated Docker containers with Yao's LLM pipeline. ## Test Environment ### Environment Configuration Tests should run with the local development environment: ```bash # Source environment variables source /Users/max/Yao/yao/env.local.sh # Key variables used: # YAO_TEST_APPLICATION=/Users/max/Yao/yao-dev-app # YAO_ROOT=$YAO_TEST_APPLICATION # DEEPSEEK_API_KEY, DEEPSEEK_API_PROXY, DEEPSEEK_MODELS_V3 ``` ### Test Application Test assistants at `yao-dev-app/assistants/tests/sandbox/`: ``` yao-dev-app/assistants/tests/ └── sandbox/ ├── basic/ # Basic sandbox execution test │ ├── package.yao # uses.search: disabled │ └── prompts.yml ├── hooks/ # Hook integration test │ ├── package.yao # uses.search: disabled │ ├── prompts.yml │ └── src/index.ts └── full/ # Full test with MCPs, Skills, Hooks ├── package.yao # uses.search: disabled, mcp: {servers: [...]} ├── prompts.yml ├── src/index.ts └── skills/echo-test/ # Agent Skills standard ├── SKILL.md └── scripts/echo.sh ``` ### Connector Configuration Use `deepseek.v3` as the default connector (via Volcengine API). ## Implementation Status ### Phase 1: Core Types and Interfaces ✅ COMPLETED - [x] Define `Executor` interface with all methods - [x] Define `Options` struct with JSON tags - [x] Define `FileInfo` alias to infrastructure sandbox - [x] Add `DefaultImage()` and `IsValidCommand()` helpers ### Phase 2: Claude Executor Implementation ✅ COMPLETED - [x] Implement `Executor` struct - [x] Implement `NewExecutor()` constructor with container reuse - [x] Implement `Stream()` method with CCR config writing - [x] Implement `Execute()` method (wrapper) - [x] Implement `Close()` method (removes container) - [x] Implement filesystem methods: `ReadFile`, `WriteFile`, `ListDir` - [x] Implement `Exec()` method - [x] Implement `GetWorkDir()` method ### Phase 3: CCR Configuration ✅ COMPLETED - [x] Implement `BuildCCRConfig()` with correct CCR format - [x] Auto-detect provider type (volcengine, deepseek, openai, claude) - [x] Add transformer for DeepSeek/Volcengine (maxtoken) - [x] Generate Router configuration - [x] Write config to container before execution ### Phase 4: Assistant Integration ✅ COMPLETED - [x] Implement `GetSandboxManager()` singleton - [x] Implement `HasSandbox()` method - [x] Implement `initSandbox()` with cleanup function - [x] Implement `executeSandboxStream()` method - [x] Build executor options from assistant config - [x] Resolve connector settings (host, key, model) - [x] Add trace logging for sandbox creation - [x] Send loading message during sandbox init - [x] Expose executor to hooks via `ctx.SetSandboxExecutor()` - [x] Handle sandbox lifecycle (create → hooks → execute → cleanup) ### Phase 5: JSAPI Integration ✅ COMPLETED - [x] Define `SandboxExecutor` interface - [x] Implement JS bindings for `ReadFile`, `WriteFile`, `ListDir`, `Exec` - [x] Expose `workdir` property - [x] Register in context's `NewObject` method ### Phase 6: Concurrency & Resource Management ✅ COMPLETED - [x] Container creation uses Double-Check Locking (in `manager.GetOrCreate`) - [x] Same chatID reuses container (by design) - [x] Container cleanup on request completion (`defer sandboxCleanup()`) - [x] Unique chatID in tests to avoid conflicts ### Phase 7: MCP & Skills Integration ✅ COMPLETED - [x] Build MCP config from assistant's `mcp.servers` configuration - [x] Write MCP config to container workspace (`.mcp.json`) - [x] Resolve skills directory from `assistants/{name}/skills/` - [x] Copy skills to container (`/workspace/.claude/skills/`) - [x] Skip MCP tool execution in `agent.go` for sandbox mode (Claude CLI handles internally) - [x] Add unit tests for MCP config building (`TestBuildMCPConfigForSandbox`) - [x] Add unit tests for skills directory resolution (`TestSandboxMCPAndSkillsOptions`) ### Phase 8: MCP IPC Bridge ✅ COMPLETED - [x] Modify `BuildMCPConfigForSandbox` to use `yao-bridge` command for IPC - [x] Create IPC session in `sandbox/manager.createContainer()` (socket created before container) - [x] Bind mount IPC socket to container at `/tmp/yao.sock` - [x] Add `SetMCPTools()` method to `ipc.Session` for runtime tool configuration - [x] Set MCP tools dynamically in `claude.Executor.Stream()` before execution - [x] IPC session lifecycle managed by `sandbox.Manager` (create on container create, close on remove) - [x] Load MCP tool definitions from gou/mcp and pass to IPC session - [x] Add `TestClaudeExecutorIPCSocketMount` to verify socket bind mount - [x] Verify E2E test shows "Loaded X MCP tools for IPC" ### Phase 9: Workspace Management ⏳ PENDING - [ ] Implement workspace cleanup configuration - [ ] Implement stale workspace detection - [ ] Implement cleanup scheduler ### Phase 9: Cursor Placeholder ⏳ PENDING - [ ] Create `cursor/README.md` placeholder ## Testing Status ### Unit Tests | Package | Test File | Status | |---------|-----------|--------| | `agent/sandbox` | `types_test.go` | ✅ PASS | | `agent/sandbox` | `executor_test.go` | ✅ PASS | | `agent/sandbox/claude` | `command_test.go` | ✅ PASS | | `agent/sandbox/claude` | `executor_test.go` | ✅ PASS | ### Integration Tests | Package | Test File | Status | |---------|-----------|--------| | `agent/sandbox` | `integration_test.go` | ✅ PASS | ### JSAPI Tests | Package | Test File | Status | |---------|-----------|--------| | `agent/context` | `jsapi_sandbox_test.go` | ✅ PASS | ### Assistant Loading Tests | Package | Test File | Status | |---------|-----------|--------| | `agent/assistant` | `sandbox_test.go` | ✅ PASS | | `agent/assistant` | `sandbox_integration_test.go` | ✅ PASS | ### E2E Tests | Package | Test Case | Status | |---------|-----------|--------| | `agent/assistant` | `TestSandboxBasicE2E` | ✅ PASS | | `agent/assistant` | `TestSandboxHooksE2E` | ✅ PASS | | `agent/assistant` | `TestSandboxFullE2E` | ✅ PASS | | `agent/assistant` | `TestSandboxContextAccess` | ✅ PASS | | `agent/assistant` | `TestSandboxLoadConfiguration` | ✅ PASS | | `agent/assistant` | `TestSandboxMCPToolCall` | ✅ PASS | | `agent/assistant` | `TestSandboxMCPEchoTool` | ✅ PASS | ### Running Tests ```bash # Source environment source /Users/max/Yao/yao/env.local.sh # Run all sandbox tests go test -v ./agent/sandbox/... # Run assistant sandbox tests go test -v ./agent/assistant -run "Sandbox" # Run E2E tests (requires Docker) go test -v ./agent/assistant -run "TestSandbox.*E2E" -timeout 300s ``` ## File Structure ``` yao/agent/sandbox/ # Executor layer ├── DESIGN.md # ✅ Design document ├── PLAN.md # ✅ This file ├── types.go # ✅ Common types and interfaces ├── types_test.go # ✅ Types tests ├── executor.go # ✅ Factory function ├── executor_test.go # ✅ Factory tests ├── integration_test.go # ✅ Integration tests ├── claude/ │ ├── types.go # ✅ Claude-specific types │ ├── executor.go # ✅ Executor implementation │ ├── executor_test.go # ✅ Executor tests │ ├── command.go # ✅ Command builder + CCR config │ └── command_test.go # ✅ Command tests └── cursor/ └── README.md # ⏳ Placeholder (pending) yao/agent/assistant/ # Integration layer ├── sandbox.go # ✅ Sandbox handler ├── sandbox_test.go # ✅ Loading tests ├── sandbox_integration_test.go # ✅ Integration tests ├── sandbox_e2e_test.go # ✅ E2E tests ├── sandbox_debug_test.go # ✅ Debug tests └── agent.go # ✅ Modified: sandbox detection in Stream() yao/agent/context/ # Context layer ├── jsapi_sandbox.go # ✅ Sandbox JSAPI bindings └── jsapi_sandbox_test.go # ✅ Sandbox JSAPI tests yao-dev-app/assistants/tests/sandbox/ # Test assistants ├── basic/ # ✅ Basic sandbox test ├── hooks/ # ✅ Hooks test └── full/ # ✅ Full test with MCPs and Skills ``` ## Key Design Decisions ### 1. Container Reuse Same `userID + chatID` reuses the same container: - Workspace directory persists across requests - CCR config is written on each request (same content, safe to overwrite) - Container is removed when request completes ### 2. Concurrency - Container creation: Protected by mutex + double-check locking - Container execution: Multiple requests can run concurrently in same container - Claude CLI: Supports concurrent execution ### 3. CCR Configuration CCR requires specific JSON format: ```json { "Providers": [{"name": "volcengine", "api_base_url": "...", ...}], "Router": {"default": "volcengine,model", ...} } ``` Auto-detection of provider type based on host URL. ### 4. Resource Cleanup - `executor.Close()` removes the container and closes IPC session - `defer sandboxCleanup()` in `agent.go` ensures cleanup - Tests use unique chatID (timestamp) to avoid conflicts ### 5. MCP IPC Architecture ``` Host (Yao) Container (Claude CLI) ┌────────────────────────┐ ┌────────────────────────┐ │ IPC Manager │ │ yao-bridge │ │ └─ Session │◄─────────────│ (stdio ↔ socket) │ │ └─ MCPTools │ Unix Socket │ │ │ └─ Process │ (/tmp/ │ Claude CLI reads │ │ executor │ yao.sock) │ .mcp.json and calls │ └────────────────────────┘ │ yao-bridge for tools │ └────────────────────────┘ ``` - `.mcp.json` points to single "yao" server using `yao-bridge /tmp/yao.sock` - IPC session created with authorized MCP tools from assistant config - Tools executed via `process.New()` in IPC session handler ## Known Issues ### macOS Docker Desktop Socket Permissions On macOS with Docker Desktop (gRPC-FUSE), Unix socket permissions are not properly preserved when bind mounting from the host. The IPC socket created on the host with `0666` permissions appears as `0660` inside the container. **Solution**: After container start, we execute `chmod 666 /tmp/yao.sock` as root inside the container to fix permissions. This is handled automatically by `sandbox.Manager.fixIPCSocketPermissions()`. ## Notes - All tests validate return values (use `require`/`assert`) - Docker must be available for integration and E2E tests - Tests automatically clean up containers after completion - Use `uses.search: disabled` in test assistants to avoid auto-search LLM calls ================================================ FILE: agent/sandbox/claude/attachments_test.go ================================================ package claude import ( "context" "fmt" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func TestExtensionFromContentType(t *testing.T) { tests := []struct { contentType string expected string }{ {"image/png", ".png"}, {"image/jpeg", ".jpg"}, {"image/gif", ".gif"}, {"image/webp", ".webp"}, {"image/svg+xml", ".svg"}, {"application/pdf", ".pdf"}, {"text/plain", ".txt"}, {"text/html", ".html"}, {"text/css", ".css"}, {"text/javascript", ".js"}, {"application/javascript", ".js"}, {"application/json", ".json"}, {"application/zip", ".zip"}, {"application/octet-stream", ""}, {"unknown/type", ""}, } for _, tt := range tests { t.Run(tt.contentType, func(t *testing.T) { assert.Equal(t, tt.expected, extensionFromContentType(tt.contentType)) }) } } func TestFormatFileSize(t *testing.T) { tests := []struct { bytes int expected string }{ {0, "0B"}, {100, "100B"}, {1023, "1023B"}, {1024, "1.0KB"}, {1536, "1.5KB"}, {10240, "10.0KB"}, {1048576, "1.0MB"}, {1572864, "1.5MB"}, {10485760, "10.0MB"}, } for _, tt := range tests { t.Run(fmt.Sprintf("%d", tt.bytes), func(t *testing.T) { assert.Equal(t, tt.expected, formatFileSize(tt.bytes)) }) } } func TestPrepareAttachmentsPlainText(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() opts := &Options{ Command: "claude", Image: "alpine:latest", UserID: "test-user", ChatID: fmt.Sprintf("test-chat-att-plain-%d", time.Now().UnixNano()), } exec, err := NewExecutor(manager, opts) require.NoError(t, err) defer exec.Close() ctx := context.Background() // Plain text messages should pass through unchanged messages := []agentContext.Message{ {Role: "system", Content: "You are a helpful assistant"}, {Role: "user", Content: "Hello, world!"}, {Role: "assistant", Content: "Hi there!"}, {Role: "user", Content: "What is 1+1?"}, } result, err := exec.prepareAttachments(ctx, messages) require.NoError(t, err) require.Len(t, result, 4) // Verify messages are unchanged assert.Equal(t, "system", string(result[0].Role)) assert.Equal(t, "You are a helpful assistant", result[0].Content) assert.Equal(t, "Hello, world!", result[1].Content) assert.Equal(t, "Hi there!", result[2].Content) assert.Equal(t, "What is 1+1?", result[3].Content) } func TestPrepareAttachmentsMultimodalNoWrapper(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() opts := &Options{ Command: "claude", Image: "alpine:latest", UserID: "test-user", ChatID: fmt.Sprintf("test-chat-att-nowrap-%d", time.Now().UnixNano()), } exec, err := NewExecutor(manager, opts) require.NoError(t, err) defer exec.Close() ctx := context.Background() // Multimodal message with a non-wrapper URL (e.g. regular http URL) // Should convert to text description but not try to resolve attachment messages := []agentContext.Message{ { Role: "user", Content: []interface{}{ map[string]interface{}{"type": "text", "text": "Look at this"}, map[string]interface{}{ "type": "image_url", "image_url": map[string]interface{}{ "url": "https://example.com/image.png", "detail": "auto", }, }, }, }, } result, err := exec.prepareAttachments(ctx, messages) require.NoError(t, err) require.Len(t, result, 1) // Content should be converted to text with URL reference content, ok := result[0].Content.(string) require.True(t, ok, "Content should be converted to string") assert.Contains(t, content, "Look at this") assert.Contains(t, content, "[Image: https://example.com/image.png]") } func TestPrepareAttachmentsTextOnlyMultimodal(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() opts := &Options{ Command: "claude", Image: "alpine:latest", UserID: "test-user", ChatID: fmt.Sprintf("test-chat-att-textonly-%d", time.Now().UnixNano()), } exec, err := NewExecutor(manager, opts) require.NoError(t, err) defer exec.Close() ctx := context.Background() // Multimodal message with only text parts messages := []agentContext.Message{ { Role: "user", Content: []interface{}{ map[string]interface{}{"type": "text", "text": "Hello"}, map[string]interface{}{"type": "text", "text": "World"}, }, }, } result, err := exec.prepareAttachments(ctx, messages) require.NoError(t, err) require.Len(t, result, 1) // Should combine text parts content, ok := result[0].Content.(string) require.True(t, ok, "Content should be converted to string") assert.Contains(t, content, "Hello") assert.Contains(t, content, "World") } func TestPrepareAttachmentsInvalidWrapperURL(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() opts := &Options{ Command: "claude", Image: "alpine:latest", UserID: "test-user", ChatID: fmt.Sprintf("test-chat-att-invalid-%d", time.Now().UnixNano()), } exec, err := NewExecutor(manager, opts) require.NoError(t, err) defer exec.Close() ctx := context.Background() // Message with an attachment URL pointing to a non-existent manager messages := []agentContext.Message{ { Role: "user", Content: []interface{}{ map[string]interface{}{"type": "text", "text": "See this image"}, map[string]interface{}{ "type": "image_url", "image_url": map[string]interface{}{ "url": "__nonexistent.uploader://fakefile123", "detail": "auto", }, }, }, }, } result, err := exec.prepareAttachments(ctx, messages) require.NoError(t, err) require.Len(t, result, 1) // Should gracefully fallback to error text content, ok := result[0].Content.(string) require.True(t, ok, "Content should be converted to string") assert.Contains(t, content, "See this image") assert.Contains(t, content, "failed to load") } func TestPrepareAttachmentsMixedRoles(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() opts := &Options{ Command: "claude", Image: "alpine:latest", UserID: "test-user", ChatID: fmt.Sprintf("test-chat-att-mixed-%d", time.Now().UnixNano()), } exec, err := NewExecutor(manager, opts) require.NoError(t, err) defer exec.Close() ctx := context.Background() // Only user messages should be processed; system and assistant messages pass through messages := []agentContext.Message{ {Role: "system", Content: "System prompt"}, { Role: "user", Content: []interface{}{ map[string]interface{}{"type": "text", "text": "User message with image"}, map[string]interface{}{ "type": "image_url", "image_url": map[string]interface{}{ "url": "https://example.com/photo.jpg", "detail": "auto", }, }, }, }, {Role: "assistant", Content: "I can see the photo"}, {Role: "user", Content: "Thanks!"}, } result, err := exec.prepareAttachments(ctx, messages) require.NoError(t, err) require.Len(t, result, 4) // System and assistant messages unchanged assert.Equal(t, "System prompt", result[0].Content) assert.Equal(t, "I can see the photo", result[2].Content) assert.Equal(t, "Thanks!", result[3].Content) // User multimodal message converted content, ok := result[1].Content.(string) require.True(t, ok, "User multimodal content should be converted to string") assert.Contains(t, content, "User message with image") assert.Contains(t, content, "[Image: https://example.com/photo.jpg]") } ================================================ FILE: agent/sandbox/claude/command.go ================================================ package claude import ( "encoding/json" "fmt" "strings" "github.com/yaoapp/gou/connector" agentContext "github.com/yaoapp/yao/agent/context" ) // sandboxEnvPrompt is the system prompt injected for sandbox environment // This tells Claude CLI about the workspace and project structure const sandboxEnvPrompt = `## Sandbox Environment You are running in a sandboxed environment with the following setup: - **Working Directory**: /workspace - **Project Structure**: If this is a new project, create a dedicated project folder (e.g., /workspace/my-project/) and work inside it - **File Access**: You have full read/write access to /workspace - **Output Files**: Save all output files to the working directory When creating new projects: 1. Create a project directory with a descriptive name 2. Initialize the project structure inside that directory 3. Keep all related files organized within the project folder ## IMPORTANT: Restricted Tools The following tools are NOT available in this environment and you must NOT use them: - EnterPlanMode, ExitPlanMode (use regular text to explain plans instead) - Task, TaskOutput, TaskStop (complete tasks directly without delegation) - AskUserQuestion (make reasonable assumptions instead of asking) - Skill, ToolSearch (not supported) Focus on using the core tools: Bash, Read, Write, Edit, Glob, Grep, WebSearch, WebFetch. ## User Attachments User-uploaded files (images, documents, code files, etc.) are placed in /workspace/.attachments/ When the user references an attached file, read it from this directory using the Read or Bash tool. For image files, you can view them directly as Claude supports vision on local files. ## GitHub CLI (gh) Usage When working with GitHub and a token is provided: 1. First authenticate gh CLI using the token: echo "TOKEN" | gh auth login --with-token 2. Then use gh commands normally (gh repo create, gh pr create, etc.) 3. Do NOT use curl to call GitHub API directly - always prefer gh CLI ` // claudeArgWhitelist maps package.yao sandbox.arguments keys to Claude CLI flags. // Only keys listed here are passed through; everything else is ignored. var claudeArgWhitelist = map[string]string{ "max_turns": "--max-turns", // Maximum conversation turns "disallowed_tools": "--disallowed-tools", // Comma-separated tool blacklist (e.g. "WebSearch,WebFetch") "allowed_tools": "--allowedTools", // Comma-separated tool whitelist (e.g. "Bash,Read,Write") } // BuildCommand builds the Claude CLI command and environment variables // Uses stdin with --input-format stream-json for unlimited prompt length // isContinuation: if true, uses --continue to resume previous session (only sends last user message) func BuildCommand(messages []agentContext.Message, opts *Options) ([]string, map[string]string, error) { return BuildCommandWithContinuation(messages, opts, false) } // BuildCommandWithContinuation builds the Claude CLI command with continuation support // isContinuation: if true, uses --continue to resume previous session func BuildCommandWithContinuation(messages []agentContext.Message, opts *Options, isContinuation bool) ([]string, map[string]string, error) { // Build system prompt from conversation history (only for first request) var systemPrompt string if !isContinuation { systemPrompt, _ = buildPrompts(messages) // Inject sandbox environment prompt if systemPrompt != "" { systemPrompt = systemPrompt + "\n\n" + sandboxEnvPrompt } else { systemPrompt = sandboxEnvPrompt } } // Build input JSONL for Claude CLI (stream-json format) // For continuation, only send the last user message var inputJSONL []byte var err error if isContinuation { inputJSONL, err = BuildLastUserMessageJSONL(messages) } else { inputJSONL, err = BuildFirstRequestJSONL(messages) } if err != nil { return nil, nil, fmt.Errorf("failed to build input JSONL: %w", err) } // Build Claude CLI arguments var claudeArgs []string // Add permission mode (required for MCP tools to work) permMode := "bypassPermissions" // default for sandbox if opts != nil && opts.Arguments != nil { if mode, ok := opts.Arguments["permission_mode"].(string); ok && mode != "" { permMode = mode } } claudeArgs = append(claudeArgs, "--dangerously-skip-permissions") claudeArgs = append(claudeArgs, "--permission-mode", permMode) // Add streaming format flags (required for proper streaming output) claudeArgs = append(claudeArgs, "--input-format", "stream-json") claudeArgs = append(claudeArgs, "--output-format", "stream-json") claudeArgs = append(claudeArgs, "--include-partial-messages") // Enable realtime streaming claudeArgs = append(claudeArgs, "--verbose") // For continuation, use --continue to resume the previous session // Claude CLI will read session data from $HOME/.claude/ (which is /workspace/.claude/) if isContinuation { claudeArgs = append(claudeArgs, "--continue") } // Pass through whitelisted arguments to Claude CLI flags. // Map: package.yao arguments key → Claude CLI flag if opts != nil && opts.Arguments != nil { for key, flag := range claudeArgWhitelist { if val, ok := opts.Arguments[key]; ok { claudeArgs = append(claudeArgs, flag, fmt.Sprintf("%v", val)) } } } // Add MCP config if available if opts != nil && len(opts.MCPConfig) > 0 { claudeArgs = append(claudeArgs, "--mcp-config", "/workspace/.mcp.json") // Allow all tools from the "yao" MCP server claudeArgs = append(claudeArgs, "--allowedTools", "mcp__yao__*") } // Build the full bash command // Use heredoc for both system prompt and input JSONL to avoid shell escaping issues // System prompt may contain quotes, newlines, special characters that break shell quoting var bashCmd strings.Builder // Ensure $HOME/.Xauthority exists for PyAutoGUI/Xlib (HOME=/workspace). // Xvfb runs without auth, but Xlib requires the file to exist. bashCmd.WriteString("touch /home/sandbox/.Xauthority 2>/dev/null; touch \"$HOME/.Xauthority\" 2>/dev/null\n") // If we have a system prompt (first request only), write it to a temp file via heredoc first // then use --append-system-prompt-file if systemPrompt != "" { bashCmd.WriteString("cat << 'PROMPTEOF' > /tmp/.system-prompt.txt\n") bashCmd.WriteString(systemPrompt) bashCmd.WriteString("\nPROMPTEOF\n") claudeArgs = append(claudeArgs, "--append-system-prompt-file", "/tmp/.system-prompt.txt") } // Build claude command with all arguments // Append 2>&1 to the claude command so stderr is merged into stdout; // Docker's stdcopy discards the stderr stream, making errors invisible. bashCmd.WriteString("cat << 'INPUTEOF' | claude -p") for _, arg := range claudeArgs { bashCmd.WriteString(fmt.Sprintf(" %q", arg)) } bashCmd.WriteString(" 2>&1") bashCmd.WriteString("\n") bashCmd.WriteString(string(inputJSONL)) bashCmd.WriteString("\nINPUTEOF") cmd := []string{"bash", "-c", bashCmd.String()} // Build environment variables env := buildEnvironment(opts, systemPrompt) return cmd, env, nil } // BuildInputJSONL converts messages to Claude CLI stream-json input format // Deprecated: Use BuildFirstRequestJSONL or BuildLastUserMessageJSONL instead func BuildInputJSONL(messages []agentContext.Message) ([]byte, error) { return BuildFirstRequestJSONL(messages) } // BuildFirstRequestJSONL builds JSONL for the first request (all messages) // Sends all user and assistant messages to establish context func BuildFirstRequestJSONL(messages []agentContext.Message) ([]byte, error) { var lines []string for _, msg := range messages { // Skip system messages (handled via --system-prompt) if msg.Role == "system" { continue } // Build the message content var content interface{} if msg.Content != nil { content = msg.Content } else { content = "" } // Create stream-json message streamMsg := map[string]interface{}{ "type": string(msg.Role), // "user" or "assistant" "message": map[string]interface{}{ "role": string(msg.Role), "content": content, }, } jsonBytes, err := json.Marshal(streamMsg) if err != nil { return nil, fmt.Errorf("failed to marshal message: %w", err) } lines = append(lines, string(jsonBytes)) } return []byte(strings.Join(lines, "\n")), nil } // BuildLastUserMessageJSONL builds JSONL with only the last user message // Used for continuation requests where Claude CLI manages history via --continue func BuildLastUserMessageJSONL(messages []agentContext.Message) ([]byte, error) { // Find the last user message var lastUserMessage *agentContext.Message for i := len(messages) - 1; i >= 0; i-- { if messages[i].Role == "user" { lastUserMessage = &messages[i] break } } if lastUserMessage == nil { return nil, fmt.Errorf("no user message found") } var content interface{} if lastUserMessage.Content != nil { content = lastUserMessage.Content } else { content = "" } userMsg := map[string]interface{}{ "type": "user", "message": map[string]interface{}{ "role": "user", "content": content, }, } jsonBytes, err := json.Marshal(userMsg) if err != nil { return nil, fmt.Errorf("failed to marshal user message: %w", err) } return jsonBytes, nil } // buildPrompts extracts system prompt and user prompt from messages func buildPrompts(messages []agentContext.Message) (systemPrompt string, userPrompt string) { var systemParts []string var conversationParts []string var lastUserMessage string for _, msg := range messages { switch msg.Role { case "system": systemParts = append(systemParts, getMessageContent(msg)) case "user": lastUserMessage = getMessageContent(msg) conversationParts = append(conversationParts, fmt.Sprintf("User: %s", lastUserMessage)) case "assistant": conversationParts = append(conversationParts, fmt.Sprintf("Assistant: %s", getMessageContent(msg))) } } // Build system prompt with conversation history systemPrompt = strings.Join(systemParts, "\n\n") // If there's conversation history, include it in the system prompt if len(conversationParts) > 1 { historySection := "\n\n## Conversation History\n\n" + strings.Join(conversationParts[:len(conversationParts)-1], "\n\n") systemPrompt += historySection } // The user prompt is the last user message userPrompt = lastUserMessage return systemPrompt, userPrompt } // getMessageContent extracts text content from a message func getMessageContent(msg agentContext.Message) string { if msg.Content == nil { return "" } // Handle string content if str, ok := msg.Content.(string); ok { return str } // Handle content array (multimodal messages) if arr, ok := msg.Content.([]interface{}); ok { var parts []string for _, item := range arr { if m, ok := item.(map[string]interface{}); ok { if m["type"] == "text" { if text, ok := m["text"].(string); ok { parts = append(parts, text) } } } } return strings.Join(parts, "\n") } return "" } // buildEnvironment builds environment variables for Claude CLI func buildEnvironment(opts *Options, systemPrompt string) map[string]string { env := make(map[string]string) if opts == nil { return env } // Set HOME to /workspace so Claude CLI stores session data in the workspace // This allows session persistence across requests for the same chat // Session data is stored in $HOME/.claude/ (i.e., /workspace/.claude/) env["HOME"] = "/workspace" // Fix Python user-site-packages: changing HOME from /home/sandbox to /workspace // breaks Python's ability to find packages installed via pip --user (e.g., playwright, // pyautogui, playwright-stealth) which live in /home/sandbox/.local/lib/pythonX.Y/site-packages/ env["PYTHONPATH"] = "/home/sandbox/.local/lib/python3.12/site-packages" // Fix X11 auth: PyAutoGUI/Xlib looks for $HOME/.Xauthority, but HOME=/workspace // so it fails to find /home/sandbox/.Xauthority created during image build. // Explicitly set XAUTHORITY to the correct path. env["XAUTHORITY"] = "/home/sandbox/.Xauthority" if opts.ConnectorType == "anthropic" { // Anthropic mode: Claude CLI connects directly to the Anthropic-compatible backend // No proxy needed — the backend already speaks Anthropic Messages API env["ANTHROPIC_BASE_URL"] = opts.ConnectorHost env["ANTHROPIC_API_KEY"] = opts.ConnectorKey } else { // OpenAI mode (default): Claude CLI connects to claude-proxy on localhost:3456 // The proxy translates Anthropic Messages API → OpenAI Chat Completions API env["ANTHROPIC_BASE_URL"] = "http://127.0.0.1:3456" env["ANTHROPIC_API_KEY"] = "dummy" // Proxy doesn't verify this } // Set model environment variables from connector // Claude CLI uses these to select the model for all roles if opts.Model != "" { env["ANTHROPIC_MODEL"] = opts.Model env["ANTHROPIC_DEFAULT_OPUS_MODEL"] = opts.Model env["ANTHROPIC_DEFAULT_SONNET_MODEL"] = opts.Model env["ANTHROPIC_DEFAULT_HAIKU_MODEL"] = opts.Model env["CLAUDE_CODE_SUBAGENT_MODEL"] = opts.Model } // Pass secrets as environment variables for Claude CLI to use // These are configured in package.yao sandbox.secrets (e.g., LLM_API_KEY, GITHUB_TOKEN) // start-claude-proxy also exports them for the proxy process, but Claude CLI // is launched via a separate docker exec, so it needs them passed explicitly here. if len(opts.Secrets) > 0 { for k, v := range opts.Secrets { env[k] = v } } // Note: System prompt and max_turns are passed via CLI flags in BuildCommand // CLAUDE_SYSTEM_PROMPT environment variable is NOT supported by Claude CLI // --append-system-prompt or --system-prompt flags must be used instead return env } // BuildProxyConfig builds the claude-proxy configuration JSON // This config file is read by start-claude-proxy script in the container // Config is written to /tmp/.yao/proxy.json (not /workspace/) for security func BuildProxyConfig(opts *Options) ([]byte, error) { if opts == nil { return nil, fmt.Errorf("options is required") } // Build backend URL using the shared connector.BuildAPIURL helper // so that the /v1 prefix is applied consistently with the agent LLM path. backendURL := connector.BuildAPIURL(opts.ConnectorHost, "/chat/completions") config := map[string]interface{}{ "backend": backendURL, "api_key": opts.ConnectorKey, "model": opts.Model, } // Add extra connector options if present (e.g., thinking, max_tokens, temperature) // These will be passed to the proxy via CLAUDE_PROXY_OPTIONS environment variable if len(opts.ConnectorOptions) > 0 { config["options"] = opts.ConnectorOptions } // Add secrets if present (e.g., GITHUB_TOKEN, AWS_ACCESS_KEY) // These will be exported as environment variables for Claude CLI to use if len(opts.Secrets) > 0 { config["secrets"] = opts.Secrets } return json.MarshalIndent(config, "", " ") } // BuildCCRConfig is deprecated, kept for backward compatibility // Use BuildProxyConfig instead func BuildCCRConfig(opts *Options) ([]byte, error) { return BuildProxyConfig(opts) } ================================================ FILE: agent/sandbox/claude/command_test.go ================================================ package claude import ( "encoding/json" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" agentContext "github.com/yaoapp/yao/agent/context" ) func TestBuildCommand(t *testing.T) { messages := []agentContext.Message{ {Role: "system", Content: "You are a helpful assistant"}, {Role: "user", Content: "Hello"}, } opts := &Options{ ConnectorHost: "https://api.example.com", ConnectorKey: "key123", Model: "test-model", } cmd, env, err := BuildCommand(messages, opts) require.NoError(t, err) // Verify command structure // Command is now: ["bash", "-c", "cat << 'INPUTEOF' | claude -p ... INPUTEOF"] assert.Equal(t, "bash", cmd[0]) assert.Equal(t, "-c", cmd[1]) // User message should be in bash command (as JSONL via stdin) assert.Contains(t, cmd[2], "Hello") // Should have stream-json flags assert.Contains(t, cmd[2], "--input-format") assert.Contains(t, cmd[2], "--output-format") assert.Contains(t, cmd[2], "--include-partial-messages") assert.Contains(t, cmd[2], "--verbose") assert.Contains(t, cmd[2], "stream-json") // Verify environment variables (claude-proxy) assert.Equal(t, "http://127.0.0.1:3456", env["ANTHROPIC_BASE_URL"]) assert.Equal(t, "dummy", env["ANTHROPIC_API_KEY"]) } func TestBuildCommandWithSystemPrompt(t *testing.T) { messages := []agentContext.Message{ {Role: "system", Content: "You are a code reviewer"}, {Role: "user", Content: "Review this code"}, {Role: "assistant", Content: "Sure, I'll review it"}, {Role: "user", Content: "Here is the code"}, } opts := &Options{} cmd, _, err := BuildCommand(messages, opts) require.NoError(t, err) // System prompt should be written to file via heredoc, then passed via --append-system-prompt-file bashCmd := cmd[2] // The bash -c command string assert.Contains(t, bashCmd, "cat << 'PROMPTEOF' > /tmp/.system-prompt.txt") assert.Contains(t, bashCmd, "You are a code reviewer") assert.Contains(t, bashCmd, "PROMPTEOF") assert.Contains(t, bashCmd, "--append-system-prompt-file") assert.Contains(t, bashCmd, "/tmp/.system-prompt.txt") } func TestBuildCommandWithSpecialCharsInPrompt(t *testing.T) { // Test that special characters in prompts are handled correctly messages := []agentContext.Message{ {Role: "system", Content: "You are a helper.\n\n## Rules\n- Rule 1: Don't use \"quotes\" wrongly\n- Rule 2: Handle 'single quotes' too\n- Rule 3: Special chars like $VAR and `backticks`"}, {Role: "user", Content: "Hello"}, } opts := &Options{} cmd, _, err := BuildCommand(messages, opts) require.NoError(t, err) bashCmd := cmd[2] // The heredoc approach should preserve all special characters assert.Contains(t, bashCmd, "## Rules") assert.Contains(t, bashCmd, `Don't use "quotes" wrongly`) assert.Contains(t, bashCmd, "'single quotes'") } func TestBuildCommandWithArguments(t *testing.T) { messages := []agentContext.Message{ {Role: "user", Content: "Hello"}, } opts := &Options{ Arguments: map[string]interface{}{ "max_turns": 20, "permission_mode": "acceptEdits", }, } cmd, _, err := BuildCommand(messages, opts) require.NoError(t, err) bashCmd := cmd[2] // The bash -c command string // max_turns should be in command args via --max-turns assert.Contains(t, bashCmd, "--max-turns") assert.Contains(t, bashCmd, "20") // permission_mode should be in command args assert.Contains(t, bashCmd, "acceptEdits") } func TestBuildProxyConfig(t *testing.T) { opts := &Options{ ConnectorHost: "https://api.example.com", ConnectorKey: "key123", Model: "test-model", } configJSON, err := BuildProxyConfig(opts) require.NoError(t, err) configStr := string(configJSON) // Proxy config uses simple format // BuildAPIURL adds /v1 prefix for hosts that don't end with "/" assert.Contains(t, configStr, "backend") assert.Contains(t, configStr, "https://api.example.com/v1/chat/completions") assert.Contains(t, configStr, "api_key") assert.Contains(t, configStr, "key123") assert.Contains(t, configStr, "model") assert.Contains(t, configStr, "test-model") } func TestBuildProxyConfigVolcengine(t *testing.T) { opts := &Options{ ConnectorHost: "https://ark.cn-beijing.volces.com/api/v3/", ConnectorKey: "test-key", Model: "ep-xxx", } configJSON, err := BuildProxyConfig(opts) require.NoError(t, err) configStr := string(configJSON) // URL should end with /chat/completions assert.Contains(t, configStr, "/chat/completions") assert.Contains(t, configStr, "ep-xxx") } func TestBuildInputJSONL(t *testing.T) { messages := []agentContext.Message{ {Role: "system", Content: "You are helpful"}, {Role: "user", Content: "Hello"}, {Role: "assistant", Content: "Hi there!"}, {Role: "user", Content: "How are you?"}, } jsonl, err := BuildInputJSONL(messages) require.NoError(t, err) // Should not contain system messages (handled separately) assert.NotContains(t, string(jsonl), "You are helpful") // Should contain user and assistant messages assert.Contains(t, string(jsonl), "Hello") assert.Contains(t, string(jsonl), "Hi there!") assert.Contains(t, string(jsonl), "How are you?") // Verify JSONL format (each line is valid JSON) lines := splitLines(string(jsonl)) for _, line := range lines { if line == "" { continue } var msg map[string]interface{} err := json.Unmarshal([]byte(line), &msg) assert.NoError(t, err, "Line should be valid JSON: %s", line) assert.Contains(t, msg, "type") assert.Contains(t, msg, "message") } } func TestBuildInputJSONLMultimodal(t *testing.T) { // Test with multimodal content (image) messages := []agentContext.Message{ { Role: "user", Content: []interface{}{ map[string]interface{}{"type": "text", "text": "What's in this image?"}, map[string]interface{}{ "type": "image", "source": map[string]interface{}{ "type": "base64", "media_type": "image/png", "data": "iVBORw0KGgo=", }, }, }, }, } jsonl, err := BuildInputJSONL(messages) require.NoError(t, err) // Should contain the multimodal content assert.Contains(t, string(jsonl), "What's in this image?") assert.Contains(t, string(jsonl), "image") assert.Contains(t, string(jsonl), "base64") } func TestGetMessageContent(t *testing.T) { // String content msg1 := agentContext.Message{Content: "Hello World"} assert.Equal(t, "Hello World", getMessageContent(msg1)) // Nil content msg2 := agentContext.Message{Content: nil} assert.Equal(t, "", getMessageContent(msg2)) // Array content (multimodal) msg3 := agentContext.Message{ Content: []interface{}{ map[string]interface{}{"type": "text", "text": "Part 1"}, map[string]interface{}{"type": "text", "text": "Part 2"}, }, } assert.Contains(t, getMessageContent(msg3), "Part 1") assert.Contains(t, getMessageContent(msg3), "Part 2") } // Helper to split lines func splitLines(s string) []string { var lines []string start := 0 for i := 0; i < len(s); i++ { if s[i] == '\n' { lines = append(lines, s[start:i]) start = i + 1 } } if start < len(s) { lines = append(lines, s[start:]) } return lines } ================================================ FILE: agent/sandbox/claude/e2e_test.go ================================================ package claude import ( "context" "fmt" "os" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/config" infraSandbox "github.com/yaoapp/yao/sandbox" "github.com/yaoapp/yao/test" ) // TestE2ESkipClaudeCLI verifies that Claude CLI is skipped when no prompts/skills/mcp // This is the "hook-only" mode where hooks take full control func TestE2ESkipClaudeCLI(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test in short mode") } test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() // Create options WITHOUT SystemPrompt, SkillsDir, or MCPConfig // This should trigger the skip logic opts := &Options{ Command: "claude", Image: "alpine:latest", // Use alpine since we're not calling Claude CLI UserID: "test-user", ChatID: fmt.Sprintf("test-e2e-skip-%d", time.Now().UnixNano()), ConnectorHost: "", ConnectorKey: "", Model: "", // No SystemPrompt, SkillsDir, or MCPConfig } exec, err := NewExecutor(manager, opts) require.NoError(t, err) defer exec.Close() // Verify shouldSkipClaudeCLI returns true assert.True(t, exec.shouldSkipClaudeCLI(), "Should skip Claude CLI when no prompts/skills/mcp") // Execute Stream - it should return immediately without calling Claude CLI ctx := agentContext.New(context.Background(), nil, opts.ChatID) messages := []agentContext.Message{ {Role: "user", Content: "Hello"}, } response, err := exec.Stream(ctx, messages, nil) require.NoError(t, err, "Stream should succeed") require.NotNil(t, response, "Response should not be nil") // Verify response indicates skip assert.Contains(t, response.ID, "sandbox-skip", "Response ID should indicate skip") assert.Equal(t, "sandbox", response.Model, "Model should be 'sandbox' for skip mode") assert.Empty(t, response.Content, "Content should be empty for skip mode") t.Log("✓ Claude CLI skip mode verified") } // TestE2EExecuteClaudeCLI verifies that Claude CLI is called when prompts are configured // This requires the real yaoapp/sandbox-claude image and a valid connector func TestE2EExecuteClaudeCLI(t *testing.T) { if testing.Short() { t.Skip("Skipping E2E test in short mode") } // Check for required environment variables apiKey := os.Getenv("DEEPSEEK_API_KEY") apiProxy := os.Getenv("DEEPSEEK_API_PROXY") model := os.Getenv("DEEPSEEK_MODELS_V3") if apiKey == "" || apiProxy == "" || model == "" { t.Skip("Skipping test: DEEPSEEK_API_KEY, DEEPSEEK_API_PROXY, or DEEPSEEK_MODELS_V3 not set") } test.Prepare(t, config.Conf) defer test.Clean() // Get data root from environment dataRoot := os.Getenv("YAO_ROOT") if dataRoot == "" { t.Skip("Skipping test: YAO_ROOT not set") } // Create config with proper paths cfg := infraSandbox.DefaultConfig() cfg.Init(dataRoot) manager, err := infraSandbox.NewManager(cfg) if err != nil { t.Skipf("Skipping test: Docker not available: %v", err) } defer manager.Close() // Create options WITH SystemPrompt (triggers Claude CLI execution) opts := &Options{ Command: "claude", Image: "yaoapp/sandbox-claude:latest", UserID: "test-user", ChatID: fmt.Sprintf("test-e2e-exec-%d", time.Now().UnixNano()), ConnectorHost: apiProxy, ConnectorKey: apiKey, Model: model, SystemPrompt: "You are a helpful assistant. Keep responses brief.", Timeout: 5 * time.Minute, } exec, err := NewExecutor(manager, opts) if err != nil { t.Skipf("Skipping test: Failed to create executor: %v", err) } defer exec.Close() // Verify shouldSkipClaudeCLI returns false assert.False(t, exec.shouldSkipClaudeCLI(), "Should NOT skip Claude CLI when prompts are configured") // Execute Stream with a simple prompt ctx := agentContext.New(context.Background(), nil, opts.ChatID) messages := []agentContext.Message{ {Role: "user", Content: "Reply with exactly: TEST_SUCCESS"}, } // Collect streaming output var streamedContent strings.Builder streamHandler := func(chunkType message.StreamChunkType, data []byte) int { if chunkType == message.ChunkText { streamedContent.Write(data) } return 0 // continue streaming } t.Log("Executing Claude CLI with real API call...") startTime := time.Now() response, err := exec.Stream(ctx, messages, streamHandler) duration := time.Since(startTime) t.Logf("Execution took: %v", duration) if err != nil { t.Logf("Stream error (might be expected if Docker/API issue): %v", err) t.Skipf("Skipping assertion: %v", err) } require.NotNil(t, response, "Response should not be nil") // Log response details t.Logf("Response ID: %s", response.ID) t.Logf("Response Model: %s", response.Model) t.Logf("Response Content: %v", response.Content) t.Logf("Streamed Content: %s", streamedContent.String()) // Verify we got some response var fullResponse string if content, ok := response.Content.(string); ok { fullResponse = content } if fullResponse == "" { fullResponse = streamedContent.String() } if fullResponse != "" { t.Logf("✓ Claude CLI executed successfully with response: %s", truncate(fullResponse, 200)) } else { t.Log("⚠ Empty response (Claude CLI might have issues)") } } // TestE2EBuildInputJSONLIntegration tests the full flow of building input JSONL func TestE2EBuildInputJSONLIntegration(t *testing.T) { // Test with conversation history messages := []agentContext.Message{ {Role: "system", Content: "You are a helpful assistant"}, {Role: "user", Content: "What is 2+2?"}, {Role: "assistant", Content: "4"}, {Role: "user", Content: "What about 3+3?"}, } jsonl, err := BuildInputJSONL(messages) require.NoError(t, err) t.Logf("Input JSONL:\n%s", string(jsonl)) // Verify format lines := strings.Split(string(jsonl), "\n") assert.GreaterOrEqual(t, len(lines), 3, "Should have at least 3 lines (user, assistant, user)") // System message should NOT be in JSONL assert.NotContains(t, string(jsonl), "You are a helpful assistant", "System message should not be in JSONL") // User and assistant messages should be present assert.Contains(t, string(jsonl), "What is 2+2", "First user message should be present") assert.Contains(t, string(jsonl), "4", "Assistant response should be present") assert.Contains(t, string(jsonl), "What about 3+3", "Second user message should be present") t.Log("✓ Input JSONL format verified") } // TestE2EBuildCommand tests the full command building func TestE2EBuildCommand(t *testing.T) { messages := []agentContext.Message{ {Role: "system", Content: "You are helpful"}, {Role: "user", Content: "Hello"}, } opts := &Options{ ConnectorHost: "https://api.example.com", ConnectorKey: "test-key", Model: "test-model", Arguments: map[string]interface{}{ "permission_mode": "bypassPermissions", }, MCPConfig: []byte(`{"mcpServers":{}}`), } cmd, env, err := BuildCommand(messages, opts) require.NoError(t, err) t.Logf("Command: %v", cmd) t.Logf("Environment: %v", env) // Verify command structure assert.Equal(t, "bash", cmd[0]) assert.Equal(t, "-c", cmd[1]) bashCmd := cmd[2] // Should use heredoc with INPUTEOF assert.Contains(t, bashCmd, "cat << 'INPUTEOF'", "Should use heredoc") assert.Contains(t, bashCmd, "INPUTEOF", "Should have INPUTEOF delimiter") // Should have streaming flags assert.Contains(t, bashCmd, "--input-format", "Should have input-format flag") assert.Contains(t, bashCmd, "--output-format", "Should have output-format flag") assert.Contains(t, bashCmd, "--verbose", "Should have verbose flag") assert.Contains(t, bashCmd, "stream-json", "Should use stream-json format") // Should have permission flags assert.Contains(t, bashCmd, "--dangerously-skip-permissions", "Should have skip-permissions flag") assert.Contains(t, bashCmd, "--permission-mode", "Should have permission-mode flag") assert.Contains(t, bashCmd, "bypassPermissions", "Should have bypassPermissions value") // Should have MCP config assert.Contains(t, bashCmd, "--mcp-config", "Should have mcp-config flag") // Environment should have proxy settings assert.Equal(t, "http://127.0.0.1:3456", env["ANTHROPIC_BASE_URL"]) assert.Equal(t, "dummy", env["ANTHROPIC_API_KEY"]) // System prompt should be passed via CLI argument, not environment variable // CLAUDE_SYSTEM_PROMPT env var is NOT supported by Claude CLI assert.Contains(t, bashCmd, "--append-system-prompt", "Should have append-system-prompt flag") assert.Contains(t, bashCmd, "You are helpful", "System prompt should be in CLI args") t.Log("✓ Command building verified") } // Helper function to truncate strings func truncate(s string, maxLen int) string { if len(s) <= maxLen { return s } return s[:maxLen] + "..." } ================================================ FILE: agent/sandbox/claude/executor.go ================================================ package claude import ( "bufio" "context" "encoding/json" "fmt" "io" "log" "os" "path/filepath" "strings" "time" goujson "github.com/yaoapp/gou/json" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/attachment" infraSandbox "github.com/yaoapp/yao/sandbox" "github.com/yaoapp/yao/sandbox/ipc" ) // Options for Claude executor (copied from parent package to avoid import cycle) type Options struct { Command string Image string MaxMemory string MaxCPU float64 Timeout time.Duration Arguments map[string]interface{} UserID string ChatID string MCPConfig []byte MCPTools map[string]*ipc.MCPTool // MCP tools to expose via IPC SkillsDir string SystemPrompt string // System prompt from assistant prompts.yml ConnectorHost string ConnectorKey string Model string ConnectorType string // Connector API type: "openai" or "anthropic" ConnectorOptions map[string]interface{} // Extra connector options (e.g., thinking, max_tokens) Secrets map[string]string // Secrets to pass to container (e.g., GITHUB_TOKEN) } // Executor implements the sandbox.Executor interface for Claude CLI type Executor struct { manager *infraSandbox.Manager containerName string opts *Options workDir string loadingMsgID string // Loading message ID for tool execution updates } // NewExecutor creates a new Claude executor func NewExecutor(manager *infraSandbox.Manager, opts interface{}) (*Executor, error) { if manager == nil { return nil, fmt.Errorf("manager is required") } // Type assertion to get options var execOpts *Options switch o := opts.(type) { case *Options: execOpts = o default: // Try to convert from map or other struct return nil, fmt.Errorf("invalid options type: %T", opts) } if execOpts == nil { return nil, fmt.Errorf("options is required") } if execOpts.UserID == "" { return nil, fmt.Errorf("UserID is required") } if execOpts.ChatID == "" { return nil, fmt.Errorf("ChatID is required") } // Create or get container // Note: IPC session is created by manager.createContainer, socket is already bind mounted ctx := context.Background() createOpts := infraSandbox.CreateOptions{ UserID: execOpts.UserID, ChatID: execOpts.ChatID, Image: execOpts.Image, } container, err := manager.GetOrCreate(ctx, execOpts.UserID, execOpts.ChatID, createOpts) if err != nil { return nil, fmt.Errorf("failed to create container: %w", err) } // Get workspace directory from config config := manager.GetConfig() workDir := config.ContainerWorkDir if workDir == "" { workDir = "/workspace" } return &Executor{ manager: manager, containerName: container.Name, opts: execOpts, workDir: workDir, }, nil } // SetLoadingMsgID sets the loading message ID for tool execution updates func (e *Executor) SetLoadingMsgID(id string) { e.loadingMsgID = id } // Stream runs the Claude CLI with streaming output func (e *Executor) Stream(ctx *agentContext.Context, messages []agentContext.Message, handler message.StreamFunc) (*agentContext.CompletionResponse, error) { // Create a cancellable context for this stream operation // We need to handle both: // 1. HTTP context cancellation (client disconnect) // 2. InterruptController cancellation (user clicks "stop" button) // // Note on InterruptController: // - ctx.Interrupt.Context() is only cancelled when InterruptForce && len(Messages) == 0 // - When user sends messages with the interrupt, the context is NOT cancelled // - We use ctx.Interrupt.IsInterrupted() to check for any interrupt signal stdCtx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() // Start a goroutine to monitor for interrupts and HTTP context cancellation go func() { ticker := time.NewTicker(500 * time.Millisecond) defer ticker.Stop() for { select { case <-stdCtx.Done(): // Already cancelled, exit return case <-ticker.C: // Check if there's a pending interrupt signal using Peek() // This works even when Messages are included (which doesn't cancel the context) if ctx != nil && ctx.Interrupt != nil { if signal := ctx.Interrupt.Peek(); signal != nil { cancelFunc() return } } // Check InterruptController.IsInterrupted() (for context-cancelled interrupts) if ctx != nil && ctx.Interrupt != nil && ctx.Interrupt.IsInterrupted() { cancelFunc() return } // Check HTTP context if ctx != nil && ctx.Context != nil { select { case <-ctx.Context.Done(): cancelFunc() return default: } } } } }() // Set MCP tools for this request (dynamic, runtime configuration) if len(e.opts.MCPTools) > 0 { ipcManager := e.manager.GetIPCManager() if ipcManager != nil { if session, ok := ipcManager.Get(e.opts.ChatID); ok { session.SetMCPTools(e.opts.MCPTools) } } } // Prepare environment: write configs and copy skills if err := e.prepareEnvironment(stdCtx); err != nil { return nil, fmt.Errorf("failed to prepare environment: %w", err) } // Resolve attachment URLs and write files to container // This converts __yao.attachment:// URLs to local file paths in /workspace/.attachments/ if resolved, attErr := e.prepareAttachments(stdCtx, messages); attErr != nil { // Non-fatal: log warning and continue with original messages log.Printf("[sandbox] Warning: failed to prepare attachments: %v", attErr) } else { messages = resolved } // Check if we should skip Claude CLI execution // Skip if no prompts, no skills, and no MCP config skipCLI := e.shouldSkipClaudeCLI() if skipCLI { // Return empty response - hooks can use sandbox API to do their work return &agentContext.CompletionResponse{ ID: fmt.Sprintf("sandbox-skip-%d", time.Now().UnixNano()), Model: "sandbox", Created: time.Now().Unix(), Role: "assistant", Content: "", FinishReason: agentContext.FinishReasonStop, }, nil } // Check if this is a continuation (Claude CLI session exists in workspace) isContinuation := e.hasExistingSession(stdCtx) // Build Claude CLI command using stored options cmd, env, err := BuildCommandWithContinuation(messages, e.opts, isContinuation) if err != nil { return nil, fmt.Errorf("failed to build command: %w", err) } // Prepare execution options execOpts := &infraSandbox.ExecOptions{ WorkDir: e.workDir, Env: env, } if e.opts != nil && e.opts.Timeout > 0 { execOpts.Timeout = e.opts.Timeout } reader, err := e.manager.Stream(stdCtx, e.containerName, cmd, execOpts) if err != nil { return nil, fmt.Errorf("failed to execute command: %w", err) } // Ensure reader is closed when context is cancelled or function returns // This is important for cleanup when user clicks "stop" done := make(chan struct{}) defer func() { close(done) reader.Close() }() // Monitor for context cancellation and forcefully kill Claude CLI process go func() { // Also start a ticker to periodically check context status for debugging ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() for { select { case <-stdCtx.Done(): // First, kill the Claude CLI process inside the container // This is important because closing the reader/connection alone may not stop the process // Use a background context since stdCtx is already cancelled killCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Kill claude process (the Claude CLI binary) e.manager.KillProcess(killCtx, e.containerName, "claude") // Also close the reader to unblock any pending reads reader.Close() return case <-done: // Normal completion, nothing to do return case <-ticker.C: // Periodic check - no action needed } } }() // DEBUG: Tee the reader to write raw output to a log file for debugging debugLogPath := e.workDir + "/claude-cli-raw.log" debugReader := e.createDebugReader(stdCtx, reader, debugLogPath) // Parse streaming output (uses e.loadingMsgID set via SetLoadingMsgID) return e.parseStream(ctx, debugReader, handler) } // shouldSkipClaudeCLI checks if Claude CLI execution should be skipped // Skip when: no system prompt, no skills, and no MCP config func (e *Executor) shouldSkipClaudeCLI() bool { hasPrompts := e.opts.SystemPrompt != "" hasSkills := e.opts.SkillsDir != "" hasMCP := len(e.opts.MCPConfig) > 0 // If any of these are present, execute Claude CLI return !hasPrompts && !hasSkills && !hasMCP } // hasExistingSession checks if Claude CLI has an existing session in the workspace // Claude CLI stores session data in $HOME/.claude/projects/ (which is /workspace/.claude/projects/) // If session data exists, we should use --continue to resume the session func (e *Executor) hasExistingSession(ctx context.Context) bool { // Check if /workspace/.claude/projects/ directory has any content // This indicates a previous session exists sessionDir := e.workDir + "/.claude/projects" files, err := e.manager.ListDir(ctx, e.containerName, sessionDir) if err != nil { // Directory doesn't exist or error reading - no existing session return false } // If there are any files/directories in the projects folder, session exists return len(files) > 0 } // prepareEnvironment prepares the container environment before execution // This includes: claude-proxy config, MCP config, and Skills directory func (e *Executor) prepareEnvironment(ctx context.Context) error { // 1. Write claude-proxy config and start the proxy if err := e.startClaudeProxy(ctx); err != nil { return fmt.Errorf("failed to start claude-proxy: %w", err) } // 2. Write MCP config if provided if len(e.opts.MCPConfig) > 0 { if err := e.writeMCPConfig(ctx); err != nil { return fmt.Errorf("failed to write MCP config: %w", err) } } // 3. Copy Skills directory if provided if e.opts.SkillsDir != "" { if err := e.copySkillsDirectory(ctx); err != nil { // Non-fatal: log warning but continue // Skills might not exist or be optional _ = err // Ignore error, skills are optional } } return nil } // startClaudeProxy writes proxy config and starts claude-proxy func (e *Executor) startClaudeProxy(ctx context.Context) error { // Skip if no connector configured (e.g., test containers without claude-proxy) if e.opts.ConnectorHost == "" || e.opts.ConnectorKey == "" { return nil } // Skip proxy for Anthropic connectors — Claude CLI connects directly // The backend already speaks Anthropic Messages API, no conversion needed if e.opts.ConnectorType == "anthropic" { return nil } // Build proxy config configJSON, err := BuildProxyConfig(e.opts) if err != nil { return fmt.Errorf("failed to build proxy config: %w", err) } // Create config directory (outside workspace for security - user can't see api_key/secrets) // /tmp/.yao/ is not visible to user's file manager configDir := "/tmp/.yao" if _, err := e.manager.Exec(ctx, e.containerName, []string{"mkdir", "-p", configDir}, nil); err != nil { return fmt.Errorf("failed to create config directory %s: %w", configDir, err) } // Write config to secure location (not in /workspace/) configPath := configDir + "/proxy.json" if err := e.manager.WriteFile(ctx, e.containerName, configPath, configJSON); err != nil { return fmt.Errorf("failed to write config to %s: %w", configPath, err) } // Start the proxy (only if start-claude-proxy exists in the image) result, err := e.manager.Exec(ctx, e.containerName, []string{"which", "start-claude-proxy"}, &infraSandbox.ExecOptions{ WorkDir: e.workDir, }) if err != nil || result.ExitCode != 0 { // start-claude-proxy not available (e.g., alpine test image), skip return nil } // Start the proxy result, err = e.manager.Exec(ctx, e.containerName, []string{"start-claude-proxy"}, &infraSandbox.ExecOptions{ WorkDir: e.workDir, Env: map[string]string{ "WORKSPACE": e.workDir, }, }) if err != nil { return fmt.Errorf("failed to start claude-proxy: %w", err) } if result.ExitCode != 0 { return fmt.Errorf("claude-proxy failed to start: %s", result.Stderr) } return nil } // writeMCPConfig writes the MCP configuration file to the container workspace func (e *Executor) writeMCPConfig(ctx context.Context) error { if len(e.opts.MCPConfig) == 0 { return nil } // Write MCP config to workspace (.mcp.json) mcpPath := e.workDir + "/.mcp.json" if err := e.manager.WriteFile(ctx, e.containerName, mcpPath, e.opts.MCPConfig); err != nil { return fmt.Errorf("failed to write MCP config to %s: %w", mcpPath, err) } return nil } // copySkillsDirectory copies the skills directory to the container func (e *Executor) copySkillsDirectory(ctx context.Context) error { if e.opts.SkillsDir == "" { return nil } // Target path in container: /workspace/.claude/skills/ // This follows Claude CLI's expected skills location claudeDir := e.workDir + "/.claude" // Create .claude directory first if _, err := e.manager.Exec(ctx, e.containerName, []string{"mkdir", "-p", claudeDir}, nil); err != nil { return fmt.Errorf("failed to create .claude directory: %w", err) } // Copy skills from host to container // CopyToContainer extracts tar to containerPath, and createTarFromPath uses // filepath.Dir(hostPath) as base, so if hostPath is /path/to/skills, // tar entries are like "skills/skill-name/SKILL.md" // Extracting to /workspace/.claude/ gives us /workspace/.claude/skills/skill-name/SKILL.md if err := e.manager.CopyToContainer(ctx, e.containerName, e.opts.SkillsDir, claudeDir); err != nil { return fmt.Errorf("failed to copy skills to container: %w", err) } return nil } // prepareAttachments resolves __yao.attachment:// URLs in messages, // writes the actual files to the container's /workspace/.attachments/ directory, // and replaces the attachment content parts with text references to the file paths. // This allows Claude CLI to read the files using its built-in Read/Bash tools. func (e *Executor) prepareAttachments(ctx context.Context, messages []agentContext.Message) ([]agentContext.Message, error) { // Track used filenames to handle duplicates usedNames := make(map[string]int) attachmentDir := e.workDir + "/.attachments" dirCreated := false hasAttachments := false result := make([]agentContext.Message, len(messages)) copy(result, messages) for i, msg := range result { if msg.Role != "user" { continue } // Handle content array (multimodal messages come as []interface{} from JSON) parts, ok := msg.Content.([]interface{}) if !ok { // Try typed content parts if typedParts, ok := msg.Content.([]agentContext.ContentPart); ok { iparts := make([]interface{}, len(typedParts)) for j, p := range typedParts { // Convert to map for uniform handling m := map[string]interface{}{"type": string(p.Type)} if p.Text != "" { m["text"] = p.Text } if p.ImageURL != nil { m["image_url"] = map[string]interface{}{ "url": p.ImageURL.URL, "detail": string(p.ImageURL.Detail), } } if p.File != nil { m["file"] = map[string]interface{}{ "url": p.File.URL, "filename": p.File.Filename, } } iparts[j] = m } parts = iparts } else { continue } } if len(parts) == 0 { continue } // Process each content part var textParts []string for _, item := range parts { m, ok := item.(map[string]interface{}) if !ok { continue } partType, _ := m["type"].(string) switch partType { case "text": if text, ok := m["text"].(string); ok && text != "" { textParts = append(textParts, text) } case "image_url": imgData, _ := m["image_url"].(map[string]interface{}) if imgData == nil { continue } url, _ := imgData["url"].(string) if url == "" { continue } uploaderName, fileID, isWrapper := attachment.Parse(url) if !isWrapper { // Not an attachment URL, keep as text reference textParts = append(textParts, fmt.Sprintf("[Image: %s]", url)) continue } // Resolve the attachment ref, err := e.resolveAttachment(ctx, uploaderName, fileID, "", attachmentDir, usedNames, &dirCreated) if err != nil { log.Printf("[sandbox] Warning: failed to resolve image attachment %s: %v", fileID, err) textParts = append(textParts, "[Attached image: failed to load]") continue } textParts = append(textParts, ref) hasAttachments = true case "file": fileData, _ := m["file"].(map[string]interface{}) if fileData == nil { continue } url, _ := fileData["url"].(string) hintName, _ := fileData["filename"].(string) if url == "" { continue } uploaderName, fileID, isWrapper := attachment.Parse(url) if !isWrapper { textParts = append(textParts, fmt.Sprintf("[File: %s]", url)) continue } ref, err := e.resolveAttachment(ctx, uploaderName, fileID, hintName, attachmentDir, usedNames, &dirCreated) if err != nil { log.Printf("[sandbox] Warning: failed to resolve file attachment %s: %v", fileID, err) textParts = append(textParts, "[Attached file: failed to load]") continue } textParts = append(textParts, ref) hasAttachments = true default: // Keep other types as-is (shouldn't happen normally) continue } } // Merge text parts into a single string when the original content was // a multimodal array ([]interface{} / []ContentPart). This is needed // even when only "text" parts are present so that downstream code // (BuildInputJSONL, etc.) always sees a plain string. if len(textParts) > 0 { newMsg := result[i] newMsg.Content = strings.Join(textParts, "\n\n") result[i] = newMsg } } if !hasAttachments { return result, nil } return result, nil } // resolveAttachment reads an attachment from the attachment manager and writes it // to the container's .attachments directory. Returns a text reference string. func (e *Executor) resolveAttachment( ctx context.Context, uploaderName, fileID, hintName, attachmentDir string, usedNames map[string]int, dirCreated *bool, ) (string, error) { // Get attachment manager manager, exists := attachment.Managers[uploaderName] if !exists { return "", fmt.Errorf("attachment manager not found: %s", uploaderName) } // Get file info fileInfo, err := manager.Info(ctx, fileID) if err != nil { return "", fmt.Errorf("failed to get file info: %w", err) } // Read file data data, err := manager.Read(ctx, fileID) if err != nil { return "", fmt.Errorf("failed to read file: %w", err) } // Determine filename filename := fileInfo.Filename if filename == "" && hintName != "" { filename = hintName } if filename == "" { // Fallback: use fileID with extension from content type ext := extensionFromContentType(fileInfo.ContentType) filename = fileID + ext } // Handle duplicate filenames baseName := filename if count, exists := usedNames[baseName]; exists { ext := filepath.Ext(filename) name := strings.TrimSuffix(filename, ext) filename = fmt.Sprintf("%s_%d%s", name, count+1, ext) usedNames[baseName] = count + 1 } else { usedNames[baseName] = 0 } // Create attachments directory if not yet created if !*dirCreated { if err := e.manager.WriteFile(ctx, e.containerName, attachmentDir+"/.keep", []byte("")); err != nil { return "", fmt.Errorf("failed to create attachments directory: %w", err) } *dirCreated = true } // Write file to container containerPath := attachmentDir + "/" + filename if err := e.manager.WriteFile(ctx, e.containerName, containerPath, data); err != nil { return "", fmt.Errorf("failed to write file to container: %w", err) } // Build human-readable size string sizeStr := formatFileSize(fileInfo.Bytes) // Return text reference return fmt.Sprintf("[Attached file: %s (%s, %s)]", containerPath, fileInfo.ContentType, sizeStr), nil } // extensionFromContentType returns a file extension for a given content type func extensionFromContentType(contentType string) string { switch contentType { case "image/png": return ".png" case "image/jpeg": return ".jpg" case "image/gif": return ".gif" case "image/webp": return ".webp" case "image/svg+xml": return ".svg" case "application/pdf": return ".pdf" case "text/plain": return ".txt" case "text/html": return ".html" case "text/css": return ".css" case "text/javascript", "application/javascript": return ".js" case "application/json": return ".json" case "application/zip": return ".zip" default: return "" } } // formatFileSize returns a human-readable file size string func formatFileSize(bytes int) string { if bytes < 1024 { return fmt.Sprintf("%dB", bytes) } if bytes < 1024*1024 { return fmt.Sprintf("%.1fKB", float64(bytes)/1024) } return fmt.Sprintf("%.1fMB", float64(bytes)/(1024*1024)) } // Execute runs the Claude CLI and returns the response func (e *Executor) Execute(ctx *agentContext.Context, messages []agentContext.Message) (*agentContext.CompletionResponse, error) { return e.Stream(ctx, messages, nil) } // debugWriter wraps an io.Reader to write all data to a debug log file type debugWriter struct { reader io.Reader logFile *os.File buffer []byte } func (d *debugWriter) Read(p []byte) (n int, err error) { n, err = d.reader.Read(p) if n > 0 && d.logFile != nil { // Write raw bytes to log file d.logFile.Write(p[:n]) d.logFile.Sync() } return n, err } func (d *debugWriter) Close() error { if d.logFile != nil { d.logFile.Close() } return nil } // createDebugReader creates a tee reader that writes to a debug log file // The log file is written to the container's workspace for inspection func (e *Executor) createDebugReader(ctx context.Context, reader io.ReadCloser, logPath string) io.Reader { // Create a local temp file for debug logging // We write to a local file first, then copy to container when done localLogPath := "/tmp/claude-cli-debug-" + e.containerName + ".log" logFile, err := os.Create(localLogPath) if err != nil { return reader } // Write header logFile.WriteString("=== Claude CLI Raw Output Debug Log ===\n") logFile.WriteString(fmt.Sprintf("Container: %s\n", e.containerName)) logFile.WriteString(fmt.Sprintf("Time: %s\n", time.Now().Format(time.RFC3339))) logFile.WriteString(fmt.Sprintf("WorkDir: %s\n", e.workDir)) logFile.WriteString("=== BEGIN OUTPUT ===\n") logFile.Sync() return &debugWriter{ reader: reader, logFile: logFile, } } // parseStream parses Claude CLI streaming output (stream-json format) // Claude CLI output format with --include-partial-messages: // - {"type":"system","subtype":"init",...} - initialization // - {"type":"stream_event","event":{"delta":{"type":"text_delta","text":"..."}}} - real-time text deltas // - {"type":"assistant","message":{...,"content":[{"type":"text","text":"..."}],...}} - complete messages // - {"type":"result","subtype":"success",...,"result":"..."} - final result func (e *Executor) parseStream(ctx *agentContext.Context, reader io.Reader, handler message.StreamFunc) (*agentContext.CompletionResponse, error) { scanner := bufio.NewScanner(reader) // Increase buffer size for potentially large outputs buf := make([]byte, 0, 64*1024) scanner.Buffer(buf, 1024*1024) var textContent strings.Builder var toolCalls []agentContext.ToolCall var model string var usage *message.UsageInfo var finalResult string messageStarted := false // Track if we've sent ChunkMessageStart prepLoadingClosed := false // Track if "preparing sandbox" loading has been closed // Tool input accumulation state type toolState struct { name string index int inputJSON strings.Builder loadingID string // Each tool has its own loading message } var currentTool *toolState var lastToolLoadingID string // Track the last tool loading ID to close it // Helper function to close "preparing sandbox" loading on first output closePrepLoading := func() { if !prepLoadingClosed && e.loadingMsgID != "" && ctx != nil { doneMsg := &message.Message{ MessageID: e.loadingMsgID, Delta: true, DeltaAction: message.DeltaReplace, Type: message.TypeLoading, Props: map[string]interface{}{ "message": "", "done": true, }, } ctx.Send(doneMsg) prepLoadingClosed = true } } lineCount := 0 // Get the underlying context for cancellation checks var stdCtx context.Context if ctx != nil && ctx.Context != nil { stdCtx = ctx.Context } else { stdCtx = context.Background() } for scanner.Scan() { // Check for context cancellation on each iteration select { case <-stdCtx.Done(): return nil, stdCtx.Err() default: // Continue processing } line := scanner.Text() lineCount++ if line == "" { continue } // Try to parse as JSON (Claude CLI --output-format stream-json) var msg map[string]interface{} if err := json.Unmarshal([]byte(line), &msg); err != nil { // Not JSON, might be plain text output textContent.WriteString(line) textContent.WriteString("\n") continue } msgType, _ := msg["type"].(string) // Process Claude CLI stream-json message types switch msgType { case "system": // Initialization message - extract model if available if m, ok := msg["model"].(string); ok { model = m } case "stream_event": // Real-time streaming event (from --include-partial-messages) // Format: {"type":"stream_event","event":{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"..."}}} if event, ok := msg["event"].(map[string]interface{}); ok { eventType, _ := event["type"].(string) switch eventType { case "content_block_start": // Handle new content blocks // Format: {"event":{"type":"content_block_start","index":1,"content_block":{"type":"tool_use"|"text",...}}} if contentBlock, ok := event["content_block"].(map[string]interface{}); ok { blockType, _ := contentBlock["type"].(string) switch blockType { case "text": // New text block starting - add paragraph separator if we already have content // This ensures proper separation between text blocks across tool-use rounds if textContent.Len() > 0 { textContent.WriteString("\n\n") if handler != nil && messageStarted { handler(message.ChunkText, []byte("\n\n")) } } case "tool_use": toolName, _ := contentBlock["name"].(string) blockIndex := 0 if idx, ok := event["index"].(float64); ok { blockIndex = int(idx) } if toolName != "" && ctx != nil { // Close "preparing sandbox" loading on first tool closePrepLoading() // Close previous tool loading if exists if lastToolLoadingID != "" { doneMsg := &message.Message{ MessageID: lastToolLoadingID, Delta: true, DeltaAction: message.DeltaReplace, Type: message.TypeLoading, Props: map[string]interface{}{ "message": "", "done": true, }, } ctx.Send(doneMsg) } // Create new loading message for this tool locale := ctx.Locale toolLoadingMsg := &message.Message{ Type: message.TypeLoading, Props: map[string]interface{}{ "message": getToolDescription(toolName, locale), }, } newLoadingID, _ := ctx.SendStream(toolLoadingMsg) // Initialize tool state for input accumulation currentTool = &toolState{ name: toolName, index: blockIndex, loadingID: newLoadingID, } lastToolLoadingID = newLoadingID log.Printf("[Sandbox] Tool started: %s", toolName) } } } case "content_block_delta": if delta, ok := event["delta"].(map[string]interface{}); ok { deltaType, _ := delta["type"].(string) switch deltaType { case "text_delta": if text, ok := delta["text"].(string); ok && text != "" { // Close "preparing sandbox" loading on first text output closePrepLoading() // Send to stream handler for real-time output if handler != nil { // Send ChunkMessageStart first if not already started if !messageStarted { startData := message.EventMessageStartData{ MessageID: fmt.Sprintf("sandbox-%d", time.Now().UnixNano()), Type: "text", Timestamp: time.Now().UnixMilli(), } startDataJSON, _ := json.Marshal(startData) handler(message.ChunkMessageStart, startDataJSON) messageStarted = true } handler(message.ChunkText, []byte(text)) } // Also accumulate for final response textContent.WriteString(text) } case "input_json_delta": // Accumulate tool input JSON fragments if currentTool != nil { if partialJSON, ok := delta["partial_json"].(string); ok { currentTool.inputJSON.WriteString(partialJSON) } } } } case "content_block_stop": // Tool input complete - parse and update loading with detailed info if currentTool != nil && currentTool.loadingID != "" && ctx != nil { inputStr := currentTool.inputJSON.String() if inputStr != "" { // Use gou/json.Parse for fault-tolerant parsing locale := ctx.Locale detailedMsg := getToolDetailedDescription(currentTool.name, inputStr, locale) if detailedMsg != "" { toolMsg := &message.Message{ MessageID: currentTool.loadingID, Delta: true, DeltaAction: message.DeltaReplace, Type: message.TypeLoading, Props: map[string]interface{}{ "message": detailedMsg, }, } ctx.Send(toolMsg) log.Printf("[Sandbox] Tool: %s -> %s", currentTool.name, detailedMsg) } } // Note: Don't close loading here - it will be closed when next tool starts or at end // Reset tool state but keep lastToolLoadingID to close it later currentTool = nil } } } case "assistant": // Assistant message - extract content // With --include-partial-messages, we receive real-time text via stream_event // The assistant message contains the full accumulated content if msgData, ok := msg["message"].(map[string]interface{}); ok { // Get model from message if m, ok := msgData["model"].(string); ok && model == "" { model = m } // Check if this is the final message (has stop_reason) stopReason, hasStopReason := msgData["stop_reason"].(string) isFinalMessage := hasStopReason && stopReason != "" // Extract content from final message // This serves as a fallback if stream_event wasn't received if isFinalMessage { if contentArr, ok := msgData["content"].([]interface{}); ok { for _, item := range contentArr { if contentItem, ok := item.(map[string]interface{}); ok { itemType, _ := contentItem["type"].(string) switch itemType { case "text": // Only use this if we haven't already accumulated text from stream_event if textContent.Len() == 0 { if text, ok := contentItem["text"].(string); ok && text != "" { textContent.WriteString(text) // Send to stream handler if available if handler != nil { if !messageStarted { startData := message.EventMessageStartData{ MessageID: fmt.Sprintf("sandbox-%d", time.Now().UnixNano()), Type: "text", Timestamp: time.Now().UnixMilli(), } startDataJSON, _ := json.Marshal(startData) handler(message.ChunkMessageStart, startDataJSON) messageStarted = true } handler(message.ChunkText, []byte(text)) } } } case "tool_use": toolName := getString(contentItem, "name") toolCall := agentContext.ToolCall{ ID: getString(contentItem, "id"), Type: agentContext.ToolTypeFunction, Function: agentContext.Function{ Name: toolName, }, } // Get input as JSON string var inputJSONStr string if input, ok := contentItem["input"]; ok { if inputJSON, err := json.Marshal(input); err == nil { inputJSONStr = string(inputJSON) toolCall.Function.Arguments = inputJSONStr } } toolCalls = append(toolCalls, toolCall) // Create tool loading message (from complete assistant message) // This is a fallback for when stream_event wasn't received if toolName != "" && ctx != nil { // Close previous tool loading if exists if lastToolLoadingID != "" { doneMsg := &message.Message{ MessageID: lastToolLoadingID, Delta: true, DeltaAction: message.DeltaReplace, Type: message.TypeLoading, Props: map[string]interface{}{ "message": "", "done": true, }, } ctx.Send(doneMsg) } // Create new loading for this tool locale := ctx.Locale detailedMsg := getToolDetailedDescription(toolName, inputJSONStr, locale) if detailedMsg != "" { toolLoadingMsg := &message.Message{ Type: message.TypeLoading, Props: map[string]interface{}{ "message": detailedMsg, }, } newLoadingID, _ := ctx.SendStream(toolLoadingMsg) lastToolLoadingID = newLoadingID log.Printf("[Sandbox] Tool: %s -> %s", toolName, detailedMsg) } } } } } } } // Extract usage (from any message that has it) if usageData, ok := msgData["usage"].(map[string]interface{}); ok { usage = &message.UsageInfo{} if v, ok := usageData["input_tokens"].(float64); ok { usage.PromptTokens = int(v) } if v, ok := usageData["output_tokens"].(float64); ok { usage.CompletionTokens = int(v) } usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } } case "result": // Final result message // Check if this is an error result (is_error: true) isError, _ := msg["is_error"].(bool) if result, ok := msg["result"].(string); ok { if isError { // This is an error - return it as an error return nil, fmt.Errorf("Claude CLI error: %s", result) } finalResult = result } // Send done signal to handler (only if message was started and not an error) if handler != nil && messageStarted && !isError { handler(message.ChunkMessageEnd, nil) } case "error": // Error message if errMsg, ok := msg["error"].(string); ok { return nil, fmt.Errorf("Claude CLI error: %s", errMsg) } if errObj, ok := msg["error"].(map[string]interface{}); ok { if errMsg, ok := errObj["message"].(string); ok { return nil, fmt.Errorf("Claude CLI error: %s", errMsg) } } } } scanErr := scanner.Err() if scanErr != nil { return nil, fmt.Errorf("error reading stream: %w", scanErr) } // Close the last tool loading message if exists if lastToolLoadingID != "" && ctx != nil { doneMsg := &message.Message{ MessageID: lastToolLoadingID, Delta: true, DeltaAction: message.DeltaReplace, Type: message.TypeLoading, Props: map[string]interface{}{ "message": "", "done": true, }, } ctx.Send(doneMsg) } // Use final result if available, otherwise use accumulated text content content := textContent.String() if finalResult != "" && content == "" { content = finalResult } // Build response response := &agentContext.CompletionResponse{ ID: fmt.Sprintf("sandbox-%d", time.Now().UnixNano()), Model: model, Created: time.Now().Unix(), Role: "assistant", Content: content, FinishReason: agentContext.FinishReasonStop, } // Add tool calls if any if len(toolCalls) > 0 { response.ToolCalls = toolCalls response.FinishReason = agentContext.FinishReasonToolCalls } // Add usage if available if usage != nil { response.Usage = usage } return response, nil } // truncateStr truncates a string to maxLen characters func truncateStr(s string, maxLen int) string { if len(s) <= maxLen { return s } return s[:maxLen] + "..." } // getToolDescription returns a human-readable, localized description for a Claude CLI tool func getToolDescription(toolName string, locale string) string { // Map tool names to i18n keys toolKeys := map[string]string{ "Read": "sandbox.tool.read", "Write": "sandbox.tool.write", "Edit": "sandbox.tool.edit", "StrReplace": "sandbox.tool.edit", "Bash": "sandbox.tool.bash", "Shell": "sandbox.tool.bash", "Glob": "sandbox.tool.glob", "Grep": "sandbox.tool.grep", "LS": "sandbox.tool.ls", "Task": "sandbox.tool.task", "WebSearch": "sandbox.tool.web_search", "WebFetch": "sandbox.tool.web_fetch", "TodoWrite": "sandbox.tool.todo_write", "AskQuestion": "sandbox.tool.ask_question", "SwitchMode": "sandbox.tool.switch_mode", "ReadLints": "sandbox.tool.read_lints", "EditNotebook": "sandbox.tool.edit_notebook", } if key, ok := toolKeys[toolName]; ok { return i18n.T(locale, key) } // For unknown tools, use the unknown key and replace {{name}} manually template := i18n.T(locale, "sandbox.tool.unknown") return strings.Replace(template, "{{name}}", toolName, 1) } // getToolDetailedDescription returns a detailed description with specific parameters // It parses the tool input JSON and extracts key information to show users func getToolDetailedDescription(toolName string, inputJSON string, locale string) string { // Parse the input JSON using fault-tolerant parser parsed, err := goujson.Parse(inputJSON) if err != nil { // Fall back to basic description if parsing fails return getToolDescription(toolName, locale) } input, ok := parsed.(map[string]interface{}) if !ok { return getToolDescription(toolName, locale) } // Extract key information based on tool type var detail string switch toolName { case "Bash", "Shell": // Show the command being executed if cmd, ok := input["command"].(string); ok && cmd != "" { // Truncate long commands if len(cmd) > 50 { cmd = cmd[:47] + "..." } detail = cmd } case "Read": // Show the file being read if path, ok := input["path"].(string); ok && path != "" { detail = filepath.Base(path) } case "Write": // Show the file being written // Note: Claude CLI uses "file_path" for Write tool, not "path" if path, ok := input["file_path"].(string); ok && path != "" { detail = filepath.Base(path) } else if path, ok := input["path"].(string); ok && path != "" { detail = filepath.Base(path) } case "Edit", "StrReplace": // Show the file being edited if path, ok := input["path"].(string); ok && path != "" { detail = filepath.Base(path) } case "Glob": // Show the glob pattern if pattern, ok := input["glob_pattern"].(string); ok && pattern != "" { detail = pattern } else if pattern, ok := input["pattern"].(string); ok && pattern != "" { detail = pattern } case "Grep": // Show the search pattern if pattern, ok := input["pattern"].(string); ok && pattern != "" { if len(pattern) > 30 { pattern = pattern[:27] + "..." } detail = pattern } case "LS": // Show the directory if path, ok := input["target_directory"].(string); ok && path != "" { detail = filepath.Base(path) } else if path, ok := input["path"].(string); ok && path != "" { detail = filepath.Base(path) } case "WebSearch": // Show the search query if query, ok := input["search_term"].(string); ok && query != "" { if len(query) > 40 { query = query[:37] + "..." } detail = query } else if query, ok := input["query"].(string); ok && query != "" { if len(query) > 40 { query = query[:37] + "..." } detail = query } case "WebFetch": // Show the URL if url, ok := input["url"].(string); ok && url != "" { // Extract domain from URL if len(url) > 50 { url = url[:47] + "..." } detail = url } case "Task": // Show the task description if desc, ok := input["description"].(string); ok && desc != "" { if len(desc) > 40 { desc = desc[:37] + "..." } detail = desc } } // Build the message with detail baseMsg := getToolDescription(toolName, locale) if detail != "" { return baseMsg + ": " + detail } return baseMsg } // ReadFile reads a file from the container func (e *Executor) ReadFile(ctx context.Context, path string) ([]byte, error) { // Make path absolute if not if !strings.HasPrefix(path, "/") { path = e.workDir + "/" + path } return e.manager.ReadFile(ctx, e.containerName, path) } // WriteFile writes content to a file in the container func (e *Executor) WriteFile(ctx context.Context, path string, content []byte) error { // Make path absolute if not if !strings.HasPrefix(path, "/") { path = e.workDir + "/" + path } return e.manager.WriteFile(ctx, e.containerName, path, content) } // ListDir lists directory contents in the container func (e *Executor) ListDir(ctx context.Context, path string) ([]infraSandbox.FileInfo, error) { // Make path absolute if not if !strings.HasPrefix(path, "/") { path = e.workDir + "/" + path } return e.manager.ListDir(ctx, e.containerName, path) } // Exec executes a command in the container func (e *Executor) Exec(ctx context.Context, cmd []string) (string, error) { result, err := e.manager.Exec(ctx, e.containerName, cmd, &infraSandbox.ExecOptions{ WorkDir: e.workDir, }) if err != nil { return "", err } if result.ExitCode != 0 { return result.Stdout, fmt.Errorf("command exited with code %d: %s", result.ExitCode, result.Stderr) } return result.Stdout, nil } // GetWorkDir returns the container workspace directory func (e *Executor) GetWorkDir() string { return e.workDir } // GetSandboxID returns the sandbox ID (userID-chatID) func (e *Executor) GetSandboxID() string { if e.opts == nil { return "" } return fmt.Sprintf("%s-%s", e.opts.UserID, e.opts.ChatID) } // GetVNCUrl returns the VNC preview URL path // Returns empty string if VNC is not enabled for this sandbox image func (e *Executor) GetVNCUrl() string { if e.opts == nil { return "" } imageName := e.opts.Image if imageName == "" { return "" } // Check if the image supports VNC using the shared keyword list in sandbox package if !infraSandbox.IsVNCImage(imageName) { return "" } // Return only the sandbox ID, the full URL is constructed by openapi/sandbox.GetVNCClientURL() return e.GetSandboxID() } // Close releases the executor resources and removes the container // Note: IPC session is managed by sandbox.Manager.Remove() func (e *Executor) Close() error { if e.manager != nil && e.containerName != "" { ctx := context.Background() return e.manager.Remove(ctx, e.containerName) } return nil } // Helper function to get string from map func getString(m map[string]interface{}, key string) string { if v, ok := m[key].(string); ok { return v } return "" } ================================================ FILE: agent/sandbox/claude/executor_test.go ================================================ package claude import ( "context" "fmt" "os" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/config" infraSandbox "github.com/yaoapp/yao/sandbox" "github.com/yaoapp/yao/test" ) // createTestManager creates a sandbox manager for testing with proper configuration func createTestManager(t *testing.T) *infraSandbox.Manager { // Get data root from environment or use temp directory dataRoot := os.Getenv("YAO_ROOT") if dataRoot == "" { dataRoot = t.TempDir() } // Create config with proper paths cfg := infraSandbox.DefaultConfig() cfg.Init(dataRoot) manager, err := infraSandbox.NewManager(cfg) if err != nil { t.Skipf("Skipping test: Docker not available: %v", err) return nil } return manager } func TestNewClaudeExecutor(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() opts := &Options{ Command: "claude", Image: "yaoapp/sandbox-claude:latest", UserID: "test-user", ChatID: fmt.Sprintf("test-chat-claude-%d", time.Now().UnixNano()), ConnectorHost: "https://api.example.com", ConnectorKey: "key123", Model: "test-model", } exec, err := NewExecutor(manager, opts) require.NoError(t, err) require.NotNil(t, exec) // Verify executor was created assert.Equal(t, "/workspace", exec.GetWorkDir()) assert.NoError(t, exec.Close()) } func TestClaudeExecutorMissingRequiredFields(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() // Missing UserID _, err := NewExecutor(manager, &Options{ Command: "claude", ChatID: "test-chat", }) assert.Error(t, err) assert.Contains(t, err.Error(), "UserID is required") // Missing ChatID _, err = NewExecutor(manager, &Options{ Command: "claude", UserID: "test-user", }) assert.Error(t, err) assert.Contains(t, err.Error(), "ChatID is required") } func TestClaudeExecutorFileOperations(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() opts := &Options{ Command: "claude", Image: "alpine:latest", // Use alpine for simpler testing UserID: "test-user", ChatID: fmt.Sprintf("test-chat-file-ops-%d", time.Now().UnixNano()), } exec, err := NewExecutor(manager, opts) require.NoError(t, err) defer exec.Close() ctx := context.Background() // Test WriteFile content := []byte("Hello, World!") err = exec.WriteFile(ctx, "test-file.txt", content) require.NoError(t, err) // Test ReadFile readContent, err := exec.ReadFile(ctx, "test-file.txt") require.NoError(t, err) assert.Equal(t, content, readContent) // Test ListDir files, err := exec.ListDir(ctx, ".") require.NoError(t, err) assert.True(t, len(files) > 0, "Expected at least one file in directory") // Find our test file var found bool for _, f := range files { if f.Name == "test-file.txt" { found = true break } } assert.True(t, found, "Expected to find test-file.txt in directory listing") } func TestClaudeExecutorExec(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() opts := &Options{ Command: "claude", Image: "alpine:latest", // Use alpine for simpler testing UserID: "test-user", ChatID: fmt.Sprintf("test-chat-exec-%d", time.Now().UnixNano()), } exec, err := NewExecutor(manager, opts) require.NoError(t, err) defer exec.Close() ctx := context.Background() // Test simple echo command output, err := exec.Exec(ctx, []string{"echo", "hello-world"}) require.NoError(t, err) assert.Contains(t, output, "hello-world") } // TestClaudeExecutorMCPConfigWrite tests that MCP config is correctly written to container func TestClaudeExecutorMCPConfigWrite(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() // Create MCP config JSON mcpConfig := []byte(`{"mcpServers":{"echo":{"command":"yao-mcp-proxy","args":["echo"],"tools":["ping","echo"]}}}`) opts := &Options{ Command: "claude", Image: "alpine:latest", UserID: "test-user", ChatID: fmt.Sprintf("test-chat-mcp-write-%d", time.Now().UnixNano()), MCPConfig: mcpConfig, // No connector config - skip proxy start for alpine test image } exec, err := NewExecutor(manager, opts) require.NoError(t, err) defer exec.Close() ctx := context.Background() // Call prepareEnvironment to write configs err = exec.prepareEnvironment(ctx) require.NoError(t, err, "prepareEnvironment should succeed") // Verify MCP config was written by reading it back readContent, err := exec.ReadFile(ctx, ".mcp.json") require.NoError(t, err, "Should be able to read .mcp.json") require.NotEmpty(t, readContent, "MCP config should not be empty") t.Logf("MCP config in container: %s", string(readContent)) // Verify content matches assert.JSONEq(t, string(mcpConfig), string(readContent), "MCP config content should match") t.Log("✓ MCP config verified in container") } // TestClaudeExecutorSkillsCopy tests that skills directory is correctly copied to container // Uses real test application skills directory func TestClaudeExecutorSkillsCopy(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() // Use real skills directory from test application appRoot := os.Getenv("YAO_ROOT") require.NotEmpty(t, appRoot, "YAO_ROOT should be set") skillsDir := appRoot + "/assistants/tests/sandbox/full/skills" // Verify skills directory exists on host info, err := os.Stat(skillsDir) require.NoError(t, err, "Skills directory should exist: %s", skillsDir) require.True(t, info.IsDir(), "Skills path should be a directory") t.Logf("Using real skills directory: %s", skillsDir) opts := &Options{ Command: "claude", Image: "alpine:latest", UserID: "test-user", ChatID: fmt.Sprintf("test-chat-skills-%d", time.Now().UnixNano()), SkillsDir: skillsDir, // No connector config - skip proxy start for alpine test image } exec, err := NewExecutor(manager, opts) require.NoError(t, err) defer exec.Close() ctx := context.Background() // Call prepareEnvironment to copy skills err = exec.prepareEnvironment(ctx) require.NoError(t, err, "prepareEnvironment should succeed") // Verify .claude directory was created output, err := exec.Exec(ctx, []string{"ls", "-la", ".claude"}) require.NoError(t, err, ".claude directory should exist") t.Logf(".claude directory contents:\n%s", output) // Verify skills directory exists in container output, err = exec.Exec(ctx, []string{"ls", "-la", ".claude/skills"}) require.NoError(t, err, "skills directory should exist in container") t.Logf("skills directory contents:\n%s", output) assert.Contains(t, output, "echo-test", "echo-test skill should exist") // Verify echo-test skill was copied correctly output, err = exec.Exec(ctx, []string{"ls", "-la", ".claude/skills/echo-test"}) require.NoError(t, err, "echo-test skill directory should exist") assert.Contains(t, output, "SKILL.md", "SKILL.md should exist in echo-test") assert.Contains(t, output, "scripts", "scripts directory should exist in echo-test") t.Logf("echo-test skill contents:\n%s", output) // Read SKILL.md content to verify readContent, err := exec.ReadFile(ctx, ".claude/skills/echo-test/SKILL.md") require.NoError(t, err, "Should be able to read SKILL.md from container") require.NotEmpty(t, readContent, "SKILL.md content should not be empty") // Verify content contains expected strings from the real SKILL.md assert.Contains(t, string(readContent), "name: echo-test", "SKILL.md should contain skill name") assert.Contains(t, string(readContent), "# Echo Test", "SKILL.md should contain the title") t.Logf("✓ SKILL.md content verified (%d bytes)", len(readContent)) t.Log("✓ Skills directory verified in container with real test data") } // TestClaudeExecutorPrepareEnvironmentIntegration tests full environment preparation // Uses real test application data func TestClaudeExecutorPrepareEnvironmentIntegration(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() // Use real skills directory from test application appRoot := os.Getenv("YAO_ROOT") require.NotEmpty(t, appRoot, "YAO_ROOT should be set") skillsDir := appRoot + "/assistants/tests/sandbox/full/skills" // Verify skills directory exists _, err := os.Stat(skillsDir) require.NoError(t, err, "Skills directory should exist") // Create MCP config (simulating what buildMCPConfigForSandbox produces) mcpConfig := []byte(`{"mcpServers":{"echo":{"command":"yao-mcp-proxy","args":["echo"],"tools":["ping","echo","status"]}}}`) opts := &Options{ Command: "claude", Image: "alpine:latest", UserID: "test-user", ChatID: fmt.Sprintf("test-chat-full-env-%d", time.Now().UnixNano()), MCPConfig: mcpConfig, SkillsDir: skillsDir, // No connector config - skip proxy start for alpine test image } exec, err := NewExecutor(manager, opts) require.NoError(t, err) defer exec.Close() ctx := context.Background() // Call prepareEnvironment err = exec.prepareEnvironment(ctx) require.NoError(t, err, "prepareEnvironment should succeed") // 1. Check MCP config mcpContent, err := exec.ReadFile(ctx, ".mcp.json") require.NoError(t, err, "MCP config should exist in container") assert.JSONEq(t, string(mcpConfig), string(mcpContent), "MCP config content should match") t.Logf("✓ MCP config verified: %s", string(mcpContent)) // 2. Check Skills directory structure output, err := exec.Exec(ctx, []string{"ls", "-la", ".claude/skills"}) require.NoError(t, err, "Skills directory should exist in container") assert.Contains(t, output, "echo-test", "echo-test skill should exist") t.Logf("✓ Skills directory contents:\n%s", output) // 3. Check skill content skillContent, err := exec.ReadFile(ctx, ".claude/skills/echo-test/SKILL.md") require.NoError(t, err, "SKILL.md should exist in container") require.NotEmpty(t, skillContent, "SKILL.md should not be empty") assert.Contains(t, string(skillContent), "name: echo-test", "SKILL.md should contain skill name") assert.Contains(t, string(skillContent), "# Echo Test", "SKILL.md should contain the title") t.Logf("✓ SKILL.md verified: %d bytes", len(skillContent)) t.Log("✓ Full environment preparation verified with real test data") } // TestClaudeExecutorIPCSocketMount verifies that IPC socket is bind mounted to container func TestClaudeExecutorIPCSocketMount(t *testing.T) { manager := createTestManager(t) opts := &Options{ Command: "claude", Image: "alpine:latest", UserID: "test-user", ChatID: "test-ipc-socket-" + fmt.Sprintf("%d", time.Now().UnixNano()), ConnectorHost: "https://api.test.com", ConnectorKey: "test-key", Model: "test-model", } exec, err := NewExecutor(manager, opts) require.NoError(t, err) defer exec.Close() ctx := context.Background() // Check if IPC socket exists in container output, err := exec.Exec(ctx, []string{"ls", "-la", "/run/yao.sock"}) require.NoError(t, err, "IPC socket should exist in container") assert.Contains(t, output, "yao.sock", "Should find yao.sock file") t.Logf("✓ IPC socket mounted: %s", strings.TrimSpace(output)) // Verify it's a socket file (starts with 's' in ls output) assert.Contains(t, output, "srw", "Should be a socket file (starts with 's')") t.Log("✓ IPC socket is correctly bind mounted to container") } ================================================ FILE: agent/sandbox/claude/real_e2e_test.go ================================================ package claude import ( "context" "fmt" "os" "strings" "testing" "time" "github.com/stretchr/testify/require" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/config" infraSandbox "github.com/yaoapp/yao/sandbox" "github.com/yaoapp/yao/test" ) // TestRealClaudeCLIExecution tests real Claude CLI execution with streaming // This test requires: // 1. Docker running with yaoapp/sandbox-claude:latest image // 2. Environment variables: DEEPSEEK_API_KEY, DEEPSEEK_API_PROXY, DEEPSEEK_MODELS_V3 func TestRealClaudeCLIExecution(t *testing.T) { if testing.Short() { t.Skip("Skipping real E2E test in short mode") } // Check for required environment variables apiKey := os.Getenv("DEEPSEEK_API_KEY") apiProxy := os.Getenv("DEEPSEEK_API_PROXY") model := os.Getenv("DEEPSEEK_MODELS_V3") if apiKey == "" || apiProxy == "" || model == "" { t.Skip("Skipping test: DEEPSEEK_API_KEY, DEEPSEEK_API_PROXY, or DEEPSEEK_MODELS_V3 not set") } test.Prepare(t, config.Conf) defer test.Clean() // Get data root from environment dataRoot := os.Getenv("YAO_ROOT") if dataRoot == "" { t.Skip("Skipping test: YAO_ROOT not set") } // Create config with proper paths cfg := infraSandbox.DefaultConfig() cfg.Init(dataRoot) manager, err := infraSandbox.NewManager(cfg) if err != nil { t.Skipf("Skipping test: Docker not available: %v", err) } defer manager.Close() // Create options WITH SystemPrompt (triggers Claude CLI execution) opts := &Options{ Command: "claude", Image: "yaoapp/sandbox-claude:latest", UserID: "test-user", ChatID: fmt.Sprintf("test-real-e2e-%d", time.Now().UnixNano()), ConnectorHost: apiProxy, ConnectorKey: apiKey, Model: model, SystemPrompt: "You are a helpful assistant. Reply concisely.", Timeout: 3 * time.Minute, ConnectorOptions: map[string]interface{}{ "max_tokens": 4096, // Limit max_tokens to avoid backend API limits }, } t.Logf("Creating executor with options:") t.Logf(" ConnectorHost: %s", opts.ConnectorHost) t.Logf(" Model: %s", opts.Model) t.Logf(" SystemPrompt: %s", opts.SystemPrompt) exec, err := NewExecutor(manager, opts) if err != nil { t.Skipf("Skipping test: Failed to create executor: %v", err) } defer exec.Close() // Verify shouldSkipClaudeCLI returns false if exec.shouldSkipClaudeCLI() { t.Fatal("shouldSkipClaudeCLI should return false when SystemPrompt is set") } // Test 1: First, manually test claude-proxy t.Log("=== Test 1: Verify claude-proxy is working ===") stdCtx := context.Background() // Prepare environment (this starts claude-proxy) err = exec.prepareEnvironment(stdCtx) require.NoError(t, err, "prepareEnvironment should succeed") // Check proxy is running result, err := exec.manager.Exec(stdCtx, exec.containerName, []string{"pgrep", "-f", "claude-proxy"}, nil) if err != nil || result.ExitCode != 0 { t.Log("claude-proxy not running, checking why...") // Check proxy log logContent, _ := exec.ReadFile(stdCtx, "proxy.log") t.Logf("Proxy log: %s", string(logContent)) // Check config configContent, _ := exec.ReadFile(stdCtx, ".claude-proxy.json") t.Logf("Proxy config: %s", string(configContent)) } else { t.Logf("claude-proxy is running with PID: %s", strings.TrimSpace(result.Stdout)) } // Test 2: Test simple command execution t.Log("=== Test 2: Simple command execution ===") ctx := agentContext.New(stdCtx, nil, opts.ChatID) messages := []agentContext.Message{ {Role: "user", Content: "Reply with exactly: HELLO_TEST_SUCCESS"}, } // Collect streaming output var streamedChunks []string var streamedContent strings.Builder streamHandler := func(chunkType message.StreamChunkType, data []byte) int { chunk := string(data) streamedChunks = append(streamedChunks, chunk) streamedContent.Write(data) t.Logf("Stream chunk [%s]: %q", chunkType, chunk) return 0 // continue streaming } t.Log("Executing Claude CLI...") startTime := time.Now() response, err := exec.Stream(ctx, messages, streamHandler) duration := time.Since(startTime) t.Logf("Execution took: %v", duration) if err != nil { t.Logf("Stream error: %v", err) // Debug: check what's in the container t.Log("=== Debug info ===") // Check proxy log logContent, _ := exec.ReadFile(stdCtx, "proxy.log") t.Logf("Proxy log:\n%s", string(logContent)) // List workspace output, _ := exec.Exec(stdCtx, []string{"ls", "-la", "/workspace"}) t.Logf("Workspace contents:\n%s", output) // Check environment output, _ = exec.Exec(stdCtx, []string{"env"}) t.Logf("Environment:\n%s", output) t.Fatalf("Stream failed: %v", err) } require.NotNil(t, response, "Response should not be nil") // Log results t.Logf("=== Results ===") t.Logf("Response ID: %s", response.ID) t.Logf("Response Model: %s", response.Model) t.Logf("Response Content: %v", response.Content) t.Logf("Streamed chunks count: %d", len(streamedChunks)) t.Logf("Total streamed content: %s", streamedContent.String()) // Verify we got some response var fullResponse string if content, ok := response.Content.(string); ok { fullResponse = content } if fullResponse == "" { fullResponse = streamedContent.String() } if fullResponse == "" { // Check proxy log for errors logContent, _ := exec.ReadFile(stdCtx, "proxy.log") t.Logf("Proxy log (for debugging):\n%s", string(logContent)) t.Fatal("Got empty response from Claude CLI") } t.Logf("✓ Successfully got response: %s", fullResponse) // Check if streaming worked if len(streamedChunks) > 0 { t.Logf("✓ Streaming worked with %d chunks", len(streamedChunks)) } else { t.Log("⚠ No streaming chunks received (might be buffered)") } } // TestClaudeCLIDirectExecution tests running claude directly in the container func TestClaudeCLIDirectExecution(t *testing.T) { if testing.Short() { t.Skip("Skipping real E2E test in short mode") } // Check for required environment variables apiKey := os.Getenv("DEEPSEEK_API_KEY") apiProxy := os.Getenv("DEEPSEEK_API_PROXY") model := os.Getenv("DEEPSEEK_MODELS_V3") if apiKey == "" || apiProxy == "" || model == "" { t.Skip("Skipping test: DEEPSEEK_API_KEY, DEEPSEEK_API_PROXY, or DEEPSEEK_MODELS_V3 not set") } test.Prepare(t, config.Conf) defer test.Clean() dataRoot := os.Getenv("YAO_ROOT") if dataRoot == "" { t.Skip("Skipping test: YAO_ROOT not set") } cfg := infraSandbox.DefaultConfig() cfg.Init(dataRoot) manager, err := infraSandbox.NewManager(cfg) if err != nil { t.Skipf("Skipping test: Docker not available: %v", err) } defer manager.Close() opts := &Options{ Command: "claude", Image: "yaoapp/sandbox-claude:latest", UserID: "test-user", ChatID: fmt.Sprintf("test-direct-%d", time.Now().UnixNano()), ConnectorHost: apiProxy, ConnectorKey: apiKey, Model: model, Timeout: 3 * time.Minute, ConnectorOptions: map[string]interface{}{ "max_tokens": 4096, // Limit max_tokens to avoid backend API limits }, } exec, err := NewExecutor(manager, opts) if err != nil { t.Skipf("Skipping test: Failed to create executor: %v", err) } defer exec.Close() stdCtx := context.Background() // Step 1: Write proxy config and start proxy t.Log("=== Step 1: Start claude-proxy ===") err = exec.prepareEnvironment(stdCtx) require.NoError(t, err) // Wait for proxy to start time.Sleep(2 * time.Second) // Check proxy status result, err := exec.manager.Exec(stdCtx, exec.containerName, []string{"pgrep", "-f", "claude-proxy"}, nil) if err == nil && result.ExitCode == 0 { t.Logf("✓ claude-proxy running, PID: %s", strings.TrimSpace(result.Stdout)) } else { t.Log("⚠ claude-proxy might not be running") } // Step 2: Run claude CLI directly with simple prompt t.Log("=== Step 2: Run claude CLI directly ===") // Build a simple command - pass env vars explicitly directCmd := []string{ "bash", "-c", `echo '{"type":"user","message":{"role":"user","content":"say hello"}}' | claude -p --dangerously-skip-permissions --permission-mode bypassPermissions --input-format stream-json --output-format stream-json --verbose 2>&1`, } reader, err := exec.manager.Stream(stdCtx, exec.containerName, directCmd, &infraSandbox.ExecOptions{ WorkDir: exec.workDir, Timeout: 2 * time.Minute, Env: map[string]string{ "ANTHROPIC_BASE_URL": "http://127.0.0.1:3456", "ANTHROPIC_API_KEY": "dummy", }, }) if err != nil { t.Fatalf("Failed to execute: %v", err) } defer reader.Close() // Read output buf := make([]byte, 64*1024) var output strings.Builder for { n, err := reader.Read(buf) if n > 0 { chunk := string(buf[:n]) output.WriteString(chunk) t.Logf("Output chunk: %q", chunk) } if err != nil { break } } t.Logf("=== Full output ===\n%s", output.String()) if output.Len() == 0 { // Check logs logContent, _ := exec.ReadFile(stdCtx, "proxy.log") t.Logf("Proxy log:\n%s", string(logContent)) t.Fatal("Got no output from claude CLI") } // Check for success indicators outputStr := output.String() if strings.Contains(outputStr, "error") || strings.Contains(outputStr, "Error") { t.Logf("⚠ Output contains error") } if strings.Contains(outputStr, "content_block") || strings.Contains(outputStr, "message_start") { t.Log("✓ Got streaming JSON output from Claude CLI") } } ================================================ FILE: agent/sandbox/claude/types.go ================================================ package claude // StreamMessage represents a parsed stream message from Claude CLI type StreamMessage struct { Type string `json:"type"` Subtype string `json:"subtype,omitempty"` Content interface{} `json:"content,omitempty"` Error string `json:"error,omitempty"` } // ToolCall represents a tool invocation from the agent type ToolCall struct { ID string `json:"id"` Name string `json:"name"` Arguments map[string]interface{} `json:"arguments"` } // ToolResult represents a tool execution result type ToolResult struct { ID string `json:"id"` Content string `json:"content"` IsError bool `json:"is_error,omitempty"` } // CLIResponse represents the parsed response from Claude CLI type CLIResponse struct { Text string `json:"text,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` Usage *Usage `json:"usage,omitempty"` Model string `json:"model,omitempty"` } // Usage represents token usage statistics type Usage struct { InputTokens int `json:"input_tokens,omitempty"` OutputTokens int `json:"output_tokens,omitempty"` } ================================================ FILE: agent/sandbox/cursor/README.md ================================================ # Cursor Executor ## Status **Not Implemented** - This is a placeholder for future Cursor CLI integration. ## Planned Features The Cursor executor will provide similar functionality to the Claude executor: - Execute Cursor CLI in a Docker sandbox container - Stream output in real-time - File system operations (ReadFile, WriteFile, ListDir) - Command execution (Exec) - Integration with Yao's MCP servers ## Configuration When implemented, the Cursor executor will be configured in assistant `package.yao`: ```jsonc { "name": "Coder Assistant", "connector": "deepseek.v3", "sandbox": { "command": "cursor", // Use Cursor CLI "image": "yaoapp/sandbox-cursor:latest", "timeout": "10m" } } ``` ## Implementation Notes The implementation should follow the same pattern as `claude/executor.go`: 1. Create `cursor/executor.go` implementing the `sandbox.Executor` interface 2. Create `cursor/command.go` for building Cursor CLI commands 3. Create `cursor/types.go` for Cursor-specific types 4. Add appropriate tests ## Docker Image A `yaoapp/sandbox-cursor` Docker image will need to be created with: - Ubuntu 24.04 LTS base - Node.js 22 LTS - Python 3.12 - Cursor CLI installed and configured ## References - [Cursor CLI Documentation](https://cursor.sh/docs) - [Claude Executor Implementation](../claude/executor.go) - [Sandbox Design Document](../DESIGN.md) ================================================ FILE: agent/sandbox/executor.go ================================================ package sandbox import ( "fmt" "github.com/yaoapp/yao/agent/sandbox/claude" infraSandbox "github.com/yaoapp/yao/sandbox" ) // New creates a new Executor based on the command type func New(manager *infraSandbox.Manager, opts *Options) (Executor, error) { if opts == nil { return nil, fmt.Errorf("options is required") } if !IsValidCommand(opts.Command) { return nil, fmt.Errorf("unsupported command type: %s, supported: %v", opts.Command, CommandTypes) } // Set default image if not specified if opts.Image == "" { opts.Image = DefaultImage(opts.Command) } switch opts.Command { case "claude": // Convert to claude.Options claudeOpts := &claude.Options{ Command: opts.Command, Image: opts.Image, MaxMemory: opts.MaxMemory, MaxCPU: opts.MaxCPU, Timeout: opts.Timeout, Arguments: opts.Arguments, UserID: opts.UserID, ChatID: opts.ChatID, MCPConfig: opts.MCPConfig, MCPTools: opts.MCPTools, SkillsDir: opts.SkillsDir, SystemPrompt: opts.SystemPrompt, // Required for Claude CLI execution ConnectorHost: opts.ConnectorHost, ConnectorKey: opts.ConnectorKey, Model: opts.Model, ConnectorType: opts.ConnectorType, // "openai" or "anthropic" ConnectorOptions: opts.ConnectorOptions, // Extra options like thinking, max_tokens Secrets: opts.Secrets, // Secrets for container env vars } return claude.NewExecutor(manager, claudeOpts) case "cursor": return nil, fmt.Errorf("cursor executor not implemented yet") default: return nil, fmt.Errorf("unsupported command type: %s", opts.Command) } } ================================================ FILE: agent/sandbox/executor_test.go ================================================ package sandbox import ( "os" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/config" infraSandbox "github.com/yaoapp/yao/sandbox" "github.com/yaoapp/yao/test" ) // createTestManager creates a sandbox manager for testing with proper configuration func createTestManager(t *testing.T) *infraSandbox.Manager { // Get data root from environment or use temp directory dataRoot := os.Getenv("YAO_ROOT") if dataRoot == "" { dataRoot = t.TempDir() } // Create config with proper paths cfg := infraSandbox.DefaultConfig() cfg.Init(dataRoot) manager, err := infraSandbox.NewManager(cfg) if err != nil { t.Skipf("Skipping test: Docker not available: %v", err) return nil } return manager } func TestNewExecutorWithInvalidOptions(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() // Test with nil options _, err := New(manager, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "options is required") // Test with invalid command _, err = New(manager, &Options{ Command: "invalid", UserID: "user1", ChatID: "chat1", }) assert.Error(t, err) assert.Contains(t, err.Error(), "unsupported command type") } func TestNewExecutorWithValidOptions(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() // Test with valid claude options opts := &Options{ Command: "claude", UserID: "test-user", ChatID: "test-chat", ConnectorHost: "https://api.example.com", ConnectorKey: "key123", Model: "test-model", } exec, err := New(manager, opts) require.NoError(t, err) require.NotNil(t, exec) // Verify executor was created assert.NotEmpty(t, exec.GetWorkDir()) assert.NoError(t, exec.Close()) } func TestDefaultImageIsSetWhenEmpty(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() opts := &Options{ Command: "claude", Image: "", // Empty, should be set to default UserID: "test-user", ChatID: "test-chat-2", } exec, err := New(manager, opts) require.NoError(t, err) defer exec.Close() // Ensure cleanup // The image should have been set to default assert.Equal(t, "yaoapp/sandbox-claude:latest", opts.Image) } func TestCursorExecutorNotImplemented(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createTestManager(t) if manager == nil { return } defer manager.Close() opts := &Options{ Command: "cursor", UserID: "test-user", ChatID: "test-chat-3", } _, err := New(manager, opts) assert.Error(t, err) assert.Contains(t, err.Error(), "not implemented") } ================================================ FILE: agent/sandbox/integration_test.go ================================================ package sandbox import ( "context" "os" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" infraSandbox "github.com/yaoapp/yao/sandbox" "github.com/yaoapp/yao/test" ) // createIntegrationTestManager creates a sandbox manager for integration testing func createIntegrationTestManager(t *testing.T) *infraSandbox.Manager { dataRoot := os.Getenv("YAO_ROOT") if dataRoot == "" { dataRoot = t.TempDir() } cfg := infraSandbox.DefaultConfig() cfg.Init(dataRoot) manager, err := infraSandbox.NewManager(cfg) if err != nil { t.Skipf("Skipping test: Docker not available: %v", err) return nil } return manager } // TestExecutorInterfaceCompatibility verifies that the executor implements both interfaces correctly func TestExecutorInterfaceCompatibility(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createIntegrationTestManager(t) if manager == nil { return } defer manager.Close() // Create executor via factory function opts := &Options{ Command: "claude", Image: "alpine:latest", UserID: "test-user", ChatID: "test-compat", } executor, err := New(manager, opts) require.NoError(t, err) require.NotNil(t, executor) defer executor.Close() // Verify executor implements agent/sandbox.Executor interface var _ Executor = executor // Verify executor can be cast to context.SandboxExecutor ctxExecutor, ok := executor.(agentContext.SandboxExecutor) require.True(t, ok, "executor should implement context.SandboxExecutor") require.NotNil(t, ctxExecutor) // Test SandboxExecutor methods work ctx := context.Background() // WriteFile err = ctxExecutor.WriteFile(ctx, "compat-test.txt", []byte("compatibility test")) require.NoError(t, err) // ReadFile content, err := ctxExecutor.ReadFile(ctx, "compat-test.txt") require.NoError(t, err) assert.Equal(t, "compatibility test", string(content)) // ListDir files, err := ctxExecutor.ListDir(ctx, ".") require.NoError(t, err) assert.True(t, len(files) > 0) // Exec output, err := ctxExecutor.Exec(ctx, []string{"echo", "compat"}) require.NoError(t, err) assert.Contains(t, output, "compat") // GetWorkDir workDir := ctxExecutor.GetWorkDir() assert.NotEmpty(t, workDir) } // TestExecutorRoundTrip tests the full round-trip of creating executor and performing operations func TestExecutorRoundTrip(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createIntegrationTestManager(t) if manager == nil { return } defer manager.Close() opts := &Options{ Command: "claude", Image: "alpine:latest", UserID: "test-user", ChatID: "test-roundtrip", ConnectorHost: "https://api.example.com", ConnectorKey: "test-key", Model: "test-model", } // Create executor executor, err := New(manager, opts) require.NoError(t, err) require.NotNil(t, executor) defer executor.Close() ctx := context.Background() // 1. Write a file testContent := "Hello, integration test!" err = executor.WriteFile(ctx, "integration.txt", []byte(testContent)) require.NoError(t, err, "WriteFile should succeed") // 2. Read the file back readContent, err := executor.ReadFile(ctx, "integration.txt") require.NoError(t, err, "ReadFile should succeed") assert.Equal(t, testContent, string(readContent), "Content should match") // 3. List directory files, err := executor.ListDir(ctx, ".") require.NoError(t, err, "ListDir should succeed") var found bool for _, f := range files { if f.Name == "integration.txt" { found = true assert.False(t, f.IsDir, "Should not be a directory") assert.Equal(t, int64(len(testContent)), f.Size, "Size should match") break } } assert.True(t, found, "Should find integration.txt in listing") // 4. Execute command output, err := executor.Exec(ctx, []string{"cat", "/workspace/integration.txt"}) require.NoError(t, err, "Exec should succeed") assert.Contains(t, output, testContent, "cat output should contain file content") // 5. Verify workdir assert.Equal(t, "/workspace", executor.GetWorkDir(), "WorkDir should be /workspace") } // TestMultipleExecutorsIsolation verifies that multiple executors have isolated workspaces func TestMultipleExecutorsIsolation(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager := createIntegrationTestManager(t) if manager == nil { return } defer manager.Close() // Create two executors with different chat IDs opts1 := &Options{ Command: "claude", Image: "alpine:latest", UserID: "test-user", ChatID: "test-isolation-1", } opts2 := &Options{ Command: "claude", Image: "alpine:latest", UserID: "test-user", ChatID: "test-isolation-2", } exec1, err := New(manager, opts1) require.NoError(t, err) defer exec1.Close() exec2, err := New(manager, opts2) require.NoError(t, err) defer exec2.Close() ctx := context.Background() // Write different content to each executor err = exec1.WriteFile(ctx, "test.txt", []byte("executor 1")) require.NoError(t, err) err = exec2.WriteFile(ctx, "test.txt", []byte("executor 2")) require.NoError(t, err) // Read back and verify isolation content1, err := exec1.ReadFile(ctx, "test.txt") require.NoError(t, err) assert.Equal(t, "executor 1", string(content1), "Executor 1 should have its own content") content2, err := exec2.ReadFile(ctx, "test.txt") require.NoError(t, err) assert.Equal(t, "executor 2", string(content2), "Executor 2 should have its own content") } ================================================ FILE: agent/sandbox/types.go ================================================ package sandbox import ( "context" "time" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" infraSandbox "github.com/yaoapp/yao/sandbox" "github.com/yaoapp/yao/sandbox/ipc" ) // Executor executes LLM requests in sandbox type Executor interface { // Execute runs the request and returns response (uses options set at creation time) Execute(ctx *agentContext.Context, messages []agentContext.Message) (*agentContext.CompletionResponse, error) // Stream runs the request with streaming output (uses options set at creation time) Stream(ctx *agentContext.Context, messages []agentContext.Message, handler message.StreamFunc) (*agentContext.CompletionResponse, error) // SetLoadingMsgID sets the loading message ID for tool execution status updates SetLoadingMsgID(id string) // Filesystem operations (for Hooks) ReadFile(ctx context.Context, path string) ([]byte, error) WriteFile(ctx context.Context, path string, content []byte) error ListDir(ctx context.Context, path string) ([]infraSandbox.FileInfo, error) // Command execution (for Hooks) Exec(ctx context.Context, cmd []string) (string, error) // GetWorkDir returns the container workspace directory GetWorkDir() string // GetSandboxID returns the sandbox ID (userID-chatID) GetSandboxID() string // GetVNCUrl returns the VNC preview URL path (e.g., /api/__yao/vnc/{sandboxID}/) // Returns empty string if VNC is not enabled for this sandbox image GetVNCUrl() string // Close releases container resources Close() error } // FileInfo is an alias to infrastructure sandbox FileInfo for convenience type FileInfo = infraSandbox.FileInfo // Options for sandbox execution type Options struct { // Command type (claude, cursor) Command string `json:"command"` // Docker image (optional, auto-selected by command) Image string `json:"image,omitempty"` // Resource limits MaxMemory string `json:"max_memory,omitempty"` MaxCPU float64 `json:"max_cpu,omitempty"` // Execution timeout Timeout time.Duration `json:"timeout,omitempty"` // Command-specific arguments (passed to CLI) Arguments map[string]interface{} `json:"arguments,omitempty"` // ======================================== // Internal fields (auto-resolved by Yao) // Do NOT set these in package.yao config // ======================================== // UserID for workspace isolation UserID string `json:"-"` // ChatID for session isolation ChatID string `json:"-"` // MCP configuration - auto-loaded from assistants/{name}/mcps/ MCPConfig []byte `json:"-"` // MCPTools - MCP tools to expose via IPC (tool name → tool definition) MCPTools map[string]*ipc.MCPTool `json:"-"` // Skills directory - auto-resolved to assistants/{name}/skills/ SkillsDir string `json:"-"` // SystemPrompt - extracted from assistant prompts.yml // Used to determine if Claude CLI should be called SystemPrompt string `json:"-"` // Connector settings - auto-resolved from connector config file // e.g., connectors/deepseek/v3.conn.yao → host, key, model ConnectorHost string `json:"-"` ConnectorKey string `json:"-"` Model string `json:"-"` // ConnectorType - connector API type: "openai" or "anthropic" // Determines whether to use claude-proxy (openai) or direct connection (anthropic) ConnectorType string `json:"-"` // ConnectorOptions - extra options from connector config (e.g., thinking, max_tokens, temperature) // These are backend-specific parameters passed to the proxy ConnectorOptions map[string]interface{} `json:"-"` // Secrets - sensitive values from sandbox.secrets config (e.g., GITHUB_TOKEN) // Resolved from $ENV.XXX references, exported as env vars in container Secrets map[string]string `json:"-"` } // SandboxConfig represents the sandbox configuration in assistant package.yao type SandboxConfig struct { // Command type (claude, cursor) Command string `json:"command" yaml:"command"` // Docker image (optional, auto-selected by command) Image string `json:"image,omitempty" yaml:"image,omitempty"` // Resource limits MaxMemory string `json:"max_memory,omitempty" yaml:"max_memory,omitempty"` MaxCPU float64 `json:"max_cpu,omitempty" yaml:"max_cpu,omitempty"` // Execution timeout Timeout string `json:"timeout,omitempty" yaml:"timeout,omitempty"` // Command-specific arguments (passed to CLI) Arguments map[string]interface{} `json:"arguments,omitempty" yaml:"arguments,omitempty"` } // DefaultImage returns the default Docker image for a command type func DefaultImage(command string) string { switch command { case "claude": return "yaoapp/sandbox-claude:latest" case "cursor": return "yaoapp/sandbox-cursor:latest" default: return "" } } // CommandTypes is the list of supported command types var CommandTypes = []string{"claude", "cursor"} // IsValidCommand checks if a command type is valid func IsValidCommand(command string) bool { for _, c := range CommandTypes { if c == command { return true } } return false } ================================================ FILE: agent/sandbox/types_test.go ================================================ package sandbox import ( "testing" "github.com/stretchr/testify/assert" ) func TestDefaultImage(t *testing.T) { tests := []struct { command string expected string }{ {"claude", "yaoapp/sandbox-claude:latest"}, {"cursor", "yaoapp/sandbox-cursor:latest"}, {"unknown", ""}, } for _, tt := range tests { t.Run(tt.command, func(t *testing.T) { result := DefaultImage(tt.command) assert.Equal(t, tt.expected, result) }) } } func TestIsValidCommand(t *testing.T) { tests := []struct { command string expected bool }{ {"claude", true}, {"cursor", true}, {"unknown", false}, {"", false}, } for _, tt := range tests { t.Run(tt.command, func(t *testing.T) { result := IsValidCommand(tt.command) assert.Equal(t, tt.expected, result) }) } } func TestOptionsValidation(t *testing.T) { // Test that Options struct can be created with all fields opts := &Options{ Command: "claude", Image: "yaoapp/sandbox-claude:latest", MaxMemory: "4g", MaxCPU: 2.0, UserID: "user123", ChatID: "chat456", ConnectorHost: "https://api.example.com", ConnectorKey: "key123", Model: "deepseek-v3", Arguments: map[string]interface{}{ "max_turns": 20, "permission_mode": "acceptEdits", }, } assert.Equal(t, "claude", opts.Command) assert.Equal(t, "user123", opts.UserID) assert.Equal(t, "chat456", opts.ChatID) assert.Equal(t, 20, opts.Arguments["max_turns"]) } func TestSandboxConfigParsing(t *testing.T) { // Test that SandboxConfig can be used for parsing assistant config config := &SandboxConfig{ Command: "claude", Image: "custom-image:v1", MaxMemory: "8g", MaxCPU: 4.0, Timeout: "10m", Arguments: map[string]interface{}{ "permission_mode": "bypassPermissions", }, } assert.Equal(t, "claude", config.Command) assert.Equal(t, "custom-image:v1", config.Image) assert.Equal(t, "8g", config.MaxMemory) assert.Equal(t, "10m", config.Timeout) } ================================================ FILE: agent/sandbox/v2/claude/attachments.go ================================================ package claude import ( "context" "fmt" "path/filepath" "strings" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/attachment" workspace "github.com/yaoapp/yao/tai/workspace" ) // prepareAttachments resolves __yao.attachment:// URLs in messages, // copies actual files into the workspace .attachments/{chatID}/ directory via ws.Copy, // and replaces multimodal content parts with text references. func prepareAttachments(ctx context.Context, messages []agentContext.Message, chatID string, ws workspace.FS) ([]agentContext.Message, error) { usedNames := make(map[string]int) attachDir := ".attachments/" + chatID result := make([]agentContext.Message, len(messages)) copy(result, messages) for i, msg := range result { if msg.Role != "user" { continue } parts, ok := msg.Content.([]interface{}) if !ok { if typedParts, ok := msg.Content.([]agentContext.ContentPart); ok { iparts := make([]interface{}, len(typedParts)) for j, p := range typedParts { m := map[string]interface{}{"type": string(p.Type)} if p.Text != "" { m["text"] = p.Text } if p.ImageURL != nil { m["image_url"] = map[string]interface{}{ "url": p.ImageURL.URL, "detail": string(p.ImageURL.Detail), } } if p.File != nil { m["file"] = map[string]interface{}{ "url": p.File.URL, "filename": p.File.Filename, } } iparts[j] = m } parts = iparts } else { continue } } if len(parts) == 0 { continue } var textParts []string for _, item := range parts { m, ok := item.(map[string]interface{}) if !ok { continue } partType, _ := m["type"].(string) switch partType { case "text": if text, ok := m["text"].(string); ok && text != "" { textParts = append(textParts, text) } case "image_url": imgData, _ := m["image_url"].(map[string]interface{}) if imgData == nil { continue } url, _ := imgData["url"].(string) if url == "" { continue } uploaderName, fileID, isWrapper := attachment.Parse(url) if !isWrapper { textParts = append(textParts, fmt.Sprintf("[Image: %s]", url)) continue } ref, err := resolveAttachment(ctx, uploaderName, fileID, "", attachDir, usedNames, ws) if err != nil { textParts = append(textParts, "[Attached image: failed to load]") continue } textParts = append(textParts, ref) case "file": fileData, _ := m["file"].(map[string]interface{}) if fileData == nil { continue } url, _ := fileData["url"].(string) hintName, _ := fileData["filename"].(string) if url == "" { continue } uploaderName, fileID, isWrapper := attachment.Parse(url) if !isWrapper { textParts = append(textParts, fmt.Sprintf("[File: %s]", url)) continue } ref, err := resolveAttachment(ctx, uploaderName, fileID, hintName, attachDir, usedNames, ws) if err != nil { textParts = append(textParts, "[Attached file: failed to load]") continue } textParts = append(textParts, ref) } } if len(textParts) > 0 { newMsg := result[i] newMsg.Content = strings.Join(textParts, "\n\n") result[i] = newMsg } } return result, nil } // resolveAttachment gets the local path of an attachment and copies it into // the workspace via ws.Copy("local:///abs/path", ".attachments/{chatID}/filename"). func resolveAttachment( ctx context.Context, uploaderName, fileID, hintName, attachDir string, usedNames map[string]int, ws workspace.FS, ) (string, error) { manager, exists := attachment.Managers[uploaderName] if !exists { return "", fmt.Errorf("attachment manager not found: %s", uploaderName) } fileInfo, err := manager.Info(ctx, fileID) if err != nil { return "", fmt.Errorf("failed to get file info: %w", err) } absPath, _, err := manager.LocalPath(ctx, fileID) if err != nil { return "", fmt.Errorf("failed to get local path: %w", err) } filename := fileInfo.Filename if filename == "" && hintName != "" { filename = hintName } if filename == "" { ext := extensionFromContentType(fileInfo.ContentType) filename = fileID + ext } baseName := filename if count, exists := usedNames[baseName]; exists { ext := filepath.Ext(filename) name := strings.TrimSuffix(filename, ext) filename = fmt.Sprintf("%s_%d%s", name, count+1, ext) usedNames[baseName] = count + 1 } else { usedNames[baseName] = 0 } dstPath := attachDir + "/" + filename src := "local:///" + absPath if _, err := ws.Copy(src, dstPath); err != nil { return "", fmt.Errorf("failed to copy attachment to workspace: %w", err) } sizeStr := formatFileSize(fileInfo.Bytes) return fmt.Sprintf("[Attached file: %s (%s, %s)]", dstPath, fileInfo.ContentType, sizeStr), nil } func extensionFromContentType(contentType string) string { switch contentType { case "image/png": return ".png" case "image/jpeg": return ".jpg" case "image/gif": return ".gif" case "image/webp": return ".webp" case "image/svg+xml": return ".svg" case "application/pdf": return ".pdf" case "text/plain": return ".txt" case "text/html": return ".html" case "text/css": return ".css" case "text/javascript", "application/javascript": return ".js" case "application/json": return ".json" case "application/zip": return ".zip" default: return "" } } func formatFileSize(bytes int) string { switch { case bytes >= 1024*1024: return fmt.Sprintf("%.1fMB", float64(bytes)/(1024*1024)) case bytes >= 1024: return fmt.Sprintf("%.1fKB", float64(bytes)/1024) default: return fmt.Sprintf("%dB", bytes) } } ================================================ FILE: agent/sandbox/v2/claude/oscompat.go ================================================ package claude import ( "fmt" "path" "strings" "github.com/yaoapp/yao/agent/sandbox/v2/types" infra "github.com/yaoapp/yao/sandbox/v2" ) // osEnv captures OS-dependent paths and shell settings derived from the // Computer's SystemInfo. All runner code should use osEnv instead of // hardcoded Linux constants. type osEnv struct { OS string // "windows", "linux", "darwin", ... Shell string // preferred shell binary: "bash", "pwsh", "cmd.exe", ... WorkDir string // working directory on the target machine UserHome string // user home directory (empty if irrelevant) TempDir string // system temp directory } func (e *osEnv) isWindows() bool { return strings.EqualFold(e.OS, "windows") } // resolveOSEnv builds an osEnv from the Computer's reported SystemInfo, // falling back to SandboxConfig values where available, then to per-OS defaults. func resolveOSEnv(computer infra.Computer, _ *types.SandboxConfig) *osEnv { sys := computer.ComputerInfo().System env := &osEnv{ OS: strings.ToLower(sys.OS), Shell: sys.Shell, TempDir: sys.TempDir, WorkDir: computer.GetWorkDir(), } if env.TempDir == "" { env.TempDir = env.pathJoin(env.WorkDir, ".tmp") } return env } // shellCmd returns the command slice to run a script through the appropriate shell. func (e *osEnv) shellCmd(script string) []string { shell := strings.ToLower(e.Shell) switch shell { case "pwsh": return []string{"pwsh", "-NoProfile", "-Command", script} case "powershell": return []string{"powershell", "-NoProfile", "-Command", script} case "cmd.exe", "cmd": return []string{"cmd.exe", "/C", script} default: return []string{"bash", "-c", script} } } // mkdirCmd returns a shell command string to create a directory (with parents). func (e *osEnv) mkdirCmd(dir string) string { if e.isWindows() { return fmt.Sprintf(`if (!(Test-Path '%s')) { New-Item -ItemType Directory -Path '%s' -Force | Out-Null }`, dir, dir) } return fmt.Sprintf("mkdir -p %s", dir) } // listDirCmd returns a command slice to list directory contents. func (e *osEnv) listDirCmd(dir string) []string { if e.isWindows() { return e.shellCmd(fmt.Sprintf("Get-ChildItem -Name '%s'", dir)) } return []string{"ls", dir} } // killProcessCmd returns a command slice to kill processes matching a pattern. // On Windows, uses taskkill /T to kill the entire process tree, which handles // child processes (chrome.exe, python3, etc.) that Claude CLI may have spawned. func (e *osEnv) killProcessCmd(pattern string) []string { if e.isWindows() { // taskkill /F /T kills the process tree; fall back to Stop-Process. script := fmt.Sprintf( "Get-Process -ErrorAction SilentlyContinue | Where-Object {$_.ProcessName -like '*%s*'} | ForEach-Object { taskkill /F /T /PID $_.Id 2>$null }; "+ "Get-Process -ErrorAction SilentlyContinue | Where-Object {$_.ProcessName -like '*%s*'} | Stop-Process -Force -ErrorAction SilentlyContinue", pattern, pattern) return e.shellCmd(script) } return []string{"sh", "-c", fmt.Sprintf("pkill -f '%s' || true", pattern)} } // rootDir returns the filesystem root for the target OS. func (e *osEnv) rootDir() string { if e.isWindows() { return `C:\` } return "/" } // pathJoin joins path segments using the appropriate separator. func (e *osEnv) pathJoin(parts ...string) string { if e.isWindows() { return strings.Join(parts, `\`) } return path.Join(parts...) } // buildCLIScript builds the complete CLI invocation script for the target OS. // Returns (script, stdin) — on Linux stdin is nil (heredoc handles it), // on Windows stdin contains inputJSONL bytes to pass via gRPC Stdin. func (e *osEnv) buildCLIScript(args []string, systemPrompt, inputJSONL string) (string, []byte) { workDir := e.WorkDir promptFile := e.pathJoin(workDir, ".yao", ".system-prompt.txt") if e.isWindows() { return e.buildPowerShellScript(args, systemPrompt, inputJSONL, workDir, promptFile) } return e.buildBashScript(args, systemPrompt, inputJSONL, workDir, promptFile), nil } func (e *osEnv) buildBashScript(args []string, systemPrompt, inputJSONL, workDir, promptFile string) string { var b strings.Builder if e.UserHome != "" { b.WriteString(fmt.Sprintf("touch %s/.Xauthority 2>/dev/null; ", e.UserHome)) } b.WriteString("touch \"$HOME/.Xauthority\" 2>/dev/null\n") if systemPrompt != "" { b.WriteString(fmt.Sprintf("mkdir -p %s/.yao\n", workDir)) b.WriteString(fmt.Sprintf("cat << 'PROMPTEOF' > %s\n", promptFile)) b.WriteString(systemPrompt) b.WriteString("\nPROMPTEOF\n") args = append(args, "--append-system-prompt-file", promptFile) } b.WriteString("cat << 'INPUTEOF' | claude -p") for _, arg := range args { b.WriteString(fmt.Sprintf(" %q", arg)) } b.WriteString(" 2>&1\n") b.WriteString(inputJSONL) b.WriteString("\nINPUTEOF") return b.String() } // buildPowerShellScript builds a script that writes the system prompt file, // then launches claude -p. inputJSONL is returned as stdin bytes to be passed // directly via gRPC, bypassing PowerShell's encoding entirely. func (e *osEnv) buildPowerShellScript(args []string, systemPrompt, inputJSONL, workDir, promptFile string) (string, []byte) { var b strings.Builder noBOM := "(New-Object System.Text.UTF8Encoding $false)" // Force UTF-8 for both input and output streams. // On CJK Windows the default codepage is often GBK/GB2312 (936) // which corrupts Claude CLI's UTF-8 JSON output. b.WriteString("[Console]::InputEncoding = [System.Text.Encoding]::UTF8\n") b.WriteString("[Console]::OutputEncoding = [System.Text.Encoding]::UTF8\n") b.WriteString("$OutputEncoding = [System.Text.Encoding]::UTF8\n") // Ensure claude.exe can be found even when Tai runs as a different user. // Claude CLI is typically installed per-user (e.g. C:\Users\X\.local\bin) // which isn't in the PATH when Tai runs as a service or another account. // Scan all user profiles for common install locations. b.WriteString("foreach ($d in (Get-ChildItem 'C:\\Users' -Directory -ErrorAction SilentlyContinue)) {\n") b.WriteString(" $p = Join-Path $d.FullName '.local\\bin'\n") b.WriteString(" if (Test-Path (Join-Path $p 'claude.exe')) { $env:PATH = \"$p;$env:PATH\"; break }\n") b.WriteString("}\n") b.WriteString("if ($env:APPDATA) { $env:PATH = \"$env:APPDATA\\npm;$env:PATH\" }\n") yaoDir := e.pathJoin(workDir, ".yao") b.WriteString(fmt.Sprintf("if (!(Test-Path '%s')) { New-Item -ItemType Directory -Path '%s' -Force | Out-Null }\n", yaoDir, yaoDir)) if systemPrompt != "" { escaped := strings.ReplaceAll(systemPrompt, "'", "''") b.WriteString(fmt.Sprintf("[IO.File]::WriteAllText('%s', @'\n%s\n'@, %s)\n", promptFile, escaped, noBOM)) args = append(args, "--append-system-prompt-file", promptFile) } b.WriteString("claude -p") for _, arg := range args { b.WriteString(fmt.Sprintf(" '%s'", strings.ReplaceAll(arg, "'", "''"))) } return b.String(), []byte(inputJSONL + "\n") } ================================================ FILE: agent/sandbox/v2/claude/parse.go ================================================ package claude import ( "bufio" "context" "encoding/json" "errors" "fmt" "io" "log" "strings" "time" goujson "github.com/yaoapp/gou/json" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" ) // errStreamCompleted is a sentinel indicating the parser received the terminal // "result" message. It is NOT a real error — callers should treat it as // successful completion of the stream. var errStreamCompleted = errors.New("claude stream completed") // parseStreamJSON reads stream-json lines from Claude CLI stdout and // pushes them through handler as standard StreamChunkType events. func parseStreamJSON(ctx context.Context, stdout io.ReadCloser, handler message.StreamFunc) error { // When the context is cancelled (upstream timeout / interrupt), close // stdout so that scanner.Scan() unblocks immediately. Without this, // a failed TerminateProcess (Access is denied) would leave us stuck // forever on the read. doneParsing := make(chan struct{}) defer close(doneParsing) go func() { select { case <-ctx.Done(): stdout.Close() case <-doneParsing: } }() scanner := bufio.NewScanner(stdout) buf := make([]byte, 0, 64*1024) scanner.Buffer(buf, 1024*1024) messageStarted := false toolBlockActive := false toolIndex := 0 type toolState struct { id string name string index int inputJSON strings.Builder } var currentTool *toolState for scanner.Scan() { line := scanner.Text() if line == "" { continue } var msg map[string]any if err := json.Unmarshal([]byte(line), &msg); err != nil { continue } msgType, _ := msg["type"].(string) stopped := false switch msgType { case "system": if handler != nil { data, _ := json.Marshal(msg) if handler(message.ChunkMetadata, data) != 0 { stopped = true } } case "stream_event": event, _ := msg["event"].(map[string]any) if event == nil { continue } eventType, _ := event["type"].(string) switch eventType { case "content_block_start": if cb, ok := event["content_block"].(map[string]any); ok { blockType, _ := cb["type"].(string) if blockType == "tool_use" { toolName, _ := cb["name"].(string) toolID, _ := cb["id"].(string) if toolID == "" { toolID = fmt.Sprintf("tool_%d_%d", toolIndex, time.Now().UnixNano()) } currentTool = &toolState{id: toolID, name: toolName, index: toolIndex} toolIndex++ if handler != nil { if messageStarted { handler(message.ChunkMessageEnd, nil) messageStarted = false } if !toolBlockActive { startData := message.EventMessageStartData{ MessageID: fmt.Sprintf("sandbox-tool-%d", time.Now().UnixNano()), Type: "tool_call", Timestamp: time.Now().UnixMilli(), } sd, _ := json.Marshal(startData) if handler(message.ChunkMessageStart, sd) != 0 { stopped = true break } toolBlockActive = true } tcData, _ := json.Marshal([]map[string]any{{ "index": currentTool.index, "id": currentTool.id, "type": "function", "function": map[string]any{ "name": toolName, "arguments": "", }, }}) if handler(message.ChunkToolCall, tcData) != 0 { stopped = true } } } } case "content_block_delta": if delta, ok := event["delta"].(map[string]any); ok { deltaType, _ := delta["type"].(string) switch deltaType { case "text_delta": if text, ok := delta["text"].(string); ok && text != "" { text = strings.ReplaceAll(text, "\r\n", "\n") text = strings.ReplaceAll(text, "\r", "\n") if handler != nil { if toolBlockActive { handler(message.ChunkMessageEnd, nil) toolBlockActive = false messageStarted = false } if !messageStarted { startData := message.EventMessageStartData{ MessageID: fmt.Sprintf("sandbox-%d", time.Now().UnixNano()), Type: "text", Timestamp: time.Now().UnixMilli(), } sd, _ := json.Marshal(startData) if handler(message.ChunkMessageStart, sd) != 0 { stopped = true break } messageStarted = true } if handler(message.ChunkText, []byte(text)) != 0 { stopped = true } } } case "input_json_delta": if currentTool != nil { if partial, ok := delta["partial_json"].(string); ok { currentTool.inputJSON.WriteString(partial) if handler != nil { tcData, _ := json.Marshal([]map[string]any{{ "index": currentTool.index, "function": map[string]any{ "arguments": partial, }, }}) if handler(message.ChunkToolCall, tcData) != 0 { stopped = true } } } } } } case "content_block_stop": currentTool = nil } case "assistant": if msgData, ok := msg["message"].(map[string]any); ok { stopReason, _ := msgData["stop_reason"].(string) if stopReason != "" { if contentArr, ok := msgData["content"].([]any); ok { for _, item := range contentArr { ci, ok := item.(map[string]any) if !ok { continue } itemType, _ := ci["type"].(string) if itemType == "tool_use" && handler != nil { toolName, _ := ci["name"].(string) toolID, _ := ci["id"].(string) if toolID == "" { toolID = fmt.Sprintf("tool_%d_%d", toolIndex, time.Now().UnixNano()) } inputRaw, _ := json.Marshal(ci["input"]) idx := toolIndex toolIndex++ if !toolBlockActive { startData := message.EventMessageStartData{ MessageID: fmt.Sprintf("sandbox-tool-%d", time.Now().UnixNano()), Type: "tool_call", Timestamp: time.Now().UnixMilli(), } sd, _ := json.Marshal(startData) if handler(message.ChunkMessageStart, sd) != 0 { stopped = true break } toolBlockActive = true } tcData, _ := json.Marshal([]map[string]any{{ "index": idx, "id": toolID, "type": "function", "function": map[string]any{ "name": toolName, "arguments": string(inputRaw), }, }}) if handler(message.ChunkToolCall, tcData) != 0 { stopped = true break } } if itemType == "text" { if text, ok := ci["text"].(string); ok && text != "" && handler != nil && !messageStarted { text = strings.ReplaceAll(text, "\r\n", "\n") text = strings.ReplaceAll(text, "\r", "\n") if toolBlockActive { handler(message.ChunkMessageEnd, nil) toolBlockActive = false } startData := message.EventMessageStartData{ MessageID: fmt.Sprintf("sandbox-%d", time.Now().UnixNano()), Type: "text", Timestamp: time.Now().UnixMilli(), } sd, _ := json.Marshal(startData) if handler(message.ChunkMessageStart, sd) != 0 { stopped = true break } if handler(message.ChunkText, []byte(text)) != 0 { stopped = true break } messageStarted = true } } } } // Close any open message from the streaming phase. // stream_event text_deltas set messageStarted=true but // nothing resets it when the turn ends — the assistant // message marks the turn boundary, so we must close // the message here to keep state in sync with the // stream handler (which already sent message_end). if handler != nil { if toolBlockActive { handler(message.ChunkMessageEnd, nil) toolBlockActive = false } if messageStarted { handler(message.ChunkMessageEnd, nil) messageStarted = false } } } } case "result": isError, _ := msg["is_error"].(bool) if isError { if result, ok := msg["result"].(string); ok { if handler != nil { handler(message.ChunkError, []byte(result)) } return fmt.Errorf("Claude CLI error: %s", result) } } if handler != nil { if toolBlockActive { handler(message.ChunkMessageEnd, nil) toolBlockActive = false } if messageStarted { handler(message.ChunkMessageEnd, nil) } } // "result" is the terminal message in Claude CLI's stream-json // protocol. Return immediately instead of continuing to // scanner.Scan(), which would block forever if the process // stays alive (e.g. child processes like chrome.exe keep the // stdout pipe open). return errStreamCompleted case "error": var errMsg string switch e := msg["error"].(type) { case string: errMsg = e case map[string]any: errMsg, _ = e["message"].(string) } if errMsg != "" { if handler != nil { handler(message.ChunkError, []byte(errMsg)) } return fmt.Errorf("Claude CLI error: %s", errMsg) } } if stopped { break } } if err := scanner.Err(); err != nil { // If the context was cancelled (upstream timeout / interrupt), the // stdout pipe was closed by the goroutine above. The resulting // read error is expected — surface it as context.Canceled so the // caller can handle it uniformly. if ctx.Err() != nil { return ctx.Err() } return err } return nil } // buildFirstRequestJSONL builds JSONL with all messages for the first request. func buildFirstRequestJSONL(messages []agentContext.Message) string { var lines []string for _, msg := range messages { if msg.Role == "system" { continue } content := msg.Content if content == nil { content = "" } streamMsg := map[string]any{ "type": string(msg.Role), "message": map[string]any{ "role": string(msg.Role), "content": content, }, } data, _ := json.Marshal(streamMsg) lines = append(lines, string(data)) } return strings.Join(lines, "\n") } // buildLastUserMessageJSONL builds JSONL with only the last user message. func buildLastUserMessageJSONL(messages []agentContext.Message) string { for i := len(messages) - 1; i >= 0; i-- { if messages[i].Role == "user" { content := messages[i].Content if content == nil { content = "" } msg := map[string]any{ "type": "user", "message": map[string]any{ "role": "user", "content": content, }, } data, _ := json.Marshal(msg) return string(data) } } return "" } // Suppress unused import warnings — goujson.Parse is used for tool description // parsing in V1 and will be used for detailed tool descriptions in future. var _ = goujson.Parse var _ = log.Printf ================================================ FILE: agent/sandbox/v2/claude/runner.go ================================================ package claude import ( "context" "encoding/json" "errors" "fmt" "io" "os" "path" "strings" "time" "github.com/yaoapp/gou/connector" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/sandbox/v2/types" infra "github.com/yaoapp/yao/sandbox/v2" ) const defaultProxyPort = 3456 // ClaudeRunner implements the Runner interface for Claude CLI (mode=cli). type ClaudeRunner struct { mode string hasMCP bool mcpToolPattern string // e.g. "mcp__yao__*,mcp__github__*" servicePort int servicePath string serviceProtocol string streamCompleted bool // set when Stream received "result"; Cleanup skips kill } // New creates a new ClaudeRunner. func New() *ClaudeRunner { return &ClaudeRunner{mode: "cli"} } func (r *ClaudeRunner) Name() string { return "claude" } // Prepare executes user-defined and runner-specific prepare steps. func (r *ClaudeRunner) Prepare(ctx context.Context, req *types.PrepareRequest) error { r.mode = req.Config.Runner.Mode if r.mode == "" { r.mode = "cli" } steps := append([]types.PrepareStep{}, req.Config.Prepare...) if req.SkillsDir != "" { ws := req.Computer.Workplace() if ws != nil { src := "local:///" + req.SkillsDir dst := ".claude/skills" if _, err := ws.Copy(src, dst); err != nil { fmt.Fprintf(os.Stderr, "[claude] warn: copy skills %s -> %s: %v\n", src, dst, err) } } } if len(req.MCPServers) > 0 { r.hasMCP = true r.mcpToolPattern = buildMCPAllowedTools(req.MCPServers) mcpJSON := buildMCPConfig(req.MCPServers) steps = append(steps, types.PrepareStep{ Action: "file", Path: ".claude/mcp.json", Content: mcpJSON, }) } if req.RunSteps != nil && len(steps) > 0 { if err := req.RunSteps(ctx, steps, req.Computer, req.Config.ID, req.ConfigHash, req.AssistantDir); err != nil { return fmt.Errorf("claude prepare steps: %w", err) } } return nil } // Stream executes the Claude CLI and streams output to handler. func (r *ClaudeRunner) Stream(ctx context.Context, req *types.StreamRequest, handler message.StreamFunc) error { computer := req.Computer if computer == nil { return fmt.Errorf("computer is nil") } oe := resolveOSEnv(computer, req.Config) if req.ChatID != "" { ws := computer.Workplace() if ws != nil { processed, err := prepareAttachments(ctx, req.Messages, req.ChatID, ws) if err != nil { return fmt.Errorf("prepareAttachments: %w", err) } req.Messages = processed } } isContinuation := hasExistingSession(ctx, computer, oe) cmd, env, stdin := r.buildCLICommand(req, oe, isContinuation) streamOpts := []infra.ExecOption{infra.WithWorkDir(oe.WorkDir), infra.WithEnv(env)} if len(stdin) > 0 { streamOpts = append(streamOpts, infra.WithStdin(stdin)) } fmt.Fprintf(os.Stderr, "[claude] Stream cmd=%v hasMCP=%v isContinuation=%v stdinLen=%d workDir=%q\n", cmd, r.hasMCP, isContinuation, len(stdin), oe.WorkDir) execStream, err := computer.Stream(ctx, cmd, streamOpts...) if err != nil { return fmt.Errorf("computer.Stream: %w", err) } streamCtx, streamCancel := context.WithCancel(ctx) defer streamCancel() // Kill claude processes only when the context is cancelled externally // (upstream timeout, user interrupt) — NOT on normal return. go func() { <-streamCtx.Done() if ctx.Err() == nil { fmt.Fprintf(os.Stderr, "[claude] streamCtx done: normal return, skipping kill (ctx.Err=nil)\n") return } fmt.Fprintf(os.Stderr, "[claude] streamCtx done: context cancelled externally (ctx.Err=%v), killing processes\n", ctx.Err()) killCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() computer.Exec(killCtx, oe.killProcessCmd("claude")) execStream.Cancel() }() var stderrBuf strings.Builder go func() { buf := make([]byte, 4096) for { n, err := execStream.Stderr.Read(buf) if n > 0 { stderrBuf.Write(buf[:n]) chunk := string(buf[:n]) if strings.Contains(strings.ToLower(chunk), "error") { streamCancel() io.Copy(&stderrBuf, execStream.Stderr) return } } if err != nil { return } } }() parseErr := parseStreamJSON(streamCtx, execStream.Stdout, handler) fmt.Fprintf(os.Stderr, "[claude] parseStreamJSON returned: %v\n", parseErr) // Received "result" — Claude finished normally. Return immediately. if errors.Is(parseErr, errStreamCompleted) { r.streamCompleted = true fmt.Fprintf(os.Stderr, "[claude] stream completed normally, returning nil\n") return nil } // Parse failed or stream ended without "result" — wait for process. fmt.Fprintf(os.Stderr, "[claude] stream did NOT complete normally, waiting for process exit...\n") exitCode, waitErr := execStream.Wait() stderrStr := strings.TrimSpace(stderrBuf.String()) if parseErr != nil { if stderrStr != "" { return fmt.Errorf("%w (stderr: %s)", parseErr, stderrStr) } return parseErr } if waitErr != nil { if stderrStr != "" { return fmt.Errorf("%w (stderr: %s)", waitErr, stderrStr) } return waitErr } if exitCode != 0 { fmt.Fprintf(os.Stderr, "[claude] exit code=%d stderr=%q\n", exitCode, stderrStr) if stderrStr != "" { return fmt.Errorf("claude CLI exited with code %d: %s", exitCode, stderrStr) } return fmt.Errorf("claude CLI exited with code %d", exitCode) } return nil } // Cleanup kills any remaining claude processes. If the stream completed // normally (received "result"), child processes are preserved — the user // may have asked Claude to launch a browser, server, etc. func (r *ClaudeRunner) Cleanup(ctx context.Context, computer infra.Computer) error { if computer == nil { return nil } if r.streamCompleted { fmt.Fprintf(os.Stderr, "[claude] cleanup: stream completed normally, skipping process kill (child processes preserved)\n") return nil } if r.mode != "service" { oe := resolveOSEnv(computer, nil) computer.Exec(ctx, oe.killProcessCmd("claude")) } return nil } // hasExistingSession checks if a Claude CLI session exists in the workspace. func hasExistingSession(ctx context.Context, computer infra.Computer, oe *osEnv) bool { sessionDir := oe.pathJoin(oe.WorkDir, ".claude", "projects") result, err := computer.Exec(ctx, oe.listDirCmd(sessionDir)) if err != nil || result.ExitCode != 0 { return false } return strings.TrimSpace(result.Stdout) != "" } // buildCLICommand constructs the Claude CLI command, environment variables, and optional stdin bytes. func (r *ClaudeRunner) buildCLICommand(req *types.StreamRequest, oe *osEnv, isContinuation bool) ([]string, map[string]string, []byte) { env := make(map[string]string) if oe.isWindows() { env["USERPROFILE"] = oe.WorkDir if len(oe.WorkDir) >= 2 && oe.WorkDir[1] == ':' { env["HOMEDRIVE"] = oe.WorkDir[:2] env["HOMEPATH"] = oe.WorkDir[2:] } } else { env["HOME"] = oe.WorkDir if oe.UserHome != "" { env["XAUTHORITY"] = path.Join(oe.UserHome, ".Xauthority") } } if req.Connector != nil { setting := req.Connector.Setting() host, _ := setting["host"].(string) key, _ := setting["key"].(string) model, _ := setting["model"].(string) if req.Connector.Is(connector.ANTHROPIC) { env["ANTHROPIC_BASE_URL"] = host env["ANTHROPIC_API_KEY"] = key } else { env["ANTHROPIC_BASE_URL"] = fmt.Sprintf("http://127.0.0.1:%d", defaultProxyPort) env["ANTHROPIC_API_KEY"] = "dummy" } if model != "" { env["ANTHROPIC_MODEL"] = model env["ANTHROPIC_DEFAULT_OPUS_MODEL"] = model env["ANTHROPIC_DEFAULT_SONNET_MODEL"] = model env["ANTHROPIC_DEFAULT_HAIKU_MODEL"] = model env["CLAUDE_CODE_SUBAGENT_MODEL"] = model } if thinking, ok := setting["thinking"].(map[string]interface{}); ok { thinkType, _ := thinking["type"].(string) switch thinkType { case "disabled": env["MAX_THINKING_TOKENS"] = "0" case "enabled": if budget, ok := thinking["budget_tokens"].(float64); ok && budget > 0 { env["MAX_THINKING_TOKENS"] = fmt.Sprintf("%d", int(budget)) } } } } if req.Config != nil && len(req.Config.Secrets) > 0 { for k, v := range req.Config.Secrets { env[k] = v } } if req.Token != nil { if req.Token.Token != "" { env["YAO_TOKEN"] = req.Token.Token } if req.Token.RefreshToken != "" { env["YAO_REFRESH_TOKEN"] = req.Token.RefreshToken } } var systemPrompt string envPrompt := buildSandboxEnvPrompt(oe) if !isContinuation && req.SystemPrompt != "" { systemPrompt = req.SystemPrompt + "\n\n" + envPrompt } else if !isContinuation { systemPrompt = envPrompt } var inputJSONL string if isContinuation { inputJSONL = buildLastUserMessageJSONL(req.Messages) } else { inputJSONL = buildFirstRequestJSONL(req.Messages) } var args []string permMode := "" if req.Config != nil && req.Config.Runner.Options != nil { if v, ok := req.Config.Runner.Options["permission_mode"]; ok { permMode = fmt.Sprintf("%v", v) } } if permMode == "bypassPermissions" { args = append(args, "--dangerously-skip-permissions") args = append(args, "--permission-mode", permMode) } args = append(args, "--input-format", "stream-json") args = append(args, "--output-format", "stream-json") args = append(args, "--include-partial-messages") args = append(args, "--verbose") if isContinuation { args = append(args, "--continue") } if req.Config != nil && req.Config.Runner.Options != nil { for key, val := range req.Config.Runner.Options { if flag, ok := claudeArgWhitelist[key]; ok { args = append(args, flag, fmt.Sprintf("%v", val)) } } } if r.hasMCP { mcpPath := oe.pathJoin(oe.WorkDir, ".claude", "mcp.json") args = append(args, "--mcp-config", mcpPath) if r.mcpToolPattern != "" { args = append(args, "--allowedTools", r.mcpToolPattern) } } script, stdin := oe.buildCLIScript(args, systemPrompt, inputJSONL) return oe.shellCmd(script), env, stdin } // buildMCPConfig creates the .mcp.json for Claude CLI based on declared servers. // Each server delegates to "tai mcp" which implements the standard MCP protocol // over stdio and bridges to Yao gRPC with authentication. // Connection is configured via env vars (YAO_GRPC_ADDR, YAO_TOKEN, etc.) // injected by the sandbox infrastructure at container start. func buildMCPConfig(servers []types.MCPServer) []byte { mcpServers := make(map[string]any, len(servers)) for _, s := range servers { name := s.ServerID if name == "" { continue } mcpServers[name] = map[string]any{ "command": "tai", "args": []string{"mcp", name}, } } if len(mcpServers) == 0 { mcpServers["yao"] = map[string]any{ "command": "tai", "args": []string{"mcp"}, } } config := map[string]any{"mcpServers": mcpServers} data, _ := json.Marshal(config) return data } // buildMCPAllowedTools generates the --allowedTools pattern from server IDs. func buildMCPAllowedTools(servers []types.MCPServer) string { patterns := make([]string, 0, len(servers)) for _, s := range servers { if s.ServerID != "" { patterns = append(patterns, fmt.Sprintf("mcp__%s__*", s.ServerID)) } } if len(patterns) == 0 { return "mcp__yao__*" } return strings.Join(patterns, ",") } // buildSandboxEnvPrompt generates the sandbox environment prompt with system info and working directory. func buildSandboxEnvPrompt(oe *osEnv) string { workDir := oe.WorkDir osName := oe.OS if osName == "" { osName = "linux" } shell := oe.Shell if shell == "" { shell = "bash" } shellNote := "" if oe.isWindows() { shellNote = ` - **Desktop Environment**: You have full access to the Windows desktop (GUI applications, browsers, etc.) - **Important**: When you launch GUI applications (browsers, editors, etc.), do NOT close them unless explicitly asked — the user expects them to remain open` } return fmt.Sprintf(`## Sandbox Environment - **Operating System**: %[2]s - **Shell**: %[3]s - **Working Directory**: %[1]s - **File Access**: You have full read/write access to %[1]s%[4]s ## User Attachments User-uploaded files (images, documents, code files, etc.) are placed in %[1]s/.attachments/{chatID}/ Each chat session has its own subdirectory to avoid conflicts. When the user references an attached file, read it from this directory using the Read or Bash tool. For image files, you can view them directly as Claude supports vision on local files. `, workDir, osName, shell, shellNote) } var claudeArgWhitelist = map[string]string{ "max_turns": "--max-turns", "disallowed_tools": "--disallowed-tools", "allowed_tools": "--allowedTools", } ================================================ FILE: agent/sandbox/v2/claude/runner_test.go ================================================ package claude_test import ( "bytes" "context" "fmt" "mime/multipart" "os" "path/filepath" "runtime" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/caller" agentcontext "github.com/yaoapp/yao/agent/context" sandboxtestutils "github.com/yaoapp/yao/agent/sandbox/v2/testutils" "github.com/yaoapp/yao/attachment" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) type e2eCase struct { ID string Prompt string Timeout time.Duration } var cases = []e2eCase{ { ID: "tests.sandbox-v2.oneshot-cli", Prompt: "Reply exactly with: hello sandbox v2", Timeout: 3 * time.Minute, }, } func TestSandboxV2_Claude_E2E(t *testing.T) { sandboxtestutils.Prepare(t) defer sandboxtestutils.Clean(t) require.NotNil(t, caller.AgentGetterFunc, "AgentGetterFunc should be registered after Prepare") for _, tc := range cases { tc := tc t.Run(tc.ID, func(t *testing.T) { agent, err := caller.AgentGetterFunc(tc.ID) require.NoError(t, err, "should load assistant %s", tc.ID) timeout := tc.Timeout if timeout == 0 { timeout = 3 * time.Minute } chatID := fmt.Sprintf("e2e-%s-%d", tc.ID, time.Now().UnixMilli()) ctx := agentcontext.New( context.Background(), &oauthtypes.AuthorizedInfo{ TeamID: "test-team-e2e", UserID: "test-user-e2e", }, chatID, ) messages := []agentcontext.Message{ {Role: "user", Content: tc.Prompt}, } done := make(chan struct{}) var resp *agentcontext.Response var streamErr error go func() { defer close(done) resp, streamErr = agent.Stream(ctx, messages) }() select { case <-done: case <-time.After(timeout): t.Fatalf("timeout after %v", timeout) } require.NoError(t, streamErr, "Stream should not return error") require.NotNil(t, resp, "response should not be nil") // ── 1. CompletionResponse should behave like the LLM path ── require.NotNil(t, resp.Completion, "completion should not be nil") assert.Equal(t, "assistant", resp.Completion.Role, "role should be assistant") assert.Equal(t, agentcontext.FinishReasonStop, resp.Completion.FinishReason, "finish_reason should be stop") assert.NotNil(t, resp.Completion.Content, "Content should be populated (same as LLM path)") contentStr, ok := resp.Completion.Content.(string) require.True(t, ok, "Content should be a string, got %T", resp.Completion.Content) t.Logf("CompletionResponse.Content (%d chars): %s", len(contentStr), contentStr) assert.Contains(t, contentStr, "hello sandbox v2", "Content should contain expected text") // ── 2. Buffer: frame sequence handled correctly ── require.NotNil(t, ctx.Buffer, "ctx.Buffer should not be nil") msgs := ctx.Buffer.GetMessages() t.Logf("buffer message count: %d", len(msgs)) for _, m := range msgs { t.Logf(" seq=%d role=%s type=%s streaming=%v props_keys=%v", m.Sequence, m.Role, m.Type, m.IsStreaming, mapKeys(m.Props)) } var userInputCount, assistantTextCount, loadingCount int var bufferTextContent string for _, m := range msgs { switch { case m.Role == "user" && m.Type == "user_input": userInputCount++ case m.Role == "assistant" && m.Type == "loading": loadingCount++ case m.Role == "assistant" && m.Type == "text": assistantTextCount++ assert.False(t, m.IsStreaming, "text message should not be streaming (handleMessageEnd should have finalized it)") require.NotNil(t, m.Props, "text message props should not be nil") if c, ok := m.Props["content"].(string); ok { bufferTextContent += c } } } assert.Equal(t, 1, userInputCount, "should have exactly 1 user_input message") assert.GreaterOrEqual(t, loadingCount, 1, "should have at least 1 loading message") assert.Equal(t, 1, assistantTextCount, "should have exactly 1 assistant text message (from handleMessageEnd)") assert.Contains(t, bufferTextContent, "hello sandbox v2", "buffer text should contain expected content") // ── 3. Buffer content matches CompletionResponse.Content ── assert.Equal(t, contentStr, bufferTextContent, "CompletionResponse.Content and Buffer text should match") }) } } func TestSandboxV2_Claude_Attachments(t *testing.T) { sandboxtestutils.Prepare(t) defer sandboxtestutils.Clean(t) require.NotNil(t, caller.AgentGetterFunc, "AgentGetterFunc should be registered after Prepare") agent, err := caller.AgentGetterFunc("tests.sandbox-v2.oneshot-cli") require.NoError(t, err) // ── 1. Locate testdata via runtime.Caller ── _, thisFile, _, ok := runtime.Caller(0) require.True(t, ok) testdataDir := filepath.Join(filepath.Dir(thisFile), "testdata") // ── 2. Create attachment manager and upload test files ── const uploaderName = "__yao.attachment" manager, err := attachment.New(attachment.ManagerOption{ Driver: "local", MaxSize: "50M", AllowedTypes: []string{"image/*", "text/*", "application/*", "video/*", ".ts", ".js", ".tsx", ".jsx"}, Options: map[string]interface{}{"path": filepath.Join(os.TempDir(), "test_sandbox_v2_attach")}, }) require.NoError(t, err) manager.Name = uploaderName attachment.Managers[uploaderName] = manager t.Cleanup(func() { delete(attachment.Managers, uploaderName) }) imgFile := uploadTestFile(t, manager, testdataDir, "test-image.png", "image/png") codeFile := uploadTestFile(t, manager, testdataDir, "code.ts", "text/plain") imgWrapper := fmt.Sprintf("%s://%s", uploaderName, imgFile.ID) codeWrapper := fmt.Sprintf("%s://%s", uploaderName, codeFile.ID) t.Logf("image wrapper: %s", imgWrapper) t.Logf("code wrapper: %s", codeWrapper) // ── 3. Build multimodal messages (same as CUI InputArea) ── chatID := fmt.Sprintf("e2e-attach-%d", time.Now().UnixMilli()) ctx := agentcontext.New( context.Background(), &oauthtypes.AuthorizedInfo{TeamID: "test-team-e2e", UserID: "test-user-e2e"}, chatID, ) messages := []agentcontext.Message{ { Role: "user", Content: []interface{}{ map[string]interface{}{"type": "text", "text": "Describe the attached image and summarize the attached code file. Reply in English."}, map[string]interface{}{ "type": "image_url", "image_url": map[string]interface{}{"url": imgWrapper, "detail": "auto"}, }, map[string]interface{}{ "type": "file", "file": map[string]interface{}{"url": codeWrapper, "filename": "code.ts"}, }, }, }, } // ── 4. Run E2E stream ── done := make(chan struct{}) var resp *agentcontext.Response var streamErr error go func() { defer close(done) resp, streamErr = agent.Stream(ctx, messages) }() select { case <-done: case <-time.After(5 * time.Minute): t.Fatalf("timeout after 5m") } require.NoError(t, streamErr, "Stream should not return error") require.NotNil(t, resp) require.NotNil(t, resp.Completion) contentStr, ok := resp.Completion.Content.(string) require.True(t, ok, "Content should be a string, got %T", resp.Completion.Content) t.Logf("Response (%d chars): %s", len(contentStr), contentStr) lower := strings.ToLower(contentStr) // ── 5. Verify Claude actually read the image ── imageKeywords := []string{"hello", "utf", "chinese", "text", "emoji"} imgHit := false for _, kw := range imageKeywords { if strings.Contains(lower, kw) { imgHit = true break } } assert.True(t, imgHit, "response should mention image content (tried: %v)", imageKeywords) // ── 6. Verify Claude actually read the code ── codeKeywords := []string{"excel", "typescript", "class", "volcengine"} codeHit := false for _, kw := range codeKeywords { if strings.Contains(lower, kw) { codeHit = true break } } assert.True(t, codeHit, "response should mention code content (tried: %v)", codeKeywords) } func uploadTestFile(t *testing.T, manager *attachment.Manager, testdataDir, filename, contentType string) *attachment.File { t.Helper() path := filepath.Join(testdataDir, filename) data, err := os.ReadFile(path) require.NoError(t, err, "read testdata/%s", filename) fh := &attachment.FileHeader{ FileHeader: &multipart.FileHeader{ Filename: filename, Size: int64(len(data)), Header: make(map[string][]string), }, } fh.Header.Set("Content-Type", contentType) file, err := manager.Upload(context.Background(), fh, bytes.NewReader(data), attachment.UploadOption{ Groups: []string{"e2e-sandbox-v2"}, }) require.NoError(t, err, "upload testdata/%s", filename) t.Logf("uploaded %s => ID=%s, Path=%s", filename, file.ID, file.Path) return file } func mapKeys(m map[string]interface{}) []string { if m == nil { return nil } keys := make([]string, 0, len(m)) for k := range m { keys = append(keys, k) } return keys } ================================================ FILE: agent/sandbox/v2/claude/testdata/code.ts ================================================ import { Process } from "@yao/runtime"; /** * Excel class for manipulating Excel files via Yao's Excel Module */ export class Excel { private handle: string | null = null; /** * Creates a new Excel instance * @param file Path to the Excel file */ constructor(private file: string, writable: boolean = false) { this.file = file; this.Open(writable); } /** * Read each sheet top n rows * @param file Path to the Excel file * @param n number of rows to read * @returns Object with sheet names as keys and arrays of row values as values */ static Heads( file: string, n: number = 5, filters?: string[] ): Record { const excel = new Excel(file); const heads = excel.Heads(n, filters); excel.Close(); return heads; } /** * Read each sheet top n rows * @param n number of rows to read * @returns Object with sheet names as keys and arrays of row values as values * @throws Error if file not opened */ Heads(n: number = 5, filters: string[] = []): Record { if (!this.handle) throw new Error("Excel file not opened"); const sheets = this.Sheets(); const result: Record = {}; for (const sheet of sheets) { if (filters.length > 0 && !filters.includes(sheet)) { continue; } // Open row iterator for the sheet const iterator = this.each.OpenRow(sheet); const rows: any[][] = []; // Read n rows let row; let count = 0; while ( count < n && (row = Process(`excel.each.NextRow`, iterator)) !== null ) { // Add column headers (A, B, C, ...) for the first row if (count === 0) { const headerRow = []; for (let i = 0; i < row.length; i++) { headerRow.push(this.convert.ColumnNumberToName(i + 1)); } rows.push(headerRow); } // Trim Each cell's value row = row.map((cell) => cell?.trim?.()); rows.push(row); count++; } // Close the row iterator this.each.CloseRow(iterator); // Find the max length of each row, and pad the column headers(A, B, C, ...) to the same length const maxLength = Math.max(...rows.map((row) => row.length)); const start = rows[0].length; const neededLength = maxLength - rows[0].length; for (let i = 0; i < neededLength; i++) { rows[0].push(this.convert.ColumnNumberToName(start + i + 1)); } // Add the sheet's rows to the result result[sheet] = rows; } return result; } /** * Check if a sheet exists in the Excel file * @param file Path to the Excel file * @param sheet Sheet name to check * @returns boolean - true if sheet exists, false otherwise */ static Exists(file: string, sheet: string) { const excel = new Excel(file); const exists = excel.sheet.Exists(sheet); excel.Close(); return exists; } /** * Opens an Excel file for reading or writing * @param writable Whether to open in writable mode (true) or read-only mode (false) * @returns Handle ID used for subsequent operations */ Open(writable: boolean = false) { this.handle = Process(`excel.Open`, this.file, writable); return this.handle; } /** * Closes the Excel file and releases resources * IMPORTANT: Always call this method when done to prevent memory leaks */ Close() { if (this.handle) { Process(`excel.Close`, this.handle); this.handle = null; } } /** * Saves changes to the Excel file * @throws Error if file not opened */ Save() { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.Save`, this.handle); } /** * Gets all sheet names in the workbook * @returns Array of sheet names * @throws Error if file not opened */ Sheets() { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.Sheets`, this.handle); } // Sheet operations sheet = { /** * Creates a new sheet in the workbook * @param name Name for the new sheet * @returns number Index of the new sheet * @throws Error if file not opened */ Create: (name: string) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.sheet.create`, this.handle, name); }, /** * Lists all sheets in the workbook * @returns string[] Array of sheet names * @throws Error if file not opened */ List: () => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.sheet.list`, this.handle); }, /** * Checks if a sheet exists in the workbook * @param name Sheet name to check * @returns boolean - true if sheet exists, false otherwise * @throws Error if file not opened */ Exists: (name: string) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.sheet.exists`, this.handle, name); }, /** * Reads all data from a sheet * @param name Sheet name * @returns any[][] Two-dimensional array of cell values * @throws Error if file not opened */ Read: (name: string) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.sheet.read`, this.handle, name); }, /** * Reads all data from a sheet with pagination support * @param name Sheet name * @param from Starting row index (0-based) * @param chunk_size Number of rows to read * @returns any[][] Two-dimensional array of cell values * @throws Error if file not opened */ Rows: (name: string, from: number, chunk_size: number) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.sheet.rows`, this.handle, name, from, chunk_size); }, /** * Updates data in a sheet. Creates the sheet if it doesn't exist. * @param name Sheet name * @param data Two-dimensional array of values to write * @throws Error if file not opened */ Update: (name: string, data: any[][]) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.sheet.update`, this.handle, name, data); }, /** * Copies a sheet with all its content and formatting * @param source Source sheet name * @param target Target sheet name (must not exist) * @throws Error if file not opened */ Copy: (source: string, target: string) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.sheet.copy`, this.handle, source, target); }, /** * Deletes a sheet from the workbook * @param name Sheet name to delete * @throws Error if file not opened */ Delete: (name: string) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.sheet.delete`, this.handle, name); }, /** * Gets the dimensions (number of rows and columns) of a sheet * @param name Sheet name * @returns {rows: number, cols: number} - Object containing row and column counts * @throws Error if file not opened */ Dimension: (name: string) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.sheet.dimension`, this.handle, name); }, }; // Reading operations read = { /** * Reads a cell's value * @param sheet Sheet name * @param cell Cell reference (e.g. "A1") * @returns Cell value * @throws Error if file not opened */ Cell: (sheet: string, cell: string) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.read.Cell`, this.handle, sheet, cell); }, /** * Reads all rows in a sheet * @param sheet Sheet name * @returns Two-dimensional array of cell values * @throws Error if file not opened */ Row: (sheet: string) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.read.Row`, this.handle, sheet); }, /** * Reads all columns in a sheet * @param sheet Sheet name * @returns Two-dimensional array of cell values * @throws Error if file not opened */ Column: (sheet: string) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.read.Column`, this.handle, sheet); }, }; // Writing operations write = { /** * Writes a value to a cell * @param sheet Sheet name * @param cell Cell reference (e.g. "A1") * @param value Value to write (string, number, boolean, etc.) * @throws Error if file not opened */ Cell: (sheet: string, cell: string, value: any) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.write.Cell`, this.handle, sheet, cell, value); }, /** * Writes values to a row starting at the specified cell * @param sheet Sheet name * @param startCell Starting cell reference (e.g. "A1") * @param values Array of values to write * @throws Error if file not opened */ Row: (sheet: string, startCell: string, values: any[]) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.write.Row`, this.handle, sheet, startCell, values); }, /** * Writes values to a column starting at the specified cell * @param sheet Sheet name * @param startCell Starting cell reference (e.g. "A1") * @param values Array of values to write * @throws Error if file not opened */ Column: (sheet: string, startCell: string, values: any[]) => { if (!this.handle) throw new Error("Excel file not opened"); return Process( `excel.write.Column`, this.handle, sheet, startCell, values ); }, /** * Writes a two-dimensional array of values starting at the specified cell * @param sheet Sheet name * @param startCell Starting cell reference (e.g. "A1") * @param values Two-dimensional array of values to write * @throws Error if file not opened */ All: (sheet: string, startCell: string, values: any[][]) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.write.All`, this.handle, sheet, startCell, values); }, }; // Setting properties set = { /** * Applies a style to a cell * @param sheet Sheet name * @param cell Cell reference (e.g. "A1") * @param styleID Style ID to apply * @throws Error if file not opened */ Style: (sheet: string, cell: string, styleID: number) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.set.Style`, this.handle, sheet, cell, styleID); }, /** * Sets a row's height * @param sheet Sheet name * @param row Row number * @param height Height in points * @throws Error if file not opened */ RowHeight: (sheet: string, row: number, height: number) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.set.RowHeight`, this.handle, sheet, row, height); }, /** * Sets column width for a range of columns * @param sheet Sheet name * @param startCol Starting column letter * @param endCol Ending column letter * @param width Width in points * @throws Error if file not opened */ ColumnWidth: ( sheet: string, startCol: string, endCol: string, width: number ) => { if (!this.handle) throw new Error("Excel file not opened"); return Process( `excel.set.ColumnWidth`, this.handle, sheet, startCol, endCol, width ); }, /** * Merges cells in a range * @param sheet Sheet name * @param startCell Starting cell reference (e.g. "A1") * @param endCell Ending cell reference (e.g. "B2") * @throws Error if file not opened */ MergeCell: (sheet: string, startCell: string, endCell: string) => { if (!this.handle) throw new Error("Excel file not opened"); return Process( `excel.set.MergeCell`, this.handle, sheet, startCell, endCell ); }, /** * Unmerges previously merged cells * @param sheet Sheet name * @param startCell Starting cell reference (e.g. "A1") * @param endCell Ending cell reference (e.g. "B2") * @throws Error if file not opened */ UnmergeCell: (sheet: string, startCell: string, endCell: string) => { if (!this.handle) throw new Error("Excel file not opened"); return Process( `excel.set.UnmergeCell`, this.handle, sheet, startCell, endCell ); }, /** * Sets a formula in a cell * @param sheet Sheet name * @param cell Cell reference (e.g. "C1") * @param formula Excel formula without the leading equals sign * @throws Error if file not opened */ Formula: (sheet: string, cell: string, formula: string) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.set.Formula`, this.handle, sheet, cell, formula); }, /** * Adds a hyperlink to a cell * @param sheet Sheet name * @param cell Cell reference (e.g. "A1") * @param url URL for the hyperlink * @param text Display text for the hyperlink * @throws Error if file not opened */ Link: (sheet: string, cell: string, url: string, text: string) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.set.Link`, this.handle, sheet, cell, url, text); }, }; // Iteration methods each = { /** * Opens a row iterator * @param sheet Sheet name * @returns Row iterator ID * @throws Error if file not opened */ OpenRow: (sheet: string) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.each.OpenRow`, this.handle, sheet); }, /** * Gets the next row from the iterator * @param rowID Row iterator ID from excel.each.OpenRow * @returns Array of cell values or null if no more rows */ NextRow: (rowID: string) => { return Process(`excel.each.NextRow`, rowID); }, /** * Closes the row iterator * @param rowID Row iterator ID from excel.each.OpenRow */ CloseRow: (rowID: string) => { return Process(`excel.each.CloseRow`, rowID); }, /** * Opens a column iterator * @param sheet Sheet name * @returns Column iterator ID * @throws Error if file not opened */ OpenColumn: (sheet: string) => { if (!this.handle) throw new Error("Excel file not opened"); return Process(`excel.each.OpenColumn`, this.handle, sheet); }, /** * Gets the next column from the iterator * @param colID Column iterator ID from excel.each.OpenColumn * @returns Array of cell values or null if no more columns */ NextColumn: (colID: string) => { return Process(`excel.each.NextColumn`, colID); }, /** * Closes the column iterator * @param colID Column iterator ID from excel.each.OpenColumn */ CloseColumn: (colID: string) => { return Process(`excel.each.CloseColumn`, colID); }, }; // Conversion utilities convert = { /** * Converts a column name to a column number * @param colName Column name (e.g. "A", "AB") * @returns Column number (1-based) */ ColumnNameToNumber: (colName: string) => { return Process(`excel.convert.ColumnNameToNumber`, colName); }, /** * Converts a column number to a column name * @param colNum Column number (1-based) * @returns Column name */ ColumnNumberToName: (colNum: number) => { return Process(`excel.convert.ColumnNumberToName`, colNum); }, /** * Converts a cell reference to coordinates * @param cell Cell reference (e.g. "A1") * @returns Array with [columnNumber, rowNumber] (1-based) */ CellNameToCoordinates: (cell: string) => { return Process(`excel.convert.CellNameToCoordinates`, cell); }, /** * Converts coordinates to a cell reference * @param col Column number (1-based) * @param row Row number (1-based) * @returns Cell reference */ CoordinatesToCellName: (col: number, row: number) => { return Process(`excel.convert.CoordinatesToCellName`, col, row); }, }; } /** * Volcengine OpenAPI SDK */ import { Exception, http, Process } from "@yao/runtime"; export class Volcengine { private AccessKeyId: string; private SecretAccessKey: string; private Region: string; private Service: string; private Endpoint: string; constructor(option: Option) { this.AccessKeyId = option.AccessKeyId; this.SecretAccessKey = option.SecretAccessKey; this.Region = option.Region; this.Service = option.Service; this.Endpoint = option.Endpoint ? `https://${option.Endpoint}` : `https://${this.Service}.${this.Region}.volcengineapi.com`; } public Get(query: Record) { const url = this.Endpoint; const host = url.split("://")[1].split("/")[0]; const headers = { host: host }; const request: Request = { Method: "GET", URI: "/", Query: query, Headers: headers, Payload: null, }; const auth = this.getAuthorization(request); // Add authorization header headers["Authorization"] = auth; headers["Content-Type"] = "application/json"; const resp = http.Get(url, query, headers); if (resp.code > 299 || resp.code < 200) { const { ResponseMetadata } = resp.data || {}; const { Error } = ResponseMetadata || {}; const message = Error?.Message || (resp.code === 0 ? resp.message : "Unknown error"); throw new Exception(message, resp.code); } return resp.data; } /** * Post request * @param query Query parameters * @param payload Payload * @returns Response */ public Post(query: Record, payload: Record) { const url = this.Endpoint; const host = url.split("://")[1].split("/")[0]; const headers = { host: host }; const body = JSON.stringify(payload); const request: Request = { Method: "POST", URI: "/", Query: query, Headers: headers, Payload: body, }; const auth = this.getAuthorization(request); headers["Authorization"] = auth; headers["Content-Type"] = "application/json"; const resp = http.Post(url, body, null, query, headers); if (resp.code > 299 || resp.code < 200) { const { ResponseMetadata } = resp.data || {}; const { Error } = ResponseMetadata || {}; const message = Error?.Message || (resp.code === 0 ? resp.message : "Unknown error"); throw new Exception(message, resp.code); } return resp.data; } /** * Create a canonical request * @param request Request object * @returns Canonical request string */ private canonicalRequest(request: Request): string { const xDate = this.formatDate(new Date()); // 1. HTTP Method const method = request.Method; // 2. URI (default to '/' if null) const uri = request.URI || "/"; // 3. Query String let queryString = ""; if (request.Query) { if (Array.isArray(request.Query)) { // Handle array of query parameters const queryParams = request.Query.reduce((acc: string[], curr) => { Object.entries(curr).forEach(([key, value]) => { if (value !== null && value !== undefined && value !== "") { acc.push( `${encodeURIComponent(key)}=${encodeURIComponent(value)}` ); } }); return acc; }, []); queryString = queryParams.sort().join("&"); } else { // Handle single query object const queryParams = Object.entries(request.Query) .filter( ([_, value]) => value !== null && value !== undefined && value !== "" ) .map( ([key, value]) => `${encodeURIComponent(key)}=${encodeURIComponent(value)}` ) .sort(); queryString = queryParams.join("&"); } } // 4. Headers // First, collect all headers in a normalized format const headers: Record = { "x-date": xDate }; if (request.Headers) { if (Array.isArray(request.Headers)) { request.Headers.forEach((headerObj) => { Object.entries(headerObj).forEach(([key, value]) => { if (value !== null && value !== undefined && value.trim() !== "") { headers[key.toLowerCase()] = value.trim(); } }); }); } else { Object.entries(request.Headers).forEach(([key, value]) => { if (value !== null && value !== undefined && value.trim() !== "") { headers[key.toLowerCase()] = value.trim(); } }); } } // Get required headers if they exist const signedHeaderKeys: string[] = []; const requiredHeaders = ["host", "x-date"]; // Add required headers first if they exist requiredHeaders.forEach((key) => { if (headers[key]) { signedHeaderKeys.push(key); } }); // Add any additional headers // const additionalHeaders = Object.keys(headers) // .filter((key) => !requiredHeaders.includes(key)) // .sort(); // signedHeaderKeys.push(...additionalHeaders); // Build canonical headers string const canonicalHeaders = signedHeaderKeys .map((key) => `${key}:${headers[key]}`) .join("\n"); // Build signed headers string const signedHeaders = signedHeaderKeys.join(";"); // 5. Payload/Body let hashedPayload = Process("crypto.Hash", "SHA256", ""); if (request.Payload !== null && request.Payload !== undefined) { if (typeof request.Payload === "string") { if (request.Payload !== "") { hashedPayload = Process("crypto.Hash", "SHA256", request.Payload); } } else { const payload = JSON.stringify(request.Payload); if (payload !== "{}" && payload !== "[]") { hashedPayload = Process("crypto.Hash", "SHA256", payload); } } } // Combine all components const parts = [ method, uri, queryString, canonicalHeaders, "", // Empty line after headers signedHeaders, hashedPayload, ]; return parts.join("\n"); } /** * Format date to YYYYMMDDTHHMMSSZ * @param date Date object * @returns Formatted date string */ private formatDate(date: Date): string { const year = date.getUTCFullYear(); const month = String(date.getUTCMonth() + 1).padStart(2, "0"); const day = String(date.getUTCDate()).padStart(2, "0"); const hours = String(date.getUTCHours()).padStart(2, "0"); const minutes = String(date.getUTCMinutes()).padStart(2, "0"); const seconds = String(date.getUTCSeconds()).padStart(2, "0"); return `${year}${month}${day}T${hours}${minutes}${seconds}Z`; } /** * Create string to sign * @param canonicalRequest Canonical request string * @returns String to sign */ private stringToSign(canonicalRequest: string): string { const algorithm = "HMAC-SHA256"; const requestDateTime = this.formatDate(new Date()); const requestDate = requestDateTime.slice(0, 8); const credentialScope = `${requestDate}/${this.Region}/${this.Service}/request`; // YYYYMMDD const hashedCanonicalRequest = Process( "crypto.Hash", "SHA256", canonicalRequest ); return `${algorithm}\n${requestDateTime}\n${credentialScope}\n${hashedCanonicalRequest}`; } /** * Derive signing key * @param date Date in YYYY/MM/DD format * @returns Signing key */ private getSigningKey(date: string): string { const kDate = Process("crypto.HMAC", "SHA256", date, this.SecretAccessKey); const kRegion = Process( "crypto.HMACWith", { key: "hex" }, this.Region, kDate ); const kService = Process( "crypto.HMACWith", { key: "hex" }, this.Service, kRegion ); const kSigning = Process( "crypto.HMACWith", { key: "hex" }, "request", kService ); return kSigning; } /** * Calculate signature * @param stringToSign String to sign * @param signingKey Signing key * @returns Signature */ private signature(stringToSign: string, signingKey: string): string { return Process("crypto.HMACWith", { key: "hex" }, stringToSign, signingKey); } /** * Build authorization header * @param request Request object * @returns Authorization header value */ public getAuthorization(request: Request): string { const xDate = this.formatDate(new Date()); if (request.Headers) { if (typeof request.Headers === "object") { request.Headers["x-date"] = request.Headers["x-date"] ? request.Headers["x-date"] : xDate; } } // 1. Create canonical request const canonicalReq = this.canonicalRequest(request); // 2. Create string to sign const stringToSign = this.stringToSign(canonicalReq); // 3. Get date from string to sign const [algorithm, requestDateTime, credentialScope] = stringToSign.split("\n"); const date = requestDateTime.slice(0, 8); // 4. Derive signing key const signingKey = this.getSigningKey(date); // 5. Calculate signature const signature = this.signature(stringToSign, signingKey); // 6. Build authorization header let signedHeaders = ""; if (request.Headers) { const headers: Record = {}; if (Array.isArray(request.Headers)) { request.Headers.forEach((headerObj) => { Object.entries(headerObj).forEach(([key, value]) => { headers[key.toLowerCase()] = value.trim(); }); }); } else { Object.entries(request.Headers).forEach(([key, value]) => { headers[key.toLowerCase()] = value.trim(); }); } signedHeaders = Object.keys(headers).sort().join(";"); } return `${algorithm} Credential=${this.AccessKeyId}/${credentialScope}, SignedHeaders=${signedHeaders}, Signature=${signature}`; } } export interface Option { AccessKeyId: string; SecretAccessKey: string; Endpoint?: string; Region: string; Service: string; } export interface Request { Method: "GET" | "POST"; URI: string | null; // Default / Query: Record | Record[] | null; Headers: Record | Record[] | null; Payload: string | Record | any[] | null; } ================================================ FILE: agent/sandbox/v2/init.go ================================================ package sandboxv2 import ( "github.com/yaoapp/yao/agent/sandbox/v2/claude" "github.com/yaoapp/yao/agent/sandbox/v2/types" yaorunner "github.com/yaoapp/yao/agent/sandbox/v2/yao" ) func init() { Register("claude", func() types.Runner { return claude.New() }) Register("claude/cli", func() types.Runner { return claude.New() }) Register("yao", func() types.Runner { return yaorunner.New() }) } ================================================ FILE: agent/sandbox/v2/lifecycle.go ================================================ package sandboxv2 import ( "context" "crypto/rand" "encoding/hex" "fmt" "log" "github.com/yaoapp/gou/connector" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/sandbox/v2/types" infra "github.com/yaoapp/yao/sandbox/v2" "github.com/yaoapp/yao/tai" "github.com/yaoapp/yao/workspace" ) // BuildIdentifier determines the Computer identifier based on lifecycle policy // and optional metadata override. Returns "" for oneshot (always new). func BuildIdentifier(cfg *types.SandboxConfig, ownerID, chatID, assistantID, workspaceID string, metadata map[string]any) string { if cfg.Lifecycle == "oneshot" { return "" } switch cfg.Lifecycle { case "session": return fmt.Sprintf("%s-%s-%s", ownerID, assistantID, chatID) case "longrunning", "persistent": return fmt.Sprintf("%s-%s.%s", ownerID, assistantID, workspaceID) default: return "" } } // ResolveNodeID determines the target nodeID and computer kind based on // metadata and DSL configuration, without creating or acquiring a container. // Returns (nodeID, kind, error). kind is "box" or "host". func ResolveNodeID(ctx *agentContext.Context, cfg *types.SandboxConfig, manager *infra.Manager) (string, string, error) { computerID := "" if ctx.Metadata != nil { if cid, ok := ctx.Metadata["computer_id"].(string); ok && cid != "" { computerID = cid } } workspaceID := "" if ctx.Metadata != nil { if ws, ok := ctx.Metadata["workspace_id"].(string); ok && ws != "" { workspaceID = ws } } ownerID := resolveOwnerID(ctx) if workspaceID == "" { workspaceID = ownerID } if workspaceID != "" && workspaceID != ownerID { wsNode, err := workspace.M().NodeForWorkspace(context.Background(), workspaceID) if err == nil && wsNode != "" { computerID = wsNode } } if computerID != "" { if node, ok := tai.GetNodeMeta(computerID); ok { hasContainerRuntime := node.Capabilities.Docker || node.Capabilities.K8s if node.Capabilities.HostExec && !hasContainerRuntime { return computerID, "host", nil } if node.Capabilities.HostExec && hasContainerRuntime && cfg.Computer.Image == "" { return computerID, "host", nil } if !hasContainerRuntime { return "", "", fmt.Errorf("node %q has no container runtime and no host_exec capability", computerID) } return computerID, "box", nil } return computerID, "box", nil } if cfg.Computer.Image == "" { nodeID := cfg.NodeID return nodeID, "host", nil } nodeID := cfg.NodeID return nodeID, "box", nil } // GetComputer obtains or creates a Computer for the current request. // An optional connector may be passed to inject OPENAI_PROXY_* env vars. // Returns the Computer, the resolved identifier, and any error. func GetComputer(ctx *agentContext.Context, cfg *types.SandboxConfig, manager *infra.Manager, conn ...connector.Connector) (infra.Computer, string, error) { ownerID := resolveOwnerID(ctx) workspaceID := "" if ctx.Metadata != nil { if ws, ok := ctx.Metadata["workspace_id"].(string); ok && ws != "" { workspaceID = ws } } if workspaceID == "" { workspaceID = ownerID } identifier := BuildIdentifier(cfg, ownerID, ctx.ChatID, ctx.AssistantID, workspaceID, ctx.Metadata) // Fill runtime fields. cfg.Owner = ownerID cfg.ID = identifier cfg.WorkspaceID = workspaceID // Resolve computer_id from metadata to determine kind and nodeID. computerID := "" if ctx.Metadata != nil { if cid, ok := ctx.Metadata["computer_id"].(string); ok && cid != "" { computerID = cid } } // Workspace-wins rule: when both workspace_id and computer_id are present, // the workspace's bound node takes precedence over computer_id. if workspaceID != "" && workspaceID != ownerID { wsNode, err := workspace.M().NodeForWorkspace(context.Background(), workspaceID) if err == nil && wsNode != "" { if computerID != "" && computerID != wsNode { log.Printf("[sandbox/v2] workspace %s bound to node %s overrides computer_id %s", workspaceID, wsNode, computerID) } computerID = wsNode } } if computerID != "" { return resolveComputerByID(cfg, manager, computerID, ownerID, identifier, workspaceID, conn...) } // No computer_id: fall back to DSL-based dispatch (original logic). return resolveComputerByDSL(cfg, manager, ownerID, identifier, workspaceID, conn...) } // resolveComputerByID dispatches based on the runtime computer_id from metadata. // It queries the registry and sandbox manager to determine the computer kind. func resolveComputerByID( cfg *types.SandboxConfig, manager *infra.Manager, computerID, ownerID, identifier, workspaceID string, conn ...connector.Connector, ) (infra.Computer, string, error) { // 1) Check if computer_id is a known Tai node (host or node kind). if node, ok := tai.GetNodeMeta(computerID); ok { cfg.NodeID = computerID hasContainerRuntime := node.Capabilities.Docker || node.Capabilities.K8s if node.Capabilities.HostExec && !hasContainerRuntime { // Host-only node: must use host mode regardless of DSL image config. cfg.Kind = "host" host, err := manager.Host(context.Background(), computerID) if err != nil { return nil, identifier, fmt.Errorf("get host computer: %w", err) } host.BindWorkplace(workspaceID) return host, identifier, nil } if node.Capabilities.HostExec && hasContainerRuntime && cfg.Computer.Image == "" { // Dual-capable node with no image in DSL: prefer host mode. cfg.Kind = "host" host, err := manager.Host(context.Background(), computerID) if err != nil { return nil, identifier, fmt.Errorf("get host computer: %w", err) } host.BindWorkplace(workspaceID) return host, identifier, nil } if !hasContainerRuntime { return nil, identifier, fmt.Errorf("node %q has no container runtime and no host_exec capability", computerID) } // Node with container runtime and DSL has image: create/reuse a box. cfg.Kind = "box" return resolveBox(cfg, manager, ownerID, identifier, workspaceID, conn...) } // 2) Check if computer_id is an existing box ID. if manager != nil { box, err := manager.Get(context.Background(), computerID) if err == nil && box != nil { cfg.Kind = "box" box.BindWorkplace(workspaceID) return box, computerID, nil } } return nil, identifier, fmt.Errorf("computer %q not found in registry or sandbox manager", computerID) } // resolveComputerByDSL dispatches based on DSL static configuration (cfg.Computer.Image). func resolveComputerByDSL( cfg *types.SandboxConfig, manager *infra.Manager, ownerID, identifier, workspaceID string, conn ...connector.Connector, ) (infra.Computer, string, error) { // Host mode: no image → host computer. if cfg.Computer.Image == "" { cfg.Kind = "host" nodeID := cfg.NodeID if nodeID == "" { return nil, identifier, fmt.Errorf("host mode requires a nodeID (set in sandbox.yao or workspace)") } host, err := manager.Host(context.Background(), nodeID) if err != nil { return nil, identifier, fmt.Errorf("get host computer: %w", err) } host.BindWorkplace(workspaceID) return host, identifier, nil } cfg.Kind = "box" return resolveBox(cfg, manager, ownerID, identifier, workspaceID, conn...) } // resolveBox reuses or creates a box container. func resolveBox( cfg *types.SandboxConfig, manager *infra.Manager, ownerID, identifier, workspaceID string, conn ...connector.Connector, ) (infra.Computer, string, error) { // Reuse: non-empty identifier → try Get first. if identifier != "" { box, err := manager.Get(context.Background(), identifier) if err == nil && box != nil { if box.IsStopped() { if startErr := manager.StartBox(context.Background(), identifier); startErr != nil { log.Printf("[sandbox/v2] auto-start stopped box %s failed: %v, creating new", identifier, startErr) } else { box.BindWorkplace(workspaceID) return box, identifier, nil } } else { box.BindWorkplace(workspaceID) return box, identifier, nil } } } // Create new box. var c connector.Connector if len(conn) > 0 { c = conn[0] } createOpts, err := BuildCreateOptions(cfg, identifier, ownerID, workspaceID, c) if err != nil { return nil, identifier, fmt.Errorf("build create options: %w", err) } // Oneshot with empty identifier: generate a random one. if createOpts.ID == "" { createOpts.ID = randomID() identifier = createOpts.ID cfg.ID = identifier } box, err := manager.Create(context.Background(), createOpts) if err != nil { return nil, identifier, fmt.Errorf("create computer: %w", err) } return box, identifier, nil } // LifecycleAction performs the post-request lifecycle operation based on policy. // Called in defer after executeSandboxStream completes. func LifecycleAction(ctx context.Context, cfg *types.SandboxConfig, computer infra.Computer, manager *infra.Manager) { if computer == nil || cfg == nil { return } info := computer.ComputerInfo() switch cfg.Lifecycle { case "oneshot": if info.Kind == "box" && manager != nil { if err := manager.Remove(ctx, cfg.ID); err != nil { log.Printf("[sandbox/v2] oneshot remove %s: %v", cfg.ID, err) } } case "session", "longrunning": if info.Kind == "box" && manager != nil { manager.Heartbeat(cfg.ID, false, 0) // active=false: request finished, start idle timer } case "persistent": // No action — persistent boxes survive indefinitely. } } // resolveOwnerID returns teamID if available, otherwise userID. func resolveOwnerID(ctx *agentContext.Context) string { if ctx.Authorized != nil { if ctx.Authorized.TeamID != "" { return ctx.Authorized.TeamID } if ctx.Authorized.UserID != "" { return ctx.Authorized.UserID } } return "anonymous" } func randomID() string { b := make([]byte, 8) _, _ = rand.Read(b) return hex.EncodeToString(b) } ================================================ FILE: agent/sandbox/v2/lifecycle_test.go ================================================ package sandboxv2_test import ( "context" "fmt" "strings" "testing" "time" agentContext "github.com/yaoapp/yao/agent/context" sandboxv2 "github.com/yaoapp/yao/agent/sandbox/v2" "github.com/yaoapp/yao/agent/sandbox/v2/types" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" infra "github.com/yaoapp/yao/sandbox/v2" ) // =========================================================================== // BuildIdentifier — pure-function tests (no infra needed) // =========================================================================== func TestBuildIdentifier_Oneshot(t *testing.T) { cfg := &types.SandboxConfig{Lifecycle: "oneshot"} id := sandboxv2.BuildIdentifier(cfg, "owner1", "chat1", "ast1", "ws1", nil) if id != "" { t.Errorf("oneshot should return empty, got %q", id) } } func TestBuildIdentifier_Session(t *testing.T) { cfg := &types.SandboxConfig{Lifecycle: "session"} id := sandboxv2.BuildIdentifier(cfg, "owner1", "chat42", "ast1", "ws1", nil) if id != "owner1-ast1-chat42" { t.Errorf("session: got %q, want %q", id, "owner1-ast1-chat42") } } func TestBuildIdentifier_Longrunning(t *testing.T) { cfg := &types.SandboxConfig{Lifecycle: "longrunning"} id := sandboxv2.BuildIdentifier(cfg, "owner1", "chat1", "ast99", "ws1", nil) if id != "owner1-ast99.ws1" { t.Errorf("longrunning: got %q, want %q", id, "owner1-ast99.ws1") } } func TestBuildIdentifier_Persistent(t *testing.T) { cfg := &types.SandboxConfig{Lifecycle: "persistent"} id := sandboxv2.BuildIdentifier(cfg, "owner1", "chat1", "ast99", "ws1", nil) if id != "owner1-ast99.ws1" { t.Errorf("persistent: got %q, want %q", id, "owner1-ast99.ws1") } } func TestBuildIdentifier_MetadataOverride(t *testing.T) { cfg := &types.SandboxConfig{Lifecycle: "session"} meta := map[string]any{"computer_id": "custom-box"} id := sandboxv2.BuildIdentifier(cfg, "owner1", "chat1", "ast1", "ws1", meta) // computer_id is used for routing only, not for identifier generation. if id != "owner1-ast1-chat1" { t.Errorf("metadata override: got %q, want %q", id, "owner1-ast1-chat1") } } func TestBuildIdentifier_MetadataEmptyIgnored(t *testing.T) { cfg := &types.SandboxConfig{Lifecycle: "session"} meta := map[string]any{"computer_id": ""} id := sandboxv2.BuildIdentifier(cfg, "owner1", "chat42", "ast1", "ws1", meta) if id != "owner1-ast1-chat42" { t.Errorf("empty metadata should fall through to session, got %q", id) } } func TestBuildIdentifier_UnknownLifecycle(t *testing.T) { cfg := &types.SandboxConfig{Lifecycle: "unknown"} id := sandboxv2.BuildIdentifier(cfg, "owner1", "chat1", "ast1", "ws1", nil) if id != "" { t.Errorf("unknown lifecycle should return empty, got %q", id) } } // =========================================================================== // GetComputer — real container tests // =========================================================================== func makeAgentCtx(teamID, userID, chatID, assistantID string, metadata map[string]any) *agentContext.Context { var auth *oauthTypes.AuthorizedInfo if teamID != "" || userID != "" { auth = &oauthTypes.AuthorizedInfo{TeamID: teamID, UserID: userID} } return &agentContext.Context{ Context: context.Background(), Authorized: auth, ChatID: chatID, AssistantID: assistantID, Metadata: metadata, } } func TestGetComputer_BoxCreate(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) ensureImage(t, m, nc) wsID := fmt.Sprintf("lc-create-%d", time.Now().UnixNano()) createTestWorkspace(t, nc.TaiID, wsID) cfg := &types.SandboxConfig{ Version: "2.0", Lifecycle: "oneshot", Computer: types.ComputerConfig{Image: testImage()}, NodeID: nc.TaiID, } meta := map[string]any{"workspace_id": wsID} ctx := makeAgentCtx("team-t1", "", "chat-1", "ast-1", meta) computer, identifier, err := sandboxv2.GetComputer(ctx, cfg, m) if err != nil { t.Fatalf("GetComputer: %v", err) } defer cleanupComputer(t, m, cfg) if identifier == "" { t.Fatal("oneshot should get a random identifier, got empty") } info := computer.ComputerInfo() if info.Kind != "box" { t.Errorf("kind = %q, want %q", info.Kind, "box") } if cfg.Owner != "team-t1" { t.Errorf("cfg.Owner = %q, want %q", cfg.Owner, "team-t1") } if cfg.Kind != "box" { t.Errorf("cfg.Kind = %q, want %q", cfg.Kind, "box") } }) } } func TestGetComputer_BoxReuse(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) ensureImage(t, m, nc) wsID := fmt.Sprintf("lc-reuse-%d", time.Now().UnixNano()) createTestWorkspace(t, nc.TaiID, wsID) cfg := &types.SandboxConfig{ Version: "2.0", Lifecycle: "session", Computer: types.ComputerConfig{Image: testImage()}, NodeID: nc.TaiID, } meta := map[string]any{"workspace_id": wsID} ctx := makeAgentCtx("team-reuse", "", "chat-reuse", "ast-1", meta) computer1, id1, err := sandboxv2.GetComputer(ctx, cfg, m) if err != nil { t.Fatalf("first GetComputer: %v", err) } defer cleanupComputer(t, m, cfg) cfg2 := &types.SandboxConfig{ Version: "2.0", Lifecycle: "session", Computer: types.ComputerConfig{Image: testImage()}, NodeID: nc.TaiID, } computer2, id2, err := sandboxv2.GetComputer(ctx, cfg2, m) if err != nil { t.Fatalf("second GetComputer: %v", err) } if id1 != id2 { t.Errorf("identifiers differ: %q vs %q", id1, id2) } info1 := computer1.ComputerInfo() info2 := computer2.ComputerInfo() if info1.ContainerID != info2.ContainerID { t.Errorf("container IDs differ: %q vs %q (should reuse)", info1.ContainerID, info2.ContainerID) } }) } } func TestGetComputer_WorkspaceBindAlways(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) ensureImage(t, m, nc) wsID := fmt.Sprintf("lc-ws-%d", time.Now().UnixNano()) createTestWorkspace(t, nc.TaiID, wsID) cfg := &types.SandboxConfig{ Version: "2.0", Lifecycle: "oneshot", Computer: types.ComputerConfig{Image: testImage()}, NodeID: nc.TaiID, } meta := map[string]any{"workspace_id": wsID} ctx := makeAgentCtx("team-ws", "", "chat-ws", "ast-ws", meta) computer, _, err := sandboxv2.GetComputer(ctx, cfg, m) if err != nil { t.Fatalf("GetComputer: %v", err) } defer cleanupComputer(t, m, cfg) if cfg.WorkspaceID != wsID { t.Errorf("WorkspaceID = %q, want %q", cfg.WorkspaceID, wsID) } ws := computer.Workplace() if ws == nil { t.Fatal("Workplace() returned nil, workspace should always be bound") } }) } } func TestGetComputer_WorkspaceFallbackOwner(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) ensureImage(t, m, nc) ownerID := fmt.Sprintf("lc-owner-%d", time.Now().UnixNano()) createTestWorkspace(t, nc.TaiID, ownerID) cfg := &types.SandboxConfig{ Version: "2.0", Lifecycle: "oneshot", Computer: types.ComputerConfig{Image: testImage()}, NodeID: nc.TaiID, } ctx := makeAgentCtx(ownerID, "", "chat-fb", "ast-fb", nil) computer, _, err := sandboxv2.GetComputer(ctx, cfg, m) if err != nil { t.Fatalf("GetComputer: %v", err) } defer cleanupComputer(t, m, cfg) if cfg.WorkspaceID != ownerID { t.Errorf("WorkspaceID = %q, want %q (should fallback to ownerID)", cfg.WorkspaceID, ownerID) } ws := computer.Workplace() if ws == nil { t.Fatal("Workplace() returned nil") } }) } } func TestGetComputer_OwnerPriority(t *testing.T) { skipIfNoDocker(t) nc := boxNodes()[0] m := setupManager(t, &nc) ensureImage(t, m, nc) t.Run("teamID", func(t *testing.T) { wsID := fmt.Sprintf("lc-ownp-team-%d", time.Now().UnixNano()) createTestWorkspace(t, nc.TaiID, wsID) cfg := &types.SandboxConfig{ Version: "2.0", Lifecycle: "oneshot", Computer: types.ComputerConfig{Image: testImage()}, NodeID: nc.TaiID, } ctx := makeAgentCtx("my-team", "my-user", "c", "a", map[string]any{"workspace_id": wsID}) _, _, err := sandboxv2.GetComputer(ctx, cfg, m) if err != nil { t.Fatalf("GetComputer: %v", err) } defer cleanupComputer(t, m, cfg) if cfg.Owner != "my-team" { t.Errorf("Owner = %q, want %q (teamID takes precedence)", cfg.Owner, "my-team") } }) t.Run("userID", func(t *testing.T) { wsID := fmt.Sprintf("lc-ownp-user-%d", time.Now().UnixNano()) createTestWorkspace(t, nc.TaiID, wsID) cfg := &types.SandboxConfig{ Version: "2.0", Lifecycle: "oneshot", Computer: types.ComputerConfig{Image: testImage()}, NodeID: nc.TaiID, } ctx := makeAgentCtx("", "my-user", "c", "a", map[string]any{"workspace_id": wsID}) _, _, err := sandboxv2.GetComputer(ctx, cfg, m) if err != nil { t.Fatalf("GetComputer: %v", err) } defer cleanupComputer(t, m, cfg) if cfg.Owner != "my-user" { t.Errorf("Owner = %q, want %q", cfg.Owner, "my-user") } }) t.Run("anonymous", func(t *testing.T) { wsID := fmt.Sprintf("lc-ownp-anon-%d", time.Now().UnixNano()) createTestWorkspace(t, nc.TaiID, wsID) cfg := &types.SandboxConfig{ Version: "2.0", Lifecycle: "oneshot", Computer: types.ComputerConfig{Image: testImage()}, NodeID: nc.TaiID, } ctx := makeAgentCtx("", "", "c", "a", map[string]any{"workspace_id": wsID}) _, _, err := sandboxv2.GetComputer(ctx, cfg, m) if err != nil { t.Fatalf("GetComputer: %v", err) } defer cleanupComputer(t, m, cfg) if cfg.Owner != "anonymous" { t.Errorf("Owner = %q, want %q", cfg.Owner, "anonymous") } }) } func TestGetComputer_HostMode(t *testing.T) { skipIfNoHostExec(t) for _, tgt := range hostTargets() { tgt := tgt t.Run(tgt.Name, func(t *testing.T) { m := setupHostManager(t, &tgt) cfg := &types.SandboxConfig{ Version: "2.0", Lifecycle: "session", Computer: types.ComputerConfig{}, NodeID: tgt.TaiID, } ctx := makeAgentCtx("team-host", "", "chat-host", "ast-host", nil) computer, _, err := sandboxv2.GetComputer(ctx, cfg, m) if err != nil { t.Fatalf("GetComputer host: %v", err) } if cfg.Kind != "host" { t.Errorf("Kind = %q, want %q", cfg.Kind, "host") } info := computer.ComputerInfo() if info.Kind != "host" { t.Errorf("ComputerInfo.Kind = %q, want %q", info.Kind, "host") } ws := computer.Workplace() if ws == nil { t.Fatal("Workplace() returned nil on host mode") } }) } } func TestGetComputer_HostMissingNodeID(t *testing.T) { skipIfNoDocker(t) nc := boxNodes()[0] m := setupManager(t, &nc) cfg := &types.SandboxConfig{ Version: "2.0", Lifecycle: "session", Computer: types.ComputerConfig{}, NodeID: "", } ctx := makeAgentCtx("team-err", "", "c", "a", nil) _, _, err := sandboxv2.GetComputer(ctx, cfg, m) if err == nil { t.Fatal("expected error for host mode without nodeID") } if !strings.Contains(err.Error(), "nodeID") { t.Errorf("error should mention nodeID, got: %v", err) } } // =========================================================================== // LifecycleAction — behavior tests // =========================================================================== func TestLifecycleAction_Oneshot(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) ensureImage(t, m, nc) wsID := fmt.Sprintf("lc-oneshot-%d", time.Now().UnixNano()) createTestWorkspace(t, nc.TaiID, wsID) cfg := &types.SandboxConfig{ Version: "2.0", Lifecycle: "oneshot", Computer: types.ComputerConfig{Image: testImage()}, NodeID: nc.TaiID, } ctx := makeAgentCtx("team-oneshot", "", "c", "a", map[string]any{"workspace_id": wsID}) computer, _, err := sandboxv2.GetComputer(ctx, cfg, m) if err != nil { t.Fatalf("GetComputer: %v", err) } boxID := cfg.ID sandboxv2.LifecycleAction(context.Background(), cfg, computer, m) _, getErr := m.Get(context.Background(), boxID) if getErr == nil { t.Error("box should be removed after oneshot LifecycleAction") } }) } } func TestLifecycleAction_Session(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) ensureImage(t, m, nc) wsID := fmt.Sprintf("lc-sess-%d", time.Now().UnixNano()) createTestWorkspace(t, nc.TaiID, wsID) cfg := &types.SandboxConfig{ Version: "2.0", Lifecycle: "session", Computer: types.ComputerConfig{Image: testImage()}, NodeID: nc.TaiID, } ctx := makeAgentCtx("team-sess", "", "chat-sess", "ast-sess", map[string]any{"workspace_id": wsID}) computer, _, err := sandboxv2.GetComputer(ctx, cfg, m) if err != nil { t.Fatalf("GetComputer: %v", err) } defer cleanupComputer(t, m, cfg) sandboxv2.LifecycleAction(context.Background(), cfg, computer, m) box, err := m.Get(context.Background(), cfg.ID) if err != nil { t.Fatalf("box should still exist after session LifecycleAction: %v", err) } if box == nil { t.Fatal("box is nil after session LifecycleAction") } }) } } func TestLifecycleAction_Persistent(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) ensureImage(t, m, nc) wsID := fmt.Sprintf("lc-pers-%d", time.Now().UnixNano()) createTestWorkspace(t, nc.TaiID, wsID) cfg := &types.SandboxConfig{ Version: "2.0", Lifecycle: "persistent", Computer: types.ComputerConfig{Image: testImage()}, NodeID: nc.TaiID, } ctx := makeAgentCtx("team-pers", "", "chat-pers", "ast-pers", map[string]any{"workspace_id": wsID}) computer, _, err := sandboxv2.GetComputer(ctx, cfg, m) if err != nil { t.Fatalf("GetComputer: %v", err) } defer cleanupComputer(t, m, cfg) sandboxv2.LifecycleAction(context.Background(), cfg, computer, m) box, err := m.Get(context.Background(), cfg.ID) if err != nil { t.Fatalf("box should still exist after persistent LifecycleAction: %v", err) } if box == nil { t.Fatal("box is nil after persistent LifecycleAction") } }) } } func TestLifecycleAction_NilSafe(t *testing.T) { cfg := &types.SandboxConfig{Lifecycle: "oneshot"} sandboxv2.LifecycleAction(context.Background(), cfg, nil, nil) sandboxv2.LifecycleAction(context.Background(), nil, nil, nil) } // =========================================================================== // helpers // =========================================================================== func ensureImage(t *testing.T, m *infra.Manager, nc nodeConfig) { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() if err := m.EnsureImage(ctx, nc.TaiID, testImage(), infra.ImagePullOptions{}); err != nil { t.Fatalf("EnsureImage: %v", err) } } func cleanupComputer(t *testing.T, m *infra.Manager, cfg *types.SandboxConfig) { t.Helper() if cfg.ID == "" { return } ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() if err := m.Remove(ctx, cfg.ID); err != nil { t.Logf("cleanup Remove(%s): %v", cfg.ID, err) } } ================================================ FILE: agent/sandbox/v2/options.go ================================================ package sandboxv2 import ( "encoding/json" "fmt" "os" "strings" "time" "github.com/yaoapp/gou/connector" "github.com/yaoapp/yao/agent/sandbox/v2/types" infra "github.com/yaoapp/yao/sandbox/v2" ) // resolveEnvRef resolves $ENV.XXX references to os.Getenv("XXX"). func resolveEnvRef(value string) string { if strings.HasPrefix(value, "$ENV.") { return os.Getenv(value[5:]) } return value } // BuildCreateOptions converts a SandboxConfig into the V2 infrastructure // CreateOptions. An optional connector is used to inject OPENAI_PROXY_* // environment variables when the connector is OpenAI-compatible (non-Anthropic). func BuildCreateOptions(cfg *types.SandboxConfig, identifier, ownerID, workspaceID string, conn ...connector.Connector) (infra.CreateOptions, error) { opts := infra.CreateOptions{ ID: identifier, Owner: ownerID, Image: cfg.Computer.Image, WorkDir: cfg.Computer.WorkDir, User: cfg.Computer.User, MountPath: cfg.Computer.MountPath, MountMode: cfg.Computer.MountMode, WorkspaceID: workspaceID, Labels: cfg.Labels, DisplayName: cfg.DisplayName, } if opts.Labels == nil { opts.Labels = make(map[string]string) } // Lifecycle policy switch cfg.Lifecycle { case "oneshot": opts.Policy = infra.OneShot case "session": opts.Policy = infra.Session case "longrunning": opts.Policy = infra.LongRunning case "persistent": opts.Policy = infra.Persistent default: opts.Policy = infra.OneShot } // Timeouts if cfg.IdleTimeout != "" { d, err := time.ParseDuration(cfg.IdleTimeout) if err != nil { return opts, fmt.Errorf("idle_timeout: %w", err) } opts.IdleTimeout = d } if opts.IdleTimeout == 0 { switch opts.Policy { case infra.Session: opts.IdleTimeout = infra.DefaultSessionIdleTimeout case infra.LongRunning: opts.IdleTimeout = infra.DefaultLongRunningIdleTimeout } } if cfg.MaxLifetime != "" { d, err := time.ParseDuration(cfg.MaxLifetime) if err != nil { return opts, fmt.Errorf("max_lifetime: %w", err) } opts.MaxLifetime = d } if cfg.StopTimeout != "" { d, err := time.ParseDuration(cfg.StopTimeout) if err != nil { return opts, fmt.Errorf("stop_timeout: %w", err) } opts.StopTimeout = d } // Memory (string like "4g" → bytes) if cfg.Computer.Memory != "" { mem, err := parseMemory(cfg.Computer.Memory) if err != nil { return opts, fmt.Errorf("memory: %w", err) } opts.Memory = mem } opts.CPUs = cfg.Computer.CPUs // VNC opts.VNC = cfg.Computer.VNC.Enabled // Ports for _, p := range cfg.Computer.Ports { opts.Ports = append(opts.Ports, infra.PortMapping{ ContainerPort: p.Port, HostPort: p.HostPort, Protocol: p.Protocol, }) } // NodeID (host mode pre-selection) if cfg.NodeID != "" { opts.NodeID = cfg.NodeID } // Merge environment + secrets into CreateOptions.Env. // Secrets override environment for same-name keys. // $ENV.XXX references are resolved at runtime. envSize := len(cfg.Environment) + len(cfg.Secrets) if envSize > 0 { opts.Env = make(map[string]string, envSize) for k, v := range cfg.Environment { opts.Env[k] = resolveEnvRef(v) } for k, v := range cfg.Secrets { opts.Env[k] = resolveEnvRef(v) } } if opts.Env == nil { opts.Env = make(map[string]string) } // Inject OPENAI_PROXY_* when connector is OpenAI-compatible (non-Anthropic). // The a2o proxy inside the container translates Anthropic API → OpenAI API. if len(conn) > 0 && conn[0] != nil && !conn[0].Is(connector.ANTHROPIC) { injectProxyEnv(opts.Env, conn[0]) } // Inject VNC_* environment variables from config. if cfg.Computer.VNC.Enabled { opts.Env["VNC_ENABLED"] = "true" if cfg.Computer.VNC.Password != "" { opts.Env["VNC_PASSWORD"] = resolveEnvRef(cfg.Computer.VNC.Password) } if cfg.Computer.VNC.Resolution != "" { opts.Env["VNC_RESOLUTION"] = cfg.Computer.VNC.Resolution } if cfg.Computer.VNC.ViewOnly { opts.Env["VNC_VIEW_ONLY"] = "true" } } return opts, nil } // injectProxyEnv extracts backend URL, model, and API key from an // OpenAI-compatible connector's settings and writes them as OPENAI_PROXY_* // environment variables into env. func injectProxyEnv(env map[string]string, conn connector.Connector) { settings := conn.Setting() if settings == nil { return } if host, ok := settings["host"].(string); ok && host != "" { env["OPENAI_PROXY_BACKEND"] = host } if model, ok := settings["model"].(string); ok && model != "" { env["OPENAI_PROXY_MODEL"] = model } if key, ok := settings["key"].(string); ok && key != "" { env["OPENAI_PROXY_API_KEY"] = key } // Forward extra options as JSON. extra := make(map[string]interface{}) for k, v := range settings { switch k { case "host", "model", "key", "proxy", "type": continue default: extra[k] = v } } if len(extra) > 0 { if data, err := json.Marshal(extra); err == nil { env["OPENAI_PROXY_OPTIONS"] = string(data) } } } // parseMemory converts a human-readable memory string to bytes. // Supported formats: "4GB", "4G", "4g", "512MB", "512M", "512m", "1024KB", "1024K", "1024". func parseMemory(s string) (int64, error) { if len(s) == 0 { return 0, nil } upper := strings.ToUpper(s) var num string var multiplier int64 switch { case strings.HasSuffix(upper, "GB"): num = s[:len(s)-2] multiplier = 1 << 30 case strings.HasSuffix(upper, "MB"): num = s[:len(s)-2] multiplier = 1 << 20 case strings.HasSuffix(upper, "KB"): num = s[:len(s)-2] multiplier = 1 << 10 case strings.HasSuffix(upper, "TB"): num = s[:len(s)-2] multiplier = 1 << 40 case strings.HasSuffix(upper, "G"): num = s[:len(s)-1] multiplier = 1 << 30 case strings.HasSuffix(upper, "M"): num = s[:len(s)-1] multiplier = 1 << 20 case strings.HasSuffix(upper, "K"): num = s[:len(s)-1] multiplier = 1 << 10 case strings.HasSuffix(upper, "T"): num = s[:len(s)-1] multiplier = 1 << 40 default: num = s multiplier = 1 } var val float64 if _, err := fmt.Sscanf(num, "%f", &val); err != nil { return 0, fmt.Errorf("invalid memory value %q", s) } return int64(val * float64(multiplier)), nil } ================================================ FILE: agent/sandbox/v2/prepare.go ================================================ package sandboxv2 import ( "context" "fmt" "log" "path" pathpkg "path/filepath" "strings" "github.com/yaoapp/yao/agent/sandbox/v2/types" infra "github.com/yaoapp/yao/sandbox/v2" "github.com/yaoapp/yao/tai/workspace" ) const onceMarkerDir = ".yao/prepare" // RunPrepareSteps executes a list of PrepareStep actions on the given Computer. // file/copy/marker operations use computer.Workplace() (gRPC volume, cross-platform). // exec operations use shell via Computer.Exec. // assistantDir is the absolute host path to the assistant source directory; // copy steps with a relative src resolve against it (host → workspace push). func RunPrepareSteps(ctx context.Context, steps []types.PrepareStep, computer infra.Computer, assistantID, configHash, assistantDir string) error { if len(steps) == 0 { return nil } var ws workspace.FS if computer != nil { ws = computer.Workplace() } markerDir := onceMarkerDir if assistantID != "" { markerDir = onceMarkerDir + "/" + assistantID } markerPath := markerDir + "/done" skipOnce := false if configHash != "" && ws != nil { if data, err := ws.ReadFile(markerPath); err == nil { if strings.TrimSpace(string(data)) == configHash { skipOnce = true } } } for i, step := range steps { if step.Once && skipOnce { continue } var err error switch step.Action { case "file": err = runFileStep(ws, step) case "copy": err = runCopyStep(ws, step, assistantDir) case "exec": err = runExecStep(ctx, computer, step) case "process": log.Printf("[sandbox/v2] prepare step %d: action=process (reserved, skipping)", i) default: err = fmt.Errorf("unknown prepare action %q", step.Action) } if err != nil { if step.IgnoreError { log.Printf("[sandbox/v2] prepare step %d (%s): ignored error: %v", i, step.Action, err) continue } return fmt.Errorf("prepare step %d (%s): %w", i, step.Action, err) } } if configHash != "" && ws != nil { ws.MkdirAll(markerDir, 0755) ws.WriteFile(markerPath, []byte(configHash), 0644) } return nil } // --------------------------------------------------------------------------- // Step runners // --------------------------------------------------------------------------- func runFileStep(ws workspace.FS, step types.PrepareStep) error { if step.Path == "" { return fmt.Errorf("file step requires path") } if ws == nil { return fmt.Errorf("file step requires workspace") } dir := path.Dir(step.Path) if dir != "." && dir != "/" { if err := ws.MkdirAll(dir, 0755); err != nil { return fmt.Errorf("mkdir %s: %w", dir, err) } } if err := ws.WriteFile(step.Path, step.Content, 0644); err != nil { return fmt.Errorf("write file %s: %w", step.Path, err) } return nil } // runCopyStep copies files into the workspace using ws.Copy which supports // the "local:///" URI scheme for host-to-workspace transfers. // // src resolution: // - Already a host URI ("local:///..." or "tmp:///...") → used as-is // - Relative path + assistantDir provided → resolved to "local:////" // - Relative path without assistantDir → treated as workspace-internal path func runCopyStep(ws workspace.FS, step types.PrepareStep, assistantDir string) error { if step.Src == "" || step.Dst == "" { return fmt.Errorf("copy step requires src and dst") } if ws == nil { return fmt.Errorf("copy step requires workspace") } src := step.Src if !isHostURI(src) && assistantDir != "" { src = "local:///" + pathpkg.Join(assistantDir, src) } if _, err := ws.Copy(src, step.Dst); err != nil { return fmt.Errorf("copy %s -> %s: %w", src, step.Dst, err) } return nil } func isHostURI(s string) bool { return strings.HasPrefix(s, "local:///") || strings.HasPrefix(s, "tmp:///") } func runExecStep(ctx context.Context, computer infra.Computer, step types.PrepareStep) error { if step.Cmd == "" { return fmt.Errorf("exec step requires cmd") } kind := shellFromSystem(computer) script := step.Cmd if step.Background { if kind == shellSh { script = fmt.Sprintf("nohup %s > /dev/null 2>&1 &", step.Cmd) } else { script = fmt.Sprintf("Start-Process -NoNewWindow -FilePath 'cmd.exe' -ArgumentList '/C %s'", step.Cmd) } } rootDir := "/" if isWindowsComputer(computer) { rootDir = `C:\` } result, err := computer.Exec(ctx, shellWrap(kind, script), infra.WithWorkDir(rootDir)) if err != nil { return err } label := "exec" if step.Background { label = "exec(background)" } return checkResult(result, label) } func isWindowsComputer(computer infra.Computer) bool { return strings.EqualFold(computer.ComputerInfo().System.OS, "windows") } // checkResult inspects ExecResult for errors. func checkResult(result *infra.ExecResult, label string) error { if result.Error != "" { return fmt.Errorf("%s: %s", label, result.Error) } if result.ExitCode != 0 { stderr := result.Stderr if len(stderr) > 200 { stderr = stderr[:200] + "..." } return fmt.Errorf("%s: exit %d: %s", label, result.ExitCode, stderr) } return nil } ================================================ FILE: agent/sandbox/v2/prepare_test.go ================================================ package sandboxv2_test import ( "context" "fmt" "strings" "testing" "time" sandboxv2 "github.com/yaoapp/yao/agent/sandbox/v2" "github.com/yaoapp/yao/agent/sandbox/v2/types" ) // --------------------------------------------------------------------------- // Box tests (local + remote) // --------------------------------------------------------------------------- func TestRunPrepareSteps_Exec(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) box := createBox(t, m, nc) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() steps := []types.PrepareStep{ {Action: "exec", Cmd: "echo hello > /tmp/prep-test"}, {Action: "exec", Cmd: "echo world >> /tmp/prep-test"}, } err := sandboxv2.RunPrepareSteps(ctx, steps, box, "test-assistant", "", "") if err != nil { t.Fatalf("RunPrepareSteps: %v", err) } result, err := box.Exec(ctx, []string{"cat", "/tmp/prep-test"}) if err != nil { t.Fatalf("cat: %v", err) } got := strings.TrimSpace(result.Stdout) if got != "hello\nworld" { t.Errorf("content = %q, want %q", got, "hello\nworld") } }) } } func TestRunPrepareSteps_File(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) box := createBox(t, m, nc) wsID := fmt.Sprintf("test-file-%d", time.Now().UnixNano()) box.BindWorkplace(wsID) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() steps := []types.PrepareStep{ {Action: "file", Path: "config/test.txt", Content: []byte("file-content-v2")}, } err := sandboxv2.RunPrepareSteps(ctx, steps, box, "test-assistant", "", "") if err != nil { t.Fatalf("RunPrepareSteps: %v", err) } ws := box.Workplace() data, err := ws.ReadFile("config/test.txt") if err != nil { t.Fatalf("ReadFile: %v", err) } if string(data) != "file-content-v2" { t.Errorf("content = %q, want %q", string(data), "file-content-v2") } }) } } func TestRunPrepareSteps_Copy(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) box := createBox(t, m, nc) wsID := fmt.Sprintf("test-copy-%d", time.Now().UnixNano()) box.BindWorkplace(wsID) ws := box.Workplace() ws.WriteFile("src.txt", []byte("copy-src"), 0644) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() steps := []types.PrepareStep{ {Action: "copy", Src: "src.txt", Dst: "dst.txt"}, } err := sandboxv2.RunPrepareSteps(ctx, steps, box, "test-assistant", "", "") if err != nil { t.Fatalf("RunPrepareSteps: %v", err) } data, err := ws.ReadFile("dst.txt") if err != nil { t.Fatalf("ReadFile: %v", err) } if string(data) != "copy-src" { t.Errorf("content = %q, want %q", string(data), "copy-src") } }) } } func TestRunPrepareSteps_OnceMarker(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) box := createBox(t, m, nc) wsID := fmt.Sprintf("test-once-%d", time.Now().UnixNano()) box.BindWorkplace(wsID) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() counter := "/tmp/once-counter" steps := []types.PrepareStep{ {Action: "exec", Cmd: "echo -n x >> " + counter, Once: true}, } hash := "abc123" assistantID := "test-once" if err := sandboxv2.RunPrepareSteps(ctx, steps, box, assistantID, hash, ""); err != nil { t.Fatalf("first run: %v", err) } r1, _ := box.Exec(ctx, []string{"cat", counter}) if r1.Stdout != "x" { t.Fatalf("first run: got %q, want %q", r1.Stdout, "x") } if err := sandboxv2.RunPrepareSteps(ctx, steps, box, assistantID, hash, ""); err != nil { t.Fatalf("second run: %v", err) } r2, _ := box.Exec(ctx, []string{"cat", counter}) if r2.Stdout != "x" { t.Errorf("second run: got %q, want %q (once step should be skipped)", r2.Stdout, "x") } if err := sandboxv2.RunPrepareSteps(ctx, steps, box, assistantID, "new-hash", ""); err != nil { t.Fatalf("third run: %v", err) } r3, _ := box.Exec(ctx, []string{"cat", counter}) if r3.Stdout != "xx" { t.Errorf("third run: got %q, want %q (hash changed, should re-execute)", r3.Stdout, "xx") } }) } } func TestRunPrepareSteps_OnceIsolation(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) box := createBox(t, m, nc) wsID := fmt.Sprintf("test-iso-%d", time.Now().UnixNano()) box.BindWorkplace(wsID) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() stepsA := []types.PrepareStep{ {Action: "exec", Cmd: "echo -n A >> /tmp/iso-a", Once: true}, } stepsB := []types.PrepareStep{ {Action: "exec", Cmd: "echo -n B >> /tmp/iso-b", Once: true}, } hash := "same-hash" if err := sandboxv2.RunPrepareSteps(ctx, stepsA, box, "assistant-a", hash, ""); err != nil { t.Fatalf("assistant-a: %v", err) } if err := sandboxv2.RunPrepareSteps(ctx, stepsB, box, "assistant-b", hash, ""); err != nil { t.Fatalf("assistant-b: %v", err) } rA, _ := box.Exec(ctx, []string{"cat", "/tmp/iso-a"}) rB, _ := box.Exec(ctx, []string{"cat", "/tmp/iso-b"}) if rA.Stdout != "A" { t.Errorf("assistant-a: got %q, want %q", rA.Stdout, "A") } if rB.Stdout != "B" { t.Errorf("assistant-b: got %q, want %q", rB.Stdout, "B") } if err := sandboxv2.RunPrepareSteps(ctx, stepsA, box, "assistant-a", hash, ""); err != nil { t.Fatalf("assistant-a re-run: %v", err) } if err := sandboxv2.RunPrepareSteps(ctx, stepsB, box, "assistant-b", hash, ""); err != nil { t.Fatalf("assistant-b re-run: %v", err) } rA2, _ := box.Exec(ctx, []string{"cat", "/tmp/iso-a"}) rB2, _ := box.Exec(ctx, []string{"cat", "/tmp/iso-b"}) if rA2.Stdout != "A" { t.Errorf("assistant-a re-run: got %q, want %q (should be skipped)", rA2.Stdout, "A") } if rB2.Stdout != "B" { t.Errorf("assistant-b re-run: got %q, want %q (should be skipped)", rB2.Stdout, "B") } }) } } func TestRunPrepareSteps_IgnoreError(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) box := createBox(t, m, nc) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() steps := []types.PrepareStep{ {Action: "exec", Cmd: "false", IgnoreError: true}, {Action: "exec", Cmd: "echo survived > /tmp/survived"}, } err := sandboxv2.RunPrepareSteps(ctx, steps, box, "test-assistant", "", "") if err != nil { t.Fatalf("RunPrepareSteps: %v (ignore_error should have prevented failure)", err) } result, _ := box.Exec(ctx, []string{"cat", "/tmp/survived"}) if strings.TrimSpace(result.Stdout) != "survived" { t.Errorf("second step should have executed, got %q", result.Stdout) } }) } } func TestRunPrepareSteps_FailOnError(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) box := createBox(t, m, nc) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() steps := []types.PrepareStep{ {Action: "exec", Cmd: "false"}, {Action: "exec", Cmd: "echo should-not-reach > /tmp/unreachable"}, } err := sandboxv2.RunPrepareSteps(ctx, steps, box, "test-assistant", "", "") if err == nil { t.Fatal("expected error from failing step without ignore_error") } result, _ := box.Exec(ctx, []string{"cat", "/tmp/unreachable"}) if result.ExitCode == 0 { t.Error("second step should not have executed") } }) } } func TestRunPrepareSteps_UnknownAction(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) _ = createBox(t, m, nc) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() steps := []types.PrepareStep{ {Action: "unknown_action"}, } err := sandboxv2.RunPrepareSteps(ctx, steps, nil, "test-assistant", "", "") if err == nil { t.Fatal("expected error for unknown action") } if !strings.Contains(err.Error(), "unknown_action") { t.Errorf("error should mention action name, got: %v", err) } }) } } func TestRunPrepareSteps_EmptySteps(t *testing.T) { err := sandboxv2.RunPrepareSteps(context.Background(), nil, nil, "test-assistant", "hash", "") if err != nil { t.Fatalf("empty steps should succeed: %v", err) } } func TestRunPrepareSteps_Background(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) box := createBox(t, m, nc) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() steps := []types.PrepareStep{ {Action: "exec", Cmd: "sleep 30", Background: true}, {Action: "exec", Cmd: "echo after-bg > /tmp/after-bg"}, } err := sandboxv2.RunPrepareSteps(ctx, steps, box, "test-assistant", "", "") if err != nil { t.Fatalf("RunPrepareSteps: %v", err) } result, _ := box.Exec(ctx, []string{"cat", "/tmp/after-bg"}) if strings.TrimSpace(result.Stdout) != "after-bg" { t.Errorf("background step blocked execution, got %q", result.Stdout) } }) } } func TestRunPrepareSteps_MixedActions(t *testing.T) { skipIfNoDocker(t) for _, nc := range boxNodes() { nc := nc t.Run(nc.Name, func(t *testing.T) { m := setupManager(t, &nc) box := createBox(t, m, nc) wsID := fmt.Sprintf("test-mixed-%d", time.Now().UnixNano()) box.BindWorkplace(wsID) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() steps := []types.PrepareStep{ {Action: "file", Path: "mixed.conf", Content: []byte("key=value")}, {Action: "exec", Cmd: "echo exec-ok > /tmp/mixed-exec"}, {Action: "copy", Src: "mixed.conf", Dst: "mixed-copy.conf"}, } err := sandboxv2.RunPrepareSteps(ctx, steps, box, "test-assistant", "", "") if err != nil { t.Fatalf("RunPrepareSteps: %v", err) } ws := box.Workplace() data, err := ws.ReadFile("mixed-copy.conf") if err != nil { t.Fatalf("ReadFile mixed-copy.conf: %v", err) } if string(data) != "key=value" { t.Errorf("copy result: got %q, want %q", string(data), "key=value") } result, _ := box.Exec(ctx, []string{"cat", "/tmp/mixed-exec"}) if strings.TrimSpace(result.Stdout) != "exec-ok" { t.Errorf("exec result: got %q, want %q", result.Stdout, "exec-ok") } }) } } // --------------------------------------------------------------------------- // HostExec tests // --------------------------------------------------------------------------- func TestRunPrepareSteps_HostExec(t *testing.T) { skipIfNoHostExec(t) for _, tgt := range hostTargets() { tgt := tgt t.Run(tgt.Name, func(t *testing.T) { m := setupHostManager(t, &tgt) host := createHost(t, m, tgt) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() t.Logf("SystemInfo: OS=%q Shell=%q TempDir=%q", host.ComputerInfo().System.OS, host.ComputerInfo().System.Shell, host.ComputerInfo().System.TempDir) isWin := tgt.Name == "win-native" var cmd string if isWin { cmd = `Write-Output 'host-ok'` } else { cmd = "echo host-ok" } steps := []types.PrepareStep{ {Action: "exec", Cmd: cmd}, } err := sandboxv2.RunPrepareSteps(ctx, steps, host, "test-host", "", "") if err != nil { t.Fatalf("RunPrepareSteps on host: %v", err) } }) } } func TestRunPrepareSteps_HostExecFile(t *testing.T) { skipIfNoHostExec(t) for _, tgt := range hostTargets() { tgt := tgt t.Run(tgt.Name, func(t *testing.T) { m := setupHostManager(t, &tgt) host := createHost(t, m, tgt) wsID := fmt.Sprintf("test-hostfile-%d", time.Now().UnixNano()) host.BindWorkplace(wsID) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() steps := []types.PrepareStep{ {Action: "file", Path: "host-test.txt", Content: []byte("host-file-data")}, } err := sandboxv2.RunPrepareSteps(ctx, steps, host, "test-host", "", "") if err != nil { t.Fatalf("RunPrepareSteps file: %v", err) } ws := host.Workplace() data, err := ws.ReadFile("host-test.txt") if err != nil { t.Fatalf("ReadFile: %v", err) } if string(data) != "host-file-data" { t.Errorf("content = %q, want %q", string(data), "host-file-data") } }) } } func TestRunPrepareSteps_HostExecCopy(t *testing.T) { skipIfNoHostExec(t) for _, tgt := range hostTargets() { tgt := tgt t.Run(tgt.Name, func(t *testing.T) { m := setupHostManager(t, &tgt) host := createHost(t, m, tgt) wsID := fmt.Sprintf("test-hostcopy-%d", time.Now().UnixNano()) host.BindWorkplace(wsID) ws := host.Workplace() ws.WriteFile("copy-src.txt", []byte("copy-data"), 0644) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() steps := []types.PrepareStep{ {Action: "copy", Src: "copy-src.txt", Dst: "copy-dst.txt"}, } err := sandboxv2.RunPrepareSteps(ctx, steps, host, "test-host", "", "") if err != nil { t.Fatalf("RunPrepareSteps copy: %v", err) } data, err := ws.ReadFile("copy-dst.txt") if err != nil { t.Fatalf("ReadFile: %v", err) } if string(data) != "copy-data" { t.Errorf("content = %q, want %q", string(data), "copy-data") } }) } } func TestRunPrepareSteps_HostExecOnce(t *testing.T) { skipIfNoHostExec(t) for _, tgt := range hostTargets() { tgt := tgt t.Run(tgt.Name, func(t *testing.T) { m := setupHostManager(t, &tgt) host := createHost(t, m, tgt) wsID := fmt.Sprintf("test-hostonce-%d", time.Now().UnixNano()) host.BindWorkplace(wsID) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() isWin := tgt.Name == "win-native" var cmd string if isWin { cmd = `Write-Output 'once-ok'` } else { cmd = "echo once-ok" } steps := []types.PrepareStep{ {Action: "exec", Cmd: cmd, Once: true}, } hash := "host-once-hash" aid := "host-once-aid" if err := sandboxv2.RunPrepareSteps(ctx, steps, host, aid, hash, ""); err != nil { t.Fatalf("first run: %v", err) } ws := host.Workplace() markerData, err := ws.ReadFile(".yao/prepare/" + aid + "/done") if err != nil { t.Fatalf("marker not written: %v", err) } if string(markerData) != hash { t.Errorf("marker = %q, want %q", string(markerData), hash) } }) } } ================================================ FILE: agent/sandbox/v2/runner.go ================================================ package sandboxv2 import ( "fmt" "sync" "github.com/yaoapp/yao/agent/sandbox/v2/types" ) var ( mu sync.RWMutex runners = map[string]func() types.Runner{} ) // Register adds a runner factory to the global registry. // Typically called from init() in the runner's package. func Register(name string, factory func() types.Runner) { mu.Lock() defer mu.Unlock() runners[name] = factory } // Get creates a new Runner instance from the registry. func Get(name string) (types.Runner, error) { mu.RLock() defer mu.RUnlock() factory, ok := runners[name] if !ok { return nil, fmt.Errorf("sandbox runner %q not registered", name) } return factory(), nil } ================================================ FILE: agent/sandbox/v2/shell.go ================================================ package sandboxv2 import ( "strings" infra "github.com/yaoapp/yao/sandbox/v2" ) // shellKind identifies which shell to use for command execution. type shellKind int const ( shellSh shellKind = iota // Unix: sh -c shellPwsh // Windows: pwsh -NoProfile -Command shellPS // Windows: powershell -NoProfile -Command shellCmd // Windows: cmd.exe /C (last-resort fallback) ) // shellWrap returns the Exec command slice to run a script string. func shellWrap(kind shellKind, script string) []string { switch kind { case shellPwsh: return []string{"pwsh", "-NoProfile", "-Command", script} case shellPS: return []string{"powershell", "-NoProfile", "-Command", script} case shellCmd: return []string{"cmd.exe", "/C", script} default: return []string{"sh", "-c", script} } } // shellFromSystem resolves shellKind from ComputerInfo().System.Shell // reported by the Tai node at registration time. func shellFromSystem(computer infra.Computer) shellKind { shell := strings.ToLower(computer.ComputerInfo().System.Shell) switch shell { case "pwsh": return shellPwsh case "powershell": return shellPS case "cmd.exe", "cmd": return shellCmd default: return shellSh } } ================================================ FILE: agent/sandbox/v2/stream.go ================================================ package sandboxv2 import ( "context" "errors" "fmt" "log" "time" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/sandbox/v2/types" infra "github.com/yaoapp/yao/sandbox/v2" ) // ExecuteRequest consolidates all parameters for ExecuteSandboxStream. type ExecuteRequest struct { Computer infra.Computer Runner types.Runner Config *types.SandboxConfig StreamReq *types.StreamRequest Manager *infra.Manager LoadingMsgID string } // ExecuteSandboxStream is the V2 replacement for executeSandboxStream. // It calls runner.Stream, handles interrupts, and performs cleanup/lifecycle // in defer. func ExecuteSandboxStream( ctx *agentContext.Context, req *ExecuteRequest, handler message.StreamFunc, ) (*agentContext.CompletionResponse, error) { if req.Runner == nil || req.Computer == nil { return nil, fmt.Errorf("runner and computer are required") } stdCtx := ctx.Context panicked := true // Assume panic; set false on normal exit. // Resolve stop timeout from config (default 2s). stopTimeout := 2 * time.Second if req.Config != nil && req.Config.StopTimeout != "" { if d, err := time.ParseDuration(req.Config.StopTimeout); err == nil { stopTimeout = d } } // Panic recovery (registered first, executes last in LIFO order). defer func() { if r := recover(); r != nil { log.Printf("[sandbox/v2] panic in stream: %v", r) cleanCtx, cancel := context.WithTimeout(context.Background(), stopTimeout) defer cancel() req.Runner.Cleanup(cleanCtx, req.Computer) LifecycleAction(cleanCtx, req.Config, req.Computer, req.Manager) } }() // Lifecycle action (registered second, executes second-to-last). defer func() { if !panicked { LifecycleAction(stdCtx, req.Config, req.Computer, req.Manager) } }() // Runner cleanup (registered last, executes first). defer func() { if !panicked { cleanCtx, cancel := context.WithTimeout(context.Background(), stopTimeout) defer cancel() req.Runner.Cleanup(cleanCtx, req.Computer) } }() // Build a cancellable runnerCtx that bridges agentContext interrupts. runnerCtx, cancelRunner := context.WithCancel(stdCtx) defer cancelRunner() // Prevent goroutine leak. done := make(chan struct{}) defer close(done) go func() { ticker := time.NewTicker(500 * time.Millisecond) defer ticker.Stop() for { select { case <-done: return case <-ticker.C: if ctx.Interrupt != nil { if sig := ctx.Interrupt.Peek(); sig != nil { cancelRunner() return } if ctx.Interrupt.IsInterrupted() { cancelRunner() return } } case <-stdCtx.Done(): cancelRunner() return } } }() if req.LoadingMsgID != "" { waitMsg := &message.Message{ MessageID: req.LoadingMsgID, Delta: true, DeltaAction: message.DeltaReplace, Type: message.TypeLoading, Props: map[string]any{ "message": i18n.T(ctx.Locale, "sandbox.waiting_response"), }, } ctx.Send(waitMsg) } var textContent []byte loadingClosed := false wrappedHandler := func(chunkType message.StreamChunkType, data []byte) int { if !loadingClosed && req.LoadingMsgID != "" { if chunkType == message.ChunkText || chunkType == message.ChunkToolCall || chunkType == message.ChunkMessageStart { closeLoading(ctx, req.LoadingMsgID) loadingClosed = true } } if chunkType == message.ChunkText { textContent = append(textContent, data...) } if handler != nil { return handler(chunkType, data) } return 0 } err := req.Runner.Stream(runnerCtx, req.StreamReq, wrappedHandler) if !loadingClosed && req.LoadingMsgID != "" { closeLoading(ctx, req.LoadingMsgID) } panicked = false // Normal exit reached. if err != nil { if errors.Is(err, context.Canceled) { return nil, err } return nil, fmt.Errorf("runner.Stream: %w", err) } resp := &agentContext.CompletionResponse{ Role: "assistant", FinishReason: agentContext.FinishReasonStop, } if len(textContent) > 0 { resp.Content = string(textContent) } return resp, nil } func closeLoading(ctx *agentContext.Context, loadingMsgID string) { if loadingMsgID == "" || ctx == nil { return } msg := &message.Message{ MessageID: loadingMsgID, Delta: true, DeltaAction: message.DeltaReplace, Type: message.TypeLoading, Props: map[string]any{ "done": true, "message": "", }, } ctx.Send(msg) } ================================================ FILE: agent/sandbox/v2/testutils/testutils.go ================================================ package testutils import ( "context" "os" "path/filepath" "testing" agenttestutils "github.com/yaoapp/yao/agent/testutils" "github.com/yaoapp/yao/config" sandboxv2 "github.com/yaoapp/yao/sandbox/v2" "github.com/yaoapp/yao/tai" "github.com/yaoapp/yao/tai/registry" ) // Prepare initializes the full environment required for sandbox V2 E2E tests: // - agent layer (assistants, LLM, caller) // - tai registry + local node // - sandbox V2 manager func Prepare(t *testing.T) { t.Helper() agenttestutils.Prepare(t) if registry.Global() == nil { registry.Init(nil) } dataDir := filepath.Join(config.Conf.DataRoot, "workspaces") os.MkdirAll(dataDir, 0755) tai.RegisterLocal(tai.WithDataDir(dataDir)) sandboxv2.Init() if err := sandboxv2.M().Start(context.Background()); err != nil { t.Fatalf("sandbox v2 manager start: %v", err) } t.Cleanup(func() { sandboxv2.M().Close() }) } // Clean tears down the test environment. func Clean(t *testing.T) { t.Helper() agenttestutils.Clean(t) } ================================================ FILE: agent/sandbox/v2/testutils_remote_test.go ================================================ //go:build remote package sandboxv2_test import "os" func init() { extraNodeProviders = append(extraNodeProviders, agentRemoteNodes) } func agentRemoteNodes() []nodeConfig { addr := os.Getenv("SANDBOX_TEST_REMOTE_ADDR") if addr == "" { return nil } return []nodeConfig{{Name: "remote", Addr: addr}} } ================================================ FILE: agent/sandbox/v2/testutils_test.go ================================================ package sandboxv2_test import ( "context" "fmt" "log" "os" "strconv" "strings" "testing" "time" sandbox "github.com/yaoapp/yao/sandbox/v2" "github.com/yaoapp/yao/tai" "github.com/yaoapp/yao/tai/registry" tairuntime "github.com/yaoapp/yao/tai/runtime" "github.com/yaoapp/yao/workspace" ) // --------------------------------------------------------------------------- // Build-tag extension points (same pattern as sandbox/v2). // --------------------------------------------------------------------------- var ( extraNodeProviders []func() []nodeConfig extraHostExecProviders []func() []hostTarget ) // --------------------------------------------------------------------------- // Node / host configuration // --------------------------------------------------------------------------- type nodeConfig struct { Name string Addr string TaiID string DialOps []tai.DialOption } type hostTarget struct { Name string Addr string TaiID string } // --------------------------------------------------------------------------- // Environment helpers // --------------------------------------------------------------------------- func testLocalAddr() string { if addr := os.Getenv("SANDBOX_TEST_LOCAL_ADDR"); addr != "" { return addr } return "local" } func testImage() string { if img := os.Getenv("SANDBOX_TEST_IMAGE"); img != "" { return img } return "alpine:latest" } func envPort(key string, fallback int) int { if v := os.Getenv(key); v != "" { if p, err := strconv.Atoi(v); err == nil { return p } } return fallback } // --------------------------------------------------------------------------- // Node / host discovery // --------------------------------------------------------------------------- func boxNodes() []nodeConfig { nodes := []nodeConfig{ {Name: "local", Addr: testLocalAddr()}, } for _, fn := range extraNodeProviders { nodes = append(nodes, fn()...) } return nodes } func hostTargets() []hostTarget { var targets []hostTarget for _, fn := range extraHostExecProviders { targets = append(targets, fn()...) } return targets } // --------------------------------------------------------------------------- // Dial + Register helper (replaces old tai.New) // --------------------------------------------------------------------------- func dialForTest(addr string, dialOps ...tai.DialOption) (*tai.ConnResources, error) { if addr == "local" || addr == "" { return tai.DialLocal("", "", nil) } host, grpcPort := parseHostPort(addr) ports := tai.Ports{GRPC: grpcPort} return tai.DialRemote(host, ports, dialOps...) } func registerForTest(t testing.TB, addr string, dialOps ...tai.DialOption) (string, *tai.ConnResources) { t.Helper() if registry.Global() == nil { registry.Init(nil) } res, err := dialForTest(addr, dialOps...) if err != nil { t.Fatalf("dialForTest(%s): %v", addr, err) } taiID := taiIDFromAddr(addr) reg := registry.Global() reg.Register(®istry.TaiNode{TaiID: taiID, Mode: modeForAddr(addr)}) reg.SetResources(taiID, res) return taiID, res } func taiIDFromAddr(addr string) string { if addr == "local" || addr == "" { return "local" } addr = strings.TrimPrefix(addr, "tai://") parts := strings.SplitN(addr, ":", 2) return parts[0] } func modeForAddr(addr string) string { if addr == "local" || addr == "" { return "local" } return "direct" } func parseHostPort(addr string) (string, int) { addr = strings.TrimPrefix(addr, "tai://") parts := strings.SplitN(addr, ":", 2) h := parts[0] if len(parts) == 2 { if p, err := strconv.Atoi(parts[1]); err == nil { return h, p } } return h, 19100 } // --------------------------------------------------------------------------- // TestMain — purge stale containers from previous runs // --------------------------------------------------------------------------- func TestMain(m *testing.M) { purgeStale() os.Exit(m.Run()) } func purgeStale() { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() for _, nc := range boxNodes() { res, err := dialForTest(nc.Addr, nc.DialOps...) if err != nil { continue } sb := res.Runtime if sb == nil { res.Close() continue } containers, _ := sb.List(ctx, tairuntime.ListOptions{All: true}) for _, c := range containers { id := c.Name if id == "" { id = c.ID } if strings.HasPrefix(id, "sb-prep-") || strings.HasPrefix(id, "sb-lc-") { sb.Remove(ctx, id, true) log.Printf("[purge] %s: removed %s", nc.Name, id) } } res.Close() } } // --------------------------------------------------------------------------- // Manager + Box helpers // --------------------------------------------------------------------------- func setupManager(t *testing.T, nc *nodeConfig) *sandbox.Manager { t.Helper() if registry.Global() == nil { registry.Init(nil) } taiID, res := registerForTest(t, nc.Addr, nc.DialOps...) nc.TaiID = taiID t.Cleanup(func() { res.Close() }) sandbox.Init() m := sandbox.M() t.Cleanup(func() { m.Close() }) return m } func createBox(t *testing.T, m *sandbox.Manager, nc nodeConfig) *sandbox.Box { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() if err := m.EnsureImage(ctx, nc.TaiID, testImage(), sandbox.ImagePullOptions{}); err != nil { t.Fatalf("EnsureImage: %v", err) } box, err := m.Create(ctx, sandbox.CreateOptions{ ID: fmt.Sprintf("sb-prep-%d", time.Now().UnixNano()), Image: testImage(), Owner: "test-prepare", NodeID: nc.TaiID, }) if err != nil { t.Fatalf("Create: %v", err) } t.Cleanup(func() { cCtx, cCancel := context.WithTimeout(context.Background(), 15*time.Second) defer cCancel() if err := m.Remove(cCtx, box.ID()); err != nil { t.Logf("cleanup Remove(%s): %v", box.ID(), err) } }) return box } func createHost(t *testing.T, m *sandbox.Manager, tgt hostTarget) *sandbox.Host { t.Helper() host, err := m.Host(context.Background(), tgt.TaiID) if err != nil { t.Skipf("Host(%s): %v", tgt.Name, err) } return host } func setupHostManager(t *testing.T, tgt *hostTarget) *sandbox.Manager { t.Helper() nc := nodeConfig{Name: tgt.Name, Addr: fmt.Sprintf("tai://%s", tgt.Addr)} m := setupManager(t, &nc) tgt.TaiID = nc.TaiID return m } // --------------------------------------------------------------------------- // Skip helpers // --------------------------------------------------------------------------- func skipIfNoDocker(t *testing.T) { t.Helper() if testLocalAddr() == "" { t.Skip("SANDBOX_TEST_LOCAL_ADDR not set") } } func skipIfNoHostExec(t *testing.T) { t.Helper() if len(hostTargets()) == 0 { t.Skip("no HostExec targets configured") } } func createTestWorkspace(t *testing.T, taiID, wsID string) { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() _, err := workspace.M().Create(ctx, workspace.CreateOptions{ ID: wsID, Owner: "test", Node: taiID, }) if err != nil && !strings.Contains(err.Error(), "exists") { t.Fatalf("create workspace %q: %v", wsID, err) } t.Cleanup(func() { cCtx, cCancel := context.WithTimeout(context.Background(), 10*time.Second) defer cCancel() workspace.M().Delete(cCtx, wsID, true) }) } ================================================ FILE: agent/sandbox/v2/testutils_wintest_test.go ================================================ //go:build wintest package sandboxv2_test import "os" func init() { extraHostExecProviders = append(extraHostExecProviders, agentWinHostExec) } func agentWinHostExec() []hostTarget { var targets []hostTarget if addr := os.Getenv("TAI_TEST_WIN_HOSTEXEC_LINUX"); addr != "" { targets = append(targets, hostTarget{Name: "win-linux", Addr: addr}) } if addr := os.Getenv("TAI_TEST_WIN_HOSTEXEC_NATIVE"); addr != "" { targets = append(targets, hostTarget{Name: "win-native", Addr: addr}) } return targets } ================================================ FILE: agent/sandbox/v2/token.go ================================================ package sandboxv2 import ( "fmt" "time" lrustore "github.com/yaoapp/gou/store/lru" "github.com/yaoapp/yao/agent/sandbox/v2/types" "github.com/yaoapp/yao/openapi/oauth" ) const ( accessTokenTTL = 2 * time.Hour refreshTokenTTL = 30 * 24 * time.Hour // 30 days tokenCacheSize = 1024 ) var tokenCache *lrustore.Cache func init() { c, err := lrustore.New(tokenCacheSize) if err != nil { panic("sandbox token cache init failed: " + err.Error()) } tokenCache = c } func cacheKey(teamID, userID string) string { if teamID == "" { return userID } return teamID + "/" + userID } func getToken(teamID, userID string) *types.SandboxToken { val, ok := tokenCache.Get(cacheKey(teamID, userID)) if !ok { return nil } tok, _ := val.(*types.SandboxToken) return tok } func setToken(teamID, userID string, tok *types.SandboxToken, ttl time.Duration) { tokenCache.Set(cacheKey(teamID, userID), tok, ttl) } // IssueSandboxToken returns a valid identity token for the given user. // Tokens are cached by (teamID, userID); a new token is only issued on // cache miss or expiry. Returns nil without error when oauth.OAuth is nil. func IssueSandboxToken(teamID, userID string) (*types.SandboxToken, error) { if tok := getToken(teamID, userID); tok != nil { return tok, nil } svc := oauth.OAuth if svc == nil { return nil, nil } subject, err := svc.Subject("__yao.sandbox", userID) if err != nil { return nil, fmt.Errorf("sandbox token: derive subject: %w", err) } extraClaims := map[string]interface{}{ "user_id": userID, } if teamID != "" { extraClaims["team_id"] = teamID } tokenStr, err := svc.MakeAccessToken("__yao.sandbox", "grpc:mcp", subject, int(accessTokenTTL.Seconds()), extraClaims) if err != nil { return nil, fmt.Errorf("sandbox token: issue access token: %w", err) } tok := &types.SandboxToken{Token: tokenStr} refreshStr, err := svc.MakeRefreshToken("__yao.sandbox", "grpc:mcp", subject, int(refreshTokenTTL.Seconds()), extraClaims) if err != nil { return nil, fmt.Errorf("sandbox token: issue refresh token: %w", err) } tok.RefreshToken = refreshStr setToken(teamID, userID, tok, accessTokenTTL) return tok, nil } ================================================ FILE: agent/sandbox/v2/types/config.go ================================================ package types import ( "encoding/json" "fmt" ) const ( SandboxVersionV1 = "1.0" SandboxVersionV2 = "2.0" ) // SandboxConfig is the V2 sandbox configuration loaded from sandbox.yao or // the package.yao "sandbox" block when version == "2.0". type SandboxConfig struct { Version string `json:"version" yaml:"version"` Computer ComputerConfig `json:"computer" yaml:"computer"` Runner RunnerConfig `json:"runner" yaml:"runner"` Lifecycle string `json:"lifecycle,omitempty" yaml:"lifecycle,omitempty"` IdleTimeout string `json:"idle_timeout,omitempty" yaml:"idle_timeout,omitempty"` MaxLifetime string `json:"max_lifetime,omitempty" yaml:"max_lifetime,omitempty"` StopTimeout string `json:"stop_timeout,omitempty" yaml:"stop_timeout,omitempty"` Prepare []PrepareStep `json:"prepare,omitempty" yaml:"prepare,omitempty"` Environment map[string]string `json:"environment,omitempty" yaml:"environment,omitempty"` Secrets map[string]string `json:"secrets,omitempty" yaml:"secrets,omitempty"` Filter *ComputerFilter `json:"filter,omitempty" yaml:"filter,omitempty"` // Populated by the framework at runtime (never serialized). Owner string `json:"-" yaml:"-"` ID string `json:"-" yaml:"-"` Labels map[string]string `json:"-" yaml:"-"` NodeID string `json:"-" yaml:"-"` Kind string `json:"-" yaml:"-"` WorkspaceID string `json:"-" yaml:"-"` DisplayName string `json:"-" yaml:"-"` } // ComputerFilter defines the query parameters for GET /computer/options. // Declared in DSL sandbox.filter; frontend passes it through to the API. type ComputerFilter struct { Kind string `json:"kind,omitempty" yaml:"kind,omitempty"` Image string `json:"image,omitempty" yaml:"image,omitempty"` VNC *bool `json:"vnc,omitempty" yaml:"vnc,omitempty"` OS string `json:"os,omitempty" yaml:"os,omitempty"` Arch string `json:"arch,omitempty" yaml:"arch,omitempty"` MinCPUs float64 `json:"min_cpus,omitempty" yaml:"min_cpus,omitempty"` MinMem string `json:"min_mem,omitempty" yaml:"min_mem,omitempty"` Labels map[string]string `json:"labels,omitempty" yaml:"labels,omitempty"` } // ComputerConfig describes the execution environment (container or host). type ComputerConfig struct { Image string `json:"image,omitempty" yaml:"image,omitempty"` VNC VNCConfig `json:"vnc,omitempty" yaml:"vnc,omitempty"` Memory string `json:"memory,omitempty" yaml:"memory,omitempty"` CPUs float64 `json:"cpus,omitempty" yaml:"cpus,omitempty"` Ports PortList `json:"ports,omitempty" yaml:"ports,omitempty"` User string `json:"user,omitempty" yaml:"user,omitempty"` WorkDir string `json:"work_dir,omitempty" yaml:"work_dir,omitempty"` MountPath string `json:"mount_path,omitempty" yaml:"mount_path,omitempty"` MountMode string `json:"mount_mode,omitempty" yaml:"mount_mode,omitempty"` } // RunnerConfig identifies which Runner to use and how. type RunnerConfig struct { Name string `json:"name" yaml:"name"` Mode string `json:"mode,omitempty" yaml:"mode,omitempty"` Options map[string]any `json:"options,omitempty" yaml:"options,omitempty"` } // PrepareStep is a single action executed during Runner.Prepare. type PrepareStep struct { Action string `json:"action" yaml:"action"` Once bool `json:"once,omitempty" yaml:"once,omitempty"` IgnoreError bool `json:"ignore_error,omitempty" yaml:"ignore_error,omitempty"` // action=copy Src string `json:"src,omitempty" yaml:"src,omitempty"` Dst string `json:"dst,omitempty" yaml:"dst,omitempty"` // action=exec Cmd string `json:"cmd,omitempty" yaml:"cmd,omitempty"` Background bool `json:"background,omitempty" yaml:"background,omitempty"` // action=file (internal use by Runner.Prepare) Path string `json:"path,omitempty" yaml:"path,omitempty"` Content []byte `json:"-" yaml:"-"` // action=process (reserved) Name string `json:"name,omitempty" yaml:"name,omitempty"` Args []any `json:"args,omitempty" yaml:"args,omitempty"` } // --------------------------------------------------------------------------- // VNCConfig — supports both bool and object in JSON/YAML: // true → VNCConfig{Enabled: true} // {"enabled": true, "password": "xxx"} → full struct // --------------------------------------------------------------------------- type VNCConfig struct { Enabled bool `json:"enabled,omitempty" yaml:"enabled,omitempty"` ViewOnly bool `json:"view_only,omitempty" yaml:"view_only,omitempty"` Password string `json:"password,omitempty" yaml:"password,omitempty"` Resolution string `json:"resolution,omitempty" yaml:"resolution,omitempty"` } func (v *VNCConfig) UnmarshalJSON(data []byte) error { var b bool if err := json.Unmarshal(data, &b); err == nil { v.Enabled = b return nil } type alias VNCConfig var a alias if err := json.Unmarshal(data, &a); err != nil { return err } *v = VNCConfig(a) return nil } // --------------------------------------------------------------------------- // PortList — supports both int array and object array in JSON: // [3000, 8080] → []PortMapping{{Port: 3000}, {Port: 8080}} // [{"port": 3000, "host_port": 9000}] → full structs // --------------------------------------------------------------------------- type PortList []PortMapping type PortMapping struct { Port int `json:"port" yaml:"port"` HostPort int `json:"host_port,omitempty" yaml:"host_port,omitempty"` Protocol string `json:"protocol,omitempty" yaml:"protocol,omitempty"` } func (p *PortList) UnmarshalJSON(data []byte) error { var ints []int if err := json.Unmarshal(data, &ints); err == nil { out := make(PortList, len(ints)) for i, port := range ints { out[i] = PortMapping{Port: port} } *p = out return nil } var objs []PortMapping if err := json.Unmarshal(data, &objs); err != nil { return fmt.Errorf("ports: expected int array or object array: %w", err) } *p = objs return nil } ================================================ FILE: agent/sandbox/v2/types/runner.go ================================================ package types import ( "context" "github.com/yaoapp/gou/connector" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" infra "github.com/yaoapp/yao/sandbox/v2" ) // Runner is the interface that all sandbox runners must implement. // A Runner replaces the LLM invocation layer (executeLLMStream) when a // sandbox is configured. type Runner interface { Name() string Prepare(ctx context.Context, req *PrepareRequest) error Stream(ctx context.Context, req *StreamRequest, handler message.StreamFunc) error Cleanup(ctx context.Context, computer infra.Computer) error } // MCPServer mirrors store/types.MCPServerConfig to avoid a cyclic import // between this leaf package and agent/store/types. type MCPServer struct { ServerID string `json:"server_id,omitempty"` Resources []string `json:"resources,omitempty"` Tools []string `json:"tools,omitempty"` } // RunStepsFunc is the signature of RunPrepareSteps. Workspace is obtained // internally via computer.Workplace(). assistantDir is the absolute path to // the assistant source directory on the host; copy steps resolve relative src // paths against it. type RunStepsFunc func(ctx context.Context, steps []PrepareStep, computer infra.Computer, assistantID, configHash, assistantDir string) error // PrepareRequest carries everything needed by Runner.Prepare. type PrepareRequest struct { Computer infra.Computer Config *SandboxConfig Connector connector.Connector SkillsDir string AssistantDir string // absolute host path to the assistant source directory MCPServers []MCPServer ConfigHash string RunSteps RunStepsFunc } // StreamRequest carries everything needed by Runner.Stream. type StreamRequest struct { Computer infra.Computer Config *SandboxConfig Connector connector.Connector Messages []agentContext.Message SystemPrompt string ChatID string Token *SandboxToken // current user's sandbox token for MCP callbacks } ================================================ FILE: agent/sandbox/v2/types/token.go ================================================ package types // SandboxToken holds credentials for a sandbox execution session. // Expiry is managed by the LRU store TTL, not stored here. type SandboxToken struct { Token string // access token → YAO_TOKEN RefreshToken string // refresh token → YAO_REFRESH_TOKEN } ================================================ FILE: agent/sandbox/v2/yao/runner.go ================================================ package yao import ( "context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/agent/sandbox/v2/types" infra "github.com/yaoapp/yao/sandbox/v2" ) // YaoRunner is a no-op Runner for pure Hook-driven sandbox interactions. // When runner.name == "yao", the assistant relies entirely on Create/Next // hooks for logic; no external CLI is invoked. type YaoRunner struct{} func New() *YaoRunner { return &YaoRunner{} } func (r *YaoRunner) Name() string { return "yao" } // Prepare runs user-defined prepare steps (copy, exec, file) but adds // no runner-specific steps. Connector is not required. func (r *YaoRunner) Prepare(ctx context.Context, req *types.PrepareRequest) error { if req.RunSteps != nil && len(req.Config.Prepare) > 0 { return req.RunSteps(ctx, req.Config.Prepare, req.Computer, req.Config.ID, req.ConfigHash, req.AssistantDir) } return nil } // Stream is a no-op — hooks handle all interaction. Returns immediately // so the assistant framework proceeds to the Next hook. func (r *YaoRunner) Stream(_ context.Context, _ *types.StreamRequest, _ message.StreamFunc) error { return nil } // Cleanup is a no-op for the yao runner. func (r *YaoRunner) Cleanup(_ context.Context, _ infra.Computer) error { return nil } ================================================ FILE: agent/sandbox/v2/yao/runner_test.go ================================================ package yao_test import ( "context" "fmt" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/caller" agentcontext "github.com/yaoapp/yao/agent/context" sandboxtestutils "github.com/yaoapp/yao/agent/sandbox/v2/testutils" oauthtypes "github.com/yaoapp/yao/openapi/oauth/types" ) func TestSandboxV2_Yao_JSAPI(t *testing.T) { sandboxtestutils.Prepare(t) defer sandboxtestutils.Clean(t) require.NotNil(t, caller.AgentGetterFunc, "AgentGetterFunc should be registered after Prepare") agent, err := caller.AgentGetterFunc("tests.sandbox-v2.jsapi-v2") require.NoError(t, err, "should load assistant tests.sandbox-v2.jsapi-v2") chatID := fmt.Sprintf("e2e-jsapi-%d", time.Now().UnixMilli()) ctx := agentcontext.New( context.Background(), &oauthtypes.AuthorizedInfo{ TeamID: "test-team-jsapi", UserID: "test-user-jsapi", }, chatID, ) messages := []agentcontext.Message{ {Role: "user", Content: "test jsapi"}, } done := make(chan struct{}) var resp *agentcontext.Response var streamErr error go func() { defer close(done) resp, streamErr = agent.Stream(ctx, messages) }() select { case <-done: case <-time.After(3 * time.Minute): t.Fatalf("timeout after 3m") } require.NoError(t, streamErr, "Stream should not return error") require.NotNil(t, resp, "response should not be nil") // runner=yao goes through executeLLMStream, then Next hook returns { data: results } // The Next hook result should appear in resp.Next require.NotNil(t, resp.Next, "resp.Next should not be nil (Next hook returned data)") t.Logf("resp.Next: %+v", resp.Next) nextData, ok := resp.Next.(map[string]interface{}) if !ok { t.Fatalf("resp.Next should be a map, got %T: %+v", resp.Next, resp.Next) } // The Next hook returns { data: results }, the framework unwraps .data data, hasData := nextData["data"] if hasData { nextData, ok = data.(map[string]interface{}) require.True(t, ok, "data should be a map") } t.Logf("JSAPI test results: %+v", nextData) // ── Verify ctx.computer was available ── assert.Equal(t, true, nextData["has_computer"], "ctx.computer should be available") assert.Equal(t, true, nextData["has_workspace"], "ctx.workspace should be available") // ── Verify ctx.computer.Info() ── if infoRaw, ok := nextData["computer_info"]; ok { info, ok := infoRaw.(map[string]interface{}) require.True(t, ok, "computer_info should be a map") assert.NotEmpty(t, info["kind"], "computer_info.kind should not be empty") t.Logf("computer info: kind=%v os=%v", info["kind"], info["os"]) } else { assert.Nil(t, nextData["computer_info_error"], "computer.Info() should not error") } // ── Verify ctx.computer.Exec() ── assert.Equal(t, "jsapi-v2-test", nextData["exec_stdout"], "Exec should return expected stdout") assert.Nil(t, nextData["exec_error"], "Exec should not error") if exitCode, ok := nextData["exec_exit_code"]; ok { // JS numbers come back as float64 through JSON switch v := exitCode.(type) { case float64: assert.Equal(t, float64(0), v, "exit_code should be 0") case int: assert.Equal(t, 0, v, "exit_code should be 0") } } // ── Verify ctx.workspace write/read ── assert.Equal(t, true, nextData["write_read_ok"], "workspace WriteFile+ReadFile round-trip should work") assert.Equal(t, "hello from jsapi v2", nextData["read_content"], "read content should match") assert.Nil(t, nextData["write_read_error"], "write/read should not error") // ── Verify ctx.workspace MkdirAll + Exists ── assert.Equal(t, true, nextData["mkdir_exists_ok"], "MkdirAll + Exists should work") assert.Nil(t, nextData["mkdir_exists_error"], "mkdir/exists should not error") // ── Verify ctx.workspace ReadDir ── assert.Nil(t, nextData["readdir_error"], "ReadDir should not error") if count, ok := nextData["readdir_count"]; ok { switch v := count.(type) { case float64: assert.Greater(t, v, float64(0), "ReadDir should return entries") } } // ── Verify ctx.workspace Stat ── assert.Equal(t, true, nextData["stat_ok"], "Stat should return correct info") assert.Nil(t, nextData["stat_error"], "Stat should not error") // ── Verify ctx.workspace Copy ── assert.Equal(t, true, nextData["copy_ok"], "Copy should work") assert.Nil(t, nextData["copy_error"], "Copy should not error") // ── Verify ctx.workspace Rename ── assert.Equal(t, true, nextData["rename_ok"], "Rename should work") assert.Nil(t, nextData["rename_error"], "Rename should not error") // ── Verify ctx.workspace Remove ── assert.Equal(t, true, nextData["remove_ok"], "Remove should work") assert.Nil(t, nextData["remove_error"], "Remove should not error") } ================================================ FILE: agent/search/DESIGN.md ================================================ # Search Module Design ## Overview The Search module provides a unified RAG (Retrieval-Augmented Generation) interface for Yao Agent, supporting three search types: | Type | Source | Use Case | | ----- | -------------- | ---------------------------------------------------- | | `web` | Internet | Real-time information, news, external knowledge | | `kb` | Knowledge Base | Documents, FAQs, internal knowledge (vector + graph) | | `db` | Database | Structured data from Yao Models (QueryDSL) | The module follows the **Handler + Registry** pattern consistent with the `content` module, and exposes JSAPI for flexible usage in Create/Next hooks. ## Key Features - **Unified JSAPI**: `ctx.search.Web()`, `ctx.search.KB()`, `ctx.search.DB()`, `ctx.search.Parallel()` - **Citation System**: Auto-generate citation IDs (`#ref:xxx`) for LLM reference - **Real-time Output**: Stream search progress to client - **Trace Integration**: Report search operations to user for transparency - **Reranking**: Builtin, Agent, or MCP-based result reranking - **Graceful Degradation**: Search errors don't block agent flow ## Quick Start ```typescript // In Create hook (assistants/my-assistant/index.ts) function Create(ctx, messages, options) { const query = messages[messages.length - 1].content; // Simple web search const result = ctx.search.Web(query, { limit: 5 }); // Or parallel search across all sources const [web, kb, db] = ctx.search.Parallel([ { type: "web", query, limit: 5 }, { type: "kb", query, collections: ["docs"] }, { type: "db", query, models: ["product"] }, ]); return { messages: [{ role: "system", content: formatContext(web, kb, db) }], uses: { search: "disabled" }, // Disable auto search since hook handled it }; } ``` ## Goals 1. **Unified Interface**: Single API for web, knowledge base, and database search 2. **Flexibility**: Support built-in handlers and external tools (MCP/Agent delegation) 3. **JSAPI Support**: Enable search calls from Create/Next hooks via JavaScript 4. **Parallel Execution**: Support concurrent web + KB + DB searches 5. **Graceful Degradation**: Search failures should not block the main agent flow 6. **Real-time Feedback**: Stream search progress and results to users via output 7. **Traceability**: Report search operations to users for transparency 8. **Citation Support**: Enable LLM to reference search results with trackable citations ## Architecture ### Search Flow Diagram ```mermaid flowchart TD A[Stream Start] --> B{Uses.Search?} B -->|disabled| C[Skip Search] B -->|builtin/agent/mcp| D{Hook Handled?} D -->|"Yes (uses.search=disabled)"| C D -->|No| E[Auto Search] E --> F{Check Assistant Config} F --> G[Web Search] F --> H[KB Search] F --> I[DB Search] G --> J[Parallel Execute] H --> J I --> J J --> K[Merge Results] K --> L[Rerank] L --> M[Generate Citations] M --> N[Inject to System Prompt] C --> O[LLM Call] N --> O O --> P[Output with Citations] ``` ### Integration in Stream() ```mermaid sequenceDiagram participant Client participant Stream participant CreateHook participant Search participant LLM participant Output Client->>Stream: Stream(ctx, messages, options) Stream->>Stream: Initialize alt Has Create Hook Stream->>CreateHook: Create(ctx, messages, options) CreateHook-->>Stream: response (may include search results) end alt Uses.Search != "disabled" AND not handled by Hook Stream->>Search: AutoSearch(ctx, messages) Search->>Search: Web/KB/DB in parallel Search->>Search: Rerank & Citations Search->>Output: search_start, search_result, search_complete Search-->>Stream: Inject search context to messages end Stream->>LLM: Execute with search context LLM->>Output: Stream response with #ref:xxx Stream-->>Client: Complete ``` ### Directory Structure ``` agent/search/ ├── DESIGN.md # This document ├── TODO.md # Implementation plan and progress ├── search.go # Main Searcher implementation and public API ├── registry.go # Handler registry (manages web/kb/db handlers) ├── jsapi.go # JavaScript API bindings for hooks (skeleton) ├── citation.go # Citation ID generation and tracking ├── reference.go # Reference building and LLM context formatting │ ├── types/ # Type definitions (no dependencies on other search packages) │ ├── types.go # Core types (SearchType, Request, Result, ResultItem, etc.) │ ├── config.go # Configuration types (Config, CitationConfig, WeightsConfig, etc.) │ ├── reference.go # Reference type for unified context protocol │ └── graph.go # Graph-related types (GraphNode) │ ├── interfaces/ # Interface definitions (depends only on types/) │ ├── handler.go # Handler interface │ ├── searcher.go # Searcher interface (public API) │ ├── reranker.go # Reranker interface │ └── nlp.go # NLP interfaces (KeywordExtractor, QueryDSLGenerator) │ ├── rerank/ # Result reranking implementations (Handler + Registry pattern) ✅ │ ├── reranker.go # Main entry point (mode dispatch) │ ├── builtin.go # Builtin: weighted score sorting │ ├── agent.go # Agent mode (delegate to LLM assistant) │ └── mcp.go # MCP mode (external service) │ ├── nlp/ # Natural language processing for search │ ├── keyword/ # Keyword extraction (Handler + Registry pattern) ✅ │ │ ├── extractor.go # Main extractor (mode dispatch) │ │ ├── builtin.go # Builtin frequency-based extraction │ │ ├── agent.go # Agent mode (LLM-powered) │ │ └── mcp.go # MCP mode (external service) │ └── querydsl/ # QueryDSL generation for DB search (TODO) │ ├── generator.go # Main generator (mode dispatch) │ ├── builtin.go # Builtin template-based generation │ ├── agent.go # Agent mode (LLM-powered) │ └── mcp.go # MCP mode (external service) │ # Note: Embedding follows KB collection config, not in this package │ ├── handlers/ # Search handler implementations │ ├── web/ # Web search ✅ │ │ ├── handler.go # Web search handler (mode dispatch) │ │ ├── tavily.go # Tavily provider (builtin) │ │ ├── serper.go # Serper provider (serper.dev, builtin) │ │ ├── serpapi.go # SerpAPI provider (serpapi.com, multi-engine, builtin) │ │ ├── agent.go # Agent mode (AI Search) │ │ └── mcp.go # MCP mode (external service) │ │ │ ├── kb/ # Knowledge base search (skeleton) │ │ ├── handler.go # KB search handler │ │ ├── vector.go # Vector similarity search (TODO) │ │ └── graph.go # Graph-based association (TODO) │ │ │ └── db/ # Database search (skeleton) │ ├── handler.go # DB search handler │ ├── query.go # QueryDSL builder (TODO) │ └── schema.go # Model schema introspection (TODO) │ └── defaults/ # Default configuration values └── defaults.go # System built-in defaults (used by agent/load.go) # Note: Output and Trace are integrated into assistant/search.go # No separate trace.go or output.go files needed ``` ### Dependency Graph ``` ┌─────────────┐ │ types/ │ ← No internal dependencies └──────┬──────┘ │ ┌──────▼──────┐ │ interfaces/ │ ← Depends only on types/ └──────┬──────┘ │ ┌─────────────────┼─────────────────┐ │ │ │ ┌─────▼─────┐ ┌──────▼──────┐ ┌──────▼──────┐ │ rerank/ │ │ nlp/ │ │ defaults/ │ └─────┬─────┘ └──────┬──────┘ └──────┬──────┘ │ │ │ └────────┬────────┴────────┬────────┘ │ │ ┌──────▼──────┐ ┌──────▼──────┐ │ handlers/ │ │ (root pkg) │ │ web/kb/db │ │ search.go │ └──────┬──────┘ │ registry │ │ │ jsapi, etc │ └────┬─────┴─────────────┘ │ ┌─────▼─────┐ │ External │ │ Packages │ └───────────┘ ``` ### Package Import Rules 1. **`types/`** - Zero internal dependencies, only stdlib and external packages 2. **`interfaces/`** - Imports only `types/` 3. **`rerank/`**, **`nlp/`**, **`defaults/`** - Import `types/` and `interfaces/` 4. **`handlers/*`** - Import `types/`, `interfaces/`, and may use `nlp/` for NL processing 5. **Root package** - Imports all sub-packages, provides public API ### Main Searcher Implementation (`search.go`) Configuration is loaded by `agent/load.go` (global) and `agent/assistant/load.go` (assistant-level), following the existing pattern. The Search package directly uses the loaded configuration. ```go package search import ( "sync" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/handlers/db" "github.com/yaoapp/yao/agent/search/handlers/kb" "github.com/yaoapp/yao/agent/search/handlers/web" "github.com/yaoapp/yao/agent/search/interfaces" "github.com/yaoapp/yao/agent/search/rerank" "github.com/yaoapp/yao/agent/search/types" ) // Searcher is the main search implementation type Searcher struct { config *types.Config // Merged config (global + assistant) handlers map[types.SearchType]interfaces.Handler reranker *rerank.Reranker // Uses rerank package directly citation *CitationGenerator } // Uses contains the search-specific uses configuration // These are extracted from context.Uses and search config type Uses struct { Search string // "builtin", "disabled", "", "mcp:." Web string // "builtin", "", "mcp:." Keyword string // "builtin", "", "mcp:." QueryDSL string // "builtin", "", "mcp:." Rerank string // "builtin", "", "mcp:." } // New creates a new Searcher instance // cfg: merged config from agent/load.go + assistant config // uses: merged uses configuration (global → assistant → hook) func New(cfg *types.Config, uses *Uses) *Searcher { return &Searcher{ config: cfg, handlers: map[types.SearchType]interfaces.Handler{ types.SearchTypeWeb: web.NewHandler(uses.Web, cfg.Web), types.SearchTypeKB: kb.NewHandler(cfg.KB), // KB always builtin types.SearchTypeDB: db.NewHandler(uses.QueryDSL, cfg.DB), }, reranker: rerank.NewReranker(uses.Rerank, cfg.Rerank), citation: NewCitationGenerator(), } } // Search executes a single search request func (s *Searcher) Search(ctx *context.Context, req *types.Request) (*types.Result, error) { handler, ok := s.handlers[req.Type] if !ok { return &types.Result{Error: "unsupported search type"}, nil } // Execute search (handler doesn't need ctx) result, err := handler.Search(req) if err != nil { return &types.Result{Error: err.Error()}, nil } // Assign weights based on source for _, item := range result.Items { item.Weight = s.config.GetWeight(req.Source) } // Rerank if requested (reranker needs ctx for Agent/MCP modes) if req.Rerank != nil && s.reranker != nil { result.Items, _ = s.reranker.Rerank(ctx, req.Query, result.Items, req.Rerank) } // Generate citation IDs for _, item := range result.Items { item.CitationID = s.citation.Next() } return result, nil } // ParallelMode defines how parallel search should behave (inspired by JavaScript Promise) type ParallelMode string // ParallelMode constants (similar to Promise.all, Promise.any, Promise.race) const ( // ModeAll waits for all searches to complete, returns all results (like Promise.all) ModeAll ParallelMode = "all" // ModeAny returns as soon as any search succeeds (has results), others continue but are discarded (like Promise.any) ModeAny ParallelMode = "any" // ModeRace returns as soon as any search completes (success or empty), others continue but are discarded (like Promise.race) ModeRace ParallelMode = "race" ) // ParallelOptions configures parallel search behavior // All executes all searches and waits for all to complete (like Promise.all) func (s *Searcher) All(ctx *context.Context, reqs []*types.Request) ([]*types.Result, error) { return s.parallelAll(ctx, reqs) } // Any returns as soon as any search succeeds with results (like Promise.any) func (s *Searcher) Any(ctx *context.Context, reqs []*types.Request) ([]*types.Result, error) { return s.parallelAny(ctx, reqs) } // Race returns as soon as any search completes (like Promise.race) func (s *Searcher) Race(ctx *context.Context, reqs []*types.Request) ([]*types.Result, error) { return s.parallelRace(ctx, reqs) } // BuildReferences converts search results to unified Reference format func (s *Searcher) BuildReferences(results []*types.Result) []*types.Reference { var refs []*types.Reference for _, result := range results { for _, item := range result.Items { refs = append(refs, &types.Reference{ ID: item.CitationID, Type: item.Type, Source: item.Source, Weight: item.Weight, Score: item.Score, Title: item.Title, Content: item.Content, URL: item.URL, }) } } return refs } ``` ### Registry (`registry.go`) ```go package search import ( "github.com/yaoapp/yao/agent/search/interfaces" "github.com/yaoapp/yao/agent/search/types" ) // Registry manages search handlers type Registry struct { handlers map[types.SearchType]interfaces.Handler } // NewRegistry creates a new handler registry func NewRegistry() *Registry { return &Registry{ handlers: make(map[types.SearchType]interfaces.Handler), } } // Register registers a handler for a search type func (r *Registry) Register(handler interfaces.Handler) { r.handlers[handler.Type()] = handler } // Get returns the handler for a search type func (r *Registry) Get(t types.SearchType) (interfaces.Handler, bool) { h, ok := r.handlers[t] return h, ok } ``` ## Core Interfaces All interfaces are defined in `search/interfaces/` package to prevent circular dependencies. ### Handler Interface (`interfaces/handler.go`) ```go package interfaces import ( "github.com/yaoapp/yao/agent/search/types" ) // Handler defines the interface for search implementations type Handler interface { // Type returns the search type this handler supports Type() types.SearchType // Search executes the search and returns results Search(req *types.Request) (*types.Result, error) } ``` ### Searcher Interface (`interfaces/searcher.go`) ```go package interfaces import ( "github.com/yaoapp/yao/agent/search/types" ) // Searcher is the main interface exposed to external callers type Searcher interface { // Search executes a single search request Search(ctx *context.Context, req *types.Request) (*types.Result, error) // Parallel search methods - inspired by JavaScript Promise // All waits for all searches to complete (like Promise.all) All(ctx *context.Context, reqs []*types.Request) ([]*types.Result, error) // Any returns when any search succeeds with results (like Promise.any) Any(ctx *context.Context, reqs []*types.Request) ([]*types.Result, error) // Race returns when any search completes (like Promise.race) Race(ctx *context.Context, reqs []*types.Request) ([]*types.Result, error) // BuildReferences converts search results to unified Reference format for LLM BuildReferences(results []*types.Result) []*types.Reference } ``` > **Note**: Parallel search methods follow JavaScript Promise naming: > > - `All()`: Wait for all searches to complete (like `Promise.all`) > - `Any()`: Return when any search succeeds with results (like `Promise.any`) > - `Race()`: Return when any search completes (like `Promise.race`) ### NLP Interfaces (`interfaces/nlp.go`) ```go package interfaces import ( "github.com/yaoapp/gou/model" "github.com/yaoapp/gou/query/gou" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // KeywordExtractor extracts keywords for web search type KeywordExtractor interface { // Extract extracts search keywords from user message // ctx is required for Agent and MCP modes, can be nil for builtin mode Extract(ctx *context.Context, content string, opts *types.KeywordOptions) ([]string, error) } // QueryDSLGenerator generates QueryDSL for DB search type QueryDSLGenerator interface { // Generate converts natural language to QueryDSL // Uses GOU types directly: model.Model and gou.QueryDSL Generate(query string, models []*model.Model) (*gou.QueryDSL, error) } // Note: Embedding is handled by KB collection's own config (embedding provider + model), // not defined here. See KB handler for details. ``` ### Reranker Interface (`interfaces/reranker.go`) ```go package interfaces import ( "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // Reranker reorders search results by relevance type Reranker interface { // Rerank reorders results based on query relevance Rerank(ctx *context.Context, query string, items []*types.ResultItem, opts *types.RerankOptions) ([]*types.ResultItem, error) } ``` ## Types All types are defined in `search/types/` package to prevent circular dependencies. ### Core Types (`types/types.go`) ```go package types import ( "github.com/yaoapp/gou/query/gou" ) // SearchType represents the type of search type SearchType string const ( SearchTypeWeb SearchType = "web" // Web/Internet search SearchTypeKB SearchType = "kb" // Knowledge base vector search SearchTypeDB SearchType = "db" // Database search (Yao Model/QueryDSL) ) // SourceType represents where the search result came from type SourceType string const ( SourceUser SourceType = "user" // User-provided DataContent (highest priority) SourceHook SourceType = "hook" // Hook ctx.search.*() results SourceAuto SourceType = "auto" // Auto search results (lowest priority) ) // Request represents a search request type Request struct { // Common fields Query string `json:"query"` // Search query (natural language) Type SearchType `json:"type"` // Search type: "web", "kb", or "db" Limit int `json:"limit,omitempty"` // Max results (default: 10) Source SourceType `json:"source"` // Source of this request (user/hook/auto) // Web search specific Sites []string `json:"sites,omitempty"` // Restrict to specific sites TimeRange string `json:"time_range,omitempty"` // "day", "week", "month", "year" // Knowledge base specific Collections []string `json:"collections,omitempty"` // KB collection IDs Threshold float64 `json:"threshold,omitempty"` // Similarity threshold (0-1) Graph bool `json:"graph,omitempty"` // Enable graph association // Database search specific // Uses GOU QueryDSL types directly for compatibility with Yao's query system // See: github.com/yaoapp/gou/query/gou/types.go Models []string `json:"models,omitempty"` // Model IDs (e.g., "user", "agents.mybot.product") Wheres []gou.Where `json:"wheres,omitempty"` // Pre-defined filters (optional), uses GOU QueryDSL Where Orders gou.Orders `json:"orders,omitempty"` // Sort orders (optional), uses GOU QueryDSL Orders Select []string `json:"select,omitempty"` // Fields to return (optional) // Reranking Rerank *RerankOptions `json:"rerank,omitempty"` } // RerankOptions controls result reranking // Reranker type is determined by uses.rerank in agent/agent.yml type RerankOptions struct { TopN int `json:"top_n,omitempty"` // Return top N after reranking } // Result represents the search result with all intermediate processing data type Result struct { Type SearchType `json:"type"` // Search type Query string `json:"query"` // Original query Source SourceType `json:"source"` // Source of this result Items []*ResultItem `json:"items"` // Result items Total int `json:"total"` // Total matches Duration int64 `json:"duration_ms"` // Search duration in ms Error string `json:"error,omitempty"` // Error message if failed // Intermediate processing results (for storage and debugging) Keywords []string `json:"keywords,omitempty"` // Extracted keywords (Web/NLP) DSL map[string]any `json:"dsl,omitempty"` // Generated QueryDSL (DB) Entities []Entity `json:"entities,omitempty"` // Extracted entities (Graph RAG) Relations []Relation `json:"relations,omitempty"` // Extracted relations (Graph RAG) // Graph associations (KB only, if enabled) GraphNodes []*GraphNode `json:"graph_nodes,omitempty"` } // Entity represents an extracted entity (for Graph RAG) type Entity struct { Name string `json:"name"` Type string `json:"type,omitempty"` Source string `json:"source,omitempty"` } // Relation represents an extracted relation (for Graph RAG) type Relation struct { Subject string `json:"subject"` Predicate string `json:"predicate"` Object string `json:"object"` Source string `json:"source,omitempty"` } // ResultItem represents a single search result item type ResultItem struct { // Citation CitationID string `json:"citation_id"` // Unique ID for LLM reference: "ref_001" // Weighting Source SourceType `json:"source"` // Source type: "user", "hook", "auto" Weight float64 `json:"weight"` // Source weight (from config) Score float64 `json:"score,omitempty"` // Relevance score (0-1) // Common fields Type SearchType `json:"type"` // Search type for this item Title string `json:"title,omitempty"` // Title/headline Content string `json:"content"` // Main content/snippet URL string `json:"url,omitempty"` // Source URL // KB specific DocumentID string `json:"document_id,omitempty"` // Source document ID Collection string `json:"collection,omitempty"` // Collection name // DB specific Model string `json:"model,omitempty"` // Model ID RecordID interface{} `json:"record_id,omitempty"` // Record primary key Data map[string]interface{} `json:"data,omitempty"` // Full record data // Metadata Metadata map[string]interface{} `json:"metadata,omitempty"` // Additional metadata } // ProcessedQuery represents a processed query ready for execution type ProcessedQuery struct { Type SearchType `json:"type"` Keywords []string `json:"keywords,omitempty"` // For web search Vector []float32 `json:"vector,omitempty"` // For KB search DSL *gou.QueryDSL `json:"dsl,omitempty"` // For DB search, uses GOU QueryDSL } ``` > **Design Note: Result with Intermediate Data** > > The `Result` type now includes intermediate processing results (`Keywords`, `DSL`, `Entities`, `Relations`) > that were previously only available during query processing. This design enables: > > 1. **Storage for Debugging**: All processing steps are captured for later analysis > 2. **System Tuning**: Analyze extracted keywords, generated DSL, and entity extraction quality > 3. **Unified Data Flow**: Handlers populate these fields during execution, eliminating the need > for separate data collection in `executeAutoSearch` > > **Handler Responsibilities**: > > - **Web Handler**: Populates `Keywords` from NLP extraction > - **DB Handler**: Populates `DSL` from QueryDSL generation > - **KB Handler**: Populates `Entities`, `Relations`, and `GraphNodes` from Graph RAG > > **Data Flow**: > > ``` > Request → Handler → Result (with Keywords/DSL/Entities/Relations) > ↓ > BuildReferenceContext > ↓ > saveSearch (stores all intermediate data) > ``` ```go // ProcessedQuery is DEPRECATED for external use // Handlers should populate Result.Keywords/DSL/Entities/Relations directly type ProcessedQuery struct { Type SearchType `json:"type"` Keywords []string `json:"keywords,omitempty"` // For web search Vector []float32 `json:"vector,omitempty"` // For KB search DSL *gou.QueryDSL `json:"dsl,omitempty"` // For DB search } // Note: For QueryDSL and Model types, use GOU types directly: // - github.com/yaoapp/gou/query/gou.QueryDSL // - github.com/yaoapp/gou/model.Model // - github.com/yaoapp/gou/model.Column ``` > **Note**: `Wheres` and `Orders` use GOU QueryDSL types directly (`gou.Where` and `gou.Orders`) for full compatibility with Yao's query system. See `github.com/yaoapp/gou/query/gou/types.go` for the complete type definitions. ### Graph Types (`types/graph.go`) ```go package types // GraphNode represents a related entity from knowledge graph type GraphNode struct { ID string `json:"id"` Type string `json:"type"` // Entity type Name string `json:"name"` // Entity name Description string `json:"description,omitempty"` // Entity description Relation string `json:"relation,omitempty"` // Relationship to query Score float64 `json:"score,omitempty"` // Relevance score Metadata map[string]interface{} `json:"metadata,omitempty"` } ``` ### Reference Types (`types/reference.go`) ```go package types // Reference is the unified structure for all data sources // Used to build LLM context from search results type Reference struct { ID string `json:"id"` // Unique citation ID: "ref_001", "ref_002" Type SearchType `json:"type"` // Data type: "web", "kb", "db" Source SourceType `json:"source"` // Origin: "user", "hook", "auto" Weight float64 `json:"weight"` // Relevance weight (1.0=highest, 0.6=lowest) Score float64 `json:"score"` // Relevance score (0-1) Title string `json:"title"` // Optional title Content string `json:"content"` // Main content URL string `json:"url"` // Optional URL Meta map[string]interface{} `json:"meta"` // Additional metadata } // ReferenceContext holds the formatted references for LLM input type ReferenceContext struct { References []*Reference `json:"references"` // All references XML string `json:"xml"` // Formatted XML Prompt string `json:"prompt"` // Citation instruction prompt } ``` ### Configuration Types (`types/config.go`) ```go package types // Config represents the complete search configuration type Config struct { Web *WebConfig `json:"web,omitempty"` KB *KBConfig `json:"kb,omitempty"` DB *DBConfig `json:"db,omitempty"` Keyword *KeywordConfig `json:"keyword,omitempty"` QueryDSL *QueryDSLConfig `json:"querydsl,omitempty"` Rerank *RerankConfig `json:"rerank,omitempty"` Citation *CitationConfig `json:"citation,omitempty"` Weights *WeightsConfig `json:"weights,omitempty"` Options *OptionsConfig `json:"options,omitempty"` } // WebConfig for web search settings // Note: uses.web determines the mode (builtin/agent/mcp) // Provider is only used when uses.web = "builtin" type WebConfig struct { Provider string `json:"provider,omitempty"` // "tavily", "serper", or "serpapi" (for builtin mode) APIKeyEnv string `json:"api_key_env,omitempty"` // Environment variable for API key MaxResults int `json:"max_results,omitempty"` // Max results (default: 10) Engine string `json:"engine,omitempty"` // Search engine for SerpAPI: "google", "bing", "baidu", etc. (default: "google") } // KBConfig for knowledge base search settings type KBConfig struct { Collections []string `json:"collections,omitempty"` // Default collections Threshold float64 `json:"threshold,omitempty"` // Similarity threshold (default: 0.7) Graph bool `json:"graph,omitempty"` // Enable GraphRAG (default: false) } // DBConfig for database search settings type DBConfig struct { Models []string `json:"models,omitempty"` // Default models MaxResults int `json:"max_results,omitempty"` // Max results (default: 20) } // KeywordConfig for keyword extraction type KeywordConfig struct { MaxKeywords int `json:"max_keywords,omitempty"` // Max keywords (default: 10) Language string `json:"language,omitempty"` // "auto", "en", "zh", etc. } // KeywordOptions for keyword extraction (runtime options) type KeywordOptions struct { MaxKeywords int `json:"max_keywords,omitempty"` Language string `json:"language,omitempty"` } // QueryDSLConfig for QueryDSL generation from natural language type QueryDSLConfig struct { Strict bool `json:"strict,omitempty"` // Fail if generation fails (default: false) } // RerankConfig for reranking type RerankConfig struct { TopN int `json:"top_n,omitempty"` // Return top N (default: 10) } // CitationConfig for citation format type CitationConfig struct { Format string `json:"format,omitempty"` // Default: "#ref:{id}" AutoInjectPrompt bool `json:"auto_inject_prompt,omitempty"` // Auto-inject prompt (default: true) CustomPrompt string `json:"custom_prompt,omitempty"` // Custom prompt template } // WeightsConfig for source weighting type WeightsConfig struct { User float64 `json:"user,omitempty"` // User-provided (default: 1.0) Hook float64 `json:"hook,omitempty"` // Hook results (default: 0.8) Auto float64 `json:"auto,omitempty"` // Auto search (default: 0.6) } // OptionsConfig for search behavior type OptionsConfig struct { SkipThreshold int `json:"skip_threshold,omitempty"` // Skip auto search if user provides >= N results } ``` ### Note on Reranker Reranker type is determined by `uses.rerank` in `agent/agent.yml`: - `"builtin"` - Simple score-based sorting - `""` - Delegate to an assistant (Agent) - `"mcp:."` - Call MCP tool (e.g., `"mcp:my-server.rerank"`) ## Citation System Each search result has a unique `CitationID` for LLM reference. Citation logic is implemented in `search/citation.go`. ### Citation ID Generation Citation IDs are generated sequentially: `ref_001`, `ref_002`, etc. ```go // citation.go package search import ( "fmt" "sync/atomic" ) // CitationGenerator generates unique citation IDs type CitationGenerator struct { counter uint64 } // NewCitationGenerator creates a new citation generator func NewCitationGenerator() *CitationGenerator { return &CitationGenerator{} } // Next generates the next citation ID func (g *CitationGenerator) Next() string { n := atomic.AddUint64(&g.counter, 1) return fmt.Sprintf("ref_%03d", n) } ``` ### Citation Config (in `types/config.go`) ```go type CitationConfig struct { Format string `json:"format,omitempty"` // Default: "#ref:{id}" AutoInjectPrompt bool `json:"auto_inject_prompt,omitempty"` // Auto-add instructions to system prompt CustomPrompt string `json:"custom_prompt,omitempty"` // Override default prompt template } ``` ### Default Citation Prompt When `AutoInjectPrompt` is enabled (default), the system prompt includes: ``` You have access to reference data in tags. Each has: - id: Citation identifier - type: Data type (web/kb/db) - weight: Relevance weight (1.0=highest priority, 0.6=lowest) - source: Origin (user=user-provided, hook=assistant-searched, auto=auto-searched) Prioritize higher-weight references when answering. When citing a reference, use this exact HTML format: [{id}] Example: According to the product data[ref_001], the price is $999. ``` ### Custom Prompt in Config ```yaml # assistants/my-assistant.yml search: citation: format: "[{id}]" auto_inject_prompt: true custom_prompt: "Cite using [{id}]. Sources: ..." ``` ## Trace Integration Search operations create minimal trace nodes to report execution status to users, providing transparency about what the agent is doing. Detailed information is recorded via LOG for debugging. ### Trace Node Structure Uses `trace/types.NodeStatus` constants: - `pending` - Node created but not started - `running` - Node is currently executing - `completed` - Node finished successfully - `failed` - Node failed with error **Single Search:** ``` search (type: "search") ├── label // i18n: "Search" / "搜索" ├── status // "pending" | "running" | "completed" | "failed" ├── input │ ├── query // Original query │ └── types // ["web"], ["kb"], ["web", "kb", "db"] └── output // (set on complete) └── result_count // Total results found ``` **Parallel Search:** ``` search (type: "search") ├── label // i18n: "Search" / "搜索" ├── status // "pending" | "running" | "completed" | "failed" ├── input │ ├── query // Original query │ └── types // ["web", "kb", "db"] └── children ├── web (type: "search_item") │ ├── label // i18n: "Web Search" / "网页搜索" │ ├── status // "pending" | "running" | "completed" | "failed" │ └── output │ └── result_count ├── kb (type: "search_item") │ └── ... └── db (type: "search_item") └── ... ``` ### Trace Logging Detailed search information is recorded via Trace node logging methods (broadcasts to client): ```go // Node logging methods (from trace/node.go): // - node.Info(message, args...) - Info level log // - node.Debug(message, args...) - Debug level log // - node.Warn(message, args...) - Warning level log // - node.Error(message, args...) - Error level log // Search start searchNode.Info("Starting search", map[string]any{"query": query, "types": types}) // Per-type results (on parallel search children) webNode.Debug("Web search completed", map[string]any{"count": count, "duration_ms": duration}) kbNode.Debug("KB search completed", map[string]any{"count": count, "duration_ms": duration}) dbNode.Debug("DB search completed", map[string]any{"count": count, "duration_ms": duration}) // Errors (non-blocking, search continues) webNode.Warn("Web search failed", map[string]any{"error": err.Error()}) // Final summary (on parent node) searchNode.Info("Search completed", map[string]any{"total": total, "duration_ms": duration}) ``` **Log Event Structure** (broadcasted via SSE): ```go // types.TraceLog type TraceLog struct { Timestamp int64 `json:"timestamp"` // milliseconds since epoch Level string `json:"level"` // "info", "debug", "warn", "error" Message string `json:"message"` // Log message Data any `json:"data"` // Additional data NodeID string `json:"node_id"` // Parent node ID } ``` ## Real-time Output Search progress is displayed to the client using **Loading component with Replace** pattern. Uses `ctx.Send()` and `ctx.Replace()` methods. ### Output Flow ``` 1. Send Loading Message loading_id = ctx.Send({ type: "loading", props: { message: "Searching..." } }) → Client displays loading indicator 2. Execute Search (parallel web/kb/db) 3. Replace with Result Message (shows result to user) ctx.Replace(loading_id, { type: "loading", props: { message: "Found 5 references" } }) → Client displays result message 4. Mark as Done (removes the loading after brief display) ctx.Replace(loading_id, { type: "loading", props: { message: "Found 5 references", done: true } }) → Client removes loading indicator ``` ### Implementation ```go // Send loading message loadingID := ctx.Send(map[string]any{ "type": "loading", "props": map[string]any{ "message": i18n.Tr("search.loading", locale), // "Searching..." / "正在搜索..." }, }) // Execute search... // Replace with result message (displayed to user) resultMessage := i18n.Tr("search.success", locale, count) // "Found 5 references" ctx.Replace(loadingID, map[string]any{ "type": "loading", "props": map[string]any{ "message": resultMessage, }, }) // Mark as done (removes loading indicator after user sees the result) ctx.Replace(loadingID, map[string]any{ "type": "loading", "props": map[string]any{ "message": resultMessage, "done": true, // Frontend will remove loading indicator }, }) ``` ### Loading Props | Prop | Type | Description | | --------- | ------ | --------------------------------------------------- | | `message` | string | Localized message to display | | `done` | bool | When `true`, frontend removes the loading indicator | ### Localized Messages | Scenario | English | Chinese | | ------------- | ---------------------------------------- | --------------------------------- | | Loading | Searching... | 正在搜索... | | Success (1) | Found 1 reference | 找到 1 条参考资料 | | Success (N) | Found N references | 找到 N 条参考资料 | | Partial Error | Found N references (some sources failed) | 找到 N 条参考资料(部分来源失败) | | All Failed | Search failed | 搜索失败 | | No Results | No references found | 未找到相关资料 | ### Client Display Example ``` Frame 1 - During search: ┌─────────────────────────────────┐ │ Searching... │ ← Loading (done: false) └─────────────────────────────────┘ Frame 2 - Result displayed: ┌─────────────────────────────────┐ │ Found 5 references │ ← Result (done: false) └─────────────────────────────────┘ Frame 3 - Removed: (loading indicator removed when done: true) ``` ## Search Result Storage Search results are stored per request to support citation click-through and history replay. ### Data Model ``` Relationships: Chat └── Request (request_id) ├── Message[] (user, assistant, tool...) └── SearchResult[] (one request may have multiple searches) └── Reference[] (indexed references from each search) ``` ### Citation Locating LLM output uses `` tags with index: ```xml AI is artificial intelligence, it has developed rapidly... ``` Location path: `request_id` + `index` → precisely locate reference ### Database Schema **Table: `agent_search`** | Column | Type | Description | | ---------- | ----------- | -------------------------------------- | | id | BIGINT | Auto-increment primary key | | request_id | VARCHAR(64) | Associated request ID (indexed) | | chat_id | VARCHAR(64) | Associated chat ID (indexed) | | query | TEXT | Original search query | | config | JSON | Search config used (for tuning) | | keywords | JSON | Extracted keywords (from NLP) | | entities | JSON | Extracted entities (for Graph search) | | relations | JSON | Extracted relations (for Graph search) | | dsl | JSON | Generated QueryDSL (for DB search) | | source | VARCHAR(32) | Search source: web/kb/db/auto | | references | JSON | Reference[] with global index | | graph | JSON | GraphNode[] from knowledge graph | | xml | TEXT | Formatted XML for LLM context | | prompt | TEXT | Citation instruction prompt | | duration | INT | Search duration in milliseconds | | error | TEXT | Error message if failed (nullable) | | created_at | TIMESTAMP | Creation time | | deleted_at | TIMESTAMP | Soft delete time (nullable) | **Config Field Structure:** ```json { "uses": { "search": "builtin", "web": "builtin", "keyword": "builtin", "querydsl": "builtin", "rerank": "builtin" }, "web": { "provider": "tavily", "max_results": 5 }, "kb": { "collections": ["docs", "faq"], "threshold": 0.7, "graph": true }, "db": { "models": ["product", "order"], "max_results": 20 }, "rerank": { "provider": "builtin", "top_n": 10 } } ``` ### Type Definitions ```go // store/types/types.go // Search represents stored search results for a request // Stores all intermediate processing results for debugging and replay type Search struct { ID int64 `json:"id"` RequestID string `json:"request_id"` ChatID string `json:"chat_id"` Query string `json:"query"` // Original query Config map[string]any `json:"config,omitempty"` // Search config used (for tuning) Keywords []string `json:"keywords,omitempty"` // Extracted keywords (Web/NLP) Entities []Entity `json:"entities,omitempty"` // Extracted entities (Graph) Relations []Relation `json:"relations,omitempty"` // Extracted relations (Graph) DSL map[string]any `json:"dsl,omitempty"` // Generated QueryDSL (DB) Source string `json:"source"` // web/kb/db/auto References []Reference `json:"references"` Graph []GraphNode `json:"graph,omitempty"` // Graph nodes from KB XML string `json:"xml,omitempty"` // Formatted XML for LLM Prompt string `json:"prompt,omitempty"` // Citation prompt Duration int64 `json:"duration_ms"` // Search duration Error string `json:"error,omitempty"` // Error if failed CreatedAt time.Time `json:"created_at"` } // Reference represents a single reference with global index (for storage) type Reference struct { Index int `json:"index"` // Global index: 1, 2, 3... Type string `json:"type"` // web/kb/db Title string `json:"title"` URL string `json:"url,omitempty"` Snippet string `json:"snippet"` Content string `json:"content,omitempty"` // Full content (optional) Metadata map[string]any `json:"metadata,omitempty"` } // Entity represents an extracted entity from query (for Graph search) type Entity struct { Name string `json:"name"` // Entity name Type string `json:"type"` // Entity type: person, org, location, etc. Metadata map[string]any `json:"metadata,omitempty"` } // Relation represents an extracted relation from query (for Graph search) type Relation struct { Subject string `json:"subject"` // Source entity Predicate string `json:"predicate"` // Relation type Object string `json:"object"` // Target entity Metadata map[string]any `json:"metadata,omitempty"` } // GraphNode represents a node from knowledge graph (search result) type GraphNode struct { ID string `json:"id"` Type string `json:"type"` // Entity type Name string `json:"name"` // Entity name Description string `json:"description,omitempty"` Relation string `json:"relation,omitempty"` // Relationship to query Score float64 `json:"score,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } // SearchFilter for querying search records type SearchFilter struct { RequestID string `json:"request_id,omitempty"` ChatID string `json:"chat_id,omitempty"` Source string `json:"source,omitempty"` } ``` ### Store Interface Extension ```go // store/types/store.go // ChatStore interface extension type ChatStore interface { // ... existing methods ... // ========================================================================== // Search Management // ========================================================================== // SaveSearch saves search record for a request // search: Search record to save // Returns: Potential error SaveSearch(search *Search) error // GetSearches retrieves search records for a request // requestID: Request ID // Returns: Search records and potential error GetSearches(requestID string) ([]*Search, error) // GetReference retrieves a single reference by request ID and index // requestID: Request ID // index: Reference index (1-based) // Returns: Reference and potential error GetReference(requestID string, index int) (*Reference, error) // DeleteSearches deletes all search records for a chat // chatID: Chat ID // Returns: Potential error DeleteSearches(chatID string) error } ``` ### Xun Implementation ```go // store/xun/search.go // SaveSearch saves a search record func (store *Xun) SaveSearch(search *Search) error { if search.RequestID == "" { return fmt.Errorf("request_id is required") } refsJSON, err := jsoniter.MarshalToString(search.References) if err != nil { return fmt.Errorf("failed to marshal references: %w", err) } row := map[string]interface{}{ "request_id": search.RequestID, "chat_id": search.ChatID, "query": search.Query, "config": search.Config, // Search config for tuning "keywords": search.Keywords, "entities": search.Entities, // Graph entities "relations": search.Relations, // Graph relations "dsl": search.DSL, "source": search.Source, "references": refsJSON, "graph": search.Graph, // Graph nodes "xml": search.XML, "prompt": search.Prompt, "duration": search.Duration, "error": search.Error, "created_at": time.Now(), } return store.newQuerySearch().Insert(row) } // GetSearches retrieves all search records for a request func (store *Xun) GetSearches(requestID string) ([]*Search, error) { rows, err := store.newQuerySearch(). Where("request_id", requestID). WhereNull("deleted_at"). OrderBy("created_at", "asc"). Get() // ... convert rows to Search } // GetReference retrieves a single reference func (store *Xun) GetReference(requestID string, index int) (*Reference, error) { searches, err := store.GetSearches(requestID) if err != nil { return nil, err } // Find reference by index across all search records for _, search := range searches { for _, ref := range search.References { if ref.Index == index { return &ref, nil } } } return nil, fmt.Errorf("reference %d not found in request %s", index, requestID) } ``` ### Model Definition ```json // yao/models/agent/search.mod.yao { "name": "Search", "label": "Search", "description": "Search records for citation support and debugging", "tags": ["agent", "system"], "builtin": true, "readonly": true, "table": { "name": "agent_search", "comment": "Agent search table" }, "columns": [ { "name": "id", "type": "ID", "label": "ID" }, { "name": "request_id", "type": "string", "length": 64, "nullable": false, "index": true }, { "name": "chat_id", "type": "string", "length": 64, "nullable": false, "index": true }, { "name": "query", "type": "text", "nullable": true }, { "name": "config", "type": "json", "nullable": true, "comment": "Search config used (for tuning)" }, { "name": "keywords", "type": "json", "nullable": true }, { "name": "entities", "type": "json", "nullable": true }, { "name": "relations", "type": "json", "nullable": true }, { "name": "dsl", "type": "json", "nullable": true }, { "name": "source", "type": "string", "length": 32, "nullable": false }, { "name": "references", "type": "json", "nullable": true }, { "name": "graph", "type": "json", "nullable": true }, { "name": "xml", "type": "text", "nullable": true }, { "name": "prompt", "type": "text", "nullable": true }, { "name": "duration", "type": "integer", "nullable": true }, { "name": "error", "type": "text", "nullable": true } ], "option": { "timestamps": true, "soft_deletes": true } } ``` ### Stream Integration Storage logic is encapsulated in `assistant/search.go` with a dedicated method: ```go // assistant/search.go // SearchExecutionResult contains all intermediate results from search execution type SearchExecutionResult struct { Query string // Original query Config map[string]any // Search config used Keywords []string // Extracted keywords (Web/NLP) Entities []storeTypes.Entity // Extracted entities (Graph) Relations []storeTypes.Relation // Extracted relations (Graph) DSL map[string]any // Generated QueryDSL (DB) Source string // web/kb/db/auto RefCtx *searchTypes.ReferenceContext // Reference context for LLM Graph []storeTypes.GraphNode // Graph nodes from KB Duration int64 // Duration in ms Error string // Error message if failed } // saveSearch saves search record to store for citation support and debugging func (ast *Assistant) saveSearch(ctx *context.Context, result *SearchExecutionResult) { if ctx.Store == nil || result == nil { return } // Skip if no references and no error if result.RefCtx == nil && result.Error == "" { return } var refs []storeTypes.Reference var xml, prompt string if result.RefCtx != nil { refs = convertReferences(result.RefCtx.References) xml = result.RefCtx.XML prompt = result.RefCtx.Prompt } search := &storeTypes.Search{ RequestID: ctx.RequestID, ChatID: ctx.ID, Query: result.Query, Config: result.Config, // Search config for tuning analysis Keywords: result.Keywords, Entities: result.Entities, // Graph entities Relations: result.Relations, // Graph relations DSL: result.DSL, Source: result.Source, References: refs, Graph: result.Graph, // Graph nodes XML: xml, Prompt: prompt, Duration: result.Duration, Error: result.Error, } if err := ctx.Store.SaveSearch(search); err != nil { ctx.Logger.Warn("Failed to save search: %v", err) } } // convertReferences converts search references to store format func convertReferences(refs []*searchTypes.Reference) []storeTypes.Reference { result := make([]storeTypes.Reference, len(refs)) for i, ref := range refs { result[i] = storeTypes.Reference{ Index: i + 1, // 1-based index Type: string(ref.Type), Title: ref.Title, URL: ref.URL, Snippet: ref.Content, Content: ref.Content, Metadata: ref.Meta, } } return result } // In executeAutoSearch: func (ast *Assistant) executeAutoSearch(ctx *context.Context, ...) *searchTypes.ReferenceContext { start := time.Now() // 1. Execute search (Result now contains all intermediate data) results, err := searcher.All(ctx, requests) duration := time.Since(start).Milliseconds() // 2. Prepare execution result for storage execResult := &SearchExecutionResult{ Query: query, Config: buildSearchConfig(searchConfig, searchUses), Source: "auto", Duration: duration, } if err != nil { execResult.Error = err.Error() ast.saveSearch(ctx, execResult) return nil } // 3. Extract intermediate data from results // Result.Keywords, Result.DSL, Result.Entities, Result.Relations are populated by handlers for _, result := range results { if len(result.Keywords) > 0 { execResult.Keywords = result.Keywords } if result.DSL != nil { execResult.DSL = result.DSL } if len(result.Entities) > 0 { execResult.Entities = convertEntities(result.Entities) } if len(result.Relations) > 0 { execResult.Relations = convertRelations(result.Relations) } if len(result.GraphNodes) > 0 { execResult.Graph = convertGraphNodes(result.GraphNodes) } } // 4. Build reference context refCtx := search.BuildReferenceContext(results, citationConfig) execResult.RefCtx = refCtx // 5. Save search record ast.saveSearch(ctx, execResult) return refCtx } ``` ### Usage Scenarios **Scenario 1: Single Search** ``` Request: req_001 └── Search: { source: "auto", references: [{index:1,...}, {index:2,...}, {index:3,...}] } ``` **Scenario 2: Multiple Searches (e.g., Tool Call triggers another search)** ``` Request: req_001 ├── Search[0]: { source: "web", references: [{index:1,...}, {index:2,...}] } └── Search[1]: { source: "kb", references: [{index:3,...}, {index:4,...}] } ``` Index is globally incremented, so `request_id + index` is always unique. ### API Endpoints ``` GET /api/chat/{chat_id}/request/{request_id}/references # Get all references for request GET /api/chat/{chat_id}/request/{request_id}/reference/{index} # Get single reference by index ``` ### Frontend Integration ```typescript // When user clicks citation [1] async function onCitationClick(requestId: string, index: number) { const ref = await api.get( `/chat/${chatId}/request/${requestId}/reference/${index}` ); showReferenceCard({ title: ref.title, url: ref.url, snippet: ref.snippet, content: ref.content, }); } ``` ## JSAPI Integration The Search module is exposed via `ctx.search` object in hook scripts. ### Architecture To avoid circular dependency between `context` and `search` packages: ``` agent/context/jsapi_search.go agent/search/jsapi.go ┌─────────────────────────────┐ ┌─────────────────────────┐ │ SearchAPI interface │◄───────│ JSAPI struct │ │ SearchAPIFactory var │ │ (implements SearchAPI) │ │ V8 binding methods: │ │ NewJSAPI() │ │ newSearchObject() │ │ Web/KB/DB() │ │ searchWebMethod() │ │ All/Any/Race() │ │ searchKBMethod() │ │ buildRequest() │ │ searchDBMethod() │ │ parseRequests() │ │ searchAllMethod() │ │ ConfigGetter type │ │ searchAnyMethod() │ │ SetJSAPIFactory() │ │ searchRaceMethod() │ └─────────────────────────┘ └─────────────────────────────┘ │ ▲ │ │ │ └───────────────────────────────────────┘ Factory registration (with ConfigGetter in assistant/init) agent/context/jsapi.go ┌─────────────────────────────┐ │ NewObject() │ │ jsObject.Set("search", │ │ ctx.newSearchObject()) │ └─────────────────────────────┘ ``` **Key Files:** | File | Description | | -------------------------------- | ---------------------------------------------------------------- | | `context/jsapi_search.go` | SearchAPI interface + V8 binding methods | | `context/jsapi_search_test.go` | Integration tests (real V8 calls via test assistant) | | `context/jsapi.go` | Mount search object to ctx | | `search/jsapi.go` | JSAPI implementation (calls Searcher) + ConfigGetter | | `search/jsapi_test.go` | Black-box unit tests | | `assistant/assistant.go:init` | Factory registration via SetJSAPIFactory(ConfigGetter) | | `assistants/tests/search-jsapi/` | Test assistant for JSAPI integration tests (Create hook, no LLM) | ### API Methods ```typescript // In hook scripts (index.ts) // Single search methods ctx.search.Web(query: string, options?: WebOptions): Result ctx.search.KB(query: string, options?: KBOptions): Result ctx.search.DB(query: string, options?: DBOptions): Result // Parallel search methods - inspired by JavaScript Promise ctx.search.All(requests: Request[]): Result[] // Like Promise.all - wait for all ctx.search.Any(requests: Request[]): Result[] // Like Promise.any - first success ctx.search.Race(requests: Request[]): Result[] // Like Promise.race - first complete ``` ### Options Types ```typescript interface WebOptions { limit?: number; // Max results (default: 10) sites?: string[]; // Restrict to sites timeRange?: string; // "day", "week", "month", "year" rerank?: RerankOptions; } interface KBOptions { collections?: string[]; // Collection IDs threshold?: number; // Similarity threshold (0-1) limit?: number; // Max results graph?: boolean; // Enable graph association rerank?: RerankOptions; } interface DBOptions { models?: string[]; // Model IDs (default: use assistant's db.models) wheres?: Where[]; // Pre-defined filters, uses GOU QueryDSL Where format orders?: Order[]; // Sort orders, uses GOU QueryDSL Order format select?: string[]; // Fields to return limit?: number; // Max results (default: 10) rerank?: RerankOptions; } // GOU QueryDSL Where condition // See: github.com/yaoapp/gou/query/gou/types.go interface Where { field: Expression; // Field expression value?: any; // Match value op: string; // Operator: "=", "like", ">", "<", ">=", "<=", "in", "is null", etc. or?: boolean; // true for OR condition, default AND wheres?: Where[]; // Nested conditions for grouping } // GOU QueryDSL Order interface Order { field: Expression; // Field expression sort?: string; // "asc" or "desc" } // GOU Expression (simplified) interface Expression { field?: string; // Field name table?: string; // Table name (optional) } interface RerankOptions { topN?: number; // Return top N after reranking // Note: Reranker type is determined by uses.rerank in agent/agent.yml } ``` ### Usage Examples #### Example 1: Web Search ```typescript function Create(ctx, messages, options) { const query = messages[messages.length - 1].content; const result = ctx.search.Web(query, { limit: 5, timeRange: "week", }); if (result.items.length > 0) { return { messages: [ { role: "system", content: formatSearchContext(result), }, ], uses: { search: "disabled" }, // Disable auto search }; } return { messages: [] }; // Let auto search handle it } ``` #### Example 2: Knowledge Base Search with Graph ```typescript function Create(ctx, messages, options) { const query = messages[messages.length - 1].content; const result = ctx.search.KB(query, { collections: ["docs", "faq"], threshold: 0.7, limit: 10, graph: true, // Enable graph association }); if (result.items.length > 0) { return { messages: [ { role: "system", content: formatKBContext(result), }, ], uses: { search: "disabled" }, // Disable auto search }; } return { messages: [] }; // Let auto search handle it } ``` #### Example 3: Database Search ```typescript function Create(ctx, messages, options) { const query = messages[messages.length - 1].content; // Search in assistant's models (uses db.models from assistant config) const result = ctx.search.DB(query, { models: ["product", "agents.mybot.order"], // Optional: override models wheres: [{ field: "status", value: "active" }], // Pre-filter limit: 20, }); if (result.items.length > 0) { return { messages: [ { role: "system", content: formatDBContext(result), }, ], uses: { search: "disabled" }, // Disable auto search }; } return { messages: [] }; // Let auto search handle it } ``` #### Example 4: Parallel Search with ctx.search.All() ```typescript function Create(ctx, messages, options) { const query = messages[messages.length - 1].content; // Execute web, KB, and DB search in parallel (wait for all) - like Promise.all const [webResult, kbResult, dbResult] = ctx.search.All([ { type: "web", query: query, limit: 5 }, { type: "kb", query: query, collections: ["docs"], limit: 10 }, { type: "db", query: query, models: ["product"], limit: 10 }, ]); // Merge results const context = mergeSearchResults(webResult, kbResult, dbResult); return { messages: [ { role: "system", content: context, }, ], uses: { search: "disabled" }, // Disable auto search }; } ``` #### Example 4b: Parallel Search with ctx.search.Any() ```typescript function Create(ctx, messages, options) { const query = messages[messages.length - 1].content; // Return as soon as any search succeeds (has results) - like Promise.any const results = ctx.search.Any([ { type: "web", query: query, limit: 5 }, { type: "kb", query: query, collections: ["docs"], limit: 10 }, ]); // Use the first successful result const successResult = results.find((r) => r && r.items?.length > 0); if (successResult) { return { messages: [{ role: "system", content: formatContext(successResult) }], uses: { search: "disabled" }, }; } return { messages: [] }; } ``` #### Example 4c: Parallel Search with ctx.search.Race() ```typescript function Create(ctx, messages, options) { const query = messages[messages.length - 1].content; // Return as soon as any search completes (success or not) - like Promise.race const results = ctx.search.Race([ { type: "web", query: query, limit: 5 }, { type: "kb", query: query, collections: ["docs"], limit: 10 }, ]); // Use the first completed result const firstResult = results.find((r) => r != null); if (firstResult && firstResult.items?.length > 0) { return { messages: [{ role: "system", content: formatContext(firstResult) }], uses: { search: "disabled" }, }; } return { messages: [] }; } ``` #### Example 5: Custom Citation Format ```typescript function Create(ctx, messages, options) { const query = messages[messages.length - 1].content; const result = ctx.search.Web(query, { limit: 5 }); // Build custom citation prompt const refs = result.items .map((item, i) => `[${i + 1}] ${item.title} - ${item.url}`) .join("\n"); return { messages: [ { role: "system", content: `Use [N] to cite. References:\n${refs}`, }, ], uses: { search: "disabled" }, // Disable auto search citation: { autoInjectPrompt: false }, // Override citation config }; } ``` ## Configuration Configuration follows a three-layer hierarchy (later overrides earlier): 1. **System Built-in Defaults** - Hardcoded sensible defaults 2. **Global Configuration** - `agent/agent.yml` (uses) + `agent/search.yml` (search options) 3. **Assistant Configuration** - `assistants//package.yao` (uses + search options) ### Uses Configuration Processing tools are configured in `agent/agent.yml` under `uses`: ```yaml # agent/agent.yml uses: default: "yaobots" title: "workers.system.title" vision: "workers.system.vision" fetch: "workers.system.fetch" # Search processing tools (NLP) keyword: "builtin" # "builtin", "workers.nlp.keyword", "mcp:my-server.extract_keywords" querydsl: "builtin" # "builtin", "workers.nlp.querydsl", "mcp:my-server.generate_dsl" rerank: "builtin" # "builtin", "workers.rerank", "mcp:my-server.rerank" # Search handlers web: "builtin" # "builtin", "workers.search.web", "mcp:my-server.web_search" # Note: kb & db always use builtin (access internal data) # Note: embedding & entity follow KB collection config ``` Tool format: `"builtin"`, `""` (Agent), `"mcp:."` (MCP Tool) **Web Search Modes:** | Mode | Example | Description | | --------- | ---------------------------- | -------------------------------------------------------------------------- | | `builtin` | `"builtin"` | Use built-in providers (Tavily, Serper, SerpAPI) | | Agent | `"workers.search.web"` | AI-powered search: understand intent → optimize query → search → summarize | | MCP | `"mcp:my-server.web_search"` | External search tool via MCP protocol | **Why Agent for Web Search (AI Search)?** When `uses.web` is set to an assistant ID, the search flow becomes: ``` User Query: "What's the best laptop for programming in 2024?" │ ▼ ┌─────────────────────────────────────────────────────────────┐ │ Agent (workers.search.web) │ │ 1. Understand intent: laptop recommendations for coding │ │ 2. Generate optimized queries: │ │ - "best programming laptop 2024 review" │ │ - "developer laptop comparison 2024" │ │ 3. Execute multiple searches │ │ 4. Analyze & deduplicate results │ │ 5. Return structured, relevant results │ └─────────────────────────────────────────────────────────────┘ │ ▼ High-quality, intent-aware search results ``` ### System Built-in Defaults (`defaults/defaults.go`) These are the hardcoded defaults, used by `agent/load.go` when loading configuration: ```go package defaults import "github.com/yaoapp/yao/agent/search/types" // SystemDefaults provides hardcoded default values // Used by agent/load.go for merging with agent/search.yml var SystemDefaults = &types.Config{ // Web search defaults Web: &types.WebConfig{ Provider: "tavily", MaxResults: 10, }, // KB search defaults KB: &types.KBConfig{ Threshold: 0.7, Graph: false, }, // DB search defaults DB: &types.DBConfig{ MaxResults: 20, }, // Keyword extraction options (uses.keyword) Keyword: &types.KeywordConfig{ MaxKeywords: 10, Language: "auto", }, // QueryDSL generation options (uses.querydsl) QueryDSL: &types.QueryDSLConfig{ Strict: false, }, // Rerank options (uses.rerank) Rerank: &types.RerankConfig{ TopN: 10, }, // Citation Citation: &types.CitationConfig{ Format: "#ref:{id}", AutoInjectPrompt: true, }, // Source weights Weights: &types.WeightsConfig{ User: 1.0, Hook: 0.8, Auto: 0.6, }, // Behavior options Options: &types.OptionsConfig{ SkipThreshold: 5, }, } // GetWeight returns the weight for a source type func GetWeight(cfg *types.Config, source types.SourceType) float64 { if cfg == nil || cfg.Weights == nil { switch source { case types.SourceUser: return 1.0 case types.SourceHook: return 0.8 default: return 0.6 } } switch source { case types.SourceUser: return cfg.Weights.User case types.SourceHook: return cfg.Weights.Hook case types.SourceAuto: return cfg.Weights.Auto default: return 0.6 } } ``` ### Configuration Loading (in `agent/load.go`) Configuration loading follows the existing pattern in `agent/load.go`: ```go // agent/load.go import ( searchDefaults "github.com/yaoapp/yao/agent/search/defaults" searchTypes "github.com/yaoapp/yao/agent/search/types" ) var searchConfig *searchTypes.Config // initSearchConfig initialize the search configuration from agent/search.yml func initSearchConfig() error { // Start with system defaults searchConfig = searchDefaults.SystemDefaults path := filepath.Join("agent", "search.yml") if exists, _ := application.App.Exists(path); !exists { return nil // Use defaults } // Read and merge with defaults bytes, err := application.App.Read(path) if err != nil { return err } var cfg searchTypes.Config err = application.Parse("search.yml", bytes, &cfg) if err != nil { return err } // Merge: defaults < global config searchConfig = mergeSearchConfig(searchDefaults.SystemDefaults, &cfg) return nil } // GetSearchConfig returns the global search configuration func GetSearchConfig() *searchTypes.Config { return searchConfig } ``` ### Assistant-level Config Merge (in `agent/assistant/load.go`) Assistant-specific search config is merged in `assistant/load.go`: ```go // agent/assistant/load.go // GetMergedSearchConfig returns merged search config for this assistant func (ast *Assistant) GetMergedSearchConfig() *searchTypes.Config { globalCfg := agent.GetSearchConfig() if ast.Search == nil { return globalCfg } // Merge: global < assistant return mergeSearchConfig(globalCfg, ast.Search.ToConfig()) } ``` ### Global Configuration `agent/search.yml` - Override system defaults for all assistants: ```yaml # Global Search Configuration # These settings apply to all assistants unless overridden by assistant-specific configurations. # Web search settings web: provider: "tavily" # "tavily", "serper", or "serpapi" (builtin providers only) api_key_env: "TAVILY_API_KEY" max_results: 10 # engine: "google" # For SerpAPI only: "google", "bing", "baidu", "yandex", etc. # Knowledge base search settings kb: threshold: 0.7 # Similarity threshold graph: false # Enable GraphRAG association # Database search settings db: max_results: 20 # Keyword extraction options (uses.keyword) keyword: max_keywords: 10 language: "auto" # "auto", "en", "zh", etc. # QueryDSL generation options (uses.querydsl) querydsl: strict: false # Strict mode: fail if generation fails # Rerank options (uses.rerank) rerank: top_n: 10 # Return top N results after reranking # Citation format for LLM references citation: format: "#ref:{id}" auto_inject_prompt: true # Auto-inject citation instructions to system prompt # Source weighting for result merging weights: user: 1.0 # User-provided DataContent (highest priority) hook: 0.8 # Hook ctx.search.*() results auto: 0.6 # Auto search results # Search behavior options options: skip_threshold: 5 # Skip auto search if user provides >= N results ``` ### Assistant Configuration `assistants//package.yao` - Override for specific assistant: ```jsonc { "name": "My Assistant", "connector": "openai", // Overrides global uses (agent/agent.yml) "uses": { "search": "builtin", // "builtin", "disabled", "", "mcp:." "web": "builtin", // "builtin", "", "mcp:." "keyword": "workers.nlp.keyword", // Use LLM for keyword extraction "querydsl": "workers.nlp.querydsl", // Use LLM for QueryDSL generation "rerank": "mcp:my-server.rerank" // Use MCP tool for reranking }, // Search configuration (overrides agent/search.yml) "search": { // Overrides global web settings "web": { "provider": "tavily", "max_results": 5 }, // Overrides global kb settings "kb": { "collections": ["docs", "faq"], // Specific collections to search "threshold": 0.7, "graph": true }, // Overrides global db settings "db": { "models": ["product", "order"], // Uses db.models if not set "max_results": 20 }, // Overrides global keyword options "keyword": { "max_keywords": 5 }, // Overrides global querydsl options "querydsl": { "strict": true }, // Overrides global rerank options "rerank": { "top_n": 5 }, // Overrides global citation settings "citation": { "format": "#ref:{id}", "auto_inject_prompt": true } }, // Knowledge base collections available to this assistant "kb": { "collections": ["docs", "faq"] }, // Database models available to this assistant "db": { "models": ["product", "order", "customer"] } } ``` ## Execution Flow ### Search Flow ## Execution Modes ### Stream() Execution with Search ``` Stream(ctx, messages, options) │ ├── 1. Initialize │ ├── 2. Create Hook (optional) │ └── Can call ctx.search.* and return search results │ ├── 3. BuildRequest + BuildContent │ ├── 4. Auto Search Decision (shouldAutoSearch) │ ├── IF Uses.Search == "disabled" → SKIP │ ├── IF Create Hook returned uses.search="disabled" → SKIP │ └── ELSE → Execute Auto Search (executeAutoSearch) │ ├── Read assistant's search config (GetMergedSearchConfig) │ ├── Extract keywords (if uses.keyword && !Skip.Keyword) │ ├── Build search requests (buildSearchRequests) │ ├── Execute web/kb/db in parallel (searcher.All) │ ├── Build reference context (BuildReferenceContext) │ └── Inject search context to messages (injectSearchContext) │ ├── 5. LLM Call (with search context if any) │ ├── 6. Next Hook (optional) │ └── 7. Output (response may contain #ref:xxx citations) ``` **Implementation Files:** | File | Description | | --------------------- | ----------------------------------------------- | | `assistant/search.go` | Core integration logic (shouldAutoSearch, etc.) | | `assistant/agent.go` | Stream() integration point (after BuildContent) | | `search/reference.go` | BuildReferenceContext, FormatReferencesXML | **Key Functions (`assistant/search.go`):** ```go // shouldAutoSearch determines if auto search should be executed func (ast *Assistant) shouldAutoSearch(ctx *context.Context, createResponse *context.HookCreateResponse) bool // executeAutoSearch executes auto search based on configuration // opts is optional, used to check Skip.Keyword for keyword extraction func (ast *Assistant) executeAutoSearch(ctx *context.Context, messages []context.Message, createResponse *context.HookCreateResponse, opts ...*context.Options) *searchTypes.ReferenceContext // injectSearchContext injects search results into messages func (ast *Assistant) injectSearchContext(messages []context.Message, refCtx *searchTypes.ReferenceContext) []context.Message // getMergedSearchUses returns the merged uses configuration for search func (ast *Assistant) getMergedSearchUses(createResponse *context.HookCreateResponse) *context.Uses // buildSearchRequests builds search requests based on assistant configuration func (ast *Assistant) buildSearchRequests(query string, config *searchTypes.Config) []*searchTypes.Request ``` **Keyword Extraction in executeAutoSearch:** When `uses.keyword` is configured and `opts.Skip.Keyword` is not true, keyword extraction is performed before web search: ```go // Extract keywords for web search if: // 1. uses.keyword is configured (not empty) // 2. Skip.Keyword is not true // 3. Web search is enabled if webSearchEnabled && !skipKeyword && searchUses.Keyword != "" { extractor := keyword.NewExtractor(searchUses.Keyword, searchConfig.Keyword) keywords, err := extractor.Extract(ctx, query, nil) if err == nil && len(keywords) > 0 { query = strings.Join(keywords, " ") } } ``` **Integration in agent.go:** ```go // In Stream(), after BuildContent: if ast.shouldAutoSearch(ctx, createResponse) { refCtx := ast.executeAutoSearch(ctx, completionMessages, createResponse, opts) if refCtx != nil && len(refCtx.References) > 0 { completionMessages = ast.injectSearchContext(completionMessages, refCtx) } } ``` **Skip.Keyword Option (`context.Options.Skip`):** ```go type Skip struct { History bool `json:"history"` // Skip saving chat history Trace bool `json:"trace"` // Skip trace logging Output bool `json:"output"` // Skip output to client Keyword bool `json:"keyword"` // Skip keyword extraction for web search } ``` Use `Skip.Keyword = true` when you want to use the raw query directly without keyword extraction. ### Control via Uses.Search Search is controlled via the `Uses` mechanism, following the merge hierarchy: ``` Global (agent/agent.yml) → Assistant (package.yao) → CreateHook (return uses) → Request (options.uses) ``` | Uses.Search | Behavior | | ----------------------- | ------------------------------------ | | `"builtin"` | Use builtin auto search | | `"disabled"` | Disable auto search | | `""` | Delegate to an assistant (AI Search) | | `"mcp:."` | Use MCP tool for search | | `undefined` | Follow upper layer config (default) | **Go:** ```go // Use builtin auto search uses := &context.Uses{Search: "builtin"} // Disable auto search uses := &context.Uses{Search: "disabled"} // Delegate to AI Search assistant uses := &context.Uses{Search: "workers.search.ai"} // Follow assistant config (default) uses := &context.Uses{Search: ""} // or nil ``` **API Request:** ```json { "messages": [...], "uses": { "search": "builtin" } } ``` ### Hook-Controlled Search Search is controlled via the `Uses` mechanism, same as Vision/Audio. The merge hierarchy is: ``` Global (agent/agent.yml) → Assistant (package.yao) → CreateHook (return uses) ``` When you need custom search logic, handle it in Create Hook and return `uses.search` to control auto search: ```typescript function Create(ctx, messages, options) { const query = messages[messages.length - 1].content; // Custom logic: only search for certain queries if (needsSearch(query)) { const result = ctx.search.Web(query, { limit: 5 }); return { messages: [{ role: "system", content: formatContext(result) }], uses: { search: "disabled" }, // Disable auto search (hook handled it) }; } // Let auto search handle it (follow assistant config) return { messages: [] }; } ``` **Uses.Search Values:** | Value | Behavior | | ----------------------- | ------------------------------------ | | `"builtin"` | Use builtin auto search | | `"disabled"` | Disable auto search | | `""` | Delegate to an assistant (AI Search) | | `"mcp:."` | Use MCP tool for search | | `undefined` | Follow upper layer config (default) | **Uses Merge Hierarchy:** ``` ┌─────────────────────────────────────────────────────────────┐ │ 1. Global Config (agent/agent.yml) │ │ uses: │ │ search: "builtin" │ └─────────────────────────────────────────────────────────────┘ ↓ merge ┌─────────────────────────────────────────────────────────────┐ │ 2. Assistant Config (assistants//package.yao) │ │ uses: │ │ search: "workers.search.web" # Override to AI Search │ └─────────────────────────────────────────────────────────────┘ ↓ merge ┌─────────────────────────────────────────────────────────────┐ │ 3. CreateHook Return │ │ return { │ │ uses: { search: "disabled" } # Hook handled it │ │ } │ └─────────────────────────────────────────────────────────────┘ ``` > **Note**: The `Uses` struct in `context/types_llm.go` already has a `Search` field. > The value `"disabled"` is a special value to disable auto search when hook handles it. ## Search Flow ``` Request → Trace Start → Query Process → Search → Rerank → Citations → Output → Return ``` ### Query Processing | Type | Process | Tool Config | | ---- | ----------------------------------------------------- | -------------------- | | Web | Extract keywords → Build query | `uses.keyword` | | KB | Get collection's embedding model → Generate embedding | KB collection config | | DB | Parse query → Build QueryDSL → Execute on models | `uses.querydsl` | #### Processing Methods Configure via `uses.*` in `agent/agent.yml`: | Format | Description | Use Case | | --------------------- | ----------------------------------------- | ------------------------------ | | `builtin` | Rule-based, template-driven (no LLM call) | Fast, low cost, simple queries | | `` | Delegate to an assistant (Agent) | LLM-based, custom logic | | `mcp:.` | Call MCP tool | External services integration | #### Keyword Extraction (`nlp/keyword/`) Configure via `uses.keyword`. The keyword extraction module follows the Handler + Registry pattern with three modes: | Mode | Value | Description | | ------- | ---------------------------- | --------------------------------------------- | | Builtin | `"builtin"` | Frequency-based extraction (no external deps) | | Agent | `"workers.nlp.keyword"` | LLM-powered semantic extraction | | MCP | `"mcp:nlp.extract_keywords"` | External service via MCP | **Directory Structure:** ``` nlp/keyword/ ├── extractor.go # Main entry point (mode dispatch) ├── builtin.go # Builtin: frequency-based, stopword filtering ├── agent.go # Agent: delegate to LLM assistant └── mcp.go # MCP: call external tool ``` **Usage:** ```go // nlp/keyword/extractor.go package keyword // Extractor extracts keywords from text type Extractor struct { usesKeyword string // "builtin", "", "mcp:." config *types.KeywordConfig } // NewExtractor creates a new keyword extractor func NewExtractor(usesKeyword string, cfg *types.KeywordConfig) *Extractor // Extract extracts keywords based on configured mode func (e *Extractor) Extract(ctx *context.Context, content string, opts *types.KeywordOptions) ([]string, error) ``` **Builtin Implementation:** The builtin extractor uses simple frequency-based extraction with no external dependencies: - Tokenization (handles English and Chinese) - Stop word filtering (common English and Chinese stop words) - Frequency counting and ranking - Returns top N keywords by frequency > **Note**: For production use cases requiring high accuracy (semantic understanding, phrase extraction), use Agent or MCP mode. **Example:** ``` "I want to find the best wireless headphones under $100" ↓ builtin: tokenization + stopword removal + frequency ranking → ["wireless", "headphones", "find", "best"] ↓ agent: LLM semantic extraction → ["wireless headphones", "under $100", "best"] ``` #### Embedding (KB Collection Config) Embedding is **not** part of the `nlp/` package. It follows KB collection's own configuration: - Each KB collection defines its own embedding provider and model - The KB handler (`handlers/kb/`) calls the collection's embedding API directly - Entity types for GraphRAG are also defined per collection ```go // handlers/kb/handler.go func (h *Handler) Search(ctx *context.Context, req *types.Request) (*types.Result, error) { // 1. Get collection config (embedding provider, model) collection := h.getCollection(req.Collections[0]) // 2. Generate embedding using collection's config vector, err := collection.Embed(ctx, req.Query) // 3. Vector search // ... } ``` #### QueryDSL Generation (`nlp/querydsl/`) Configure via `uses.querydsl`. The QueryDSL generation module follows the same pattern as keyword extraction: | Mode | Value | Description | | ------- | ----------------------------- | ------------------------------------------- | | Builtin | `"builtin"` | Template-based generation from model schema | | Agent | `"workers.nlp.querydsl"` | LLM-powered semantic query generation | | MCP | `"mcp:nlp.generate_querydsl"` | External service via MCP | **Directory Structure:** ``` nlp/querydsl/ ├── generator.go # Main entry point (mode dispatch) ├── builtin.go # Builtin: template-based generation ├── agent.go # Agent: delegate to LLM assistant └── mcp.go # MCP: call external tool ``` **Usage:** ```go // nlp/querydsl/generator.go package querydsl // Generator generates QueryDSL from natural language type Generator struct { usesQueryDSL string config *types.QueryDSLConfig } // NewGenerator creates a new QueryDSL generator func NewGenerator(usesQueryDSL string, cfg *types.QueryDSLConfig) *Generator // Generate converts natural language to QueryDSL // Uses GOU types directly: model.Model and gou.QueryDSL func (g *Generator) Generate(query string, models []*model.Model) (*gou.QueryDSL, error) ``` **Example:** ``` "Products cheaper than $100 from Apple" ↓ builtin: template matching against model schema → QueryDSL with simple keyword matching ↓ agent: LLM generates DSL from NL + schema → QueryDSL: {"wheres": [{"column": "price", "op": "<", "value": 100}, {"column": "brand", "value": "Apple"}]} ``` ## Handlers & Providers All handler implementations are in `search/handlers/` directory. ### Web Search (`handlers/web/`) Web search supports three modes via `uses.web`: | Mode | Value | Description | | ------- | ---------------------------- | ------------------------------------------- | | Builtin | `"builtin"` | Direct API calls to Tavily/Serper/SerpAPI | | Agent | `"workers.search.web"` | AI-powered search with intent understanding | | MCP | `"mcp:my-server.web_search"` | External search tool via MCP | ```go // handlers/web/handler.go package web import ( "strings" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // Handler implements web search type Handler struct { usesWeb string // "builtin", "", "mcp:." config *types.WebConfig } // NewHandler creates a new web search handler func NewHandler(usesWeb string, cfg *types.WebConfig) *Handler // Type returns the search type this handler supports func (h *Handler) Type() types.SearchType // Search implements interfaces.Handler (without context) func (h *Handler) Search(req *types.Request) (*types.Result, error) // SearchWithContext executes web search with context (for Agent/MCP modes) func (h *Handler) SearchWithContext(ctx *agentContext.Context, req *types.Request) (*types.Result, error) ``` **Directory Structure:** ``` handlers/web/ ├── handler.go # Main entry point (mode dispatch) ├── tavily.go # Tavily provider (builtin) ├── serper.go # Serper provider (serper.dev) ├── serpapi.go # SerpAPI provider (serpapi.com, multi-engine) ├── agent.go # Agent mode (AI Search) └── mcp.go # MCP mode (external service) ``` **Built-in Providers (when `uses.web = "builtin"`):** | Provider | File | Notes | | -------- | ------------ | ----------------------------------------------- | | Tavily | `tavily.go` | Recommended for AI applications | | Serper | `serper.go` | Google search via serper.dev (POST + X-API-KEY) | | SerpAPI | `serpapi.go` | Multi-engine search via serpapi.com (GET + URL) | **SerpAPI Engine Support:** SerpAPI supports multiple search engines via the `engine` config: | Engine | Description | | ------------ | ---------------------------- | | `google` | Google Search (default) | | `bing` | Bing Search | | `baidu` | Baidu Search (Chinese) | | `yandex` | Yandex Search | | `yahoo` | Yahoo Search | | `duckduckgo` | DuckDuckGo Search | | `naver` | Naver Search (Korean) | | `ecosia` | Ecosia Search (eco-friendly) | | `seznam` | Seznam Search (Czech) | See [SerpAPI Documentation](https://serpapi.com/search-api) for the full list of supported engines. Configuration example: ```yaml # agent/search.yml web: provider: "serpapi" api_key_env: "SERPAPI_API_KEY" engine: "bing" # Use Bing instead of Google max_results: 10 ``` **Agent Mode (AI Search):** When `uses.web` is set to an assistant ID (e.g., `"workers.search.web"`), the assistant can: 1. **Understand user intent** - Parse complex queries, identify what user really wants 2. **Generate multiple queries** - Create optimized search terms for better coverage 3. **Multi-source search** - Search multiple providers or sources 4. **Result analysis** - Deduplicate, rank, and summarize results 5. **Context-aware** - Use conversation context to improve search relevance ``` User Query: "What's the best laptop for programming in 2024?" │ ▼ ┌─────────────────────────────────────────────────────────────┐ │ Agent (workers.search.web) │ │ 1. Understand intent: laptop recommendations for coding │ │ 2. Generate optimized queries: │ │ - "best programming laptop 2024 review" │ │ - "developer laptop comparison 2024" │ │ 3. Execute multiple searches via builtin providers │ │ 4. Analyze & deduplicate results │ │ 5. Return structured, relevant results │ └─────────────────────────────────────────────────────────────┘ │ ▼ High-quality, intent-aware search results ``` **Example AI Search Assistant:** ```typescript // assistants/workers/search/web/src/index.ts function Create(ctx, messages, options) { const userQuery = messages[messages.length - 1].content; // 1. Analyze intent (this assistant has access to LLM) const intent = analyzeIntent(ctx, userQuery); // 2. Generate optimized queries const queries = generateQueries(intent); // 3. Execute searches using builtin provider const allResults = []; for (const q of queries) { const result = ctx.search.Web(q, { provider: "tavily", // Use builtin provider limit: 5, }); allResults.push(...result.items); } // 4. Merge, deduplicate, and rank results const merged = mergeAndRank(allResults, intent); return { type: "search_result", items: merged, }; } ``` ### Knowledge Base (`handlers/kb/`) ```go // handlers/kb/handler.go package kb import ( "github.com/yaoapp/yao/agent/search/types" ) // Handler implements KB search type Handler struct { config *types.KBConfig } // NewHandler creates a new KB search handler func NewHandler(cfg *types.KBConfig) *Handler // Type returns the search type this handler supports func (h *Handler) Type() types.SearchType // Search executes vector search and optional graph association // TODO: Implement actual search logic func (h *Handler) Search(req *types.Request) (*types.Result, error) ``` | File | Description | | ------------ | ---------------------------------- | | `handler.go` | Main KB handler implementation | | `vector.go` | Vector similarity search | | `graph.go` | Graph-based association (GraphRAG) | ### Database Search (`handlers/db/`) ```go // handlers/db/handler.go package db import ( "github.com/yaoapp/yao/agent/search/types" ) // Handler implements DB search type Handler struct { usesQueryDSL string // "builtin", "", "mcp:." config *types.DBConfig } // NewHandler creates a new DB search handler func NewHandler(usesQueryDSL string, cfg *types.DBConfig) *Handler // Type returns the search type this handler supports func (h *Handler) Type() types.SearchType // Search converts NL to QueryDSL and executes // TODO: Implement actual search logic func (h *Handler) Search(req *types.Request) (*types.Result, error) ``` | File | Description | | ------------ | ------------------------------ | | `handler.go` | Main DB handler implementation | | `query.go` | QueryDSL builder utilities | | `schema.go` | Model schema introspection | Integrates with Yao's Model/QueryDSL system: - Natural language → QueryDSL conversion (via LLM) - Model schema introspection for query building - Support for: - Global models (`models/*.mod.yao`) - Assistant-specific models (`assistants/{id}/models/*.mod.yao` → `agents.{id}.*`) - Permission-aware queries (respects `__yao_*` permission fields) ### Reranking (`rerank/`) The rerank module follows the Handler + Registry pattern, consistent with `keyword/` and `web/`. ```go // rerank/reranker.go package rerank import ( "strings" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // Reranker reorders search results by relevance // Mode is determined by uses.rerank configuration type Reranker struct { usesRerank string // "builtin", "", "mcp:." config *types.RerankConfig } // NewReranker creates a new reranker func NewReranker(usesRerank string, cfg *types.RerankConfig) *Reranker // Rerank reorders results based on configured mode func (r *Reranker) Rerank(ctx *context.Context, query string, items []*types.ResultItem, opts *types.RerankOptions) ([]*types.ResultItem, error) ``` **Directory Structure:** ``` rerank/ ├── reranker.go # Main entry point (mode dispatch) ├── builtin.go # Builtin: weighted score sorting (score * weight) ├── agent.go # Agent mode (delegate to LLM assistant) └── mcp.go # MCP mode (external service) ``` **Builtin Implementation:** The builtin reranker uses weighted score sorting: - Calculate `weightedScore = score * weight` - Sort items by weighted score descending - Return top N items > **Note**: For production use cases requiring semantic understanding, use Agent or MCP mode. **Agent Response Format:** The agent should return reordered items in one of these formats: ```json // Format 1: Order list (recommended) { "order": ["ref_003", "ref_001", "ref_002"] } // Format 2: Items list with citation_id { "items": [{ "citation_id": "ref_003" }, { "citation_id": "ref_001" }] } ``` | File | Description | | ------------- | ---------------------------------------- | | `reranker.go` | Main entry point and mode dispatch | | `builtin.go` | Weighted score sorting (score \* weight) | | `agent.go` | Delegate to LLM assistant for reranking | | `mcp.go` | Call external MCP tool for reranking | Configure via `uses.rerank` in `agent/agent.yml`: | Value | Notes | | ---------------------- | -------------------------------- | | `builtin` | Simple score sorting (default) | | `workers.rerank` | Delegate to an assistant (Agent) | | `mcp:my-server.rerank` | Call MCP tool for reranking | ## Error Handling Search errors don't block the agent flow. Errors are returned in `Result.Error`: ```typescript const result = ctx.search.Web(query); if (result.error) { // Handle gracefully or fallback console.warn("Search failed:", result.error); } ``` ## Configuration Priority Configuration is merged with later layers overriding earlier ones: 1. **System Built-in** - Hardcoded defaults (lowest priority) 2. **Global-level** - `agent/agent.yml` (uses) + `agent/search.yml` (search options) 3. **Assistant-level** - `assistants//package.yao` (uses + search) 4. **Hook-level** - CreateHook return `uses.search` value 5. **Request-level** - `options.uses.search` in Stream() call (highest priority) - `"builtin"`: Use builtin auto search - `"disabled"`: Disable auto search - `""`: Delegate to AI Search assistant - `"mcp:."`: Use MCP tool for search ## DB Search Details ### Query Processing Flow ``` Natural Language Query │ ▼ ┌─────────────────────────────────┐ │ Get Model Schemas │ ← Introspect models from db.models config │ (fields, types, relations) │ └─────────────────────────────────┘ │ ▼ ┌─────────────────────────────────┐ │ LLM: Generate QueryDSL │ ← Convert NL to Yao QueryDSL │ (select, wheres, orders) │ └─────────────────────────────────┘ │ ▼ ┌─────────────────────────────────┐ │ Execute Query on Each Model │ ← model.Find() with QueryDSL └─────────────────────────────────┘ │ ▼ Results ``` ### Model ID Formats | Format | Example | Description | | ------ | -------------------- | --------------------------------------------------------------------- | | Global | `product` | Global model from `models/product.mod.yao` | | System | `__yao.user` | Yao system model | | Agent | `agents.mybot.order` | Assistant-specific model from `assistants/mybot/models/order.mod.yao` | ### QueryDSL Generation Prompt The DB handler uses LLM to convert natural language to QueryDSL: ``` Given the following model schemas: - product: { id, name, price, category, status, created_at } - order: { id, product_id, quantity, total, customer_id, status } User query: "find all active products under $100 in electronics category" Generate Yao QueryDSL: { "model": "product", "wheres": [ { "field": "status", "op": "=", "value": "active" }, { "field": "price", "op": "<", "value": 100 }, { "field": "category", "op": "=", "value": "electronics" } ], "orders": [{ "field": "price", "order": "asc" }], "limit": 10 } ``` ## Content Module Integration User messages may contain `type="data"` ContentParts with data source references. The `content` module processes these before LLM call. ### DataSource Types (from `context/types.go`) ```go const ( DataSourceModel DataSourceType = "model" // DB model query DataSourceKBCollection DataSourceType = "kb_collection" // KB collection search DataSourceKBDocument DataSourceType = "kb_document" // KB document retrieval DataSourceTable DataSourceType = "table" // Direct table query DataSourceAPI DataSourceType = "api" // External API DataSourceMCPResource DataSourceType = "mcp_resource" // MCP resource ) ``` ### Message with Data Reference User only specifies data source IDs. Filters are generated by Search module from natural language. ```json { "role": "user", "content": [ { "type": "text", "text": "Show me products under $100" }, { "type": "data", "data": { "sources": [ { "type": "model", "name": "product" }, { "type": "kb_collection", "name": "product-docs" } ] } } ] } ``` The Search module will: 1. Extract query from text: "products under $100" 2. For `model:product` → Generate QueryDSL: `{ "wheres": [{ "field": "price", "op": "<", "value": 100 }] }` 3. For `kb_collection:product-docs` → Vector search with query embedding ### Source Weighting & LLM Context Search results carry `source` and `weight` fields, which are used to build weighted context for LLM. **Source Types:** | Source | Weight | Description | | ------ | ------ | -------------------------------- | | `user` | 1.0 | Explicitly referenced in message | | `hook` | 0.8 | Called in Create/Next hook | | `auto` | 0.6 | Triggered by assistant config | **ResultItem with Weight:** ```go type ResultItem struct { CitationID string `json:"citation_id"` // "#ref:xxx" Source string `json:"source"` // "user", "hook", "auto" Weight float64 `json:"weight"` // 1.0, 0.8, 0.6 Score float64 `json:"score"` // Relevance score // ... other fields } ``` ### Unified Context Protocol All data sources (Content module, Hook, Auto-Search) produce the same `Reference` structure. The final LLM input uses a unified `` format. **Reference (Internal Structure):** ```go // Reference is the unified structure for all data sources type Reference struct { ID string `json:"id"` // Unique citation ID: "ref_001", "ref_002" Type string `json:"type"` // "web", "kb", "db" Source string `json:"source"` // "user", "hook", "auto" Weight float64 `json:"weight"` // 1.0, 0.8, 0.6 Score float64 `json:"score"` // Relevance score (0-1) Title string `json:"title"` // Optional title Content string `json:"content"` // Main content URL string `json:"url"` // Optional URL Meta map[string]interface{} `json:"meta"` // Additional metadata } ``` **Data Flow:** ```mermaid flowchart TD subgraph Sources ["Data Sources"] CM["Content Module
(db:xxx kb:xxx)"] HS["Hook Search
ctx.search.*()"] AS["Auto Search
(assistant config)"] end CM -->|"source=user
weight=1.0"| REF HS -->|"source=hook
weight=0.8"| REF AS -->|"source=auto
weight=0.6"| REF REF["[]Reference
(Unified Structure)"] REF --> MERGE["Merge & Deduplicate
Rerank by score × weight"] MERGE --> BUILD["Build <references> XML"] BUILD --> LLM["LLM Input"] ``` **LLM References Format:** ```xml Product: iPhone 15 Pro Price: $999 Category: Electronics The iPhone 15 Pro features the A17 Pro chip with improved performance... URL: https://example.com/iphone-review Apple announced the iPhone 15 series in September 2023... URL: https://news.example.com/apple-iphone-15 ``` **LLM System Prompt (auto-injected):** ``` You have access to reference data in tags. Each has: - id: Citation identifier - type: Data type (web/kb/db) - weight: Relevance weight (1.0=highest priority, 0.6=lowest) - source: Origin (user=user-provided, hook=assistant-searched, auto=auto-searched) Prioritize higher-weight references when answering. When citing a reference, use this exact HTML format:
[{id}] Example: According to the product data[ref_001], the price is $999. ``` **Citation Output Format:** LLM outputs citations as HTML links that can be parsed and rendered by frontend: ```html The iPhone 15 Pro[ref_001] features the A17 Pro chip[ref_002]. ``` **Citation Link Attributes:** | Attribute | Description | Example | | --------------- | ----------------------- | ----------------------- | | `class` | Fixed class for styling | `"ref"` | | `data-ref-id` | Reference ID | `"ref_001"` | | `data-ref-type` | Data type | `"db"`, `"kb"`, `"web"` | | `href` | Anchor link | `"#ref:ref_001"` | **Conversion Examples:** | Module | Input | Output Reference | | ------- | ---------------------------------- | ---------------------------------------------- | | Content | `db:product` (user message) | `{source:"user", weight:1.0, type:"db", ...}` | | Content | `kb:docs` (user message) | `{source:"user", weight:1.0, type:"kb", ...}` | | Hook | `ctx.search.Web(query)` | `{source:"hook", weight:0.8, type:"web", ...}` | | Hook | `ctx.search.KB(query)` | `{source:"hook", weight:0.8, type:"kb", ...}` | | Hook | `ctx.search.DB(query)` | `{source:"hook", weight:0.8, type:"db", ...}` | | Auto | Assistant config `search.web=true` | `{source:"auto", weight:0.6, type:"web", ...}` | | Auto | Assistant config `search.kb=true` | `{source:"auto", weight:0.6, type:"kb", ...}` | ### Processing Flow ``` Stream() │ ├── 1. Collect search results from all sources │ ├── User DataContent → source="user", weight=1.0 │ ├── Hook ctx.search.*() → source="hook", weight=0.8 │ └── Auto search → source="auto", weight=0.6 │ ├── 2. Merge, deduplicate, rerank by (score * weight) │ ├── 3. Build ... format │ └── 4. Inject references into messages for LLM ``` **Behavior Rules:** 1. **User data sufficient**: If user provides enough data (≥ skip_threshold), skip auto search 2. **Deduplication**: Same record from different sources → keep highest weight version 3. **Final ranking**: Sort by `score * weight` after reranking **Configuration:** Global defaults (`agent/search.yml`): ```yaml weights: user: 1.0 # User-provided DataContent hook: 0.8 # Hook ctx.search.*() results auto: 0.6 # Auto search results options: skip_threshold: 5 # Skip auto search if user provides >= N results ``` Assistant-level override (`assistants//package.yao`): ```jsonc { "search": { "weights": { "user": 1.0, "hook": 0.9, // Higher weight for hook results "auto": 0.5 // Lower weight for auto results }, "options": { "skip_threshold": 10 // Need more user results to skip auto search } } } ``` **System Auto-Processing:** The weighting and context building is handled automatically by the system: ``` Stream() │ ├── 1. Parse user message for DataContent sources │ └── If found → Mark as source="user", weight=1.0 │ ├── 2. Create Hook (optional) │ └── If hook calls ctx.search.*() → Mark as priority=2, weight=0.8 │ ├── 3. Auto Search Decision │ ├── Count user-provided results │ ├── IF user_results >= skip_auto_if_user_results → SKIP auto search │ └── ELSE → Execute auto search with priority=3, weight=0.6 │ ├── 4. Merge & Rerank (automatic) │ ├── Collect all results with their weights │ ├── Deduplicate (keep highest priority) │ └── Calculate finalScore = baseScore * weight │ └── 5. Inject to LLM context ``` Users don't need to handle weights in hooks - the system manages this automatically. ### Processing Flow in content.Vision() ``` content.Vision() ├── type="text" → Pass through ├── type="image_url" → Image processing ├── type="file" → File processing └── type="data" → processDataContent() ├── DataSourceModel → Query via model.Find() → Format as text ├── DataSourceKBCollection → search.KB() → Format as text ├── DataSourceKBDocument → Retrieve document → Format as text └── DataSourceMCPResource → MCP resource read → Format as text ``` ### Implementation Location The `processDataContent()` function in `content/content.go` should: 1. **For `model` type**: Call search module's DB handler or direct model query 2. **For `kb_collection` type**: Call search module's KB handler 3. **For `kb_document` type**: Retrieve specific document from KB 4. **For `mcp_resource` type**: Read MCP resource This allows the search module to be reused for both: - **Auto Search**: Triggered when `Uses.Search != "disabled"` - **Data ContentPart**: User explicitly references data sources in message ## Related Files ### Internal Dependencies - `agent/search/types/` - All type definitions (no circular dependencies) - `agent/search/interfaces/` - All interface definitions - `agent/search/defaults/` - System default configuration values - `agent/search/handlers/` - Handler implementations (web, kb, db) - `agent/search/rerank/` - Reranker implementations - `agent/search/nlp/` - NLP implementations (keyword, querydsl) ### External Dependencies - `agent/context/jsapi.go` - JSAPI base implementation - `agent/context/types.go` - DataSource, DataContent types - `agent/context/types_llm.go` - Uses configuration (Search field) - `agent/assistant/types.go` - SearchOption definition - `agent/store/types/types.go` - KnowledgeBase, Database config - `agent/output/message/types.go` - Output message types - `agent/content/content.go` - Content processing (Vision function) - `model/model.go` - Yao Model loading (global, system, assistant models) ## See Also - `agent/context/JSAPI.md` - Full JSAPI documentation - `agent/context/RESOURCE_MANAGEMENT.md` - Context lifecycle and resource management - `agent/output/README.md` - Output system documentation - `agent/store/CHAT_STORAGE_DESIGN.md` - Chat storage design ================================================ FILE: agent/search/citation.go ================================================ package search import ( "sync/atomic" ) // CitationGenerator generates unique citation IDs (1-based integers) // Thread-safe for concurrent use within a single request type CitationGenerator struct { counter uint64 } // NewCitationGenerator creates a new citation generator func NewCitationGenerator() *CitationGenerator { return &CitationGenerator{} } // Next generates the next citation ID (1, 2, 3, ...) func (g *CitationGenerator) Next() string { n := atomic.AddUint64(&g.counter, 1) return uint64ToString(n) } // NextInt generates the next citation ID as integer func (g *CitationGenerator) NextInt() int { return int(atomic.AddUint64(&g.counter, 1)) } // Current returns the current counter value without incrementing func (g *CitationGenerator) Current() int { return int(atomic.LoadUint64(&g.counter)) } // Reset resets the counter (for testing) func (g *CitationGenerator) Reset() { atomic.StoreUint64(&g.counter, 0) } // uint64ToString converts uint64 to string without fmt package func uint64ToString(n uint64) string { if n == 0 { return "0" } var buf [20]byte // max uint64 is 20 digits i := len(buf) for n > 0 { i-- buf[i] = byte('0' + n%10) n /= 10 } return string(buf[i:]) } ================================================ FILE: agent/search/citation_test.go ================================================ package search import ( "sync" "testing" "github.com/stretchr/testify/assert" ) func TestCitationGenerator_Next(t *testing.T) { gen := NewCitationGenerator() // First ID should be "1" id1 := gen.Next() assert.Equal(t, "1", id1) // Second ID should be "2" id2 := gen.Next() assert.Equal(t, "2", id2) // Third ID should be "3" id3 := gen.Next() assert.Equal(t, "3", id3) } func TestCitationGenerator_NextInt(t *testing.T) { gen := NewCitationGenerator() // First ID should be 1 id1 := gen.NextInt() assert.Equal(t, 1, id1) // Second ID should be 2 id2 := gen.NextInt() assert.Equal(t, 2, id2) } func TestCitationGenerator_Current(t *testing.T) { gen := NewCitationGenerator() // Initial should be 0 assert.Equal(t, 0, gen.Current()) // After one Next, should be 1 gen.Next() assert.Equal(t, 1, gen.Current()) // Current doesn't increment assert.Equal(t, 1, gen.Current()) } func TestCitationGenerator_Reset(t *testing.T) { gen := NewCitationGenerator() // Generate some IDs gen.Next() gen.Next() gen.Next() // Reset gen.Reset() // Next ID should be "1" again id := gen.Next() assert.Equal(t, "1", id) } func TestCitationGenerator_LargeNumbers(t *testing.T) { gen := NewCitationGenerator() // Generate 999 IDs for i := 0; i < 999; i++ { gen.Next() } // 1000th ID should be "1000" id := gen.Next() assert.Equal(t, "1000", id) } func TestCitationGenerator_Concurrent(t *testing.T) { gen := NewCitationGenerator() // Run 100 goroutines, each generating 10 IDs var wg sync.WaitGroup ids := make(chan string, 1000) for i := 0; i < 100; i++ { wg.Add(1) go func() { defer wg.Done() for j := 0; j < 10; j++ { ids <- gen.Next() } }() } wg.Wait() close(ids) // Collect all IDs idSet := make(map[string]bool) for id := range ids { idSet[id] = true } // All 1000 IDs should be unique assert.Equal(t, 1000, len(idSet)) } func TestNewCitationGenerator(t *testing.T) { gen := NewCitationGenerator() assert.NotNil(t, gen) } func TestUint64ToString(t *testing.T) { tests := []struct { input uint64 expected string }{ {0, "0"}, {1, "1"}, {10, "10"}, {100, "100"}, {999, "999"}, {1000, "1000"}, {18446744073709551615, "18446744073709551615"}, // max uint64 } for _, tt := range tests { result := uint64ToString(tt.input) assert.Equal(t, tt.expected, result, "uint64ToString(%d)", tt.input) } } ================================================ FILE: agent/search/defaults/defaults.go ================================================ package defaults import "github.com/yaoapp/yao/agent/search/types" // SystemDefaults provides hardcoded default values // Used by agent/load.go for merging with agent/search.yao var SystemDefaults = &types.Config{ // Web search defaults Web: &types.WebConfig{ Provider: "tavily", MaxResults: 10, }, // KB search defaults KB: &types.KBConfig{ Threshold: 0.7, Graph: false, }, // DB search defaults DB: &types.DBConfig{ MaxResults: 20, }, // Keyword extraction options (uses.keyword) Keyword: &types.KeywordConfig{ MaxKeywords: 10, Language: "auto", }, // QueryDSL generation options (uses.querydsl) QueryDSL: &types.QueryDSLConfig{ Strict: false, }, // Rerank options (uses.rerank) Rerank: &types.RerankConfig{ TopN: 10, }, // Citation Citation: &types.CitationConfig{ Format: "#ref:{id}", AutoInjectPrompt: true, }, // Source weights Weights: &types.WeightsConfig{ User: 1.0, Hook: 0.8, Auto: 0.6, }, // Behavior options Options: &types.OptionsConfig{ SkipThreshold: 5, }, } // GetWeight returns the weight for a source type using default config func GetWeight(source types.SourceType) float64 { return SystemDefaults.GetWeight(source) } ================================================ FILE: agent/search/handlers/db/handler.go ================================================ package db import ( "encoding/json" "fmt" "time" "github.com/yaoapp/gou/model" "github.com/yaoapp/gou/query" "github.com/yaoapp/gou/query/gou" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/nlp/querydsl" "github.com/yaoapp/yao/agent/search/types" ) // Handler implements DB search type Handler struct { usesQueryDSL string // "builtin", "", "mcp:." config *types.DBConfig // DB search configuration } // NewHandler creates a new DB search handler func NewHandler(usesQueryDSL string, cfg *types.DBConfig) *Handler { return &Handler{usesQueryDSL: usesQueryDSL, config: cfg} } // Type returns the search type this handler supports func (h *Handler) Type() types.SearchType { return types.SearchTypeDB } // Search converts NL to QueryDSL and executes // Note: This method doesn't have context, use SearchWithContext for full functionality func (h *Handler) Search(req *types.Request) (*types.Result, error) { return h.SearchWithContext(nil, req) } // SearchWithContext executes DB search with context (required for QueryDSL generation) func (h *Handler) SearchWithContext(ctx *agentContext.Context, req *types.Request) (*types.Result, error) { start := time.Now() // Validate request if req.Query == "" { return &types.Result{ Type: types.SearchTypeDB, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(start).Milliseconds(), Error: "query is required", }, nil } // Get models from request or config modelIDs := req.Models if len(modelIDs) == 0 && h.config != nil { modelIDs = h.config.Models } // If no models specified, return empty result if len(modelIDs) == 0 { return &types.Result{ Type: types.SearchTypeDB, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(start).Milliseconds(), Error: "no models specified", }, nil } // Get max results maxResults := req.Limit if maxResults == 0 && h.config != nil && h.config.MaxResults > 0 { maxResults = h.config.MaxResults } if maxResults == 0 { maxResults = 20 // default } // Context is required for QueryDSL generation if ctx == nil { return &types.Result{ Type: types.SearchTypeDB, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(start).Milliseconds(), Error: "context is required for DB search", }, nil } // 1. Load all models and build combined schema models := make(map[string]*model.Model) schemas := make([]map[string]interface{}, 0, len(modelIDs)) for _, modelID := range modelIDs { mod, err := model.Get(modelID) if err != nil { continue // Skip non-existent models } models[modelID] = mod schemas = append(schemas, h.buildModelSchema(mod)) } if len(schemas) == 0 { return &types.Result{ Type: types.SearchTypeDB, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(start).Milliseconds(), Error: "no valid models found", }, nil } // 2. Generate QueryDSL with all schemas generator := querydsl.NewGenerator(h.usesQueryDSL, nil) input := &querydsl.Input{ Query: req.Query, ModelIDs: modelIDs, Scenario: req.Scenario, // Pass scenario: filter, aggregation, join, complex Limit: maxResults, } // Build schema input: single schema or array of schemas var schemaInput interface{} if len(schemas) == 1 { schemaInput = schemas[0] } else { schemaInput = schemas } input.ExtraParams = map[string]interface{}{ "schema": schemaInput, } result, err := generator.Generate(ctx, input) if err != nil { return &types.Result{ Type: types.SearchTypeDB, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(start).Milliseconds(), Error: fmt.Sprintf("QueryDSL generation failed: %v", err), }, nil } if result == nil || result.DSL == nil { return &types.Result{ Type: types.SearchTypeDB, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(start).Milliseconds(), Error: "no QueryDSL generated", }, nil } // 3. Sanitize generated DSL (remove unsupported wildcards like "*") h.sanitizeDSL(result.DSL) // 4. Merge preset conditions into generated DSL h.mergeDSLConditions(result.DSL, req) // 5. Execute QueryDSL using gou query engine records, err := h.executeDSL(result.DSL) if err != nil { return &types.Result{ Type: types.SearchTypeDB, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(start).Milliseconds(), Error: fmt.Sprintf("query execution failed: %v", err), }, nil } // 6. Determine the primary model for result formatting // Use the "from" table from DSL, or first model primaryModelID := modelIDs[0] if result.DSL.From != nil && result.DSL.From.Name != "" { // Find model by table name for id, mod := range models { if mod.MetaData.Table.Name == result.DSL.From.Name { primaryModelID = id break } } } primaryModel := models[primaryModelID] if primaryModel == nil { primaryModel, _ = model.Get(primaryModelID) // May be nil, that's ok } // 7. Convert records to ResultItems items := h.convertToResultItems(records, primaryModelID, primaryModel, req.Source) // Apply limit if len(items) > maxResults { items = items[:maxResults] } // 8. Convert DSL to map for storage dslMap := h.dslToMap(result.DSL) return &types.Result{ Type: types.SearchTypeDB, Query: req.Query, Source: req.Source, Items: items, Total: len(items), Duration: time.Since(start).Milliseconds(), DSL: dslMap, }, nil } // mergeDSLConditions merges preset conditions from request into generated DSL func (h *Handler) mergeDSLConditions(dsl *gou.QueryDSL, req *types.Request) { if dsl == nil { return } // Merge preset Wheres (prepend to ensure they take priority) if len(req.Wheres) > 0 { dsl.Wheres = append(req.Wheres, dsl.Wheres...) } // Merge preset Orders (prepend to ensure they take priority) if len(req.Orders) > 0 { dsl.Orders = append(req.Orders, dsl.Orders...) } // Merge preset Select fields if len(req.Select) > 0 { // Convert string fields to Expression selectExprs := make([]gou.Expression, 0, len(req.Select)) for _, field := range req.Select { selectExprs = append(selectExprs, gou.Expression{Field: field}) } // If DSL has no select, use preset; otherwise merge if len(dsl.Select) == 0 { dsl.Select = selectExprs } else { // Prepend preset fields dsl.Select = append(selectExprs, dsl.Select...) } } // Ensure limit is set if dsl.Limit == 0 && req.Limit > 0 { dsl.Limit = req.Limit } } // buildModelSchema builds a simplified schema for QueryDSL generator func (h *Handler) buildModelSchema(mod *model.Model) map[string]interface{} { columns := make([]map[string]interface{}, 0, len(mod.Columns)) for _, col := range mod.Columns { colInfo := map[string]interface{}{ "name": col.Name, "type": col.Type, } if col.Label != "" { colInfo["label"] = col.Label } if col.Description != "" { colInfo["description"] = col.Description } columns = append(columns, colInfo) } return map[string]interface{}{ "name": mod.MetaData.Table.Name, "columns": columns, } } // sanitizeDSL cleans up LLM-generated DSL to remove unsupported constructs. // The QueryDSL engine does not support wildcard "*" in select fields; // an empty select list naturally returns all columns. func (h *Handler) sanitizeDSL(dsl *gou.QueryDSL) { if dsl == nil { return } if len(dsl.Select) > 0 { cleaned := make([]gou.Expression, 0, len(dsl.Select)) for _, expr := range dsl.Select { if expr.Field != "*" { cleaned = append(cleaned, expr) } } dsl.Select = cleaned } } // executeDSL executes the QueryDSL and returns records. // Uses recover to convert panics from MustGet into errors. func (h *Handler) executeDSL(dsl interface{}) (records []map[string]interface{}, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("query execution panic: %v", r) } }() engine, err := query.Select("default") if err != nil { return nil, fmt.Errorf("query engine not found: %w", err) } dslJSON, err := json.Marshal(dsl) if err != nil { return nil, fmt.Errorf("failed to marshal DSL: %w", err) } q, err := engine.Load(json.RawMessage(dslJSON)) if err != nil { return nil, fmt.Errorf("failed to load DSL: %w", err) } rawRecords := q.Get(nil) records = make([]map[string]interface{}, 0, len(rawRecords)) for _, rec := range rawRecords { records = append(records, map[string]interface{}(rec)) } return records, nil } // convertToResultItems converts query results to ResultItems func (h *Handler) convertToResultItems(records []map[string]interface{}, modelID string, mod *model.Model, source types.SourceType) []*types.ResultItem { items := make([]*types.ResultItem, 0, len(records)) primaryKey := "id" if mod != nil && mod.PrimaryKey != "" { primaryKey = mod.PrimaryKey } for _, rec := range records { item := &types.ResultItem{ Type: types.SearchTypeDB, Source: source, Model: modelID, Data: rec, } // Try to extract title from common fields item.Title = h.extractTitle(rec, mod) // Try to extract content/description item.Content = h.extractContent(rec, mod) // Try to extract record ID if id, ok := rec[primaryKey]; ok { item.RecordID = id } items = append(items, item) } return items } // extractTitle tries to extract a title from the record func (h *Handler) extractTitle(rec map[string]interface{}, mod *model.Model) string { // Common title fields titleFields := []string{"title", "name", "subject", "label"} for _, field := range titleFields { if val, ok := rec[field]; ok { if str, ok := val.(string); ok && str != "" { return str } } } return "" } // extractContent tries to extract content from the record func (h *Handler) extractContent(rec map[string]interface{}, mod *model.Model) string { // Common content fields contentFields := []string{"content", "description", "summary", "text", "body"} for _, field := range contentFields { if val, ok := rec[field]; ok { if str, ok := val.(string); ok && str != "" { return str } } } // Fallback: serialize first few fields as content content, _ := json.Marshal(rec) if len(content) > 500 { content = content[:500] } return string(content) } // dslToMap converts QueryDSL to map for storage func (h *Handler) dslToMap(dsl *gou.QueryDSL) map[string]interface{} { if dsl == nil { return nil } // Marshal and unmarshal to get a clean map data, err := json.Marshal(dsl) if err != nil { return nil } var result map[string]interface{} if err := json.Unmarshal(data, &result); err != nil { return nil } return result } ================================================ FILE: agent/search/handlers/db/handler_integration_test.go ================================================ package db_test import ( "testing" "github.com/yaoapp/gou/model" "github.com/yaoapp/gou/query/gou" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/handlers/db" "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/testutils" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // ============================================================================ // Integration Tests - Requires database and models // ============================================================================ func TestHandler_Search_Integration(t *testing.T) { // Skip if running short tests if testing.Short() { t.Skip("Skipping integration test") } // Initialize test environment (loads models, database, query engine, etc.) testutils.Prepare(t) defer testutils.Clean(t) // Create test context ctx := newTestContext(t) // Verify __yao.role model is loaded mod := model.Select("__yao.role") require.NotNil(t, mod, "__yao.role model should be loaded") t.Run("search_role_model_with_results", func(t *testing.T) { // First, ensure there's at least one role in the database ensureTestRole(t, mod) // Create handler with builtin QueryDSL generator h := db.NewHandler("builtin", &types.DBConfig{ Models: []string{"__yao.role"}, MaxResults: 10, }) req := &types.Request{ Type: types.SearchTypeDB, Query: "查询所有角色", Source: types.SourceAuto, Models: []string{"__yao.role"}, Scenario: types.ScenarioFilter, Limit: 10, } result, err := h.SearchWithContext(ctx, req) require.NoError(t, err) require.NotNil(t, result) // Verify result structure assert.Equal(t, types.SearchTypeDB, result.Type) assert.Equal(t, "查询所有角色", result.Query) assert.Equal(t, types.SourceAuto, result.Source) assert.GreaterOrEqual(t, result.Duration, int64(0)) // Should have results if result.Error != "" { t.Logf("Search error: %s", result.Error) } assert.Empty(t, result.Error, "Search should not return error") assert.Greater(t, len(result.Items), 0, "Should have at least one result") assert.Equal(t, len(result.Items), result.Total) // Verify result items for _, item := range result.Items { assert.Equal(t, types.SearchTypeDB, item.Type) assert.Equal(t, types.SourceAuto, item.Source) assert.Equal(t, "__yao.role", item.Model) assert.NotNil(t, item.Data, "Data should not be nil") assert.NotNil(t, item.RecordID, "RecordID should not be nil") } }) t.Run("search_with_filter_scenario", func(t *testing.T) { h := db.NewHandler("builtin", nil) req := &types.Request{ Type: types.SearchTypeDB, Query: "查询系统角色", Source: types.SourceHook, Models: []string{"__yao.role"}, Scenario: types.ScenarioFilter, Limit: 5, } result, err := h.SearchWithContext(ctx, req) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, types.SearchTypeDB, result.Type) assert.Equal(t, types.SourceHook, result.Source) assert.LessOrEqual(t, len(result.Items), 5, "Should respect limit") }) t.Run("search_with_preset_wheres", func(t *testing.T) { h := db.NewHandler("builtin", nil) req := &types.Request{ Type: types.SearchTypeDB, Query: "查询角色", Source: types.SourceAuto, Models: []string{"__yao.role"}, Wheres: []gou.Where{ {Condition: gou.Condition{Field: &gou.Expression{Field: "is_active"}, Value: true, OP: "="}}, }, Limit: 10, } result, err := h.SearchWithContext(ctx, req) require.NoError(t, err) require.NotNil(t, result) // All results should have is_active = true (due to preset where) for _, item := range result.Items { if data, ok := item.Data["is_active"]; ok { // is_active could be bool or int depending on driver switch v := data.(type) { case bool: assert.True(t, v) case int64: assert.Equal(t, int64(1), v) case float64: assert.Equal(t, float64(1), v) } } } }) t.Run("search_nonexistent_model_graceful", func(t *testing.T) { h := db.NewHandler("builtin", nil) req := &types.Request{ Type: types.SearchTypeDB, Query: "查询文章", Source: types.SourceAuto, Models: []string{"nonexistent_model", "article", "fake_model"}, Limit: 10, } // Should NOT panic, should return gracefully with error result, err := h.SearchWithContext(ctx, req) require.NoError(t, err) require.NotNil(t, result) // Should have error message about no valid models assert.Equal(t, types.SearchTypeDB, result.Type) assert.Equal(t, "no valid models found", result.Error) assert.Empty(t, result.Items) }) t.Run("search_mixed_models_partial_exist", func(t *testing.T) { h := db.NewHandler("builtin", nil) req := &types.Request{ Type: types.SearchTypeDB, Query: "查询角色", Source: types.SourceAuto, Models: []string{"nonexistent_model", "__yao.role", "fake_model"}, // Only __yao.role exists Limit: 10, } // Should NOT panic, should work with the existing model result, err := h.SearchWithContext(ctx, req) require.NoError(t, err) require.NotNil(t, result) // Should succeed with partial models assert.Equal(t, types.SearchTypeDB, result.Type) if result.Error == "" { // If no error, should have results from __yao.role assert.GreaterOrEqual(t, len(result.Items), 0) } }) } // newTestContext creates a test context with required fields func newTestContext(t *testing.T) *context.Context { t.Helper() authorized := &oauthTypes.AuthorizedInfo{ UserID: "test-user", } chatID := "test-chat-db-search" ctx := context.New(t.Context(), authorized, chatID) return ctx } // ensureTestRole ensures there's at least one role in the database for testing func ensureTestRole(t *testing.T, mod *model.Model) { t.Helper() // Try to find existing roles rows, err := mod.Get(model.QueryParam{Limit: 1}) if err == nil && len(rows) > 0 { return // Already have roles } // Create a test role _, err = mod.Create(map[string]interface{}{ "role_id": "test_role", "name": "Test Role", "description": "A test role for unit testing", "is_active": true, "is_system": false, "level": 1, }) if err != nil { t.Logf("Note: Could not create test role: %v", err) } } ================================================ FILE: agent/search/handlers/db/handler_test.go ================================================ package db import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/model" "github.com/yaoapp/gou/query/gou" "github.com/yaoapp/yao/agent/search/types" ) func TestNewHandler(t *testing.T) { t.Run("with nil config", func(t *testing.T) { h := NewHandler("builtin", nil) assert.NotNil(t, h) assert.Equal(t, "builtin", h.usesQueryDSL) assert.Nil(t, h.config) }) t.Run("with config", func(t *testing.T) { cfg := &types.DBConfig{ Models: []string{"product", "order"}, MaxResults: 50, } h := NewHandler("workers.nlp.querydsl", cfg) assert.NotNil(t, h) assert.Equal(t, "workers.nlp.querydsl", h.usesQueryDSL) assert.Equal(t, cfg, h.config) }) t.Run("with mcp mode", func(t *testing.T) { h := NewHandler("mcp:nlp.generate_querydsl", nil) assert.NotNil(t, h) assert.Equal(t, "mcp:nlp.generate_querydsl", h.usesQueryDSL) }) } func TestHandler_Type(t *testing.T) { h := NewHandler("builtin", nil) assert.Equal(t, types.SearchTypeDB, h.Type()) } func TestHandler_Search_Validation(t *testing.T) { tests := []struct { name string usesQueryDSL string config *types.DBConfig req *types.Request expectError string }{ { name: "empty query", usesQueryDSL: "builtin", config: nil, req: &types.Request{ Type: types.SearchTypeDB, Query: "", }, expectError: "query is required", }, { name: "no models in request or config", usesQueryDSL: "builtin", config: nil, req: &types.Request{ Type: types.SearchTypeDB, Query: "find products under $100", }, expectError: "no models specified", }, { name: "context required for DB search", usesQueryDSL: "builtin", config: &types.DBConfig{ Models: []string{"product"}, }, req: &types.Request{ Type: types.SearchTypeDB, Query: "find products under $100", Models: []string{"product"}, }, expectError: "context is required for DB search", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := NewHandler(tt.usesQueryDSL, tt.config) result, err := h.Search(tt.req) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, types.SearchTypeDB, result.Type) assert.Equal(t, tt.expectError, result.Error) assert.Equal(t, 0, len(result.Items)) assert.GreaterOrEqual(t, result.Duration, int64(0)) }) } } func TestHandler_Search_SourcePreserved(t *testing.T) { h := NewHandler("builtin", &types.DBConfig{Models: []string{"product"}}) sources := []types.SourceType{types.SourceUser, types.SourceHook, types.SourceAuto} for _, source := range sources { req := &types.Request{ Type: types.SearchTypeDB, Query: "test", Source: source, Models: []string{"product"}, } result, err := h.Search(req) assert.NoError(t, err) assert.Equal(t, source, result.Source) } } func TestHandler_BuildModelSchema(t *testing.T) { h := NewHandler("builtin", nil) // Create a mock model for testing mod := &model.Model{ MetaData: model.MetaData{ Table: model.Table{ Name: "test_products", }, }, Columns: map[string]*model.Column{ "id": { Name: "id", Type: "ID", Label: "ID", }, "name": { Name: "name", Type: "string", Label: "Name", Description: "Product name", }, "price": { Name: "price", Type: "decimal", Label: "Price", }, }, } schema := h.buildModelSchema(mod) assert.NotNil(t, schema) assert.Equal(t, "test_products", schema["name"]) columns, ok := schema["columns"].([]map[string]interface{}) assert.True(t, ok) assert.Len(t, columns, 3) // Verify columns have required fields for _, col := range columns { assert.NotEmpty(t, col["name"]) assert.NotEmpty(t, col["type"]) } } func TestHandler_BuildModelSchema_MultipleModels(t *testing.T) { h := NewHandler("builtin", nil) // Create mock models for testing joins productMod := &model.Model{ MetaData: model.MetaData{ Table: model.Table{Name: "products"}, }, Columns: map[string]*model.Column{ "id": {Name: "id", Type: "ID"}, "name": {Name: "name", Type: "string"}, "category_id": {Name: "category_id", Type: "integer"}, }, } categoryMod := &model.Model{ MetaData: model.MetaData{ Table: model.Table{Name: "categories"}, }, Columns: map[string]*model.Column{ "id": {Name: "id", Type: "ID"}, "name": {Name: "name", Type: "string"}, }, } productSchema := h.buildModelSchema(productMod) categorySchema := h.buildModelSchema(categoryMod) assert.Equal(t, "products", productSchema["name"]) assert.Equal(t, "categories", categorySchema["name"]) // Verify both schemas can be combined into an array schemas := []map[string]interface{}{productSchema, categorySchema} assert.Len(t, schemas, 2) } func TestHandler_ExtractTitle(t *testing.T) { h := NewHandler("builtin", nil) mod := &model.Model{} tests := []struct { name string record map[string]interface{} expected string }{ { name: "title field", record: map[string]interface{}{"title": "Test Title", "id": 1}, expected: "Test Title", }, { name: "name field", record: map[string]interface{}{"name": "Test Name", "id": 1}, expected: "Test Name", }, { name: "subject field", record: map[string]interface{}{"subject": "Test Subject", "id": 1}, expected: "Test Subject", }, { name: "label field", record: map[string]interface{}{"label": "Test Label", "id": 1}, expected: "Test Label", }, { name: "no title field", record: map[string]interface{}{"id": 1, "price": 100}, expected: "", }, { name: "empty title", record: map[string]interface{}{"title": "", "name": "Fallback"}, expected: "Fallback", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { title := h.extractTitle(tt.record, mod) assert.Equal(t, tt.expected, title) }) } } func TestHandler_ExtractContent(t *testing.T) { h := NewHandler("builtin", nil) mod := &model.Model{} tests := []struct { name string record map[string]interface{} expectEmpty bool }{ { name: "content field", record: map[string]interface{}{"content": "Test Content"}, expectEmpty: false, }, { name: "description field", record: map[string]interface{}{"description": "Test Description"}, expectEmpty: false, }, { name: "summary field", record: map[string]interface{}{"summary": "Test Summary"}, expectEmpty: false, }, { name: "fallback to JSON", record: map[string]interface{}{"id": 1, "price": 100}, expectEmpty: false, // Should return JSON representation }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { content := h.extractContent(tt.record, mod) if tt.expectEmpty { assert.Empty(t, content) } else { assert.NotEmpty(t, content) } }) } } func TestHandler_ConvertToResultItems(t *testing.T) { h := NewHandler("builtin", nil) mod := &model.Model{ PrimaryKey: "id", } records := []map[string]interface{}{ { "id": 1, "name": "Product 1", "description": "Description 1", "price": 99.99, }, { "id": 2, "title": "Product 2", "content": "Content 2", }, } items := h.convertToResultItems(records, "product", mod, types.SourceAuto) assert.Len(t, items, 2) // First item assert.Equal(t, types.SearchTypeDB, items[0].Type) assert.Equal(t, types.SourceAuto, items[0].Source) assert.Equal(t, "product", items[0].Model) assert.Equal(t, 1, items[0].RecordID) assert.Equal(t, "Product 1", items[0].Title) assert.Equal(t, "Description 1", items[0].Content) assert.NotNil(t, items[0].Data) // Second item assert.Equal(t, 2, items[1].RecordID) assert.Equal(t, "Product 2", items[1].Title) assert.Equal(t, "Content 2", items[1].Content) } func TestHandler_ConvertToResultItems_NilModel(t *testing.T) { h := NewHandler("builtin", nil) records := []map[string]interface{}{ {"id": 1, "name": "Test"}, } // Should use default primary key "id" when model is nil items := h.convertToResultItems(records, "test", nil, types.SourceHook) assert.Len(t, items, 1) assert.Equal(t, 1, items[0].RecordID) assert.Equal(t, "Test", items[0].Title) } func TestHandler_Search_ScenarioTypes(t *testing.T) { // Test that all scenario types are valid scenarios := []types.ScenarioType{ types.ScenarioFilter, types.ScenarioAggregation, types.ScenarioJoin, types.ScenarioComplex, } for _, scenario := range scenarios { t.Run(string(scenario), func(t *testing.T) { h := NewHandler("builtin", &types.DBConfig{Models: []string{"product"}}) req := &types.Request{ Type: types.SearchTypeDB, Query: "test query", Source: types.SourceAuto, Models: []string{"product"}, Scenario: scenario, } // Without context, should return error (but scenario should be preserved in request) result, err := h.Search(req) assert.NoError(t, err) assert.NotNil(t, result) // Verify request scenario is set correctly assert.Equal(t, scenario, req.Scenario) }) } } func TestScenarioTypeConstants(t *testing.T) { // Verify scenario type constants match expected values assert.Equal(t, types.ScenarioType("filter"), types.ScenarioFilter) assert.Equal(t, types.ScenarioType("aggregation"), types.ScenarioAggregation) assert.Equal(t, types.ScenarioType("join"), types.ScenarioJoin) assert.Equal(t, types.ScenarioType("complex"), types.ScenarioComplex) } func TestHandler_MergeDSLConditions(t *testing.T) { h := NewHandler("builtin", nil) t.Run("merge wheres", func(t *testing.T) { dsl := &gou.QueryDSL{ From: &gou.Table{Name: "users"}, Wheres: []gou.Where{ {Condition: gou.Condition{Field: &gou.Expression{Field: "status"}, Value: "active", OP: "="}}, }, } req := &types.Request{ Wheres: []gou.Where{ {Condition: gou.Condition{Field: &gou.Expression{Field: "tenant_id"}, Value: 1, OP: "="}}, }, } h.mergeDSLConditions(dsl, req) // Preset wheres should be prepended assert.Len(t, dsl.Wheres, 2) assert.Equal(t, "tenant_id", dsl.Wheres[0].Field.Field) assert.Equal(t, "status", dsl.Wheres[1].Field.Field) }) t.Run("merge orders", func(t *testing.T) { dsl := &gou.QueryDSL{ From: &gou.Table{Name: "products"}, Orders: gou.Orders{ {Field: &gou.Expression{Field: "name"}, Sort: "asc"}, }, } req := &types.Request{ Orders: gou.Orders{ {Field: &gou.Expression{Field: "created_at"}, Sort: "desc"}, }, } h.mergeDSLConditions(dsl, req) // Preset orders should be prepended assert.Len(t, dsl.Orders, 2) assert.Equal(t, "created_at", dsl.Orders[0].Field.Field) assert.Equal(t, "name", dsl.Orders[1].Field.Field) }) t.Run("merge select fields", func(t *testing.T) { dsl := &gou.QueryDSL{ From: &gou.Table{Name: "orders"}, Select: []gou.Expression{ {Field: "amount"}, }, } req := &types.Request{ Select: []string{"id", "status"}, } h.mergeDSLConditions(dsl, req) // Preset select should be prepended assert.Len(t, dsl.Select, 3) assert.Equal(t, "id", dsl.Select[0].Field) assert.Equal(t, "status", dsl.Select[1].Field) assert.Equal(t, "amount", dsl.Select[2].Field) }) t.Run("set limit from request", func(t *testing.T) { dsl := &gou.QueryDSL{ From: &gou.Table{Name: "users"}, Limit: 0, } req := &types.Request{ Limit: 50, } h.mergeDSLConditions(dsl, req) assert.Equal(t, 50, dsl.Limit) }) t.Run("preserve dsl limit if set", func(t *testing.T) { dsl := &gou.QueryDSL{ From: &gou.Table{Name: "users"}, Limit: 10, } req := &types.Request{ Limit: 50, } h.mergeDSLConditions(dsl, req) // DSL limit should be preserved assert.Equal(t, 10, dsl.Limit) }) t.Run("nil dsl", func(t *testing.T) { req := &types.Request{ Wheres: []gou.Where{ {Condition: gou.Condition{Field: &gou.Expression{Field: "id"}, Value: 1}}, }, } // Should not panic h.mergeDSLConditions(nil, req) }) } ================================================ FILE: agent/search/handlers/kb/handler.go ================================================ package kb import ( "context" "fmt" "time" "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/kb" kbapi "github.com/yaoapp/yao/kb/api" ) // Handler implements KB search using the KB API type Handler struct { config *types.KBConfig // KB search configuration } // NewHandler creates a new KB search handler func NewHandler(cfg *types.KBConfig) *Handler { return &Handler{config: cfg} } // Type returns the search type this handler supports func (h *Handler) Type() types.SearchType { return types.SearchTypeKB } // Search executes vector search and optional graph association func (h *Handler) Search(req *types.Request) (*types.Result, error) { start := time.Now() // Validate request if req.Query == "" { return &types.Result{ Type: types.SearchTypeKB, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(start).Milliseconds(), Error: "query is required", }, nil } // Check if KB API is available if kb.API == nil { return &types.Result{ Type: types.SearchTypeKB, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(start).Milliseconds(), Error: "knowledge base not initialized", }, nil } // Get collections from request or config collections := req.Collections if len(collections) == 0 && h.config != nil { collections = h.config.Collections } // If no collections specified, return empty result if len(collections) == 0 { return &types.Result{ Type: types.SearchTypeKB, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(start).Milliseconds(), }, nil } // Get threshold from request or config threshold := req.Threshold if threshold == 0 && h.config != nil && h.config.Threshold > 0 { threshold = h.config.Threshold } if threshold == 0 { threshold = 0.7 // default } // Get limit limit := req.Limit if limit == 0 { limit = 10 // default } // Determine search mode mode := kbapi.SearchModeVector if req.Graph { mode = kbapi.SearchModeExpand } if h.config != nil && h.config.Graph { mode = kbapi.SearchModeExpand } // Build KB API queries - one per collection var queries []kbapi.Query for _, collectionID := range collections { queries = append(queries, kbapi.Query{ CollectionID: collectionID, Input: req.Query, Mode: mode, Threshold: threshold, PageSize: limit, Metadata: req.Metadata, }) } // Execute search using KB API ctx := context.Background() searchResult, err := kb.API.Search(ctx, queries) if err != nil { return &types.Result{ Type: types.SearchTypeKB, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(start).Milliseconds(), Error: fmt.Sprintf("search failed: %v", err), }, nil } // Convert segments to result items // Note: MinScore filtering is already done by KB API, no need to filter again items := make([]*types.ResultItem, 0, len(searchResult.Segments)) for _, seg := range searchResult.Segments { item := &types.ResultItem{ Type: types.SearchTypeKB, Source: req.Source, Score: seg.Score, Content: seg.Text, DocumentID: seg.DocumentID, Collection: seg.CollectionID, Metadata: seg.Metadata, } // Extract title from metadata if available if seg.Metadata != nil { if title, ok := seg.Metadata["title"].(string); ok { item.Title = title } } items = append(items, item) } // Convert graph data if available var graphNodes []*types.GraphNode if searchResult.Graph != nil { for _, node := range searchResult.Graph.Nodes { // Extract name from properties if available name := "" if node.Properties != nil { if n, ok := node.Properties["name"].(string); ok { name = n } } graphNodes = append(graphNodes, &types.GraphNode{ ID: node.ID, Type: node.EntityType, Name: name, Metadata: node.Properties, }) } } result := &types.Result{ Type: types.SearchTypeKB, Query: req.Query, Source: req.Source, Items: items, Total: len(items), Duration: time.Since(start).Milliseconds(), GraphNodes: graphNodes, } return result, nil } ================================================ FILE: agent/search/handlers/kb/handler_test.go ================================================ package kb import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/search/types" ) func TestNewHandler(t *testing.T) { t.Run("with nil config", func(t *testing.T) { h := NewHandler(nil) assert.NotNil(t, h) assert.Nil(t, h.config) }) t.Run("with config", func(t *testing.T) { cfg := &types.KBConfig{ Collections: []string{"docs", "faq"}, Threshold: 0.8, Graph: true, } h := NewHandler(cfg) assert.NotNil(t, h) assert.Equal(t, cfg, h.config) }) } func TestHandler_Type(t *testing.T) { h := NewHandler(nil) assert.Equal(t, types.SearchTypeKB, h.Type()) } func TestHandler_Search_Validation(t *testing.T) { tests := []struct { name string config *types.KBConfig req *types.Request expectError string }{ { name: "empty query", config: nil, req: &types.Request{ Type: types.SearchTypeKB, Query: "", }, expectError: "query is required", }, { name: "no collections - KB not initialized", config: nil, req: &types.Request{ Type: types.SearchTypeKB, Query: "test query", }, expectError: "knowledge base not initialized", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := NewHandler(tt.config) result, err := h.Search(tt.req) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, types.SearchTypeKB, result.Type) assert.Equal(t, tt.req.Query, result.Query) if tt.expectError != "" { assert.Equal(t, tt.expectError, result.Error) } else { assert.Empty(t, result.Error) } // Duration should be set assert.GreaterOrEqual(t, result.Duration, int64(0)) }) } } func TestHandler_Search_SourcePreserved(t *testing.T) { h := NewHandler(&types.KBConfig{Collections: []string{"docs"}}) sources := []types.SourceType{types.SourceUser, types.SourceHook, types.SourceAuto} for _, source := range sources { req := &types.Request{ Type: types.SearchTypeKB, Query: "test", Source: source, Collections: []string{"docs"}, } result, err := h.Search(req) assert.NoError(t, err) assert.Equal(t, source, result.Source) } } func TestHandler_Search_CollectionsFromConfig(t *testing.T) { cfg := &types.KBConfig{ Collections: []string{"docs", "faq"}, Threshold: 0.7, } h := NewHandler(cfg) // Request without collections should use config collections req := &types.Request{ Type: types.SearchTypeKB, Query: "test query", } result, err := h.Search(req) assert.NoError(t, err) assert.NotNil(t, result) // Without KB initialized, we get "knowledge base not initialized" error assert.Equal(t, "knowledge base not initialized", result.Error) } func TestHandler_Search_CollectionsFromRequest(t *testing.T) { h := NewHandler(nil) // Request with collections req := &types.Request{ Type: types.SearchTypeKB, Query: "test query", Collections: []string{"docs", "faq"}, } result, err := h.Search(req) assert.NoError(t, err) assert.NotNil(t, result) // Without KB initialized, we get "knowledge base not initialized" error assert.Equal(t, "knowledge base not initialized", result.Error) } func TestHandler_Search_ThresholdHandling(t *testing.T) { tests := []struct { name string configThreshold float64 reqThreshold float64 }{ { name: "threshold from request", configThreshold: 0.7, reqThreshold: 0.9, }, { name: "threshold from config", configThreshold: 0.8, reqThreshold: 0, }, { name: "default threshold", configThreshold: 0, reqThreshold: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var cfg *types.KBConfig if tt.configThreshold > 0 { cfg = &types.KBConfig{ Collections: []string{"docs"}, Threshold: tt.configThreshold, } } h := NewHandler(cfg) req := &types.Request{ Type: types.SearchTypeKB, Query: "test query", Threshold: tt.reqThreshold, Collections: []string{"docs"}, } result, err := h.Search(req) assert.NoError(t, err) assert.NotNil(t, result) }) } } func TestHandler_Search_GraphMode(t *testing.T) { tests := []struct { name string configGraph bool reqGraph bool }{ { name: "graph from request", configGraph: false, reqGraph: true, }, { name: "graph from config", configGraph: true, reqGraph: false, }, { name: "no graph", configGraph: false, reqGraph: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := &types.KBConfig{ Collections: []string{"docs"}, Graph: tt.configGraph, } h := NewHandler(cfg) req := &types.Request{ Type: types.SearchTypeKB, Query: "test query", Collections: []string{"docs"}, Graph: tt.reqGraph, } result, err := h.Search(req) assert.NoError(t, err) assert.NotNil(t, result) }) } } func TestHandler_Search_LimitHandling(t *testing.T) { h := NewHandler(&types.KBConfig{Collections: []string{"docs"}}) tests := []struct { name string limit int }{ { name: "custom limit", limit: 5, }, { name: "default limit", limit: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := &types.Request{ Type: types.SearchTypeKB, Query: "test query", Collections: []string{"docs"}, Limit: tt.limit, } result, err := h.Search(req) assert.NoError(t, err) assert.NotNil(t, result) }) } } ================================================ FILE: agent/search/handlers/web/agent.go ================================================ package web import ( "encoding/json" "fmt" "time" "github.com/yaoapp/yao/agent/caller" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // AgentProvider implements web search using another agent (AI Search) type AgentProvider struct { agentID string // Agent/Assistant ID (e.g., "workers.search.web") } // NewAgentProvider creates a new Agent provider func NewAgentProvider(agentID string) *AgentProvider { return &AgentProvider{ agentID: agentID, } } // Search executes web search via agent delegation // The agent can understand intent, generate optimized queries, and return structured results func (p *AgentProvider) Search(ctx *agentContext.Context, req *types.Request) (*types.Result, error) { startTime := time.Now() // Check if context is provided if ctx == nil { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(startTime).Milliseconds(), Error: "Agent mode requires context", }, nil } // Check if AgentGetterFunc is initialized if caller.AgentGetterFunc == nil { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(startTime).Milliseconds(), Error: "AgentGetterFunc not initialized", }, nil } // Get the agent agent, err := caller.AgentGetterFunc(p.agentID) if err != nil { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(startTime).Milliseconds(), Error: fmt.Sprintf("Agent '%s' not found: %v", p.agentID, err), }, nil } // Build message for the agent // Include search parameters in the message content searchParams := map[string]interface{}{ "query": req.Query, "type": "web", "source": string(req.Source), } if req.Limit > 0 { searchParams["limit"] = req.Limit } if len(req.Sites) > 0 { searchParams["sites"] = req.Sites } if req.TimeRange != "" { searchParams["time_range"] = req.TimeRange } // Convert to JSON for the message paramsJSON, err := json.Marshal(searchParams) if err != nil { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(startTime).Milliseconds(), Error: fmt.Sprintf("Failed to serialize search params: %v", err), }, nil } // Create message for the agent message := agentContext.Message{ Role: "user", Content: string(paramsJSON), } // Call the agent with skip options (no history, no output) opts := &agentContext.Options{ Skip: &agentContext.Skip{ History: true, Output: true, }, } response, err := agent.Stream(ctx, []agentContext.Message{message}, opts) if err != nil { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(startTime).Milliseconds(), Error: fmt.Sprintf("Agent call failed: %v", err), }, nil } // Parse the agent response items, total, parseErr := p.parseAgentResponse(response, req.Source) if parseErr != "" { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(startTime).Milliseconds(), Error: parseErr, }, nil } return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: items, Total: total, Duration: time.Since(startTime).Milliseconds(), }, nil } // parseAgentResponse parses the agent's *context.Response into search result items // Now that agent.Stream() returns *context.Response directly, // we can access fields without type assertions. // // The agent returns search results in response.Next field func (p *AgentProvider) parseAgentResponse(response *agentContext.Response, source types.SourceType) ([]*types.ResultItem, int, string) { if response == nil || response.Next == nil { return nil, 0, "Agent returned nil response" } // Extract data from Next field data := extractNextData(response.Next) if data == nil { return nil, 0, "Failed to extract data from agent response" } // Extract items from data items := []*types.ResultItem{} total := 0 if itemsData, ok := data["items"].([]interface{}); ok { for _, itemData := range itemsData { if item, ok := itemData.(map[string]interface{}); ok { resultItem := &types.ResultItem{ Type: types.SearchTypeWeb, Source: source, } if title, ok := item["title"].(string); ok { resultItem.Title = title } if content, ok := item["content"].(string); ok { resultItem.Content = content } if url, ok := item["url"].(string); ok { resultItem.URL = url } if score, ok := item["score"].(float64); ok { resultItem.Score = score } items = append(items, resultItem) } } } if totalVal, ok := data["total"].(float64); ok { total = int(totalVal) } else { total = len(items) } return items, total, "" } // extractNextData extracts the actual data from response.Next field // Handles nested structures like { "data": { ... } } func extractNextData(next interface{}) map[string]interface{} { if next == nil { return nil } switch v := next.(type) { case map[string]interface{}: // Check for "data" wrapper if data, ok := v["data"].(map[string]interface{}); ok { return data } return v case string: // Try to parse as JSON var data map[string]interface{} if err := json.Unmarshal([]byte(v), &data); err == nil { return extractNextData(data) } } // Try to handle other types by converting to JSON and back if bytes, err := json.Marshal(next); err == nil { var data map[string]interface{} if err := json.Unmarshal(bytes, &data); err == nil { return extractNextData(data) } } return nil } ================================================ FILE: agent/search/handlers/web/agent_test.go ================================================ package web_test import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/handlers/web" "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/testutils" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" ) // TestAgentProviderWithAssistantConfig tests AgentProvider using web-agent-caller assistant config func TestAgentProviderWithAssistantConfig(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Load the web-agent-caller test assistant to get its config ast, err := assistant.LoadPath("/assistants/tests/web-agent-caller") require.NoError(t, err) require.NotNil(t, ast) require.NotNil(t, ast.Uses) // Verify assistant config assert.Equal(t, "tests.web-agent-caller", ast.ID) assert.Equal(t, "tests.web-agent", ast.Uses.Web) // Create AgentProvider from uses.web provider := web.NewAgentProvider(ast.Uses.Web) require.NotNil(t, provider) // Create a mock context ctx := createTestContext(t) // Execute search req := &types.Request{ Query: "Yao App Engine", Type: types.SearchTypeWeb, Source: types.SourceAuto, Limit: 5, } result, err := provider.Search(ctx, req) require.NoError(t, err) require.NotNil(t, result) // Verify result structure assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, "Yao App Engine", result.Query) assert.Equal(t, types.SourceAuto, result.Source) // Agent should return mock results from Next hook if result.Error == "" { assert.Greater(t, result.Total, 0) assert.NotEmpty(t, result.Items) assert.Greater(t, result.Duration, int64(0)) // Verify result item structure for _, item := range result.Items { assert.Equal(t, types.SearchTypeWeb, item.Type) assert.Equal(t, types.SourceAuto, item.Source) assert.NotEmpty(t, item.Title) assert.NotEmpty(t, item.URL) } t.Logf("Agent search returned %d results in %dms", result.Total, result.Duration) } else { t.Logf("Agent search returned error: %s", result.Error) } } // TestAgentProviderWithSiteRestriction tests AgentProvider with domain restriction func TestAgentProviderWithSiteRestriction(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Create AgentProvider provider := web.NewAgentProvider("tests.web-agent") // Create a mock context ctx := createTestContext(t) // Execute search with site restriction req := &types.Request{ Query: "documentation", Type: types.SearchTypeWeb, Source: types.SourceHook, Sites: []string{"github.com"}, Limit: 3, } result, err := provider.Search(ctx, req) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, types.SourceHook, result.Source) if result.Error == "" { // All results should be from github.com (mock data respects sites) for _, item := range result.Items { assert.Contains(t, item.URL, "github.com", "Result URL should be from github.com") } t.Logf("Site-restricted agent search returned %d results", result.Total) } else { t.Logf("Agent search returned error: %s", result.Error) } } // TestAgentProviderWithTimeRange tests AgentProvider with time range filter func TestAgentProviderWithTimeRange(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Create AgentProvider provider := web.NewAgentProvider("tests.web-agent") // Create a mock context ctx := createTestContext(t) // Execute search with time range req := &types.Request{ Query: "artificial intelligence news", Type: types.SearchTypeWeb, Source: types.SourceAuto, TimeRange: "week", Limit: 5, } result, err := provider.Search(ctx, req) require.NoError(t, err) require.NotNil(t, result) if result.Error == "" { t.Logf("Time-ranged agent search (last week) returned %d results in %dms", result.Total, result.Duration) } else { t.Logf("Agent search returned error: %s", result.Error) } } // TestAgentProviderNotFound tests AgentProvider when agent is not found func TestAgentProviderNotFound(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Create AgentProvider with non-existent agent provider := web.NewAgentProvider("nonexistent.agent") // Create a mock context ctx := createTestContext(t) req := &types.Request{ Query: "test query", Type: types.SearchTypeWeb, Source: types.SourceAuto, } result, err := provider.Search(ctx, req) // Should not return error, but result should have error message require.NoError(t, err) require.NotNil(t, result) assert.NotEmpty(t, result.Error) assert.Contains(t, result.Error, "not found") } // TestAgentProviderWithoutContext tests AgentProvider without context func TestAgentProviderWithoutContext(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Create AgentProvider provider := web.NewAgentProvider("tests.web-agent") req := &types.Request{ Query: "test query", Type: types.SearchTypeWeb, Source: types.SourceAuto, } // Call without context (nil) result, err := provider.Search(nil, req) // Should still work - agent provider handles nil context require.NoError(t, err) require.NotNil(t, result) // May have error if context is required for agent call t.Logf("Agent search without context: error=%s, total=%d", result.Error, result.Total) } // TestWebHandlerAgentMode tests the web handler in agent mode func TestWebHandlerAgentMode(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Create handler with agent mode handler := web.NewHandler("tests.web-agent", nil) require.NotNil(t, handler) // Verify type assert.Equal(t, types.SearchTypeWeb, handler.Type()) // Create a mock context ctx := createTestContext(t) // Execute search with context req := &types.Request{ Query: "Yao framework", Type: types.SearchTypeWeb, Source: types.SourceAuto, Limit: 5, } result, err := handler.SearchWithContext(ctx, req) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, "Yao framework", result.Query) if result.Error == "" { t.Logf("Handler agent mode returned %d results", result.Total) } else { t.Logf("Handler agent mode returned error: %s", result.Error) } } // TestWebHandlerAgentModeWithoutContext tests the web handler in agent mode without context func TestWebHandlerAgentModeWithoutContext(t *testing.T) { // Create handler with agent mode handler := web.NewHandler("tests.web-agent", nil) require.NotNil(t, handler) req := &types.Request{ Query: "test", Type: types.SearchTypeWeb, Source: types.SourceAuto, } // Call Search() without context (uses SearchWithContext with nil) result, err := handler.Search(req) require.NoError(t, err) require.NotNil(t, result) assert.NotEmpty(t, result.Error) assert.Contains(t, result.Error, "requires context") } // createTestContext creates a test context for agent calls func createTestContext(t *testing.T) *agentContext.Context { authorized := &oauthTypes.AuthorizedInfo{ UserID: "test-user", TenantID: "test-tenant", } ctx := agentContext.New(context.Background(), authorized, "test-chat-id") ctx.AssistantID = "tests.web-agent-caller" return ctx } ================================================ FILE: agent/search/handlers/web/handler.go ================================================ package web import ( "fmt" "strings" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // Handler implements web search type Handler struct { usesWeb string // "builtin", "", "mcp:." config *types.WebConfig // Web search configuration } // NewHandler creates a new web search handler func NewHandler(usesWeb string, cfg *types.WebConfig) *Handler { return &Handler{usesWeb: usesWeb, config: cfg} } // Type returns the search type this handler supports func (h *Handler) Type() types.SearchType { return types.SearchTypeWeb } // Search executes web search based on uses.web mode // ctx is optional and only required for agent mode func (h *Handler) Search(req *types.Request) (*types.Result, error) { return h.SearchWithContext(nil, req) } // SearchWithContext executes web search with optional agent context // ctx is required for agent mode, optional for builtin and MCP modes func (h *Handler) SearchWithContext(ctx *agentContext.Context, req *types.Request) (*types.Result, error) { switch { case h.usesWeb == "builtin" || h.usesWeb == "": return h.builtinSearch(req) case strings.HasPrefix(h.usesWeb, "mcp:"): return h.mcpSearch(req) default: // Agent mode: delegate to assistant for AI-powered search if ctx == nil { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Error: "Agent mode requires context", }, nil } return h.agentSearch(ctx, req) } } // builtinSearch uses Tavily/Serper/SerpAPI directly func (h *Handler) builtinSearch(req *types.Request) (*types.Result, error) { // Determine provider from config providerName := "tavily" // default if h.config != nil && h.config.Provider != "" { providerName = h.config.Provider } switch providerName { case "tavily": return NewTavilyProvider(h.config).Search(req) case "serper": // Serper (serper.dev) - POST request with X-API-KEY header return NewSerperProvider(h.config).Search(req) case "serpapi": // SerpAPI (serpapi.com) - GET request with api_key parameter return NewSerpAPIProvider(h.config).Search(req) default: return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Error: fmt.Sprintf("Unknown provider: %s (supported: tavily, serper, serpapi)", providerName), }, nil } } // agentSearch delegates to an assistant for AI-powered search func (h *Handler) agentSearch(ctx *agentContext.Context, req *types.Request) (*types.Result, error) { provider := NewAgentProvider(h.usesWeb) return provider.Search(ctx, req) } // mcpSearch calls external MCP tool func (h *Handler) mcpSearch(req *types.Request) (*types.Result, error) { // Parse "mcp:server.tool" mcpRef := strings.TrimPrefix(h.usesWeb, "mcp:") provider, err := NewMCPProvider(mcpRef) if err != nil { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Error: fmt.Sprintf("Invalid MCP format: %v", err), }, nil } return provider.Search(req) } ================================================ FILE: agent/search/handlers/web/mcp.go ================================================ package web import ( "context" "encoding/json" "fmt" "strings" "time" "github.com/yaoapp/gou/mcp" gouMCPTypes "github.com/yaoapp/gou/mcp/types" "github.com/yaoapp/yao/agent/search/types" ) // MCPProvider implements web search using MCP tool type MCPProvider struct { serverID string // MCP server ID (e.g., "search") toolName string // MCP tool name (e.g., "web_search") } // NewMCPProvider creates a new MCP provider from "mcp:server.tool" format func NewMCPProvider(mcpRef string) (*MCPProvider, error) { // Parse "server.tool" format parts := strings.SplitN(mcpRef, ".", 2) if len(parts) != 2 { return nil, fmt.Errorf("invalid MCP format, expected 'server.tool', got '%s'", mcpRef) } return &MCPProvider{ serverID: parts[0], toolName: parts[1], }, nil } // Search executes web search via MCP tool func (p *MCPProvider) Search(req *types.Request) (*types.Result, error) { startTime := time.Now() // Select MCP client client, err := mcp.Select(p.serverID) if err != nil { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(startTime).Milliseconds(), Error: fmt.Sprintf("MCP client '%s' not found: %v", p.serverID, err), }, nil } // Build MCP tool arguments args := map[string]interface{}{ "query": req.Query, } if req.Limit > 0 { args["limit"] = req.Limit } if len(req.Sites) > 0 { args["sites"] = req.Sites } if req.TimeRange != "" { args["time_range"] = req.TimeRange } // Call MCP tool ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() result, err := client.CallTool(ctx, p.toolName, args) if err != nil { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(startTime).Milliseconds(), Error: fmt.Sprintf("MCP tool call failed: %v", err), }, nil } // Parse MCP result items, total, parseErr := p.parseResult(result, req.Source) if parseErr != "" { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(startTime).Milliseconds(), Error: parseErr, }, nil } return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: items, Total: total, Duration: time.Since(startTime).Milliseconds(), }, nil } // parseResult parses MCP tool result into search result items func (p *MCPProvider) parseResult(result *gouMCPTypes.CallToolResponse, source types.SourceType) ([]*types.ResultItem, int, string) { if result == nil { return nil, 0, "MCP returned nil result" } // Check for errors in result if result.IsError { errMsg := "MCP tool returned error" if len(result.Content) > 0 && result.Content[0].Text != "" { errMsg = result.Content[0].Text } return nil, 0, errMsg } // Parse content - expect JSON data if len(result.Content) == 0 { return []*types.ResultItem{}, 0, "" } // Try to extract data from content var data map[string]interface{} for _, content := range result.Content { // Check text content type if content.Type == gouMCPTypes.ToolContentTypeText && content.Text != "" { // Try to parse as JSON if parsed, ok := parseJSON(content.Text); ok { data = parsed break } } } if data == nil { return []*types.ResultItem{}, 0, "" } // Extract items from data items := []*types.ResultItem{} total := 0 if itemsData, ok := data["items"].([]interface{}); ok { for _, itemData := range itemsData { if item, ok := itemData.(map[string]interface{}); ok { resultItem := &types.ResultItem{ Type: types.SearchTypeWeb, Source: source, } if title, ok := item["title"].(string); ok { resultItem.Title = title } if content, ok := item["content"].(string); ok { resultItem.Content = content } if url, ok := item["url"].(string); ok { resultItem.URL = url } if score, ok := item["score"].(float64); ok { resultItem.Score = score } items = append(items, resultItem) } } } if totalVal, ok := data["total"].(float64); ok { total = int(totalVal) } else { total = len(items) } return items, total, "" } // parseJSON attempts to parse a string as JSON func parseJSON(s string) (map[string]interface{}, bool) { // Simple JSON detection - if it starts with { and ends with } s = strings.TrimSpace(s) if !strings.HasPrefix(s, "{") || !strings.HasSuffix(s, "}") { return nil, false } // Use encoding/json for parsing var result map[string]interface{} if err := json.Unmarshal([]byte(s), &result); err != nil { return nil, false } return result, true } ================================================ FILE: agent/search/handlers/web/mcp_test.go ================================================ package web_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/search/handlers/web" "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/testutils" ) // TestMCPProviderWithAssistantConfig tests MCPProvider using web-mcp assistant config func TestMCPProviderWithAssistantConfig(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Load the web-mcp test assistant to get its config ast, err := assistant.LoadPath("/assistants/tests/web-mcp") require.NoError(t, err) require.NotNil(t, ast) require.NotNil(t, ast.Uses) // Verify assistant config assert.Equal(t, "tests.web-mcp", ast.ID) assert.Equal(t, "mcp:search.web_search", ast.Uses.Web) // Create MCPProvider from uses.web mcpRef := ast.Uses.Web[4:] // Remove "mcp:" prefix provider, err := web.NewMCPProvider(mcpRef) require.NoError(t, err) require.NotNil(t, provider) // Execute search req := &types.Request{ Query: "Yao App Engine", Type: types.SearchTypeWeb, Source: types.SourceAuto, Limit: 5, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) // Verify result structure assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, "Yao App Engine", result.Query) assert.Equal(t, types.SourceAuto, result.Source) // MCP should return mock results if result.Error == "" { assert.Greater(t, result.Total, 0) assert.NotEmpty(t, result.Items) assert.Greater(t, result.Duration, int64(0)) // Verify result item structure for _, item := range result.Items { assert.Equal(t, types.SearchTypeWeb, item.Type) assert.Equal(t, types.SourceAuto, item.Source) assert.NotEmpty(t, item.Title) assert.NotEmpty(t, item.URL) } t.Logf("MCP search returned %d results in %dms", result.Total, result.Duration) } else { t.Logf("MCP search returned error (expected if MCP not loaded): %s", result.Error) } } // TestMCPProviderWithSiteRestriction tests MCPProvider with domain restriction func TestMCPProviderWithSiteRestriction(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Create MCPProvider provider, err := web.NewMCPProvider("search.web_search") require.NoError(t, err) // Execute search with site restriction req := &types.Request{ Query: "documentation", Type: types.SearchTypeWeb, Source: types.SourceHook, Sites: []string{"github.com"}, Limit: 3, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, types.SourceHook, result.Source) if result.Error == "" { t.Logf("Site-restricted MCP search returned %d results", result.Total) } else { t.Logf("MCP search returned error (expected if MCP not loaded): %s", result.Error) } } // TestMCPProviderWithTimeRange tests MCPProvider with time range filter func TestMCPProviderWithTimeRange(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Create MCPProvider provider, err := web.NewMCPProvider("search.web_search") require.NoError(t, err) // Execute search with time range req := &types.Request{ Query: "artificial intelligence news", Type: types.SearchTypeWeb, Source: types.SourceAuto, TimeRange: "week", Limit: 5, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) if result.Error == "" { t.Logf("Time-ranged MCP search (last week) returned %d results in %dms", result.Total, result.Duration) } else { t.Logf("MCP search returned error (expected if MCP not loaded): %s", result.Error) } } // TestMCPProviderInvalidFormat tests MCPProvider with invalid format func TestMCPProviderInvalidFormat(t *testing.T) { // Test invalid format without dot _, err := web.NewMCPProvider("invalid") require.Error(t, err) assert.Contains(t, err.Error(), "invalid MCP format") // Test empty string _, err = web.NewMCPProvider("") require.Error(t, err) assert.Contains(t, err.Error(), "invalid MCP format") } // TestMCPProviderNotFound tests MCPProvider when MCP server is not found func TestMCPProviderNotFound(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Create MCPProvider with non-existent server provider, err := web.NewMCPProvider("nonexistent.web_search") require.NoError(t, err) req := &types.Request{ Query: "test query", Type: types.SearchTypeWeb, Source: types.SourceAuto, } result, err := provider.Search(req) // Should not return error, but result should have error message require.NoError(t, err) require.NotNil(t, result) assert.NotEmpty(t, result.Error) assert.Contains(t, result.Error, "not found") } // TestWebHandlerMCPMode tests the web handler in MCP mode func TestWebHandlerMCPMode(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Create handler with MCP mode handler := web.NewHandler("mcp:search.web_search", nil) require.NotNil(t, handler) // Verify type assert.Equal(t, types.SearchTypeWeb, handler.Type()) // Execute search req := &types.Request{ Query: "Yao framework", Type: types.SearchTypeWeb, Source: types.SourceAuto, Limit: 5, } result, err := handler.Search(req) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, "Yao framework", result.Query) if result.Error == "" { t.Logf("Handler MCP mode returned %d results", result.Total) } else { t.Logf("Handler MCP mode returned error (expected if MCP not loaded): %s", result.Error) } } // TestWebHandlerInvalidMCPFormat tests the web handler with invalid MCP format func TestWebHandlerInvalidMCPFormat(t *testing.T) { // Create handler with invalid MCP format handler := web.NewHandler("mcp:invalid", nil) require.NotNil(t, handler) req := &types.Request{ Query: "test", Type: types.SearchTypeWeb, Source: types.SourceAuto, } result, err := handler.Search(req) require.NoError(t, err) require.NotNil(t, result) assert.NotEmpty(t, result.Error) assert.Contains(t, result.Error, "Invalid MCP format") } ================================================ FILE: agent/search/handlers/web/serpapi.go ================================================ package web import ( "encoding/json" "fmt" "io" "net/http" "net/url" "os" "time" "github.com/yaoapp/yao/agent/search/types" ) const ( serpAPIURL = "https://serpapi.com/search.json" serpAPITimeout = 30 * time.Second ) // SerpAPIProvider implements web search using SerpAPI (supports multiple search engines) type SerpAPIProvider struct { apiKey string maxResults int engine string // Search engine: "google", "bing", "baidu", "yandex", "duckduckgo", etc. } // NewSerpAPIProvider creates a new SerpAPI provider func NewSerpAPIProvider(cfg *types.WebConfig) *SerpAPIProvider { apiKey := "" if cfg != nil && cfg.APIKeyEnv != "" { // Support both "$ENV.VAR_NAME" and "VAR_NAME" formats envName := cfg.APIKeyEnv if len(envName) > 5 && envName[:5] == "$ENV." { envName = envName[5:] } apiKey = os.Getenv(envName) } maxResults := 10 if cfg != nil && cfg.MaxResults > 0 { maxResults = cfg.MaxResults } engine := "google" // Default to Google if cfg != nil && cfg.Engine != "" { engine = cfg.Engine } return &SerpAPIProvider{ apiKey: apiKey, maxResults: maxResults, engine: engine, } } // serpAPIResponse represents the response from SerpAPI type serpAPIResponse struct { SearchMetadata serpAPIMetadata `json:"search_metadata"` SearchParameters serpAPIParams `json:"search_parameters"` SearchInformation serpAPIInfo `json:"search_information"` OrganicResults []serpAPIResult `json:"organic_results"` AnswerBox *serpAPIAnswerBox `json:"answer_box,omitempty"` KnowledgeGraph *serpAPIKnowledge `json:"knowledge_graph,omitempty"` RelatedSearches []serpAPIRelated `json:"related_searches,omitempty"` RelatedQuestions []serpAPIQuestion `json:"related_questions,omitempty"` } // serpAPIMetadata contains metadata from response type serpAPIMetadata struct { ID string `json:"id"` Status string `json:"status"` CreatedAt string `json:"created_at"` ProcessedAt string `json:"processed_at"` TotalTimeTaken float64 `json:"total_time_taken"` } // serpAPIParams contains search parameters from response type serpAPIParams struct { Engine string `json:"engine"` Q string `json:"q"` Location string `json:"location_used"` GoogleDomain string `json:"google_domain"` HL string `json:"hl"` GL string `json:"gl"` Device string `json:"device"` } // serpAPIInfo contains search information type serpAPIInfo struct { QueryDisplayed string `json:"query_displayed"` TotalResults int64 `json:"total_results"` TimeTakenDisplayed float64 `json:"time_taken_displayed"` OrganicResultsState string `json:"organic_results_state"` } // serpAPIResult represents a single organic search result type serpAPIResult struct { Position int `json:"position"` Title string `json:"title"` Link string `json:"link"` RedirectLink string `json:"redirect_link,omitempty"` DisplayedLink string `json:"displayed_link"` Snippet string `json:"snippet"` Date string `json:"date,omitempty"` CachedPageLink string `json:"cached_page_link,omitempty"` } // serpAPIAnswerBox represents the answer box (featured snippet) type serpAPIAnswerBox struct { Type string `json:"type,omitempty"` Title string `json:"title,omitempty"` Snippet string `json:"snippet,omitempty"` Link string `json:"link,omitempty"` } // serpAPIKnowledge represents knowledge graph data type serpAPIKnowledge struct { Title string `json:"title,omitempty"` Type interface{} `json:"type,omitempty"` // Can be string or object depending on query Description string `json:"description,omitempty"` } // serpAPIRelated represents related searches type serpAPIRelated struct { Query string `json:"query"` Link string `json:"link"` } // serpAPIQuestion represents related questions (People Also Ask) type serpAPIQuestion struct { Question string `json:"question"` Snippet string `json:"snippet,omitempty"` Title string `json:"title,omitempty"` Link string `json:"link,omitempty"` } // Search executes a web search using SerpAPI func (p *SerpAPIProvider) Search(req *types.Request) (*types.Result, error) { startTime := time.Now() // Validate API key if p.apiKey == "" { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Error: "SerpAPI API key not configured", }, nil } // Determine max results maxResults := p.maxResults if req.Limit > 0 { maxResults = req.Limit } // Build query parameters params := url.Values{} params.Set("engine", p.engine) params.Set("api_key", p.apiKey) params.Set("num", fmt.Sprintf("%d", maxResults)) // Build search query with site restrictions if specified query := req.Query if len(req.Sites) > 0 { siteQuery := "" for i, site := range req.Sites { if i > 0 { siteQuery += " OR " } siteQuery += "site:" + site } query = "(" + siteQuery + ") " + req.Query } params.Set("q", query) // Add time range if specified (tbs parameter) if req.TimeRange != "" { tbs := convertSerpAPITimeRange(req.TimeRange) if tbs != "" { params.Set("tbs", tbs) } } // Execute API call serpResp, err := p.callAPI(params) if err != nil { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(startTime).Milliseconds(), Error: fmt.Sprintf("SerpAPI error: %v", err), }, nil } // Convert results items := make([]*types.ResultItem, 0, len(serpResp.OrganicResults)) // Add answer box as first result if available if serpResp.AnswerBox != nil && serpResp.AnswerBox.Snippet != "" { items = append(items, &types.ResultItem{ Type: types.SearchTypeWeb, Title: serpResp.AnswerBox.Title, Content: serpResp.AnswerBox.Snippet, URL: serpResp.AnswerBox.Link, Score: 1.0, // Featured snippet gets highest score Source: req.Source, Metadata: map[string]interface{}{ "type": "answer_box", }, }) } // Add organic results for _, r := range serpResp.OrganicResults { // Calculate score based on position (1st = 0.95, 2nd = 0.90, etc.) score := 1.0 - float64(r.Position)*0.05 if score < 0.1 { score = 0.1 } items = append(items, &types.ResultItem{ Type: types.SearchTypeWeb, Title: r.Title, Content: r.Snippet, URL: r.Link, Score: score, Source: req.Source, }) } return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: items, Total: len(items), Duration: time.Since(startTime).Milliseconds(), }, nil } // callAPI makes the HTTP GET request to SerpAPI func (p *SerpAPIProvider) callAPI(params url.Values) (*serpAPIResponse, error) { // Build URL with query parameters reqURL := serpAPIURL + "?" + params.Encode() // Create HTTP request httpReq, err := http.NewRequest(http.MethodGet, reqURL, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } // Execute request client := &http.Client{Timeout: serpAPITimeout} resp, err := client.Do(httpReq) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() // Read response body respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } // Check status code if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(respBody)) } // Parse response var serpResp serpAPIResponse if err := json.Unmarshal(respBody, &serpResp); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } return &serpResp, nil } // convertSerpAPITimeRange converts time range to SerpAPI tbs format func convertSerpAPITimeRange(timeRange string) string { switch timeRange { case "hour": return "qdr:h" case "day": return "qdr:d" case "week": return "qdr:w" case "month": return "qdr:m" case "year": return "qdr:y" default: return "" } } ================================================ FILE: agent/search/handlers/web/serpapi_test.go ================================================ package web_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/search/handlers/web" "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/testutils" ) // TestSerpAPIProviderWithAssistantConfig tests SerpAPIProvider using web-serpapi assistant config func TestSerpAPIProviderWithAssistantConfig(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serpapi test assistant to get its config ast, err := assistant.LoadPath("/assistants/tests/web-serpapi") require.NoError(t, err) require.NotNil(t, ast) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Verify assistant config assert.Equal(t, "tests.web-serpapi", ast.ID) assert.Equal(t, "serpapi", ast.Search.Web.Provider) assert.Equal(t, "$ENV.SERPAPI_API_KEY", ast.Search.Web.APIKeyEnv) assert.Equal(t, 10, ast.Search.Web.MaxResults) // Create SerpAPIProvider with assistant's web config provider := web.NewSerpAPIProvider(ast.Search.Web) require.NotNil(t, provider) // Execute search req := &types.Request{ Query: "golang programming language", Type: types.SearchTypeWeb, Source: types.SourceAuto, Limit: 5, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) // Verify result structure assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, "golang programming language", result.Query) assert.Equal(t, types.SourceAuto, result.Source) // API key must be valid - search should succeed require.Empty(t, result.Error, "Search should succeed with valid API key, got error: %s", result.Error) // Verify we got results assert.Greater(t, result.Total, 0) assert.NotEmpty(t, result.Items) assert.Greater(t, result.Duration, int64(0)) // Verify result item structure for _, item := range result.Items { assert.Equal(t, types.SearchTypeWeb, item.Type) assert.Equal(t, types.SourceAuto, item.Source) assert.NotEmpty(t, item.Title) assert.NotEmpty(t, item.URL) assert.Greater(t, item.Score, 0.0) } t.Logf("Search returned %d results in %dms", result.Total, result.Duration) } // TestSerpAPIProviderWithSiteRestriction tests SerpAPIProvider with domain restriction func TestSerpAPIProviderWithSiteRestriction(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serpapi test assistant ast, err := assistant.LoadPath("/assistants/tests/web-serpapi") require.NoError(t, err) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Create SerpAPIProvider provider := web.NewSerpAPIProvider(ast.Search.Web) // Execute search with site restriction req := &types.Request{ Query: "documentation", Type: types.SearchTypeWeb, Source: types.SourceHook, Sites: []string{"github.com"}, Limit: 3, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, types.SourceHook, result.Source) // API key must be valid - search should succeed require.Empty(t, result.Error, "Search should succeed with valid API key, got error: %s", result.Error) require.NotEmpty(t, result.Items, "Search should return results") // All results should be from github.com for _, item := range result.Items { assert.Contains(t, item.URL, "github.com", "Result URL should be from github.com") } t.Logf("Site-restricted search returned %d results from github.com", result.Total) } // TestSerpAPIProviderWithMultipleSites tests SerpAPIProvider with multiple domain restrictions func TestSerpAPIProviderWithMultipleSites(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serpapi test assistant ast, err := assistant.LoadPath("/assistants/tests/web-serpapi") require.NoError(t, err) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Create SerpAPIProvider provider := web.NewSerpAPIProvider(ast.Search.Web) // Execute search with multiple site restrictions req := &types.Request{ Query: "golang tutorial", Type: types.SearchTypeWeb, Source: types.SourceAuto, Sites: []string{"github.com", "golang.org"}, Limit: 5, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) // API key must be valid - search should succeed require.Empty(t, result.Error, "Search should succeed with valid API key, got error: %s", result.Error) require.NotEmpty(t, result.Items, "Search should return results") // Results should be from either github.com or golang.org for _, item := range result.Items { isValidSite := false for _, site := range req.Sites { if containsSite(item.URL, site) { isValidSite = true break } } assert.True(t, isValidSite, "Result URL should be from github.com or golang.org: %s", item.URL) } t.Logf("Multi-site search returned %d results", result.Total) } // TestSerpAPIProviderWithTimeRange tests SerpAPIProvider with time range filter func TestSerpAPIProviderWithTimeRange(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serpapi test assistant ast, err := assistant.LoadPath("/assistants/tests/web-serpapi") require.NoError(t, err) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Create SerpAPIProvider provider := web.NewSerpAPIProvider(ast.Search.Web) // Execute search with time range req := &types.Request{ Query: "artificial intelligence news", Type: types.SearchTypeWeb, Source: types.SourceAuto, TimeRange: "week", // Last week Limit: 5, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) // API key must be valid - search should succeed require.Empty(t, result.Error, "Search should succeed with valid API key, got error: %s", result.Error) t.Logf("Time-ranged search (last week) returned %d results in %dms", result.Total, result.Duration) } // TestSerpAPIProviderWithoutAPIKey tests graceful degradation when API key is missing func TestSerpAPIProviderWithoutAPIKey(t *testing.T) { // Create provider with nil config (no API key) provider := web.NewSerpAPIProvider(nil) require.NotNil(t, provider) req := &types.Request{ Query: "test query", Type: types.SearchTypeWeb, Source: types.SourceAuto, } result, err := provider.Search(req) // Should not return error, but result should have error message require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, "test query", result.Query) assert.NotEmpty(t, result.Error) assert.Contains(t, result.Error, "API key") assert.Empty(t, result.Items) assert.Equal(t, 0, result.Total) } // TestSerpAPIProviderWithEmptyConfig tests provider with empty config func TestSerpAPIProviderWithEmptyConfig(t *testing.T) { // Create provider with empty config cfg := &types.WebConfig{} provider := web.NewSerpAPIProvider(cfg) require.NotNil(t, provider) req := &types.Request{ Query: "test query", Type: types.SearchTypeWeb, Source: types.SourceUser, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) assert.NotEmpty(t, result.Error) assert.Contains(t, result.Error, "API key") } // TestSerpAPIProviderMaxResults tests that max_results from config is respected func TestSerpAPIProviderMaxResults(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serpapi test assistant ast, err := assistant.LoadPath("/assistants/tests/web-serpapi") require.NoError(t, err) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Create SerpAPIProvider provider := web.NewSerpAPIProvider(ast.Search.Web) // Execute search without limit (should use config's max_results) req := &types.Request{ Query: "machine learning", Type: types.SearchTypeWeb, Source: types.SourceAuto, // No Limit set, should use config's max_results (10) } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) // API key must be valid - search should succeed require.Empty(t, result.Error, "Search should succeed with valid API key, got error: %s", result.Error) // Should respect max_results from config (+1 for possible answer box) assert.LessOrEqual(t, result.Total, ast.Search.Web.MaxResults+1) t.Logf("Search without limit returned %d results (max: %d)", result.Total, ast.Search.Web.MaxResults) // Execute search with explicit limit req2 := &types.Request{ Query: "machine learning", Type: types.SearchTypeWeb, Source: types.SourceAuto, Limit: 3, // Override config's max_results } result2, err := provider.Search(req2) require.NoError(t, err) require.NotNil(t, result2) // API key must be valid - search should succeed require.Empty(t, result2.Error, "Search should succeed with valid API key, got error: %s", result2.Error) // Should respect request's limit (+1 for possible answer box) assert.LessOrEqual(t, result2.Total, 4) t.Logf("Search with limit=3 returned %d results", result2.Total) } // TestSerpAPIProviderWithBingEngine tests SerpAPIProvider with Bing search engine func TestSerpAPIProviderWithBingEngine(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serpapi test assistant to get base config ast, err := assistant.LoadPath("/assistants/tests/web-serpapi") require.NoError(t, err) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Create config with Bing engine bingConfig := &types.WebConfig{ Provider: "serpapi", APIKeyEnv: ast.Search.Web.APIKeyEnv, MaxResults: 5, Engine: "bing", } // Create SerpAPIProvider with Bing engine provider := web.NewSerpAPIProvider(bingConfig) require.NotNil(t, provider) // Execute search req := &types.Request{ Query: "Golang programming", Type: types.SearchTypeWeb, Source: types.SourceAuto, Limit: 5, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) // Verify result structure assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, "Golang programming", result.Query) // API key must be valid - search should succeed require.Empty(t, result.Error, "Bing search should succeed with valid API key, got error: %s", result.Error) // Verify we got results assert.Greater(t, result.Total, 0) assert.NotEmpty(t, result.Items) t.Logf("Bing search returned %d results in %dms", result.Total, result.Duration) } // TestSerpAPIProviderEngineDefault tests that default engine is Google func TestSerpAPIProviderEngineDefault(t *testing.T) { // Create provider with config that has no engine specified cfg := &types.WebConfig{ Provider: "serpapi", APIKeyEnv: "SERPAPI_API_KEY", MaxResults: 10, // Engine not set - should default to "google" } provider := web.NewSerpAPIProvider(cfg) require.NotNil(t, provider) // We can't directly check the engine field since it's private, // but we verify the provider is created successfully // The actual engine usage is tested in integration tests } // containsSite checks if url contains the site domain func containsSite(url, site string) bool { return len(url) >= len(site) && containsHelper(url, site) } // containsHelper is a helper function for string containment check func containsHelper(s, substr string) bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false } ================================================ FILE: agent/search/handlers/web/serper.go ================================================ package web import ( "bytes" "encoding/json" "fmt" "io" "net/http" "os" "time" "github.com/yaoapp/yao/agent/search/types" ) const ( serperAPIURL = "https://google.serper.dev/search" serperAPITimeout = 30 * time.Second ) // SerperProvider implements web search using Serper API (serper.dev) type SerperProvider struct { apiKey string maxResults int } // NewSerperProvider creates a new Serper provider func NewSerperProvider(cfg *types.WebConfig) *SerperProvider { apiKey := "" if cfg != nil && cfg.APIKeyEnv != "" { // Support both "$ENV.VAR_NAME" and "VAR_NAME" formats envName := cfg.APIKeyEnv if len(envName) > 5 && envName[:5] == "$ENV." { envName = envName[5:] } apiKey = os.Getenv(envName) } maxResults := 10 if cfg != nil && cfg.MaxResults > 0 { maxResults = cfg.MaxResults } return &SerperProvider{ apiKey: apiKey, maxResults: maxResults, } } // serperRequest represents the request body for Serper API type serperRequest struct { Q string `json:"q"` // Search query Num int `json:"num,omitempty"` // Number of results (default: 10, max: 100) GL string `json:"gl,omitempty"` // Country code (e.g., "us", "cn") HL string `json:"hl,omitempty"` // Language code (e.g., "en", "zh-cn") TBS string `json:"tbs,omitempty"` // Time-based search (qdr:h, qdr:d, qdr:w, qdr:m, qdr:y) Page int `json:"page,omitempty"` // Page number (default: 1) AutoCor bool `json:"autocorrect"` // Auto-correct spelling } // serperResponse represents the response from Serper API type serperResponse struct { SearchParameters serperSearchParams `json:"searchParameters"` Organic []serperResult `json:"organic"` AnswerBox *serperAnswerBox `json:"answerBox,omitempty"` KnowledgeGraph *serperKnowledge `json:"knowledgeGraph,omitempty"` RelatedSearches []serperRelated `json:"relatedSearches,omitempty"` } // serperSearchParams contains search parameters from response type serperSearchParams struct { Q string `json:"q"` Type string `json:"type"` GL string `json:"gl"` HL string `json:"hl"` Num int `json:"num"` } // serperResult represents a single organic search result type serperResult struct { Title string `json:"title"` Link string `json:"link"` Snippet string `json:"snippet"` Position int `json:"position"` Date string `json:"date,omitempty"` } // serperAnswerBox represents the answer box (featured snippet) type serperAnswerBox struct { Title string `json:"title,omitempty"` Snippet string `json:"snippet,omitempty"` Link string `json:"link,omitempty"` } // serperKnowledge represents knowledge graph data type serperKnowledge struct { Title string `json:"title,omitempty"` Type string `json:"type,omitempty"` Description string `json:"description,omitempty"` } // serperRelated represents related searches type serperRelated struct { Query string `json:"query"` } // Search executes a web search using Serper API func (p *SerperProvider) Search(req *types.Request) (*types.Result, error) { startTime := time.Now() // Validate API key if p.apiKey == "" { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Error: "Serper API key not configured", }, nil } // Determine max results maxResults := p.maxResults if req.Limit > 0 { maxResults = req.Limit } // Build search query with site restrictions if specified query := req.Query if len(req.Sites) > 0 { // Serper uses "site:domain" syntax in query if len(req.Sites) == 1 { query = "site:" + req.Sites[0] + " " + req.Query } else { // Multiple sites: (site:domain1 OR site:domain2) query siteQuery := "" for i, site := range req.Sites { if i > 0 { siteQuery += " OR " } siteQuery += "site:" + site } query = "(" + siteQuery + ") " + req.Query } } // Build request body serperReq := serperRequest{ Q: query, Num: maxResults, AutoCor: true, } // Add time range if specified if req.TimeRange != "" { serperReq.TBS = convertSerperTimeRange(req.TimeRange) } // Execute API call serperResp, err := p.callAPI(&serperReq) if err != nil { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(startTime).Milliseconds(), Error: fmt.Sprintf("Serper API error: %v", err), }, nil } // Convert results items := make([]*types.ResultItem, 0, len(serperResp.Organic)) // Add answer box as first result if available if serperResp.AnswerBox != nil && serperResp.AnswerBox.Snippet != "" { items = append(items, &types.ResultItem{ Type: types.SearchTypeWeb, Title: serperResp.AnswerBox.Title, Content: serperResp.AnswerBox.Snippet, URL: serperResp.AnswerBox.Link, Score: 1.0, // Featured snippet gets highest score Source: req.Source, Metadata: map[string]interface{}{ "type": "answer_box", }, }) } // Add organic results for _, r := range serperResp.Organic { // Calculate score based on position (1st = 0.95, 2nd = 0.90, etc.) score := 1.0 - float64(r.Position)*0.05 if score < 0.1 { score = 0.1 } items = append(items, &types.ResultItem{ Type: types.SearchTypeWeb, Title: r.Title, Content: r.Snippet, URL: r.Link, Score: score, Source: req.Source, }) } return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: items, Total: len(items), Duration: time.Since(startTime).Milliseconds(), }, nil } // callAPI makes the HTTP POST request to Serper API func (p *SerperProvider) callAPI(req *serperRequest) (*serperResponse, error) { // Serialize request body body, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } // Create HTTP request httpReq, err := http.NewRequest(http.MethodPost, serperAPIURL, bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("X-API-KEY", p.apiKey) // Execute request client := &http.Client{Timeout: serperAPITimeout} resp, err := client.Do(httpReq) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() // Read response body respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } // Check status code if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(respBody)) } // Parse response var serperResp serperResponse if err := json.Unmarshal(respBody, &serperResp); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } return &serperResp, nil } // convertSerperTimeRange converts time range to Serper tbs format func convertSerperTimeRange(timeRange string) string { switch timeRange { case "hour": return "qdr:h" case "day": return "qdr:d" case "week": return "qdr:w" case "month": return "qdr:m" case "year": return "qdr:y" default: return "" } } ================================================ FILE: agent/search/handlers/web/serper_test.go ================================================ package web_test import ( "os" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/search/handlers/web" "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/testutils" ) // skipIfNoSerperKey skips the test if SERPER_API_KEY is not set // Note: Serper (serper.dev) requires registration at https://serper.dev func skipIfNoSerperKey(t *testing.T) { if os.Getenv("SERPER_API_KEY") == "" { t.Skip("Skipping Serper test: SERPER_API_KEY not set. Register at https://serper.dev for free 2500 queries.") } } // TestSerperProviderWithAssistantConfig tests SerperProvider using web-serper assistant config func TestSerperProviderWithAssistantConfig(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } skipIfNoSerperKey(t) testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serper test assistant to get its config ast, err := assistant.LoadPath("/assistants/tests/web-serper") require.NoError(t, err) require.NotNil(t, ast) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Verify assistant config assert.Equal(t, "tests.web-serper", ast.ID) assert.Equal(t, "serper", ast.Search.Web.Provider) assert.Equal(t, "$ENV.SERPER_API_KEY", ast.Search.Web.APIKeyEnv) assert.Equal(t, 10, ast.Search.Web.MaxResults) // Create SerperProvider with assistant's web config provider := web.NewSerperProvider(ast.Search.Web) require.NotNil(t, provider) // Execute search req := &types.Request{ Query: "Yao App Engine", Type: types.SearchTypeWeb, Source: types.SourceAuto, Limit: 5, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) // Verify result structure assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, "Yao App Engine", result.Query) assert.Equal(t, types.SourceAuto, result.Source) // API key must be valid - search should succeed require.Empty(t, result.Error, "Search should succeed with valid API key, got error: %s", result.Error) // Verify we got results assert.Greater(t, result.Total, 0) assert.NotEmpty(t, result.Items) assert.Greater(t, result.Duration, int64(0)) // Verify result item structure for _, item := range result.Items { assert.Equal(t, types.SearchTypeWeb, item.Type) assert.Equal(t, types.SourceAuto, item.Source) assert.NotEmpty(t, item.Title) assert.NotEmpty(t, item.URL) assert.Greater(t, item.Score, 0.0) } t.Logf("Search returned %d results in %dms", result.Total, result.Duration) } // TestSerperProviderWithSiteRestriction tests SerperProvider with domain restriction func TestSerperProviderWithSiteRestriction(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } skipIfNoSerperKey(t) testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serper test assistant ast, err := assistant.LoadPath("/assistants/tests/web-serper") require.NoError(t, err) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Create SerperProvider provider := web.NewSerperProvider(ast.Search.Web) // Execute search with site restriction req := &types.Request{ Query: "documentation", Type: types.SearchTypeWeb, Source: types.SourceHook, Sites: []string{"github.com"}, Limit: 3, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, types.SourceHook, result.Source) // API key must be valid - search should succeed require.Empty(t, result.Error, "Search should succeed with valid API key, got error: %s", result.Error) require.NotEmpty(t, result.Items, "Search should return results") // All results should be from github.com for _, item := range result.Items { assert.Contains(t, item.URL, "github.com", "Result URL should be from github.com") } t.Logf("Site-restricted search returned %d results from github.com", result.Total) } // TestSerperProviderWithMultipleSites tests SerperProvider with multiple domain restrictions func TestSerperProviderWithMultipleSites(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } skipIfNoSerperKey(t) testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serper test assistant ast, err := assistant.LoadPath("/assistants/tests/web-serper") require.NoError(t, err) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Create SerperProvider provider := web.NewSerperProvider(ast.Search.Web) // Execute search with multiple site restrictions req := &types.Request{ Query: "golang tutorial", Type: types.SearchTypeWeb, Source: types.SourceAuto, Sites: []string{"github.com", "golang.org"}, Limit: 5, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) // API key must be valid - search should succeed require.Empty(t, result.Error, "Search should succeed with valid API key, got error: %s", result.Error) require.NotEmpty(t, result.Items, "Search should return results") // Results should be from either github.com or golang.org for _, item := range result.Items { isValidSite := false for _, site := range req.Sites { if contains(item.URL, site) { isValidSite = true break } } assert.True(t, isValidSite, "Result URL should be from github.com or golang.org: %s", item.URL) } t.Logf("Multi-site search returned %d results", result.Total) } // TestSerperProviderWithTimeRange tests SerperProvider with time range filter func TestSerperProviderWithTimeRange(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } skipIfNoSerperKey(t) testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serper test assistant ast, err := assistant.LoadPath("/assistants/tests/web-serper") require.NoError(t, err) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Create SerperProvider provider := web.NewSerperProvider(ast.Search.Web) // Execute search with time range req := &types.Request{ Query: "artificial intelligence news", Type: types.SearchTypeWeb, Source: types.SourceAuto, TimeRange: "week", // Last week Limit: 5, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) // API key must be valid - search should succeed require.Empty(t, result.Error, "Search should succeed with valid API key, got error: %s", result.Error) t.Logf("Time-ranged search (last week) returned %d results in %dms", result.Total, result.Duration) } // TestSerperProviderWithoutAPIKey tests graceful degradation when API key is missing func TestSerperProviderWithoutAPIKey(t *testing.T) { // Create provider with nil config (no API key) provider := web.NewSerperProvider(nil) require.NotNil(t, provider) req := &types.Request{ Query: "test query", Type: types.SearchTypeWeb, Source: types.SourceAuto, } result, err := provider.Search(req) // Should not return error, but result should have error message require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, "test query", result.Query) assert.NotEmpty(t, result.Error) assert.Contains(t, result.Error, "API key") assert.Empty(t, result.Items) assert.Equal(t, 0, result.Total) } // TestSerperProviderWithEmptyConfig tests provider with empty config func TestSerperProviderWithEmptyConfig(t *testing.T) { // Create provider with empty config cfg := &types.WebConfig{} provider := web.NewSerperProvider(cfg) require.NotNil(t, provider) req := &types.Request{ Query: "test query", Type: types.SearchTypeWeb, Source: types.SourceUser, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) assert.NotEmpty(t, result.Error) assert.Contains(t, result.Error, "API key") } // TestSerperProviderMaxResults tests that max_results from config is respected func TestSerperProviderMaxResults(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } skipIfNoSerperKey(t) testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serper test assistant ast, err := assistant.LoadPath("/assistants/tests/web-serper") require.NoError(t, err) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Create SerperProvider provider := web.NewSerperProvider(ast.Search.Web) // Execute search without limit (should use config's max_results) req := &types.Request{ Query: "machine learning", Type: types.SearchTypeWeb, Source: types.SourceAuto, // No Limit set, should use config's max_results (10) } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) // API key must be valid - search should succeed require.Empty(t, result.Error, "Search should succeed with valid API key, got error: %s", result.Error) // Should respect max_results from config (+1 for possible answer box) assert.LessOrEqual(t, result.Total, ast.Search.Web.MaxResults+1) t.Logf("Search without limit returned %d results (max: %d)", result.Total, ast.Search.Web.MaxResults) // Execute search with explicit limit req2 := &types.Request{ Query: "machine learning", Type: types.SearchTypeWeb, Source: types.SourceAuto, Limit: 3, // Override config's max_results } result2, err := provider.Search(req2) require.NoError(t, err) require.NotNil(t, result2) // API key must be valid - search should succeed require.Empty(t, result2.Error, "Search should succeed with valid API key, got error: %s", result2.Error) // Should respect request's limit (+1 for possible answer box) assert.LessOrEqual(t, result2.Total, 4) t.Logf("Search with limit=3 returned %d results", result2.Total) } // contains checks if s contains substr (uses containsSite from serpapi_test.go) func contains(s, substr string) bool { return containsSite(s, substr) } ================================================ FILE: agent/search/handlers/web/tavily.go ================================================ package web import ( "bytes" "encoding/json" "fmt" "io" "net/http" "os" "time" "github.com/yaoapp/yao/agent/search/types" ) const ( tavilyAPIURL = "https://api.tavily.com/search" tavilyAPITimeout = 30 * time.Second ) // TavilyProvider implements web search using Tavily API type TavilyProvider struct { apiKey string maxResults int } // NewTavilyProvider creates a new Tavily provider func NewTavilyProvider(cfg *types.WebConfig) *TavilyProvider { apiKey := "" if cfg != nil && cfg.APIKeyEnv != "" { // Support both "$ENV.VAR_NAME" and "VAR_NAME" formats envName := cfg.APIKeyEnv if len(envName) > 5 && envName[:5] == "$ENV." { envName = envName[5:] } apiKey = os.Getenv(envName) } maxResults := 10 if cfg != nil && cfg.MaxResults > 0 { maxResults = cfg.MaxResults } return &TavilyProvider{ apiKey: apiKey, maxResults: maxResults, } } // tavilyRequest represents the request body for Tavily API type tavilyRequest struct { APIKey string `json:"api_key"` Query string `json:"query"` SearchDepth string `json:"search_depth,omitempty"` // "basic" or "advanced" IncludeAnswer bool `json:"include_answer,omitempty"` // Include AI-generated answer IncludeRawContent bool `json:"include_raw_content,omitempty"` // Include raw HTML content MaxResults int `json:"max_results,omitempty"` // Max number of results IncludeDomains []string `json:"include_domains,omitempty"` // Limit to specific domains ExcludeDomains []string `json:"exclude_domains,omitempty"` // Exclude specific domains } // tavilyResponse represents the response from Tavily API type tavilyResponse struct { Query string `json:"query"` Answer string `json:"answer,omitempty"` Results []tavilyResult `json:"results"` } // tavilyResult represents a single search result from Tavily type tavilyResult struct { Title string `json:"title"` URL string `json:"url"` Content string `json:"content"` Score float64 `json:"score"` RawContent string `json:"raw_content,omitempty"` } // Search executes a web search using Tavily API func (p *TavilyProvider) Search(req *types.Request) (*types.Result, error) { startTime := time.Now() // Validate API key if p.apiKey == "" { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Error: "Tavily API key not configured", }, nil } // Determine max results maxResults := p.maxResults if req.Limit > 0 { maxResults = req.Limit } // Build request body tavilyReq := tavilyRequest{ APIKey: p.apiKey, Query: req.Query, SearchDepth: "basic", IncludeAnswer: false, MaxResults: maxResults, } // Add domain restrictions if specified if len(req.Sites) > 0 { tavilyReq.IncludeDomains = req.Sites } // Execute API call tavilyResp, err := p.callAPI(&tavilyReq) if err != nil { return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: []*types.ResultItem{}, Total: 0, Duration: time.Since(startTime).Milliseconds(), Error: fmt.Sprintf("Tavily API error: %v", err), }, nil } // Convert results items := make([]*types.ResultItem, 0, len(tavilyResp.Results)) for _, r := range tavilyResp.Results { items = append(items, &types.ResultItem{ Type: types.SearchTypeWeb, Title: r.Title, Content: r.Content, URL: r.URL, Score: r.Score, Source: req.Source, }) } return &types.Result{ Type: types.SearchTypeWeb, Query: req.Query, Source: req.Source, Items: items, Total: len(items), Duration: time.Since(startTime).Milliseconds(), }, nil } // callAPI makes the HTTP request to Tavily API func (p *TavilyProvider) callAPI(req *tavilyRequest) (*tavilyResponse, error) { // Serialize request body body, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } // Create HTTP request httpReq, err := http.NewRequest(http.MethodPost, tavilyAPIURL, bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") // Execute request client := &http.Client{Timeout: tavilyAPITimeout} resp, err := client.Do(httpReq) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() // Read response body respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } // Check status code if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(respBody)) } // Parse response var tavilyResp tavilyResponse if err := json.Unmarshal(respBody, &tavilyResp); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } return &tavilyResp, nil } ================================================ FILE: agent/search/handlers/web/tavily_test.go ================================================ package web_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/search/handlers/web" "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/testutils" ) // TestTavilyProviderWithAssistantConfig tests TavilyProvider using web-tavily assistant config func TestTavilyProviderWithAssistantConfig(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Load the web-tavily test assistant to get its config ast, err := assistant.LoadPath("/assistants/tests/web-tavily") require.NoError(t, err) require.NotNil(t, ast) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Verify assistant config assert.Equal(t, "tests.web-tavily", ast.ID) assert.Equal(t, "tavily", ast.Search.Web.Provider) assert.Equal(t, "$ENV.TAVILY_API_KEY", ast.Search.Web.APIKeyEnv) assert.Equal(t, 10, ast.Search.Web.MaxResults) // Create TavilyProvider with assistant's web config provider := web.NewTavilyProvider(ast.Search.Web) require.NotNil(t, provider) // Execute search req := &types.Request{ Query: "Yao App Engine", Type: types.SearchTypeWeb, Source: types.SourceAuto, Limit: 5, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) // Verify result structure assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, "Yao App Engine", result.Query) assert.Equal(t, types.SourceAuto, result.Source) // API key must be valid - search should succeed require.Empty(t, result.Error, "Search should succeed with valid API key, got error: %s", result.Error) // Verify we got results assert.Greater(t, result.Total, 0) assert.NotEmpty(t, result.Items) assert.Greater(t, result.Duration, int64(0)) // Verify result item structure for _, item := range result.Items { assert.Equal(t, types.SearchTypeWeb, item.Type) assert.Equal(t, types.SourceAuto, item.Source) assert.NotEmpty(t, item.Title) assert.NotEmpty(t, item.URL) // Content may be empty for some results } t.Logf("Search returned %d results in %dms", result.Total, result.Duration) } // TestTavilyProviderWithSiteRestriction tests TavilyProvider with domain restriction func TestTavilyProviderWithSiteRestriction(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Load the web-tavily test assistant ast, err := assistant.LoadPath("/assistants/tests/web-tavily") require.NoError(t, err) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Create TavilyProvider provider := web.NewTavilyProvider(ast.Search.Web) // Execute search with site restriction req := &types.Request{ Query: "documentation", Type: types.SearchTypeWeb, Source: types.SourceHook, Sites: []string{"github.com"}, Limit: 3, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, types.SourceHook, result.Source) // API key must be valid - search should succeed require.Empty(t, result.Error, "Search should succeed with valid API key, got error: %s", result.Error) require.NotEmpty(t, result.Items, "Search should return results") // All results should be from github.com for _, item := range result.Items { assert.Contains(t, item.URL, "github.com", "Result URL should be from github.com") } t.Logf("Site-restricted search returned %d results from github.com", result.Total) } // TestTavilyProviderWithoutAPIKey tests graceful degradation when API key is missing func TestTavilyProviderWithoutAPIKey(t *testing.T) { // Create provider with nil config (no API key) provider := web.NewTavilyProvider(nil) require.NotNil(t, provider) req := &types.Request{ Query: "test query", Type: types.SearchTypeWeb, Source: types.SourceAuto, } result, err := provider.Search(req) // Should not return error, but result should have error message require.NoError(t, err) require.NotNil(t, result) assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, "test query", result.Query) assert.NotEmpty(t, result.Error) assert.Contains(t, result.Error, "API key") assert.Empty(t, result.Items) assert.Equal(t, 0, result.Total) } // TestTavilyProviderWithEmptyConfig tests provider with empty config func TestTavilyProviderWithEmptyConfig(t *testing.T) { // Create provider with empty config cfg := &types.WebConfig{} provider := web.NewTavilyProvider(cfg) require.NotNil(t, provider) req := &types.Request{ Query: "test query", Type: types.SearchTypeWeb, Source: types.SourceUser, } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) assert.NotEmpty(t, result.Error) assert.Contains(t, result.Error, "API key") } // TestTavilyProviderMaxResults tests that max_results from config is respected func TestTavilyProviderMaxResults(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test") } testutils.Prepare(t) defer testutils.Clean(t) // Load the web-tavily test assistant ast, err := assistant.LoadPath("/assistants/tests/web-tavily") require.NoError(t, err) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Create TavilyProvider provider := web.NewTavilyProvider(ast.Search.Web) // Execute search without limit (should use config's max_results) req := &types.Request{ Query: "artificial intelligence", Type: types.SearchTypeWeb, Source: types.SourceAuto, // No Limit set, should use config's max_results (10) } result, err := provider.Search(req) require.NoError(t, err) require.NotNil(t, result) // API key must be valid - search should succeed require.Empty(t, result.Error, "Search should succeed with valid API key, got error: %s", result.Error) // Should respect max_results from config assert.LessOrEqual(t, result.Total, ast.Search.Web.MaxResults) t.Logf("Search without limit returned %d results (max: %d)", result.Total, ast.Search.Web.MaxResults) // Execute search with explicit limit req2 := &types.Request{ Query: "artificial intelligence", Type: types.SearchTypeWeb, Source: types.SourceAuto, Limit: 3, // Override config's max_results } result2, err := provider.Search(req2) require.NoError(t, err) require.NotNil(t, result2) // API key must be valid - search should succeed require.Empty(t, result2.Error, "Search should succeed with valid API key, got error: %s", result2.Error) // Should respect request's limit assert.LessOrEqual(t, result2.Total, 3) t.Logf("Search with limit=3 returned %d results", result2.Total) } ================================================ FILE: agent/search/interfaces/handler.go ================================================ package interfaces import ( "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // Handler defines the interface for search implementations type Handler interface { // Type returns the search type this handler supports Type() types.SearchType // Search executes the search and returns results Search(req *types.Request) (*types.Result, error) } // ContextHandler extends Handler with context support // Handlers that need context (e.g., DB handler for QueryDSL generation) should implement this type ContextHandler interface { Handler // SearchWithContext executes the search with context and returns results SearchWithContext(ctx *context.Context, req *types.Request) (*types.Result, error) } ================================================ FILE: agent/search/interfaces/nlp.go ================================================ package interfaces import ( "github.com/yaoapp/gou/model" "github.com/yaoapp/gou/query/gou" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // KeywordExtractor extracts keywords for web search type KeywordExtractor interface { // Extract extracts search keywords from user message // ctx is required for Agent and MCP modes, can be nil for builtin mode Extract(ctx *context.Context, content string, opts *types.KeywordOptions) ([]string, error) } // QueryDSLGenerator generates QueryDSL for DB search type QueryDSLGenerator interface { // Generate converts natural language to QueryDSL Generate(query string, models []*model.Model) (*gou.QueryDSL, error) } // Note: Embedding is handled by KB collection's own config (embedding provider + model), // not defined here. See KB handler for details. ================================================ FILE: agent/search/interfaces/reranker.go ================================================ package interfaces import ( "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // Reranker reorders search results by relevance type Reranker interface { // Rerank reorders results based on query relevance // ctx is required for Agent and MCP modes, can be nil for builtin mode Rerank(ctx *context.Context, query string, items []*types.ResultItem, opts *types.RerankOptions) ([]*types.ResultItem, error) } ================================================ FILE: agent/search/interfaces/searcher.go ================================================ package interfaces import ( "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // Searcher is the main interface exposed to external callers type Searcher interface { // Search executes a single search request Search(ctx *context.Context, req *types.Request) (*types.Result, error) // Parallel search methods - inspired by JavaScript Promise // All waits for all searches to complete (like Promise.all) All(ctx *context.Context, reqs []*types.Request) ([]*types.Result, error) // Any returns when any search succeeds with results (like Promise.any) Any(ctx *context.Context, reqs []*types.Request) ([]*types.Result, error) // Race returns when any search completes (like Promise.race) Race(ctx *context.Context, reqs []*types.Request) ([]*types.Result, error) // BuildReferences converts search results to unified Reference format for LLM BuildReferences(results []*types.Result) []*types.Reference } ================================================ FILE: agent/search/jsapi.go ================================================ package search import ( "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // JSAPI implements context.SearchAPI interface // Provides ctx.search.Web(), ctx.search.KB(), ctx.search.DB(), ctx.search.All(), ctx.search.Any(), ctx.search.Race() type JSAPI struct { ctx *context.Context searcher *Searcher } // NewJSAPI creates a new search JSAPI instance func NewJSAPI(ctx *context.Context, config *types.Config, uses *Uses) *JSAPI { return &JSAPI{ ctx: ctx, searcher: New(config, uses), } } // Web executes web search // Options: // - limit: int - max results (default: 10) // - sites: []string - restrict to specific sites // - time_range: string - "day", "week", "month", "year" // - rerank: map[string]interface{} - rerank options func (api *JSAPI) Web(query string, opts map[string]interface{}) interface{} { req := api.buildRequest(types.SearchTypeWeb, query, opts) result, _ := api.searcher.Search(api.ctx, req) return result } // KB executes knowledge base search // Options: // - collections: []string - collection IDs // - threshold: float64 - similarity threshold (0-1) // - limit: int - max results // - graph: bool - enable graph association // - rerank: map[string]interface{} - rerank options func (api *JSAPI) KB(query string, opts map[string]interface{}) interface{} { req := api.buildRequest(types.SearchTypeKB, query, opts) result, _ := api.searcher.Search(api.ctx, req) return result } // DB executes database search // Options: // - models: []string - model IDs // - wheres: []map[string]interface{} - pre-defined filters (GOU QueryDSL Where format) // - orders: []map[string]interface{} - sort orders (GOU QueryDSL Order format) // - select: []string - fields to return // - limit: int - max results // - rerank: map[string]interface{} - rerank options func (api *JSAPI) DB(query string, opts map[string]interface{}) interface{} { req := api.buildRequest(types.SearchTypeDB, query, opts) result, _ := api.searcher.Search(api.ctx, req) return result } // All executes all searches and waits for all to complete (like Promise.all) // Each request should have: // - type: string - "web", "kb", or "db" // - query: string - search query // - ... other type-specific options func (api *JSAPI) All(requests []interface{}) []interface{} { reqs := api.parseRequests(requests) results, _ := api.searcher.All(api.ctx, reqs) return api.convertResults(results) } // Any returns as soon as any search succeeds with results (like Promise.any) // Each request should have: // - type: string - "web", "kb", or "db" // - query: string - search query // - ... other type-specific options func (api *JSAPI) Any(requests []interface{}) []interface{} { reqs := api.parseRequests(requests) results, _ := api.searcher.Any(api.ctx, reqs) return api.convertResults(results) } // Race returns as soon as any search completes (like Promise.race) // Each request should have: // - type: string - "web", "kb", or "db" // - query: string - search query // - ... other type-specific options func (api *JSAPI) Race(requests []interface{}) []interface{} { reqs := api.parseRequests(requests) results, _ := api.searcher.Race(api.ctx, reqs) return api.convertResults(results) } // buildRequest builds a Request from query and options func (api *JSAPI) buildRequest(searchType types.SearchType, query string, opts map[string]interface{}) *types.Request { req := &types.Request{ Type: searchType, Query: query, Source: types.SourceHook, // JSAPI calls are from hooks } if opts == nil { return req } // Common options if limit, ok := opts["limit"].(float64); ok { req.Limit = int(limit) } else if limit, ok := opts["limit"].(int); ok { req.Limit = limit } // Web-specific options if searchType == types.SearchTypeWeb { if sites, ok := opts["sites"].([]interface{}); ok { req.Sites = toStringSlice(sites) } if timeRange, ok := opts["time_range"].(string); ok { req.TimeRange = timeRange } } // KB-specific options if searchType == types.SearchTypeKB { if collections, ok := opts["collections"].([]interface{}); ok { req.Collections = toStringSlice(collections) } if threshold, ok := opts["threshold"].(float64); ok { req.Threshold = threshold } if graph, ok := opts["graph"].(bool); ok { req.Graph = graph } } // DB-specific options if searchType == types.SearchTypeDB { if models, ok := opts["models"].([]interface{}); ok { req.Models = toStringSlice(models) } if selectFields, ok := opts["select"].([]interface{}); ok { req.Select = toStringSlice(selectFields) } // Note: wheres and orders are more complex, handled by QueryDSL generator } // Rerank options if rerankOpts, ok := opts["rerank"].(map[string]interface{}); ok { req.Rerank = &types.RerankOptions{} if topN, ok := rerankOpts["top_n"].(float64); ok { req.Rerank.TopN = int(topN) } else if topN, ok := rerankOpts["top_n"].(int); ok { req.Rerank.TopN = topN } } return req } // parseRequests parses an array of request objects into typed Requests func (api *JSAPI) parseRequests(requests []interface{}) []*types.Request { reqs := make([]*types.Request, 0, len(requests)) for _, r := range requests { reqMap, ok := r.(map[string]interface{}) if !ok { continue } // Get type typeStr, ok := reqMap["type"].(string) if !ok { continue } searchType := types.SearchType(typeStr) // Get query query, ok := reqMap["query"].(string) if !ok { continue } // Build request with remaining options req := api.buildRequest(searchType, query, reqMap) reqs = append(reqs, req) } return reqs } // convertResults converts typed Results to interface slice for JS func (api *JSAPI) convertResults(results []*types.Result) []interface{} { out := make([]interface{}, len(results)) for i, r := range results { out[i] = r } return out } // toStringSlice converts []interface{} to []string func toStringSlice(arr []interface{}) []string { result := make([]string, 0, len(arr)) for _, v := range arr { if s, ok := v.(string); ok { result = append(result, s) } } return result } // ConfigGetter is a function type that retrieves search config and uses for an assistant type ConfigGetter func(assistantID string) (*types.Config, *Uses) // configGetter is set by assistant package during initialization var configGetter ConfigGetter // SetJSAPIFactory sets the factory function for creating SearchAPI instances // Called by assistant package during initialization // getter: function to get search config and uses from assistant ID func SetJSAPIFactory(getter ConfigGetter) { configGetter = getter context.SearchAPIFactory = func(ctx *context.Context) context.SearchAPI { var config *types.Config var uses *Uses if configGetter != nil && ctx.AssistantID != "" { config, uses = configGetter(ctx.AssistantID) } return NewJSAPI(ctx, config, uses) } } ================================================ FILE: agent/search/jsapi_db_test.go ================================================ package search_test import ( "testing" "github.com/yaoapp/gou/model" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search" "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/testutils" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // ============================================================================ // DB Search JSAPI Integration Tests // ============================================================================ func TestJSAPI_DB_Integration(t *testing.T) { // Skip if running short tests if testing.Short() { t.Skip("Skipping integration test") } // Initialize test environment (loads models, database, query engine, etc.) testutils.Prepare(t) defer testutils.Clean(t) // Create test context ctx := newJSAPITestContext(t) // Verify __yao.role model is loaded mod := model.Select("__yao.role") require.NotNil(t, mod, "__yao.role model should be loaded") // Ensure test data exists ensureJSAPITestRole(t, mod) t.Run("db_search_with_context", func(t *testing.T) { api := search.NewJSAPI(ctx, &types.Config{ DB: &types.DBConfig{ Models: []string{"__yao.role"}, MaxResults: 10, }, }, &search.Uses{QueryDSL: "builtin"}) result := api.DB("查询所有角色", map[string]interface{}{ "models": []interface{}{"__yao.role"}, "limit": float64(10), }) require.NotNil(t, result) r, ok := result.(*types.Result) require.True(t, ok) assert.Equal(t, types.SearchTypeDB, r.Type) assert.Equal(t, "查询所有角色", r.Query) assert.Equal(t, types.SourceHook, r.Source) if r.Error != "" { t.Logf("Search error: %s", r.Error) } assert.Empty(t, r.Error, "Should not have error") assert.Greater(t, len(r.Items), 0, "Should have results") }) t.Run("db_search_with_scenario", func(t *testing.T) { api := search.NewJSAPI(ctx, &types.Config{ DB: &types.DBConfig{ Models: []string{"__yao.role"}, MaxResults: 5, }, }, &search.Uses{QueryDSL: "builtin"}) result := api.DB("查询系统角色", map[string]interface{}{ "models": []interface{}{"__yao.role"}, "scenario": "filter", "limit": float64(5), }) require.NotNil(t, result) r, ok := result.(*types.Result) require.True(t, ok) assert.Equal(t, types.SearchTypeDB, r.Type) assert.LessOrEqual(t, len(r.Items), 5, "Should respect limit") }) t.Run("db_search_with_select_fields", func(t *testing.T) { api := search.NewJSAPI(ctx, &types.Config{ DB: &types.DBConfig{ Models: []string{"__yao.role"}, MaxResults: 10, }, }, &search.Uses{QueryDSL: "builtin"}) result := api.DB("查询角色名称", map[string]interface{}{ "models": []interface{}{"__yao.role"}, "select": []interface{}{"id", "name", "description"}, "limit": float64(10), }) require.NotNil(t, result) r, ok := result.(*types.Result) require.True(t, ok) assert.Equal(t, types.SearchTypeDB, r.Type) if r.Error == "" && len(r.Items) > 0 { // Verify items have data for _, item := range r.Items { assert.NotNil(t, item.Data) assert.Equal(t, "__yao.role", item.Model) } } }) t.Run("db_search_all_with_multiple_types", func(t *testing.T) { api := search.NewJSAPI(ctx, &types.Config{ KB: &types.KBConfig{Collections: []string{"docs"}}, DB: &types.DBConfig{ Models: []string{"__yao.role"}, MaxResults: 10, }, }, &search.Uses{QueryDSL: "builtin"}) requests := []interface{}{ map[string]interface{}{ "type": "db", "query": "查询角色", "models": []interface{}{"__yao.role"}, "limit": float64(5), }, map[string]interface{}{ "type": "kb", "query": "知识库查询", "collections": []interface{}{"docs"}, "limit": float64(5), }, } results := api.All(requests) require.Len(t, results, 2) // DB result r0, ok := results[0].(*types.Result) require.True(t, ok) assert.Equal(t, types.SearchTypeDB, r0.Type) // KB result r1, ok := results[1].(*types.Result) require.True(t, ok) assert.Equal(t, types.SearchTypeKB, r1.Type) }) } // newJSAPITestContext creates a test context for JSAPI tests func newJSAPITestContext(t *testing.T) *context.Context { t.Helper() authorized := &oauthTypes.AuthorizedInfo{ UserID: "test-user-jsapi", } chatID := "test-chat-jsapi-db" ctx := context.New(t.Context(), authorized, chatID) return ctx } // ensureJSAPITestRole ensures there's at least one role in the database func ensureJSAPITestRole(t *testing.T, mod *model.Model) { t.Helper() // Try to find existing roles rows, err := mod.Get(model.QueryParam{Limit: 1}) if err == nil && len(rows) > 0 { return } // Create a test role _, err = mod.Create(map[string]interface{}{ "role_id": "jsapi_test_role", "name": "JSAPI Test Role", "description": "A test role for JSAPI unit testing", "is_active": true, "is_system": false, "level": 1, }) if err != nil { t.Logf("Note: Could not create test role: %v", err) } } ================================================ FILE: agent/search/jsapi_test.go ================================================ package search_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search" "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/testutils" ) func TestNewJSAPI(t *testing.T) { api := search.NewJSAPI(nil, nil, nil) require.NotNil(t, api) } func TestJSAPI_Web(t *testing.T) { api := search.NewJSAPI(nil, &types.Config{ Web: &types.WebConfig{Provider: "tavily"}, }, &search.Uses{Web: "builtin"}) result := api.Web("test query", nil) require.NotNil(t, result) r, ok := result.(*types.Result) require.True(t, ok) assert.Equal(t, types.SearchTypeWeb, r.Type) assert.Equal(t, "test query", r.Query) assert.Equal(t, types.SourceHook, r.Source) } func TestJSAPI_Web_WithOptions(t *testing.T) { api := search.NewJSAPI(nil, &types.Config{ Web: &types.WebConfig{Provider: "tavily"}, }, &search.Uses{Web: "builtin"}) opts := map[string]interface{}{ "limit": float64(5), "sites": []interface{}{"github.com", "stackoverflow.com"}, "time_range": "week", } result := api.Web("golang concurrency", opts) require.NotNil(t, result) r, ok := result.(*types.Result) require.True(t, ok) assert.Equal(t, types.SearchTypeWeb, r.Type) assert.Equal(t, "golang concurrency", r.Query) } func TestJSAPI_KB(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) api := search.NewJSAPI(nil, &types.Config{ KB: &types.KBConfig{Collections: []string{"docs"}}, }, nil) result := api.KB("test query", nil) require.NotNil(t, result) r, ok := result.(*types.Result) require.True(t, ok) assert.Equal(t, types.SearchTypeKB, r.Type) assert.Equal(t, "test query", r.Query) assert.Equal(t, types.SourceHook, r.Source) } func TestJSAPI_KB_WithOptions(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) api := search.NewJSAPI(nil, &types.Config{ KB: &types.KBConfig{Collections: []string{"docs"}}, }, nil) opts := map[string]interface{}{ "collections": []interface{}{"docs", "faq"}, "threshold": 0.8, "limit": float64(10), "graph": true, } result := api.KB("knowledge base query", opts) require.NotNil(t, result) r, ok := result.(*types.Result) require.True(t, ok) assert.Equal(t, types.SearchTypeKB, r.Type) assert.Equal(t, "knowledge base query", r.Query) } func TestJSAPI_DB(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) api := search.NewJSAPI(nil, &types.Config{ DB: &types.DBConfig{Models: []string{"product"}}, }, &search.Uses{QueryDSL: "builtin"}) result := api.DB("test query", nil) require.NotNil(t, result) r, ok := result.(*types.Result) require.True(t, ok) assert.Equal(t, types.SearchTypeDB, r.Type) assert.Equal(t, "test query", r.Query) assert.Equal(t, types.SourceHook, r.Source) } func TestJSAPI_DB_WithOptions(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) api := search.NewJSAPI(nil, &types.Config{ DB: &types.DBConfig{Models: []string{"product"}}, }, &search.Uses{QueryDSL: "builtin"}) opts := map[string]interface{}{ "models": []interface{}{"product", "order"}, "select": []interface{}{"id", "name", "price"}, "limit": float64(20), } result := api.DB("database query", opts) require.NotNil(t, result) r, ok := result.(*types.Result) require.True(t, ok) assert.Equal(t, types.SearchTypeDB, r.Type) assert.Equal(t, "database query", r.Query) } func TestJSAPI_All(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) api := search.NewJSAPI(nil, &types.Config{ KB: &types.KBConfig{Collections: []string{"docs"}}, DB: &types.DBConfig{Models: []string{"product"}}, }, nil) requests := []interface{}{ map[string]interface{}{ "type": "kb", "query": "KB query", }, map[string]interface{}{ "type": "db", "query": "DB query", }, } results := api.All(requests) require.Len(t, results, 2) // First result (KB) r0, ok := results[0].(*types.Result) require.True(t, ok) assert.Equal(t, types.SearchTypeKB, r0.Type) assert.Equal(t, "KB query", r0.Query) // Second result (DB) r1, ok := results[1].(*types.Result) require.True(t, ok) assert.Equal(t, types.SearchTypeDB, r1.Type) assert.Equal(t, "DB query", r1.Query) } func TestJSAPI_Any(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) api := search.NewJSAPI(nil, &types.Config{ KB: &types.KBConfig{Collections: []string{"docs"}}, DB: &types.DBConfig{Models: []string{"product"}}, }, nil) requests := []interface{}{ map[string]interface{}{ "type": "kb", "query": "KB query", }, map[string]interface{}{ "type": "db", "query": "DB query", }, } results := api.Any(requests) require.Len(t, results, 2) // At least one result should be present hasResult := false for _, r := range results { if r != nil { hasResult = true break } } assert.True(t, hasResult) } func TestJSAPI_Race(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) api := search.NewJSAPI(nil, &types.Config{ KB: &types.KBConfig{Collections: []string{"docs"}}, DB: &types.DBConfig{Models: []string{"product"}}, }, nil) requests := []interface{}{ map[string]interface{}{ "type": "kb", "query": "KB query", }, map[string]interface{}{ "type": "db", "query": "DB query", }, } results := api.Race(requests) require.Len(t, results, 2) // At least one result should be present hasResult := false for _, r := range results { if r != nil { hasResult = true break } } assert.True(t, hasResult) } func TestJSAPI_All_Empty(t *testing.T) { api := search.NewJSAPI(nil, nil, nil) results := api.All([]interface{}{}) assert.Len(t, results, 0) } func TestJSAPI_Any_Empty(t *testing.T) { api := search.NewJSAPI(nil, nil, nil) results := api.Any([]interface{}{}) assert.Len(t, results, 0) } func TestJSAPI_Race_Empty(t *testing.T) { api := search.NewJSAPI(nil, nil, nil) results := api.Race([]interface{}{}) assert.Len(t, results, 0) } func TestJSAPI_Web_WithRerank(t *testing.T) { api := search.NewJSAPI(nil, &types.Config{ Web: &types.WebConfig{Provider: "tavily"}, }, &search.Uses{Web: "builtin"}) opts := map[string]interface{}{ "limit": float64(10), "rerank": map[string]interface{}{ "top_n": float64(5), }, } result := api.Web("test query", opts) require.NotNil(t, result) r, ok := result.(*types.Result) require.True(t, ok) assert.Equal(t, types.SearchTypeWeb, r.Type) } func TestJSAPI_All_InvalidRequests(t *testing.T) { api := search.NewJSAPI(nil, &types.Config{ Web: &types.WebConfig{Provider: "tavily"}, }, &search.Uses{Web: "builtin"}) // Mix of invalid and valid requests requests := []interface{}{ "invalid", // Not a map map[string]interface{}{ "query": "no type", // Missing type }, map[string]interface{}{ "type": "web", // Missing query }, map[string]interface{}{ "type": "web", "query": "valid query", }, } results := api.All(requests) // Only the valid request should produce a result assert.Len(t, results, 1) } func TestSetJSAPIFactory(t *testing.T) { // Reset factory context.SearchAPIFactory = nil // Set factory with nil getter (uses defaults) search.SetJSAPIFactory(nil) // Verify factory is set require.NotNil(t, context.SearchAPIFactory) // Create a mock context ctx := context.New(nil, nil, "test-chat") // Get search API searchAPI := context.SearchAPIFactory(ctx) require.NotNil(t, searchAPI) } func TestSetJSAPIFactory_WithGetter(t *testing.T) { // Reset factory context.SearchAPIFactory = nil // Set factory with custom getter search.SetJSAPIFactory(func(assistantID string) (*types.Config, *search.Uses) { if assistantID == "test-assistant" { return &types.Config{ Web: &types.WebConfig{Provider: "tavily"}, }, &search.Uses{Web: "builtin"} } return nil, nil }) // Verify factory is set require.NotNil(t, context.SearchAPIFactory) // Create a context with assistant ID ctx := context.New(nil, nil, "test-chat") ctx.AssistantID = "test-assistant" // Get search API searchAPI := context.SearchAPIFactory(ctx) require.NotNil(t, searchAPI) } func TestJSAPI_ImplementsSearchAPI(t *testing.T) { // Verify JSAPI implements context.SearchAPI interface var _ context.SearchAPI = search.NewJSAPI(nil, nil, nil) } ================================================ FILE: agent/search/nlp/keyword/agent.go ================================================ package keyword import ( "encoding/json" "fmt" "github.com/yaoapp/yao/agent/caller" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // AgentProvider delegates keyword extraction to an LLM-powered assistant // The assistant can understand context and extract semantically relevant keywords type AgentProvider struct { agentID string // Assistant ID to delegate to } // NewAgentProvider creates a new agent-based keyword extractor func NewAgentProvider(agentID string) *AgentProvider { return &AgentProvider{ agentID: agentID, } } // Extract extracts keywords by calling the target agent // The agent receives the content and returns extracted keywords with weights func (p *AgentProvider) Extract(ctx *agentContext.Context, content string, opts *types.KeywordOptions) ([]types.Keyword, error) { if ctx == nil { return nil, fmt.Errorf("context is required for agent keyword extraction") } // Check if AgentGetterFunc is initialized if caller.AgentGetterFunc == nil { return nil, fmt.Errorf("AgentGetterFunc not initialized") } // Get the agent agent, err := caller.AgentGetterFunc(p.agentID) if err != nil { return nil, fmt.Errorf("failed to get agent %s: %w", p.agentID, err) } // Build the request message requestData := map[string]interface{}{ "content": content, "max_keywords": opts.MaxKeywords, "language": opts.Language, } requestJSON, _ := json.Marshal(requestData) // Create message for the agent messages := []agentContext.Message{ { Role: "user", Content: string(requestJSON), }, } // Call the agent with skip options (no history, no output) options := &agentContext.Options{ Skip: &agentContext.Skip{ History: true, Output: true, }, } response, err := agent.Stream(ctx, messages, options) if err != nil { return nil, fmt.Errorf("agent call failed: %w", err) } // Parse the result from response.Next return p.parseResponse(response) } // parseResponse extracts keywords from the agent's *context.Response // Now that agent.Stream() returns *context.Response directly, // we can access fields without type assertions. // // The agent returns keywords in response.Next field as {data: {keywords: [{k, w}, ...]}} func (p *AgentProvider) parseResponse(response *agentContext.Response) ([]types.Keyword, error) { if response == nil || response.Next == nil { return []types.Keyword{}, nil } return p.parseNextData(response.Next) } // parseNextData extracts keywords from Next hook data // Expected format: {data: {keywords: [{k: "keyword", w: 0.9}, ...]}} func (p *AgentProvider) parseNextData(next interface{}) ([]types.Keyword, error) { if next == nil { return []types.Keyword{}, nil } // Try to convert to map first (most common case) var data map[string]interface{} switch v := next.(type) { case map[string]interface{}: data = v case string: // Try to parse as JSON if err := json.Unmarshal([]byte(v), &data); err != nil { // Not a JSON object, try as array of keywords var keywords []types.Keyword if err := json.Unmarshal([]byte(v), &keywords); err == nil { return keywords, nil } // Return as single keyword with default weight return []types.Keyword{{K: v, W: 0.5}}, nil } case []types.Keyword: return v, nil case []interface{}: return p.extractKeywordsFromArray(v) default: // Try to marshal and unmarshal jsonBytes, err := json.Marshal(next) if err != nil { return []types.Keyword{}, nil } if err := json.Unmarshal(jsonBytes, &data); err != nil { return []types.Keyword{}, nil } } // Extract keywords from data // Try common field names: "keywords", "data", "data.keywords" if kw, ok := data["keywords"]; ok { return p.extractKeywordsFromValue(kw) } if d, ok := data["data"]; ok { if dm, ok := d.(map[string]interface{}); ok { if kw, ok := dm["keywords"]; ok { return p.extractKeywordsFromValue(kw) } } return p.extractKeywordsFromValue(d) } return []types.Keyword{}, nil } // extractKeywordsFromValue extracts Keyword array from various types func (p *AgentProvider) extractKeywordsFromValue(v interface{}) ([]types.Keyword, error) { switch kw := v.(type) { case []types.Keyword: return kw, nil case []interface{}: return p.extractKeywordsFromArray(kw) case string: var keywords []types.Keyword if err := json.Unmarshal([]byte(kw), &keywords); err == nil { return keywords, nil } return []types.Keyword{{K: kw, W: 0.5}}, nil } return []types.Keyword{}, nil } // extractKeywordsFromArray extracts keywords from []interface{} // Handles both {k, w} objects and plain strings func (p *AgentProvider) extractKeywordsFromArray(items []interface{}) ([]types.Keyword, error) { keywords := make([]types.Keyword, 0, len(items)) for _, item := range items { switch v := item.(type) { case map[string]interface{}: // Handle {k: "keyword", w: 0.9} format k, _ := v["k"].(string) w, _ := v["w"].(float64) if k != "" { if w == 0 { w = 0.5 // Default weight } keywords = append(keywords, types.Keyword{K: k, W: w}) } case string: // Plain string, use default weight if v != "" { keywords = append(keywords, types.Keyword{K: v, W: 0.5}) } } } return keywords, nil } ================================================ FILE: agent/search/nlp/keyword/agent_test.go ================================================ package keyword_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/nlp/keyword" searchTypes "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/testutils" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" ) func TestAgentProviderWithAssistantConfig(t *testing.T) { // Skip if running short tests if testing.Short() { t.Skip("Skipping integration test") } // Initialize test environment testutils.Prepare(t) defer testutils.Clean(t) // Load the keyword-agent assistant that will provide keywords ast, err := assistant.Get("tests.keyword-agent") require.NoError(t, err) require.NotNil(t, ast) // Create test context ctx := newTestContext(t) // Create extractor with agent mode extractor := keyword.NewExtractor("tests.keyword-agent", &searchTypes.KeywordConfig{ MaxKeywords: 5, Language: "auto", }) // Test extraction content := "Machine learning and deep learning are subfields of artificial intelligence" keywords, err := extractor.Extract(ctx, content, nil) require.NoError(t, err) assert.NotEmpty(t, keywords, "Agent should return keywords") assert.LessOrEqual(t, len(keywords), 5, "Should respect max_keywords") // Verify keywords are relevant t.Logf("Extracted keywords: %v", keywords) } func TestAgentProviderWithCustomOptions(t *testing.T) { // Skip if running short tests if testing.Short() { t.Skip("Skipping integration test") } // Initialize test environment testutils.Prepare(t) defer testutils.Clean(t) // Create test context ctx := newTestContext(t) // Create extractor with agent mode extractor := keyword.NewExtractor("tests.keyword-agent", &searchTypes.KeywordConfig{ MaxKeywords: 10, }) // Test with runtime options override content := "Python programming language for data science and web development" keywords, err := extractor.Extract(ctx, content, &searchTypes.KeywordOptions{ MaxKeywords: 3, // Override to 3 }) require.NoError(t, err) assert.NotEmpty(t, keywords) assert.LessOrEqual(t, len(keywords), 3, "Should respect runtime max_keywords override") t.Logf("Extracted keywords (max 3): %v", keywords) } func TestAgentProviderWithoutContext(t *testing.T) { // Test that agent mode requires context extractor := keyword.NewExtractor("tests.keyword-agent", nil) _, err := extractor.Extract(nil, "test content", nil) assert.Error(t, err, "Agent mode should require context") assert.Contains(t, err.Error(), "context is required") } func TestAgentProviderAgentNotFound(t *testing.T) { // Skip if running short tests if testing.Short() { t.Skip("Skipping integration test") } // Initialize test environment testutils.Prepare(t) defer testutils.Clean(t) // Create test context ctx := newTestContext(t) // Create extractor with non-existent agent extractor := keyword.NewExtractor("non-existent-agent", nil) _, err := extractor.Extract(ctx, "test content", nil) assert.Error(t, err, "Should error for non-existent agent") assert.Contains(t, err.Error(), "failed to get agent") } // newTestContext creates a test context with required fields func newTestContext(t *testing.T) *context.Context { t.Helper() authorized := &oauthTypes.AuthorizedInfo{ UserID: "test-user", } chatID := "test-chat-keyword" ctx := context.New(t.Context(), authorized, chatID) return ctx } ================================================ FILE: agent/search/nlp/keyword/extractor.go ================================================ // Package keyword provides keyword extraction for web search optimization // Supports three modes via uses.keyword configuration: // - "builtin" or "": Uses __yao.keyword system agent (LLM-powered) // - "": Delegate to a custom LLM-powered assistant // - "mcp:.": Call external MCP tool package keyword import ( "fmt" "strings" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // SystemKeywordAgent is the default system agent for keyword extraction const SystemKeywordAgent = "__yao.keyword" // Extractor extracts keywords from text // Mode is determined by uses.keyword configuration type Extractor struct { usesKeyword string // "builtin", "", "mcp:." config *types.KeywordConfig // Keyword extraction options } // NewExtractor creates a new keyword extractor // usesKeyword: value from uses.keyword config // cfg: keyword extraction options from search config func NewExtractor(usesKeyword string, cfg *types.KeywordConfig) *Extractor { return &Extractor{ usesKeyword: usesKeyword, config: cfg, } } // Extract extracts keywords from content based on configured mode // Returns a list of keywords with weights optimized for search queries func (e *Extractor) Extract(ctx *context.Context, content string, opts *types.KeywordOptions) ([]types.Keyword, error) { // Merge options with config defaults mergedOpts := e.mergeOptions(opts) switch { case e.usesKeyword == "builtin" || e.usesKeyword == "": // Use system keyword agent return e.agentExtract(ctx, content, SystemKeywordAgent, mergedOpts) case strings.HasPrefix(e.usesKeyword, "mcp:"): return e.mcpExtract(ctx, content, mergedOpts) default: // Assume it's an assistant ID for Agent mode return e.agentExtract(ctx, content, e.usesKeyword, mergedOpts) } } // mergeOptions merges runtime options with config defaults func (e *Extractor) mergeOptions(opts *types.KeywordOptions) *types.KeywordOptions { result := &types.KeywordOptions{ MaxKeywords: 10, // default Language: "auto", // default } // Apply config defaults if e.config != nil { if e.config.MaxKeywords > 0 { result.MaxKeywords = e.config.MaxKeywords } if e.config.Language != "" { result.Language = e.config.Language } } // Apply runtime options (highest priority) if opts != nil { if opts.MaxKeywords > 0 { result.MaxKeywords = opts.MaxKeywords } if opts.Language != "" { result.Language = opts.Language } } return result } // agentExtract delegates to an LLM-powered assistant // The assistant can understand context and extract semantically relevant keywords func (e *Extractor) agentExtract(ctx *context.Context, content string, agentID string, opts *types.KeywordOptions) ([]types.Keyword, error) { if ctx == nil { return nil, fmt.Errorf("context is required for keyword extraction") } provider := NewAgentProvider(agentID) return provider.Extract(ctx, content, opts) } // mcpExtract calls an external MCP tool // Format: "mcp:." func (e *Extractor) mcpExtract(ctx *context.Context, content string, opts *types.KeywordOptions) ([]types.Keyword, error) { mcpRef := strings.TrimPrefix(e.usesKeyword, "mcp:") provider, err := NewMCPProvider(mcpRef) if err != nil { // Fallback to system agent on invalid MCP format return e.agentExtract(ctx, content, SystemKeywordAgent, e.mergeOptions(nil)) } return provider.Extract(ctx, content, opts) } ================================================ FILE: agent/search/nlp/keyword/extractor_test.go ================================================ package keyword_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/search/nlp/keyword" "github.com/yaoapp/yao/agent/search/types" ) func TestExtractor_BuiltinMode_RequiresContext(t *testing.T) { // Test builtin mode requires context (now uses __yao.keyword agent) extractor := keyword.NewExtractor("builtin", &types.KeywordConfig{ MaxKeywords: 5, Language: "auto", }) // Without context, should return error _, err := extractor.Extract(nil, "How to build a search engine with Elasticsearch?", nil) assert.Error(t, err) assert.Contains(t, err.Error(), "context is required") } func TestExtractor_EmptyUsesKeyword_RequiresContext(t *testing.T) { // Empty uses.keyword should default to __yao.keyword agent extractor := keyword.NewExtractor("", nil) // Without context, should return error _, err := extractor.Extract(nil, "Machine learning algorithms", nil) assert.Error(t, err) assert.Contains(t, err.Error(), "context is required") } func TestExtractor_AgentMode_RequiresContext(t *testing.T) { // Custom agent mode requires context extractor := keyword.NewExtractor("custom.keyword.agent", nil) _, err := extractor.Extract(nil, "Test query", nil) assert.Error(t, err) assert.Contains(t, err.Error(), "context is required") } func TestExtractor_MCPMode_InvalidFormat(t *testing.T) { // Invalid MCP format should fallback to system agent (which requires context) extractor := keyword.NewExtractor("mcp:invalid", nil) _, err := extractor.Extract(nil, "Test query", nil) assert.Error(t, err) assert.Contains(t, err.Error(), "context is required") } func TestExtractor_SystemKeywordAgentConstant(t *testing.T) { // Verify the system keyword agent constant assert.Equal(t, "__yao.keyword", keyword.SystemKeywordAgent) } ================================================ FILE: agent/search/nlp/keyword/mcp.go ================================================ package keyword import ( "encoding/json" "fmt" "strings" "github.com/yaoapp/gou/mcp" gouMCPTypes "github.com/yaoapp/gou/mcp/types" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // MCPProvider delegates keyword extraction to an MCP tool type MCPProvider struct { serverID string // MCP server ID toolName string // Tool name to call } // NewMCPProvider creates a new MCP-based keyword extractor // mcpRef format: "server.tool" (e.g., "nlp.extract_keywords") func NewMCPProvider(mcpRef string) (*MCPProvider, error) { parts := strings.SplitN(mcpRef, ".", 2) if len(parts) != 2 { return nil, fmt.Errorf("invalid MCP format, expected 'server.tool', got '%s'", mcpRef) } return &MCPProvider{ serverID: parts[0], toolName: parts[1], }, nil } // Extract extracts keywords by calling the MCP tool func (p *MCPProvider) Extract(ctx *agentContext.Context, content string, opts *types.KeywordOptions) ([]types.Keyword, error) { // Get MCP client client, err := mcp.Select(p.serverID) if err != nil { return nil, fmt.Errorf("MCP server '%s' not found: %w", p.serverID, err) } // Build arguments for the MCP tool arguments := map[string]interface{}{ "content": content, "max_keywords": opts.MaxKeywords, "language": opts.Language, } // Call the MCP tool (ctx embeds context.Context) callResult, err := client.CallTool(ctx, p.toolName, arguments) if err != nil { return nil, fmt.Errorf("MCP tool call failed: %w", err) } // Parse the result return p.parseResult(callResult) } // parseResult extracts keywords from the MCP tool response func (p *MCPProvider) parseResult(result *gouMCPTypes.CallToolResponse) ([]types.Keyword, error) { if result == nil { return []types.Keyword{}, nil } // Check for errors in result if result.IsError { errMsg := "MCP tool returned error" if len(result.Content) > 0 && result.Content[0].Text != "" { errMsg = result.Content[0].Text } return nil, fmt.Errorf("%s", errMsg) } // Parse content - expect JSON data with "keywords" field if len(result.Content) == 0 { return []types.Keyword{}, nil } // Try to extract keywords from content for _, content := range result.Content { // Check text content type if content.Type == gouMCPTypes.ToolContentTypeText && content.Text != "" { // Try to parse as JSON var data map[string]interface{} if err := json.Unmarshal([]byte(content.Text), &data); err == nil { // Look for "keywords" field if kw, ok := data["keywords"]; ok { return p.extractKeywordsFromValue(kw) } } // Try to parse as direct array of keywords var keywords []types.Keyword if err := json.Unmarshal([]byte(content.Text), &keywords); err == nil { return keywords, nil } } } return []types.Keyword{}, nil } // extractKeywordsFromValue extracts Keyword array from various types func (p *MCPProvider) extractKeywordsFromValue(v interface{}) ([]types.Keyword, error) { switch kw := v.(type) { case []types.Keyword: return kw, nil case []interface{}: keywords := make([]types.Keyword, 0, len(kw)) for _, item := range kw { switch v := item.(type) { case map[string]interface{}: // Handle {k: "keyword", w: 0.9} format k, _ := v["k"].(string) w, _ := v["w"].(float64) if k != "" { if w == 0 { w = 0.5 // Default weight } keywords = append(keywords, types.Keyword{K: k, W: w}) } case string: // Plain string, use default weight if v != "" { keywords = append(keywords, types.Keyword{K: v, W: 0.5}) } } } return keywords, nil case string: var keywords []types.Keyword if err := json.Unmarshal([]byte(kw), &keywords); err == nil { return keywords, nil } return []types.Keyword{{K: kw, W: 0.5}}, nil } return []types.Keyword{}, nil } ================================================ FILE: agent/search/nlp/keyword/mcp_test.go ================================================ package keyword_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/nlp/keyword" searchTypes "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/testutils" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" ) func TestMCPProviderWithAssistantConfig(t *testing.T) { // Skip if running short tests if testing.Short() { t.Skip("Skipping integration test") } // Initialize test environment testutils.Prepare(t) defer testutils.Clean(t) // Create test context ctx := newMCPTestContext(t) // Create extractor with MCP mode extractor := keyword.NewExtractor("mcp:search.extract_keywords", &searchTypes.KeywordConfig{ MaxKeywords: 5, Language: "auto", }) // Test extraction content := "Machine learning and deep learning are subfields of artificial intelligence" keywords, err := extractor.Extract(ctx, content, nil) require.NoError(t, err) assert.NotEmpty(t, keywords, "MCP should return keywords") assert.LessOrEqual(t, len(keywords), 5, "Should respect max_keywords") t.Logf("Extracted keywords via MCP: %v", keywords) } func TestMCPProviderWithCustomOptions(t *testing.T) { // Skip if running short tests if testing.Short() { t.Skip("Skipping integration test") } // Initialize test environment testutils.Prepare(t) defer testutils.Clean(t) // Create test context ctx := newMCPTestContext(t) // Create extractor with MCP mode extractor := keyword.NewExtractor("mcp:search.extract_keywords", &searchTypes.KeywordConfig{ MaxKeywords: 10, }) // Test with runtime options override content := "Python programming language for data science and web development" keywords, err := extractor.Extract(ctx, content, &searchTypes.KeywordOptions{ MaxKeywords: 3, // Override to 3 }) require.NoError(t, err) assert.NotEmpty(t, keywords) assert.LessOrEqual(t, len(keywords), 3, "Should respect runtime max_keywords override") t.Logf("Extracted keywords via MCP (max 3): %v", keywords) } func TestMCPProviderInvalidFormat(t *testing.T) { // Test invalid MCP format fallback to system agent (requires context) extractor := keyword.NewExtractor("mcp:invalid", nil) // Should fallback to system agent which requires context _, err := extractor.Extract(nil, "test content for keyword extraction", nil) assert.Error(t, err) assert.Contains(t, err.Error(), "context is required") } func TestMCPProviderServerNotFound(t *testing.T) { // Skip if running short tests if testing.Short() { t.Skip("Skipping integration test") } // Initialize test environment testutils.Prepare(t) defer testutils.Clean(t) // Create test context ctx := newMCPTestContext(t) // Create extractor with non-existent MCP server extractor := keyword.NewExtractor("mcp:nonexistent.extract_keywords", &searchTypes.KeywordConfig{}) _, err := extractor.Extract(ctx, "test content", nil) assert.Error(t, err, "Should error for non-existent MCP server") assert.Contains(t, err.Error(), "not found") } func TestMCPProviderToolNotFound(t *testing.T) { // Skip if running short tests if testing.Short() { t.Skip("Skipping integration test") } // Initialize test environment testutils.Prepare(t) defer testutils.Clean(t) // Create test context ctx := newMCPTestContext(t) // Create extractor with non-existent tool extractor := keyword.NewExtractor("mcp:search.nonexistent_tool", &searchTypes.KeywordConfig{}) _, err := extractor.Extract(ctx, "test content", nil) assert.Error(t, err, "Should error for non-existent MCP tool") } func TestMCPProviderEmptyContent(t *testing.T) { // Skip if running short tests if testing.Short() { t.Skip("Skipping integration test") } // Initialize test environment testutils.Prepare(t) defer testutils.Clean(t) // Create test context ctx := newMCPTestContext(t) // Create extractor with MCP mode extractor := keyword.NewExtractor("mcp:search.extract_keywords", nil) // Test with empty content - MCP tool should return error _, err := extractor.Extract(ctx, "", nil) assert.Error(t, err, "Should error for empty content") } // newMCPTestContext creates a test context for MCP tests func newMCPTestContext(t *testing.T) *context.Context { t.Helper() authorized := &oauthTypes.AuthorizedInfo{ UserID: "test-user", } chatID := "test-chat-mcp-keyword" ctx := context.New(t.Context(), authorized, chatID) return ctx } ================================================ FILE: agent/search/nlp/querydsl/agent.go ================================================ package querydsl import ( "encoding/json" "fmt" "github.com/yaoapp/gou/query/gou" "github.com/yaoapp/gou/query/linter" "github.com/yaoapp/yao/agent/caller" agentContext "github.com/yaoapp/yao/agent/context" ) // AgentProvider delegates QueryDSL generation to an LLM-powered assistant // The assistant can understand context and generate semantically correct QueryDSL type AgentProvider struct { agentID string // Assistant ID to delegate to } // NewAgentProvider creates a new agent-based QueryDSL generator func NewAgentProvider(agentID string) *AgentProvider { return &AgentProvider{ agentID: agentID, } } // Generate generates QueryDSL by calling the target agent with retry and lint validation // The agent receives the query and schema, returns generated QueryDSL func (p *AgentProvider) Generate(ctx *agentContext.Context, input *Input) (*Result, error) { if ctx == nil { return nil, fmt.Errorf("context is required for agent QueryDSL generation") } // Check if AgentGetterFunc is initialized if caller.AgentGetterFunc == nil { return nil, fmt.Errorf("AgentGetterFunc not initialized") } // Get the agent agent, err := caller.AgentGetterFunc(p.agentID) if err != nil { return nil, fmt.Errorf("failed to get agent %s: %w", p.agentID, err) } var lastError error var lastLintErrors string for attempt := 1; attempt <= MaxRetries; attempt++ { // Build the request message in the format expected by querydsl agent requestMessage := p.buildRequestMessage(input, attempt, lastLintErrors) // Create message for the agent messages := []agentContext.Message{ { Role: "user", Content: requestMessage, }, } // Call the agent with skip options (no history, no output) options := &agentContext.Options{ Skip: &agentContext.Skip{ History: true, Output: true, }, } result, err := agent.Stream(ctx, messages, options) if err != nil { lastError = fmt.Errorf("agent call failed: %w", err) continue } // Parse the result from response genResult, err := p.parseResponse(result) if err != nil { lastError = err continue } // Validate with linter if DSL is present if genResult.DSL != nil { lintResult := p.validateDSL(genResult.DSL) if lintResult.Valid { return genResult, nil } // Lint failed, prepare error message for retry lastLintErrors = lintResult.FormatDiagnostics() lastError = fmt.Errorf("QueryDSL validation failed: %s", lastLintErrors) // Add lint warnings to result warnings for _, diag := range lintResult.Diagnostics { genResult.Warnings = append(genResult.Warnings, fmt.Sprintf("[%s] %s: %s", diag.Code, diag.Path, diag.Message)) } continue } // No DSL returned lastError = fmt.Errorf("no QueryDSL returned from agent") } return nil, fmt.Errorf("QueryDSL generation failed after %d attempts: %w", MaxRetries, lastError) } // buildRequestMessage constructs the request message for the agent // Returns JSON format for structured communication with the agent func (p *AgentProvider) buildRequestMessage(input *Input, attempt int, lastLintErrors string) string { // Build request data as JSON requestData := map[string]interface{}{ "query": input.Query, "models": input.ModelIDs, "limit": input.Limit, } // Add schema from extra params if provided if input.ExtraParams != nil { if schema, ok := input.ExtraParams["schema"]; ok { requestData["schema"] = schema } } // Add scenario hint if specified (filter, aggregation, join, complex) if input.Scenario != "" { requestData["scenario"] = string(input.Scenario) } // Add allowed fields if specified if len(input.AllowedFields) > 0 { requestData["allowed_fields"] = input.AllowedFields } // Add retry context if this is a retry attempt if attempt > 1 && lastLintErrors != "" { requestData["retry"] = map[string]interface{}{ "attempt": attempt, "lint_errors": lastLintErrors, "instructions": "The previous QueryDSL was invalid. Please fix the errors and regenerate.", } } jsonBytes, _ := json.Marshal(requestData) return string(jsonBytes) } // validateDSL validates the generated QueryDSL using the linter func (p *AgentProvider) validateDSL(dsl *gou.QueryDSL) *linter.LintResult { // Marshal DSL to JSON for linting jsonBytes, err := json.Marshal(dsl) if err != nil { result := &linter.LintResult{Valid: false} return result } _, lintResult := linter.Parse(string(jsonBytes)) return lintResult } // parseResponse extracts QueryDSL from the agent's *context.Response // Now that agent.Stream() returns *context.Response directly, // we can access fields without type assertions. // // The querydsl agent returns QueryDSL in response.Next field // Or returns error JSON: {"error": "code", "message": "..."} func (p *AgentProvider) parseResponse(response *agentContext.Response) (*Result, error) { if response == nil { return &Result{}, nil } // Check Next field first (custom hook data) if response.Next != nil { return p.parseNextData(response.Next) } // No Next data, return empty result return &Result{}, nil } // parseNextData extracts QueryDSL from Next hook data func (p *AgentProvider) parseNextData(next interface{}) (*Result, error) { if next == nil { return &Result{}, nil } // Try to convert to map first var data map[string]interface{} switch v := next.(type) { case map[string]interface{}: data = v case string: // Try to parse as JSON if err := json.Unmarshal([]byte(v), &data); err != nil { return nil, fmt.Errorf("failed to parse agent response: %w", err) } default: // Try to marshal and unmarshal jsonBytes, err := json.Marshal(next) if err != nil { return &Result{}, nil } if err := json.Unmarshal(jsonBytes, &data); err != nil { return &Result{}, nil } } genResult := &Result{} // Check for error response: {"error": "code", "message": "..."} if errCode, hasError := data["error"]; hasError { errMsg := "" if msg, ok := data["message"].(string); ok { errMsg = msg } return nil, fmt.Errorf("QueryDSL generation error [%v]: %s", errCode, errMsg) } // Check if this is a direct QueryDSL (has "from" or "select" field) // The querydsl agent returns QueryDSL directly, e.g., {"select": [...], "from": "table", ...} if _, hasFrom := data["from"]; hasFrom { genResult.DSL = p.extractDSL(data) return genResult, nil } if _, hasSelect := data["select"]; hasSelect { genResult.DSL = p.extractDSL(data) return genResult, nil } // Check for "dsl" field wrapper: { dsl: {...} } if dsl, ok := data["dsl"]; ok { genResult.DSL = p.extractDSL(dsl) if explain, ok := data["explain"].(string); ok { genResult.Explain = explain } if warnings, ok := data["warnings"]; ok { genResult.Warnings = p.extractWarnings(warnings) } return genResult, nil } // Check for "data" field wrapper: { data: { dsl: {...}, explain: "...", warnings: [] } } if d, ok := data["data"]; ok { if dm, ok := d.(map[string]interface{}); ok { // Check if data.data contains dsl field: { data: { dsl: {...} } } if dsl, ok := dm["dsl"]; ok { genResult.DSL = p.extractDSL(dsl) } else if _, hasFrom := dm["from"]; hasFrom { // data.data is directly a QueryDSL (from __yao.querydsl Next hook) genResult.DSL = p.extractDSL(dm) } else if _, hasSelect := dm["select"]; hasSelect { // data.data is directly a QueryDSL genResult.DSL = p.extractDSL(dm) } // Extract explain and warnings from data.data if explain, ok := dm["explain"].(string); ok { genResult.Explain = explain } if warnings, ok := dm["warnings"]; ok { genResult.Warnings = p.extractWarnings(warnings) } return genResult, nil } } // Fallback: Get explain and warnings from top level if explain, ok := data["explain"].(string); ok { genResult.Explain = explain } if warnings, ok := data["warnings"]; ok { genResult.Warnings = p.extractWarnings(warnings) } return genResult, nil } // extractDSL converts interface{} to gou.QueryDSL func (p *AgentProvider) extractDSL(v interface{}) *gou.QueryDSL { if v == nil { return nil } // Marshal and unmarshal to gou.QueryDSL jsonBytes, err := json.Marshal(v) if err != nil { return nil } var dsl gou.QueryDSL if err := json.Unmarshal(jsonBytes, &dsl); err != nil { return nil } return &dsl } // extractWarnings extracts warnings array from various types func (p *AgentProvider) extractWarnings(v interface{}) []string { switch w := v.(type) { case []string: return w case []interface{}: warnings := make([]string, 0, len(w)) for _, item := range w { if s, ok := item.(string); ok { warnings = append(warnings, s) } } return warnings case string: return []string{w} } return nil } ================================================ FILE: agent/search/nlp/querydsl/agent_test.go ================================================ package querydsl_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/nlp/querydsl" "github.com/yaoapp/yao/agent/testutils" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" ) func TestNewAgentProvider(t *testing.T) { t.Run("create_provider", func(t *testing.T) { provider := querydsl.NewAgentProvider("tests.querydsl-agent") assert.NotNil(t, provider) }) } func TestAgentProvider_Generate(t *testing.T) { // Skip if running short tests if testing.Short() { t.Skip("Skipping integration test") } // Initialize test environment testutils.Prepare(t) defer testutils.Clean(t) // Load the querydsl-agent assistant ast, err := assistant.Get("tests.querydsl-agent") require.NoError(t, err) require.NotNil(t, ast) // Create test context ctx := newTestContext(t) // Create Agent provider for tests.querydsl-agent provider := querydsl.NewAgentProvider("tests.querydsl-agent") assert.NotNil(t, provider) t.Run("verify_fixed_structure", func(t *testing.T) { input := &querydsl.Input{ Query: "find active users", ModelIDs: []string{"user"}, Limit: 15, } result, err := provider.Generate(ctx, input) if err != nil { t.Logf("Generate error: %v", err) } require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.DSL, "DSL should not be nil") // Verify fixed DSL structure from mock // select: ["id", "name", "status", "created_at"] assert.Len(t, result.DSL.Select, 4) if len(result.DSL.Select) >= 4 { assert.Equal(t, "id", result.DSL.Select[0].Field) assert.Equal(t, "name", result.DSL.Select[1].Field) assert.Equal(t, "status", result.DSL.Select[2].Field) assert.Equal(t, "created_at", result.DSL.Select[3].Field) } // wheres: [{ field: "status", op: "=", value: "active" }] assert.Len(t, result.DSL.Wheres, 1) if len(result.DSL.Wheres) > 0 { assert.Equal(t, "status", result.DSL.Wheres[0].Field.Field) assert.Equal(t, "=", result.DSL.Wheres[0].OP) assert.Equal(t, "active", result.DSL.Wheres[0].Value) } // orders: [{ field: "created_at", sort: "desc" }] assert.Len(t, result.DSL.Orders, 1) if len(result.DSL.Orders) > 0 { assert.Equal(t, "created_at", result.DSL.Orders[0].Field.Field) assert.Equal(t, "desc", result.DSL.Orders[0].Sort) } // limit: 15 (from input) assert.Equal(t, float64(15), result.DSL.Limit) // explain should contain query assert.Contains(t, result.Explain, "find active users") // warnings should be empty assert.Empty(t, result.Warnings) }) } func TestAgentProvider_Generate_Error(t *testing.T) { // Skip if running short tests if testing.Short() { t.Skip("Skipping integration test") } // Initialize test environment testutils.Prepare(t) defer testutils.Clean(t) // Create test context ctx := newTestContext(t) t.Run("non-existent_agent", func(t *testing.T) { provider := querydsl.NewAgentProvider("tests.nonexistent-agent") result, err := provider.Generate(ctx, &querydsl.Input{ Query: "test", ModelIDs: []string{"user"}, }) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "failed to get agent") }) t.Run("nil_context", func(t *testing.T) { provider := querydsl.NewAgentProvider("tests.querydsl-agent") result, err := provider.Generate(nil, &querydsl.Input{ Query: "test", ModelIDs: []string{"user"}, }) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "context is required") }) } func TestGenerator_Agent_Integration(t *testing.T) { // Skip if running short tests if testing.Short() { t.Skip("Skipping integration test") } // Initialize test environment testutils.Prepare(t) defer testutils.Clean(t) // Create test context ctx := newTestContext(t) // Create generator with Agent mode (assistant ID without mcp: prefix) gen := querydsl.NewGenerator("tests.querydsl-agent", nil) t.Run("generate_via_agent", func(t *testing.T) { input := &querydsl.Input{ Query: "find active users", ModelIDs: []string{"user"}, Limit: 10, } result, err := gen.Generate(ctx, input) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.DSL) // Verify structure from agent mock assert.Len(t, result.DSL.Select, 4) assert.Len(t, result.DSL.Wheres, 1) assert.Len(t, result.DSL.Orders, 1) assert.Contains(t, result.Explain, "find active users") }) t.Run("allowed_fields_validation", func(t *testing.T) { input := &querydsl.Input{ Query: "find users", ModelIDs: []string{"user"}, AllowedFields: []string{"id", "name"}, // Only allow id and name Limit: 10, } result, err := gen.Generate(ctx, input) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.DSL) // "status" and "created_at" fields should be filtered out from select // since they are not in AllowedFields for _, expr := range result.DSL.Select { assert.Contains(t, []string{"id", "name"}, expr.Field) } // Should have warning about removed fields assert.NotEmpty(t, result.Warnings) }) } func TestAgentProvider_Generate_WithRetry(t *testing.T) { // Skip if running short tests if testing.Short() { t.Skip("Skipping integration test") } // Initialize test environment testutils.Prepare(t) defer testutils.Clean(t) // Load the querydsl-agent-retry assistant ast, err := assistant.Get("tests.querydsl-agent-retry") require.NoError(t, err) require.NotNil(t, ast) // Create test context ctx := newTestContext(t) // Create Agent provider for tests.querydsl-agent-retry // This agent returns invalid DSL on first call, valid on second provider := querydsl.NewAgentProvider("tests.querydsl-agent-retry") assert.NotNil(t, provider) t.Run("retry_on_lint_failure", func(t *testing.T) { input := &querydsl.Input{ Query: "test retry mechanism", ModelIDs: []string{"user"}, Limit: 10, } // This should succeed after retry // First call returns invalid DSL (missing 'from') // Second call (with lint_errors) returns valid DSL result, err := provider.Generate(ctx, input) require.NoError(t, err) require.NotNil(t, result) if result.DSL != nil { // Should have valid DSL after retry assert.NotNil(t, result.DSL.From, "DSL should have 'from' field after retry") // Explain should indicate this was fixed after receiving lint errors assert.Contains(t, result.Explain, "fixed after receiving lint errors") } }) } // newTestContext creates a test context with required fields func newTestContext(t *testing.T) *context.Context { t.Helper() authorized := &oauthTypes.AuthorizedInfo{ UserID: "test-user", } chatID := "test-chat-querydsl" ctx := context.New(t.Context(), authorized, chatID) return ctx } ================================================ FILE: agent/search/nlp/querydsl/generator.go ================================================ // Package querydsl provides QueryDSL generation from natural language for DB search // Supports three modes via uses.querydsl configuration: // - "builtin" or "": Uses __yao.querydsl system agent (LLM-powered) // - "": Delegate to a custom LLM-powered assistant // - "mcp:.": Call external MCP tool package querydsl import ( "fmt" "strings" "github.com/yaoapp/gou/query/gou" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // SystemQueryDSLAgent is the default system agent for QueryDSL generation const SystemQueryDSLAgent = "__yao.querydsl" // Generator generates QueryDSL from natural language // Mode is determined by uses.querydsl configuration type Generator struct { usesQueryDSL string // "builtin", "", "mcp:." config *types.QueryDSLConfig // QueryDSL generation options } // NewGenerator creates a new QueryDSL generator // usesQueryDSL: value from uses.querydsl config // cfg: QueryDSL generation options from search config func NewGenerator(usesQueryDSL string, cfg *types.QueryDSLConfig) *Generator { return &Generator{ usesQueryDSL: usesQueryDSL, config: cfg, } } // Generate generates QueryDSL from natural language based on configured mode // Returns a QueryDSL ready for execution func (g *Generator) Generate(ctx *context.Context, input *Input) (*Result, error) { var result *Result var err error switch { case g.usesQueryDSL == "builtin" || g.usesQueryDSL == "": // Use system querydsl agent result, err = g.agentGenerate(ctx, input, SystemQueryDSLAgent) case strings.HasPrefix(g.usesQueryDSL, "mcp:"): result, err = g.mcpGenerate(ctx, input) default: // Assume it's an assistant ID for Agent mode result, err = g.agentGenerate(ctx, input, g.usesQueryDSL) } if err != nil { return nil, err } // Validate generated DSL against allowed fields whitelist if result != nil && result.DSL != nil && len(input.AllowedFields) > 0 { result = g.validateFields(result, input.AllowedFields) } return result, nil } // agentGenerate delegates to an LLM-powered assistant // The assistant can understand context and generate semantically correct QueryDSL func (g *Generator) agentGenerate(ctx *context.Context, input *Input, agentID string) (*Result, error) { if ctx == nil { return nil, fmt.Errorf("context is required for QueryDSL generation") } provider := NewAgentProvider(agentID) return provider.Generate(ctx, input) } // mcpGenerate calls an external MCP tool // Format: "mcp:." func (g *Generator) mcpGenerate(ctx *context.Context, input *Input) (*Result, error) { mcpRef := strings.TrimPrefix(g.usesQueryDSL, "mcp:") provider, err := NewMCPProvider(mcpRef) if err != nil { // Fallback to system agent on invalid MCP format return g.agentGenerate(ctx, input, SystemQueryDSLAgent) } return provider.Generate(ctx, input) } // validateFields validates that all fields in the generated DSL are in the allowed list // If a field is not allowed, it's removed and a warning is added func (g *Generator) validateFields(result *Result, allowedFields []string) *Result { if result.DSL == nil { return result } // Build allowed fields set for fast lookup allowed := make(map[string]bool) for _, f := range allowedFields { allowed[f] = true } var removedFields []string // Validate Select fields if len(result.DSL.Select) > 0 { validSelect := make([]gou.Expression, 0, len(result.DSL.Select)) for _, expr := range result.DSL.Select { if allowed[expr.Field] { validSelect = append(validSelect, expr) } else if expr.Field != "" { removedFields = append(removedFields, "select:"+expr.Field) } } result.DSL.Select = validSelect } // Validate Where fields (recursive) result.DSL.Wheres = g.validateWheres(result.DSL.Wheres, allowed, &removedFields) // Validate Order fields if len(result.DSL.Orders) > 0 { validOrders := make(gou.Orders, 0, len(result.DSL.Orders)) for _, order := range result.DSL.Orders { if order.Field != nil && allowed[order.Field.Field] { validOrders = append(validOrders, order) } else if order.Field != nil && order.Field.Field != "" { removedFields = append(removedFields, "order:"+order.Field.Field) } } result.DSL.Orders = validOrders } // Add warnings for removed fields if len(removedFields) > 0 { warning := "removed fields not in allowed list: " + strings.Join(removedFields, ", ") result.Warnings = append(result.Warnings, warning) } return result } // validateWheres recursively validates where conditions func (g *Generator) validateWheres(wheres []gou.Where, allowed map[string]bool, removedFields *[]string) []gou.Where { if len(wheres) == 0 { return wheres } validWheres := make([]gou.Where, 0, len(wheres)) for _, w := range wheres { // Check if the field is allowed fieldAllowed := true if w.Field != nil && w.Field.Field != "" { if !allowed[w.Field.Field] { *removedFields = append(*removedFields, "where:"+w.Field.Field) fieldAllowed = false } } if fieldAllowed { // Recursively validate nested wheres if len(w.Wheres) > 0 { w.Wheres = g.validateWheres(w.Wheres, allowed, removedFields) } validWheres = append(validWheres, w) } } return validWheres } ================================================ FILE: agent/search/nlp/querydsl/generator_test.go ================================================ package querydsl import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/query/gou" "github.com/yaoapp/yao/agent/search/types" ) func TestNewGenerator(t *testing.T) { tests := []struct { name string usesQueryDSL string config *types.QueryDSLConfig }{ { name: "builtin mode", usesQueryDSL: "builtin", config: nil, }, { name: "empty defaults to builtin", usesQueryDSL: "", config: nil, }, { name: "agent mode", usesQueryDSL: "my-querydsl-agent", config: &types.QueryDSLConfig{Strict: true}, }, { name: "mcp mode", usesQueryDSL: "mcp:nlp.generate_querydsl", config: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gen := NewGenerator(tt.usesQueryDSL, tt.config) assert.NotNil(t, gen) assert.Equal(t, tt.usesQueryDSL, gen.usesQueryDSL) assert.Equal(t, tt.config, gen.config) }) } } func TestGenerator_Generate_Builtin_RequiresContext(t *testing.T) { // Builtin mode now uses __yao.querydsl agent which requires context gen := NewGenerator("builtin", nil) input := &Input{ Query: "find all active users", ModelIDs: []string{"user"}, Limit: 10, } // Without context, should return error _, err := gen.Generate(nil, input) assert.Error(t, err) assert.Contains(t, err.Error(), "context is required") } func TestGenerator_Generate_EmptyMode_RequiresContext(t *testing.T) { // Empty mode defaults to __yao.querydsl agent which requires context gen := NewGenerator("", nil) input := &Input{ Query: "search products", ModelIDs: []string{"product"}, Limit: 5, } // Without context, should return error _, err := gen.Generate(nil, input) assert.Error(t, err) assert.Contains(t, err.Error(), "context is required") } func TestGenerator_Generate_AgentMode_RequiresContext(t *testing.T) { // Custom agent mode requires context gen := NewGenerator("custom.querydsl.agent", nil) input := &Input{ Query: "find users", ModelIDs: []string{"user"}, Limit: 10, } _, err := gen.Generate(nil, input) assert.Error(t, err) assert.Contains(t, err.Error(), "context is required") } func TestGenerator_Generate_MCPMode_InvalidFormat(t *testing.T) { // Invalid MCP format should fallback to system agent (which requires context) gen := NewGenerator("mcp:invalid", nil) input := &Input{ Query: "find users", ModelIDs: []string{"user"}, Limit: 10, } _, err := gen.Generate(nil, input) assert.Error(t, err) assert.Contains(t, err.Error(), "context is required") } func TestSystemQueryDSLAgentConstant(t *testing.T) { // Verify the system querydsl agent constant assert.Equal(t, "__yao.querydsl", SystemQueryDSLAgent) } func TestResult(t *testing.T) { result := &Result{ DSL: &gou.QueryDSL{ Limit: 10, }, Explain: "Generated query for finding users", Warnings: []string{"using placeholder implementation"}, } assert.NotNil(t, result.DSL) assert.Equal(t, 10, result.DSL.Limit) assert.NotEmpty(t, result.Explain) assert.Len(t, result.Warnings, 1) } func TestGenerator_ValidateFields(t *testing.T) { gen := NewGenerator("", nil) t.Run("validate select fields", func(t *testing.T) { result := &Result{ DSL: &gou.QueryDSL{ Select: []gou.Expression{ {Field: "id"}, {Field: "name"}, {Field: "secret_field"}, }, }, } allowedFields := []string{"id", "name", "email"} validated := gen.validateFields(result, allowedFields) assert.NotNil(t, validated) assert.Len(t, validated.DSL.Select, 2) assert.Contains(t, validated.Warnings[0], "secret_field") }) t.Run("validate where fields", func(t *testing.T) { result := &Result{ DSL: &gou.QueryDSL{ Wheres: []gou.Where{ { Condition: gou.Condition{ Field: &gou.Expression{Field: "status"}, OP: "=", Value: "active", }, }, { Condition: gou.Condition{ Field: &gou.Expression{Field: "secret"}, OP: "=", Value: "hidden", }, }, }, }, } allowedFields := []string{"status", "name"} validated := gen.validateFields(result, allowedFields) assert.NotNil(t, validated) assert.Len(t, validated.DSL.Wheres, 1) assert.Contains(t, validated.Warnings[0], "secret") }) t.Run("validate order fields", func(t *testing.T) { result := &Result{ DSL: &gou.QueryDSL{ Orders: gou.Orders{ {Field: &gou.Expression{Field: "created_at"}, Sort: "desc"}, {Field: &gou.Expression{Field: "secret_sort"}, Sort: "asc"}, }, }, } allowedFields := []string{"created_at", "updated_at"} validated := gen.validateFields(result, allowedFields) assert.NotNil(t, validated) assert.Len(t, validated.DSL.Orders, 1) assert.Contains(t, validated.Warnings[0], "secret_sort") }) t.Run("nil DSL", func(t *testing.T) { result := &Result{DSL: nil} allowedFields := []string{"id", "name"} validated := gen.validateFields(result, allowedFields) assert.NotNil(t, validated) assert.Nil(t, validated.DSL) }) } ================================================ FILE: agent/search/nlp/querydsl/mcp.go ================================================ package querydsl import ( "encoding/json" "fmt" "strings" "github.com/yaoapp/gou/mcp" gouMCPTypes "github.com/yaoapp/gou/mcp/types" "github.com/yaoapp/gou/query/gou" "github.com/yaoapp/gou/query/linter" agentContext "github.com/yaoapp/yao/agent/context" ) // MaxRetries is the maximum number of retry attempts for QueryDSL generation const MaxRetries = 3 // MCPProvider delegates QueryDSL generation to an MCP tool type MCPProvider struct { serverID string // MCP server ID toolName string // Tool name to call } // NewMCPProvider creates a new MCP-based QueryDSL generator // mcpRef format: "server.tool" (e.g., "nlp.generate_querydsl") func NewMCPProvider(mcpRef string) (*MCPProvider, error) { parts := strings.SplitN(mcpRef, ".", 2) if len(parts) != 2 { return nil, fmt.Errorf("invalid MCP format, expected 'server.tool', got '%s'", mcpRef) } return &MCPProvider{ serverID: parts[0], toolName: parts[1], }, nil } // Generate generates QueryDSL by calling the MCP tool with retry and lint validation func (p *MCPProvider) Generate(ctx *agentContext.Context, input *Input) (*Result, error) { // Get MCP client client, err := mcp.Select(p.serverID) if err != nil { return nil, fmt.Errorf("MCP server '%s' not found: %w", p.serverID, err) } var lastError error var lastLintErrors string for attempt := 1; attempt <= MaxRetries; attempt++ { // Build arguments for the MCP tool arguments := p.buildArguments(input, attempt, lastLintErrors) // Call the MCP tool callResult, err := client.CallTool(ctx, p.toolName, arguments) if err != nil { lastError = fmt.Errorf("MCP tool call failed: %w", err) continue } // Parse the result result, err := p.parseResult(callResult) if err != nil { lastError = err continue } // Validate with linter if DSL is present if result.DSL != nil { lintResult := p.validateDSL(result.DSL) if lintResult.Valid { return result, nil } // Lint failed, prepare error message for retry lastLintErrors = lintResult.FormatDiagnostics() lastError = fmt.Errorf("QueryDSL validation failed: %s", lastLintErrors) // Add lint warnings to result warnings for _, diag := range lintResult.Diagnostics { result.Warnings = append(result.Warnings, fmt.Sprintf("[%s] %s: %s", diag.Code, diag.Path, diag.Message)) } continue } // No DSL returned lastError = fmt.Errorf("no QueryDSL returned from MCP tool") } return nil, fmt.Errorf("QueryDSL generation failed after %d attempts: %w", MaxRetries, lastError) } // buildArguments constructs the MCP tool arguments func (p *MCPProvider) buildArguments(input *Input, attempt int, lastLintErrors string) map[string]interface{} { arguments := map[string]interface{}{ "query": input.Query, "models": input.ModelIDs, "limit": input.Limit, } // Add optional fields if len(input.Wheres) > 0 { arguments["wheres"] = input.Wheres } if len(input.Orders) > 0 { arguments["orders"] = input.Orders } if len(input.AllowedFields) > 0 { arguments["allowed_fields"] = input.AllowedFields } if len(input.ExtraParams) > 0 { arguments["extra"] = input.ExtraParams } // Add retry context if this is a retry attempt if attempt > 1 && lastLintErrors != "" { arguments["retry"] = map[string]interface{}{ "attempt": attempt, "lint_errors": lastLintErrors, "instructions": "The previous QueryDSL was invalid. Please fix the errors and regenerate.", } } return arguments } // validateDSL validates the generated QueryDSL using the linter func (p *MCPProvider) validateDSL(dsl *gou.QueryDSL) *linter.LintResult { // Marshal DSL to JSON for linting jsonBytes, err := json.Marshal(dsl) if err != nil { result := &linter.LintResult{Valid: false} return result } _, lintResult := linter.Parse(string(jsonBytes)) return lintResult } // parseResult extracts QueryDSL from the MCP tool response func (p *MCPProvider) parseResult(result *gouMCPTypes.CallToolResponse) (*Result, error) { if result == nil { return &Result{}, nil } // Check for errors in result if result.IsError { errMsg := "MCP tool returned error" if len(result.Content) > 0 && result.Content[0].Text != "" { errMsg = result.Content[0].Text } return nil, fmt.Errorf("%s", errMsg) } // Parse content - expect JSON data with "dsl" field if len(result.Content) == 0 { return &Result{}, nil } genResult := &Result{} // Try to extract QueryDSL from content for _, content := range result.Content { // Check text content type if content.Type == gouMCPTypes.ToolContentTypeText && content.Text != "" { // Try to parse as JSON var data map[string]interface{} if err := json.Unmarshal([]byte(content.Text), &data); err == nil { // Look for "dsl" field if dsl, ok := data["dsl"]; ok { genResult.DSL = p.extractDSL(dsl) } if explain, ok := data["explain"].(string); ok { genResult.Explain = explain } if warnings, ok := data["warnings"]; ok { genResult.Warnings = p.extractWarnings(warnings) } return genResult, nil } // Try to parse as direct QueryDSL var dsl gou.QueryDSL if err := json.Unmarshal([]byte(content.Text), &dsl); err == nil { genResult.DSL = &dsl return genResult, nil } } } return genResult, nil } // extractDSL converts interface{} to gou.QueryDSL func (p *MCPProvider) extractDSL(v interface{}) *gou.QueryDSL { if v == nil { return nil } // Marshal and unmarshal to gou.QueryDSL jsonBytes, err := json.Marshal(v) if err != nil { return nil } var dsl gou.QueryDSL if err := json.Unmarshal(jsonBytes, &dsl); err != nil { return nil } return &dsl } // extractWarnings extracts warnings array from various types func (p *MCPProvider) extractWarnings(v interface{}) []string { switch w := v.(type) { case []string: return w case []interface{}: warnings := make([]string, 0, len(w)) for _, item := range w { if s, ok := item.(string); ok { warnings = append(warnings, s) } } return warnings case string: return []string{w} } return nil } ================================================ FILE: agent/search/nlp/querydsl/mcp_test.go ================================================ package querydsl import ( stdContext "context" "os" "testing" "github.com/stretchr/testify/assert" agentContext "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // newTestContext creates a test context for MCP testing func newTestContext() *agentContext.Context { ctx := agentContext.New(stdContext.Background(), nil, "test-chat") ctx.AssistantID = "test-assistant" ctx.Locale = "en" ctx.Referer = agentContext.RefererAPI stack, _, _ := agentContext.EnterStack(ctx, "test-assistant", &agentContext.Options{}) ctx.Stack = stack return ctx } func TestMCPProvider_Generate(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create context ctx := newTestContext() // Create MCP provider for search.generate_querydsl provider, err := NewMCPProvider("search.generate_querydsl") assert.NoError(t, err) assert.NotNil(t, provider) assert.Equal(t, "search", provider.serverID) assert.Equal(t, "generate_querydsl", provider.toolName) t.Run("verify_fixed_structure", func(t *testing.T) { input := &Input{ Query: "find active users", ModelIDs: []string{"user"}, Limit: 10, } result, err := provider.Generate(ctx, input) if err != nil { t.Logf("Generate error: %v", err) } assert.NoError(t, err) assert.NotNil(t, result) if result == nil { t.Fatal("result is nil") } if !assert.NotNil(t, result.DSL, "DSL should not be nil") { t.Logf("Result: Explain=%s, Warnings=%v", result.Explain, result.Warnings) return } // Verify fixed DSL structure from mock // select: ["id", "name", "status"] - parsed as Expression with Field property assert.Len(t, result.DSL.Select, 3) if len(result.DSL.Select) >= 3 { assert.Equal(t, "id", result.DSL.Select[0].Field) assert.Equal(t, "name", result.DSL.Select[1].Field) assert.Equal(t, "status", result.DSL.Select[2].Field) } // wheres: [{ field: "status", op: "=", value: "active" }] assert.Len(t, result.DSL.Wheres, 1) if len(result.DSL.Wheres) > 0 { assert.Equal(t, "status", result.DSL.Wheres[0].Field.Field) assert.Equal(t, "=", result.DSL.Wheres[0].OP) assert.Equal(t, "active", result.DSL.Wheres[0].Value) } // orders: [{ field: "created_at", sort: "desc" }] assert.Len(t, result.DSL.Orders, 1) if len(result.DSL.Orders) > 0 { assert.Equal(t, "created_at", result.DSL.Orders[0].Field.Field) assert.Equal(t, "desc", result.DSL.Orders[0].Sort) } // limit: 10 (from input, returned as float64 from JSON) assert.Equal(t, float64(10), result.DSL.Limit) // explain should contain query assert.Contains(t, result.Explain, "find active users") // warnings should be empty assert.Empty(t, result.Warnings) }) } func TestNewMCPProvider(t *testing.T) { t.Run("valid format", func(t *testing.T) { provider, err := NewMCPProvider("nlp.generate_querydsl") assert.NoError(t, err) assert.NotNil(t, provider) assert.Equal(t, "nlp", provider.serverID) assert.Equal(t, "generate_querydsl", provider.toolName) }) t.Run("invalid format - no dot", func(t *testing.T) { provider, err := NewMCPProvider("invalid") assert.Error(t, err) assert.Nil(t, provider) assert.Contains(t, err.Error(), "invalid MCP format") }) t.Run("complex tool name", func(t *testing.T) { provider, err := NewMCPProvider("server.tool.with.dots") assert.NoError(t, err) assert.NotNil(t, provider) assert.Equal(t, "server", provider.serverID) assert.Equal(t, "tool.with.dots", provider.toolName) }) } func TestMCPProvider_Generate_Error(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newTestContext() t.Run("non-existent server", func(t *testing.T) { provider, _ := NewMCPProvider("nonexistent.tool") result, err := provider.Generate(ctx, &Input{ Query: "test", ModelIDs: []string{"user"}, }) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "not found") }) } func TestGenerator_MCP_Integration(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Skip if not in integration test mode if os.Getenv("YAO_TEST_MCP") != "true" { t.Skip("Skipping MCP integration test (set YAO_TEST_MCP=true to run)") } ctx := newTestContext() // Create generator with MCP mode gen := NewGenerator("mcp:search.generate_querydsl", nil) t.Run("generate_via_mcp", func(t *testing.T) { input := &Input{ Query: "find active users", ModelIDs: []string{"user"}, Limit: 15, } result, err := gen.Generate(ctx, input) assert.NoError(t, err) assert.NotNil(t, result) assert.NotNil(t, result.DSL) // Verify fixed structure is correctly parsed assert.Len(t, result.DSL.Select, 3) assert.Len(t, result.DSL.Wheres, 1) assert.Len(t, result.DSL.Orders, 1) assert.Equal(t, float64(15), result.DSL.Limit) assert.Contains(t, result.Explain, "find active users") }) t.Run("allowed_fields_validation", func(t *testing.T) { input := &Input{ Query: "find users", ModelIDs: []string{"user"}, AllowedFields: []string{"id", "name"}, // Only allow id and name Limit: 10, } result, err := gen.Generate(ctx, input) assert.NoError(t, err) assert.NotNil(t, result) assert.NotNil(t, result.DSL) // "status" field should be filtered out from select and wheres // since it's not in AllowedFields for _, expr := range result.DSL.Select { assert.Contains(t, []string{"id", "name"}, expr.Field) } // Should have warning about removed fields assert.NotEmpty(t, result.Warnings) }) } func TestMCPProvider_Generate_WithRetry(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() ctx := newTestContext() // Create MCP provider for search.generate_querydsl_with_retry // This tool returns invalid DSL on first call, valid on second provider, err := NewMCPProvider("search.generate_querydsl_with_retry") assert.NoError(t, err) assert.NotNil(t, provider) t.Run("retry_on_lint_failure", func(t *testing.T) { input := &Input{ Query: "test retry mechanism", ModelIDs: []string{"user"}, Limit: 10, } // This should succeed after retry // First call returns invalid DSL (missing 'from') // Second call (with lint_errors) returns valid DSL result, err := provider.Generate(ctx, input) assert.NoError(t, err) assert.NotNil(t, result) if result != nil && result.DSL != nil { // Should have valid DSL after retry assert.NotNil(t, result.DSL.From, "DSL should have 'from' field after retry") // Explain should indicate this was fixed after receiving lint errors assert.Contains(t, result.Explain, "fixed after receiving lint errors") } }) } ================================================ FILE: agent/search/nlp/querydsl/types.go ================================================ package querydsl import ( "github.com/yaoapp/gou/query/gou" "github.com/yaoapp/yao/agent/search/types" ) // Input contains all information needed to generate QueryDSL type Input struct { Query string // Natural language query ModelIDs []string // Target model IDs (e.g., ["user", "order", "product"]) Scenario types.ScenarioType // QueryDSL scenario: "filter", "aggregation", "join", "complex" Wheres []gou.Where // Pre-defined filters (optional) Orders gou.Orders // Sort orders (optional) AllowedFields []string // Allowed fields whitelist (optional, for security validation) Limit int // Max results ExtraParams map[string]interface{} // Additional parameters } // Result represents the result of QueryDSL generation type Result struct { DSL *gou.QueryDSL `json:"dsl"` // Generated QueryDSL (supports joins) Explain string `json:"explain,omitempty"` // Human-readable explanation Warnings []string `json:"warnings,omitempty"` // Any warnings during generation } ================================================ FILE: agent/search/reference.go ================================================ package search import ( "fmt" "strings" "github.com/yaoapp/yao/agent/search/types" ) // DefaultCitationPrompt is the default prompt for citation instructions const DefaultCitationPrompt = `You have access to reference data in tags. Each has: - id: Citation identifier (integer) - type: Data type (web/kb/db) - weight: Relevance weight (1.0=highest priority, 0.6=lowest) - source: Origin (user=user-provided, hook=assistant-searched, auto=auto-searched) Prioritize higher-weight references when answering. When citing a reference, use this exact HTML format: [{id}] Example: According to the product data[1], the price is $999.` // BuildReferences converts search results to unified Reference format func BuildReferences(results []*types.Result) []*types.Reference { var refs []*types.Reference for _, result := range results { if result == nil { continue } for _, item := range result.Items { if item == nil { continue } refs = append(refs, &types.Reference{ ID: item.CitationID, Type: item.Type, Source: item.Source, Weight: item.Weight, Score: item.Score, Title: item.Title, Content: item.Content, URL: item.URL, }) } } return refs } // FormatReferencesXML formats references as XML for LLM context func FormatReferencesXML(refs []*types.Reference) string { if len(refs) == 0 { return "" } var sb strings.Builder sb.WriteString("\n") for _, ref := range refs { if ref == nil { continue } sb.WriteString(fmt.Sprintf(``, ref.ID, ref.Type, ref.Weight, ref.Source)) sb.WriteString("\n") if ref.Title != "" { sb.WriteString(ref.Title) sb.WriteString("\n") } sb.WriteString(ref.Content) if ref.URL != "" { sb.WriteString("\nURL: ") sb.WriteString(ref.URL) } sb.WriteString("\n\n") } sb.WriteString("") return sb.String() } // GetCitationPrompt returns the citation instruction prompt func GetCitationPrompt(cfg *types.CitationConfig) string { if cfg == nil { return DefaultCitationPrompt } if cfg.CustomPrompt != "" { return cfg.CustomPrompt } return DefaultCitationPrompt } // BuildReferenceContext builds the complete reference context for LLM func BuildReferenceContext(results []*types.Result, cfg *types.CitationConfig) *types.ReferenceContext { refs := BuildReferences(results) return &types.ReferenceContext{ References: refs, XML: FormatReferencesXML(refs), Prompt: GetCitationPrompt(cfg), } } ================================================ FILE: agent/search/reference_test.go ================================================ package search import ( "strings" "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/search/types" ) func TestBuildReferences(t *testing.T) { tests := []struct { name string results []*types.Result expected int }{ { name: "nil results", results: nil, expected: 0, }, { name: "empty results", results: []*types.Result{}, expected: 0, }, { name: "single result with items", results: []*types.Result{ { Type: types.SearchTypeWeb, Query: "test query", Items: []*types.ResultItem{ { CitationID: "1", Type: types.SearchTypeWeb, Source: types.SourceAuto, Weight: 0.6, Score: 0.9, Title: "Test Title", Content: "Test content", URL: "https://example.com", }, { CitationID: "2", Type: types.SearchTypeWeb, Source: types.SourceAuto, Weight: 0.6, Score: 0.8, Title: "Test Title 2", Content: "Test content 2", URL: "https://example2.com", }, }, }, }, expected: 2, }, { name: "multiple results", results: []*types.Result{ { Type: types.SearchTypeWeb, Items: []*types.ResultItem{ {CitationID: "1", Type: types.SearchTypeWeb, Content: "Web content"}, }, }, { Type: types.SearchTypeKB, Items: []*types.ResultItem{ {CitationID: "2", Type: types.SearchTypeKB, Content: "KB content"}, }, }, { Type: types.SearchTypeDB, Items: []*types.ResultItem{ {CitationID: "3", Type: types.SearchTypeDB, Content: "DB content"}, }, }, }, expected: 3, }, { name: "result with nil items", results: []*types.Result{ { Type: types.SearchTypeWeb, Items: []*types.ResultItem{ {CitationID: "1", Content: "Content 1"}, nil, {CitationID: "2", Content: "Content 2"}, }, }, }, expected: 2, }, { name: "nil result in slice", results: []*types.Result{ { Type: types.SearchTypeWeb, Items: []*types.ResultItem{ {CitationID: "1", Content: "Content"}, }, }, nil, { Type: types.SearchTypeKB, Items: []*types.ResultItem{ {CitationID: "2", Content: "Content 2"}, }, }, }, expected: 2, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { refs := BuildReferences(tt.results) assert.Equal(t, tt.expected, len(refs)) }) } } func TestBuildReferences_FieldMapping(t *testing.T) { item := &types.ResultItem{ CitationID: "1", Type: types.SearchTypeWeb, Source: types.SourceHook, Weight: 0.8, Score: 0.95, Title: "Test Title", Content: "Test Content", URL: "https://example.com", } results := []*types.Result{ {Items: []*types.ResultItem{item}}, } refs := BuildReferences(results) assert.Equal(t, 1, len(refs)) ref := refs[0] assert.Equal(t, "1", ref.ID) assert.Equal(t, types.SearchTypeWeb, ref.Type) assert.Equal(t, types.SourceHook, ref.Source) assert.Equal(t, 0.8, ref.Weight) assert.Equal(t, 0.95, ref.Score) assert.Equal(t, "Test Title", ref.Title) assert.Equal(t, "Test Content", ref.Content) assert.Equal(t, "https://example.com", ref.URL) } func TestFormatReferencesXML(t *testing.T) { tests := []struct { name string refs []*types.Reference contains []string excludes []string }{ { name: "nil refs", refs: nil, contains: []string{}, excludes: []string{""}, }, { name: "empty refs", refs: []*types.Reference{}, contains: []string{}, excludes: []string{""}, }, { name: "single ref with all fields", refs: []*types.Reference{ { ID: "1", Type: types.SearchTypeWeb, Source: types.SourceUser, Weight: 1.0, Score: 0.9, Title: "Test Title", Content: "Test Content", URL: "https://example.com", }, }, contains: []string{ "", "", ``, "", "Test Title", "Test Content", "URL: https://example.com", }, }, { name: "ref without title", refs: []*types.Reference{ { ID: "1", Type: types.SearchTypeKB, Source: types.SourceHook, Weight: 0.8, Content: "Content without title", }, }, contains: []string{ ``, "Content without title", }, excludes: []string{ "URL:", }, }, { name: "ref without URL", refs: []*types.Reference{ { ID: "1", Type: types.SearchTypeDB, Source: types.SourceAuto, Weight: 0.6, Title: "DB Record", Content: "Database content", }, }, contains: []string{ ``, "DB Record", "Database content", }, excludes: []string{ "URL:", }, }, { name: "multiple refs", refs: []*types.Reference{ {ID: "1", Type: types.SearchTypeWeb, Source: types.SourceUser, Weight: 1.0, Content: "Content 1"}, {ID: "2", Type: types.SearchTypeKB, Source: types.SourceHook, Weight: 0.8, Content: "Content 2"}, {ID: "3", Type: types.SearchTypeDB, Source: types.SourceAuto, Weight: 0.6, Content: "Content 3"}, }, contains: []string{ "", "", `id="1"`, `id="2"`, `id="3"`, "Content 1", "Content 2", "Content 3", }, }, { name: "nil ref in slice", refs: []*types.Reference{ {ID: "1", Type: types.SearchTypeWeb, Weight: 1.0, Content: "Content 1"}, nil, {ID: "2", Type: types.SearchTypeKB, Weight: 0.8, Content: "Content 2"}, }, contains: []string{ `id="1"`, `id="2"`, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { xml := FormatReferencesXML(tt.refs) for _, s := range tt.contains { assert.Contains(t, xml, s, "expected XML to contain: %s", s) } for _, s := range tt.excludes { assert.NotContains(t, xml, s, "expected XML to not contain: %s", s) } }) } } func TestFormatReferencesXML_Structure(t *testing.T) { refs := []*types.Reference{ { ID: "1", Type: types.SearchTypeWeb, Source: types.SourceUser, Weight: 1.0, Title: "Title", Content: "Content", URL: "https://example.com", }, } xml := FormatReferencesXML(refs) // Check structure assert.True(t, strings.HasPrefix(xml, "\n")) assert.True(t, strings.HasSuffix(xml, "")) assert.Contains(t, xml, "\n") } func TestGetCitationPrompt(t *testing.T) { tests := []struct { name string cfg *types.CitationConfig expected string }{ { name: "nil config", cfg: nil, expected: DefaultCitationPrompt, }, { name: "empty config", cfg: &types.CitationConfig{}, expected: DefaultCitationPrompt, }, { name: "config with custom prompt", cfg: &types.CitationConfig{ CustomPrompt: "Custom citation instructions", }, expected: "Custom citation instructions", }, { name: "config with empty custom prompt", cfg: &types.CitationConfig{ CustomPrompt: "", }, expected: DefaultCitationPrompt, }, { name: "config with format but no custom prompt", cfg: &types.CitationConfig{ Format: "[{id}]", }, expected: DefaultCitationPrompt, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { prompt := GetCitationPrompt(tt.cfg) assert.Equal(t, tt.expected, prompt) }) } } func TestDefaultCitationPrompt(t *testing.T) { // Verify default prompt contains key instructions assert.Contains(t, DefaultCitationPrompt, "") assert.Contains(t, DefaultCitationPrompt, "id: Citation identifier") assert.Contains(t, DefaultCitationPrompt, "type: Data type") assert.Contains(t, DefaultCitationPrompt, "weight: Relevance weight") assert.Contains(t, DefaultCitationPrompt, "source: Origin") assert.Contains(t, DefaultCitationPrompt, `") assert.Contains(t, ctx.XML, `id="1"`) assert.Equal(t, DefaultCitationPrompt, ctx.Prompt) }) t.Run("with custom prompt config", func(t *testing.T) { cfg := &types.CitationConfig{ CustomPrompt: "Custom prompt", } ctx := BuildReferenceContext(results, cfg) assert.NotNil(t, ctx) assert.Equal(t, "Custom prompt", ctx.Prompt) }) t.Run("with empty results", func(t *testing.T) { ctx := BuildReferenceContext([]*types.Result{}, nil) assert.NotNil(t, ctx) assert.Equal(t, 0, len(ctx.References)) assert.Equal(t, "", ctx.XML) assert.Equal(t, DefaultCitationPrompt, ctx.Prompt) }) } func TestBuildReferenceContext_Integration(t *testing.T) { // Simulate a real-world scenario with multiple search types results := []*types.Result{ { Type: types.SearchTypeWeb, Query: "AI developments", Items: []*types.ResultItem{ { CitationID: "1", Type: types.SearchTypeWeb, Source: types.SourceAuto, Weight: 0.6, Score: 0.95, Title: "OpenAI Announces GPT-5", Content: "OpenAI has announced the development of GPT-5...", URL: "https://news.example.com/gpt5", }, }, }, { Type: types.SearchTypeKB, Query: "AI developments", Items: []*types.ResultItem{ { CitationID: "2", Type: types.SearchTypeKB, Source: types.SourceHook, Weight: 0.8, Score: 0.88, Title: "Internal AI Research Notes", Content: "Our internal research on AI capabilities...", }, }, }, { Type: types.SearchTypeDB, Query: "AI developments", Items: []*types.ResultItem{ { CitationID: "3", Type: types.SearchTypeDB, Source: types.SourceUser, Weight: 1.0, Score: 0.92, Title: "Product: AI Assistant", Content: "Name: AI Assistant\nPrice: $99\nCategory: Software", }, }, }, } ctx := BuildReferenceContext(results, nil) // Verify all references are included assert.Equal(t, 3, len(ctx.References)) // Verify XML contains all references assert.Contains(t, ctx.XML, `id="1"`) assert.Contains(t, ctx.XML, `id="2"`) assert.Contains(t, ctx.XML, `id="3"`) // Verify different source types are represented assert.Contains(t, ctx.XML, `source="auto"`) assert.Contains(t, ctx.XML, `source="hook"`) assert.Contains(t, ctx.XML, `source="user"`) // Verify different search types are represented assert.Contains(t, ctx.XML, `type="web"`) assert.Contains(t, ctx.XML, `type="kb"`) assert.Contains(t, ctx.XML, `type="db"`) } ================================================ FILE: agent/search/registry.go ================================================ package search import ( "github.com/yaoapp/yao/agent/search/interfaces" "github.com/yaoapp/yao/agent/search/types" ) // Registry manages search handlers type Registry struct { handlers map[types.SearchType]interfaces.Handler } // NewRegistry creates a new handler registry func NewRegistry() *Registry { return &Registry{ handlers: make(map[types.SearchType]interfaces.Handler), } } // Register registers a handler for a search type func (r *Registry) Register(handler interfaces.Handler) { r.handlers[handler.Type()] = handler } // Get returns the handler for a search type func (r *Registry) Get(t types.SearchType) (interfaces.Handler, bool) { h, ok := r.handlers[t] return h, ok } ================================================ FILE: agent/search/registry_test.go ================================================ package search import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/search/handlers/db" "github.com/yaoapp/yao/agent/search/handlers/kb" "github.com/yaoapp/yao/agent/search/handlers/web" "github.com/yaoapp/yao/agent/search/types" ) func TestNewRegistry(t *testing.T) { r := NewRegistry() assert.NotNil(t, r) assert.NotNil(t, r.handlers) assert.Equal(t, 0, len(r.handlers)) } func TestRegistry_Register(t *testing.T) { r := NewRegistry() // Register web handler webHandler := web.NewHandler("builtin", nil) r.Register(webHandler) h, ok := r.Get(types.SearchTypeWeb) assert.True(t, ok) assert.Equal(t, types.SearchTypeWeb, h.Type()) } func TestRegistry_RegisterMultiple(t *testing.T) { r := NewRegistry() // Register all handlers r.Register(web.NewHandler("builtin", nil)) r.Register(kb.NewHandler(nil)) r.Register(db.NewHandler("builtin", nil)) // Verify all are registered webH, ok := r.Get(types.SearchTypeWeb) assert.True(t, ok) assert.Equal(t, types.SearchTypeWeb, webH.Type()) kbH, ok := r.Get(types.SearchTypeKB) assert.True(t, ok) assert.Equal(t, types.SearchTypeKB, kbH.Type()) dbH, ok := r.Get(types.SearchTypeDB) assert.True(t, ok) assert.Equal(t, types.SearchTypeDB, dbH.Type()) } func TestRegistry_Get_NotFound(t *testing.T) { r := NewRegistry() h, ok := r.Get(types.SearchTypeWeb) assert.False(t, ok) assert.Nil(t, h) } func TestRegistry_RegisterOverwrite(t *testing.T) { r := NewRegistry() // Register first handler h1 := web.NewHandler("builtin", nil) r.Register(h1) // Register second handler (same type) h2 := web.NewHandler("agent", nil) r.Register(h2) // Should get the second handler h, ok := r.Get(types.SearchTypeWeb) assert.True(t, ok) assert.NotNil(t, h) } ================================================ FILE: agent/search/rerank/agent.go ================================================ package rerank import ( "encoding/json" "fmt" "strings" "github.com/yaoapp/yao/agent/caller" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // AgentProvider implements reranking by delegating to another agent // The agent should have a Next Hook that accepts rerank request and returns reordered items type AgentProvider struct { agentID string // Assistant ID to delegate to } // NewAgentProvider creates a new agent reranker func NewAgentProvider(agentID string) *AgentProvider { return &AgentProvider{agentID: agentID} } // Rerank delegates reranking to an LLM-powered assistant // The assistant receives items and query, returns reordered item IDs or items func (p *AgentProvider) Rerank(ctx *context.Context, query string, items []*types.ResultItem, opts *types.RerankOptions) ([]*types.ResultItem, error) { if ctx == nil { return nil, fmt.Errorf("context is required for agent rerank") } // Get agent via caller interface (avoids circular dependency) agent, err := caller.AgentGetterFunc(p.agentID) if err != nil { return nil, fmt.Errorf("failed to get agent %s: %w", p.agentID, err) } // Build request message with items to rerank requestData := map[string]interface{}{ "query": query, "items": items, "top_n": opts.TopN, "action": "rerank", } requestJSON, _ := json.Marshal(requestData) // Create messages for agent messages := []context.Message{ { Role: "user", Content: string(requestJSON), }, } // Call agent's Stream method with skip options (no history, no output) options := &context.Options{ Skip: &context.Skip{ History: true, Output: true, }, } response, err := agent.Stream(ctx, messages, options) if err != nil { return nil, fmt.Errorf("agent stream failed: %w", err) } // Parse response from response.Next return p.parseAgentResponse(response, items, opts) } // parseAgentResponse extracts reranked items from agent's *context.Response // Now that agent.Stream() returns *context.Response directly, // we can access fields without type assertions. // // Expected response.Next format: // { "order": ["ref_001", "ref_003", "ref_002"] } // Or: { "items": [{ "citation_id": "ref_001", ... }, ...] } func (p *AgentProvider) parseAgentResponse(response *context.Response, originalItems []*types.ResultItem, opts *types.RerankOptions) ([]*types.ResultItem, error) { if response == nil || response.Next == nil { return originalItems, nil } // Build index map for quick lookup itemMap := make(map[string]*types.ResultItem) for _, item := range originalItems { if item.CitationID != "" { itemMap[item.CitationID] = item } } // Extract response data from Next field data := extractNextData(response.Next) if data == nil { return originalItems, nil } // Try to get reranked order from data // Expected format: { "order": ["ref_001", "ref_003", "ref_002"] } // Or: { "items": [{ "citation_id": "ref_001", ... }, ...] } var reranked []*types.ResultItem // Try "order" field (list of citation IDs) if order, ok := data["order"]; ok { if orderList := toStringSlice(order); len(orderList) > 0 { for _, id := range orderList { if item, exists := itemMap[id]; exists { reranked = append(reranked, item) delete(itemMap, id) // Avoid duplicates } } // Append remaining items not in order for _, item := range originalItems { if _, exists := itemMap[item.CitationID]; exists { reranked = append(reranked, item) } } } } // Try "items" field (full items or items with citation_id) if len(reranked) == 0 { if items, ok := data["items"]; ok { if itemsList := toItemsList(items); len(itemsList) > 0 { for _, respItem := range itemsList { // Check if it's just a reference or full item if citationID, ok := respItem["citation_id"].(string); ok { if item, exists := itemMap[citationID]; exists { reranked = append(reranked, item) delete(itemMap, citationID) } } } // Append remaining items for _, item := range originalItems { if _, exists := itemMap[item.CitationID]; exists { reranked = append(reranked, item) } } } } } // If no valid response, return original items if len(reranked) == 0 { reranked = originalItems } // Apply top N if opts.TopN > 0 && opts.TopN < len(reranked) { reranked = reranked[:opts.TopN] } return reranked, nil } // extractNextData extracts the actual data from response.Next field // Handles nested structures like { "data": { ... } } func extractNextData(next interface{}) map[string]interface{} { if next == nil { return nil } switch v := next.(type) { case map[string]interface{}: // Check for "data" wrapper if data, ok := v["data"].(map[string]interface{}); ok { return data } return v case string: // Try to parse as JSON var data map[string]interface{} if err := json.Unmarshal([]byte(v), &data); err == nil { return extractNextData(data) } } // Try to handle other types by converting to JSON and back if bytes, err := json.Marshal(next); err == nil { var data map[string]interface{} if err := json.Unmarshal(bytes, &data); err == nil { return extractNextData(data) } } return nil } // toStringSlice converts interface to string slice func toStringSlice(v interface{}) []string { switch val := v.(type) { case []string: return val case []interface{}: result := make([]string, 0, len(val)) for _, item := range val { if s, ok := item.(string); ok { result = append(result, s) } } return result } return nil } // toItemsList converts interface to list of maps func toItemsList(v interface{}) []map[string]interface{} { switch val := v.(type) { case []map[string]interface{}: return val case []interface{}: result := make([]map[string]interface{}, 0, len(val)) for _, item := range val { if m, ok := item.(map[string]interface{}); ok { result = append(result, m) } } return result } return nil } // extractAgentID extracts assistant ID from uses.rerank value // For backward compatibility, strips any prefix if present func extractAgentID(usesRerank string) string { // Remove any prefix like "agent:" if present if strings.HasPrefix(usesRerank, "agent:") { return strings.TrimPrefix(usesRerank, "agent:") } return usesRerank } ================================================ FILE: agent/search/rerank/agent_test.go ================================================ package rerank_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/rerank" "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/testutils" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" ) func TestAgentProviderWithAssistantConfig(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) // Load the rerank-agent assistant ast, err := assistant.Get("tests.rerank-agent") require.NoError(t, err) require.NotNil(t, ast) // Create test context ctx := newTestContext(t) // Create provider with test assistant provider := rerank.NewAgentProvider("tests.rerank-agent") items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0, Title: "First"}, {CitationID: "ref_002", Score: 0.8, Weight: 1.0, Title: "Second"}, {CitationID: "ref_003", Score: 0.7, Weight: 1.0, Title: "Third"}, } result, err := provider.Rerank(ctx, "test query", items, &types.RerankOptions{TopN: 10}) require.NoError(t, err) assert.NotEmpty(t, result) // The mock agent reverses the order // So we expect: ref_003, ref_002, ref_001 assert.Len(t, result, 3) assert.Equal(t, "ref_003", result[0].CitationID) assert.Equal(t, "ref_002", result[1].CitationID) assert.Equal(t, "ref_001", result[2].CitationID) } func TestAgentProviderWithTopN(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ctx := newTestContext(t) provider := rerank.NewAgentProvider("tests.rerank-agent") items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0}, {CitationID: "ref_002", Score: 0.8, Weight: 1.0}, {CitationID: "ref_003", Score: 0.7, Weight: 1.0}, {CitationID: "ref_004", Score: 0.6, Weight: 1.0}, {CitationID: "ref_005", Score: 0.5, Weight: 1.0}, } // Request top 2 only result, err := provider.Rerank(ctx, "test query", items, &types.RerankOptions{TopN: 2}) require.NoError(t, err) assert.Len(t, result, 2) } func TestAgentProviderWithoutContext(t *testing.T) { provider := rerank.NewAgentProvider("tests.rerank-agent") items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0}, } _, err := provider.Rerank(nil, "test query", items, &types.RerankOptions{TopN: 10}) assert.Error(t, err) assert.Contains(t, err.Error(), "context is required") } func TestAgentProviderAgentNotFound(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ctx := newTestContext(t) provider := rerank.NewAgentProvider("non-existent-agent") items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0}, } _, err := provider.Rerank(ctx, "test query", items, &types.RerankOptions{TopN: 10}) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to get agent") } func TestAgentProviderEmptyItems(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ctx := newTestContext(t) provider := rerank.NewAgentProvider("tests.rerank-agent") result, err := provider.Rerank(ctx, "test query", []*types.ResultItem{}, &types.RerankOptions{TopN: 10}) require.NoError(t, err) assert.Empty(t, result) } // newTestContext creates a test context with required fields func newTestContext(t *testing.T) *context.Context { t.Helper() authorized := &oauthTypes.AuthorizedInfo{ UserID: "test-user", } chatID := "test-chat-rerank" return context.New(t.Context(), authorized, chatID) } ================================================ FILE: agent/search/rerank/builtin.go ================================================ package rerank import ( "sort" "github.com/yaoapp/yao/agent/search/types" ) // BuiltinReranker implements simple score-based reranking // For production use cases requiring semantic understanding, use Agent or MCP mode. type BuiltinReranker struct{} // NewBuiltinReranker creates a new builtin reranker func NewBuiltinReranker() *BuiltinReranker { return &BuiltinReranker{} } // Rerank sorts items by weighted score (score * weight) and returns top N // This is a simple implementation without semantic understanding. func (r *BuiltinReranker) Rerank(query string, items []*types.ResultItem, opts *types.RerankOptions) ([]*types.ResultItem, error) { if len(items) == 0 { return items, nil } // Calculate weighted scores type scoredItem struct { item *types.ResultItem weightedScore float64 } scored := make([]scoredItem, len(items)) for i, item := range items { // Weighted score = base score * source weight // Higher weight sources (user=1.0) get priority over lower (auto=0.6) weight := item.Weight if weight == 0 { weight = 0.6 // Default weight for items without weight } scored[i] = scoredItem{ item: item, weightedScore: item.Score * weight, } } // Sort by weighted score descending sort.Slice(scored, func(i, j int) bool { return scored[i].weightedScore > scored[j].weightedScore }) // Get top N topN := opts.TopN if topN <= 0 || topN > len(scored) { topN = len(scored) } result := make([]*types.ResultItem, topN) for i := 0; i < topN; i++ { result[i] = scored[i].item } return result, nil } ================================================ FILE: agent/search/rerank/builtin_test.go ================================================ package rerank import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/search/types" ) func TestBuiltinReranker_EmptyItems(t *testing.T) { reranker := NewBuiltinReranker() result, err := reranker.Rerank("test query", []*types.ResultItem{}, &types.RerankOptions{TopN: 5}) assert.NoError(t, err) assert.Empty(t, result) } func TestBuiltinReranker_SortByWeightedScore(t *testing.T) { reranker := NewBuiltinReranker() items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.8, Weight: 0.6}, // weighted: 0.48 {CitationID: "ref_002", Score: 0.6, Weight: 1.0}, // weighted: 0.60 {CitationID: "ref_003", Score: 0.9, Weight: 0.8}, // weighted: 0.72 {CitationID: "ref_004", Score: 0.5, Weight: 1.0}, // weighted: 0.50 } result, err := reranker.Rerank("test query", items, &types.RerankOptions{TopN: 10}) assert.NoError(t, err) assert.Len(t, result, 4) // Should be sorted by weighted score: ref_003 (0.72) > ref_002 (0.60) > ref_004 (0.50) > ref_001 (0.48) assert.Equal(t, "ref_003", result[0].CitationID) assert.Equal(t, "ref_002", result[1].CitationID) assert.Equal(t, "ref_004", result[2].CitationID) assert.Equal(t, "ref_001", result[3].CitationID) } func TestBuiltinReranker_TopN(t *testing.T) { reranker := NewBuiltinReranker() items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0}, {CitationID: "ref_002", Score: 0.8, Weight: 1.0}, {CitationID: "ref_003", Score: 0.7, Weight: 1.0}, {CitationID: "ref_004", Score: 0.6, Weight: 1.0}, {CitationID: "ref_005", Score: 0.5, Weight: 1.0}, } result, err := reranker.Rerank("test query", items, &types.RerankOptions{TopN: 3}) assert.NoError(t, err) assert.Len(t, result, 3) assert.Equal(t, "ref_001", result[0].CitationID) assert.Equal(t, "ref_002", result[1].CitationID) assert.Equal(t, "ref_003", result[2].CitationID) } func TestBuiltinReranker_DefaultWeight(t *testing.T) { reranker := NewBuiltinReranker() // Items without weight should use default 0.6 items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 0}, // weighted: 0.9 * 0.6 = 0.54 {CitationID: "ref_002", Score: 0.5, Weight: 1.0}, // weighted: 0.5 * 1.0 = 0.50 } result, err := reranker.Rerank("test query", items, &types.RerankOptions{TopN: 10}) assert.NoError(t, err) assert.Len(t, result, 2) // ref_001 (0.54) > ref_002 (0.50) assert.Equal(t, "ref_001", result[0].CitationID) assert.Equal(t, "ref_002", result[1].CitationID) } func TestBuiltinReranker_TopNLargerThanItems(t *testing.T) { reranker := NewBuiltinReranker() items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0}, {CitationID: "ref_002", Score: 0.8, Weight: 1.0}, } // TopN > len(items) should return all items result, err := reranker.Rerank("test query", items, &types.RerankOptions{TopN: 10}) assert.NoError(t, err) assert.Len(t, result, 2) } func TestBuiltinReranker_ZeroTopN(t *testing.T) { reranker := NewBuiltinReranker() items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0}, {CitationID: "ref_002", Score: 0.8, Weight: 1.0}, } // TopN = 0 should return all items result, err := reranker.Rerank("test query", items, &types.RerankOptions{TopN: 0}) assert.NoError(t, err) assert.Len(t, result, 2) } func TestBuiltinReranker_SameWeightedScore(t *testing.T) { reranker := NewBuiltinReranker() // Items with same weighted score - order should be stable items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.8, Weight: 1.0}, // weighted: 0.80 {CitationID: "ref_002", Score: 0.8, Weight: 1.0}, // weighted: 0.80 {CitationID: "ref_003", Score: 0.4, Weight: 1.0}, // weighted: 0.40 } result, err := reranker.Rerank("test query", items, &types.RerankOptions{TopN: 10}) assert.NoError(t, err) assert.Len(t, result, 3) // ref_003 should be last assert.Equal(t, "ref_003", result[2].CitationID) } ================================================ FILE: agent/search/rerank/mcp.go ================================================ package rerank import ( "encoding/json" "fmt" "strings" "github.com/yaoapp/gou/mcp" gouMCPTypes "github.com/yaoapp/gou/mcp/types" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // MCPProvider implements reranking by calling an MCP tool type MCPProvider struct { serverID string // MCP server ID toolName string // Tool name } // NewMCPProvider creates a new MCP reranker // mcpRef format: "server_id.tool_name" func NewMCPProvider(mcpRef string) (*MCPProvider, error) { parts := strings.SplitN(mcpRef, ".", 2) if len(parts) != 2 { return nil, fmt.Errorf("invalid MCP format, expected 'server.tool', got '%s'", mcpRef) } return &MCPProvider{ serverID: parts[0], toolName: parts[1], }, nil } // Rerank calls MCP tool to rerank items func (p *MCPProvider) Rerank(ctx *context.Context, query string, items []*types.ResultItem, opts *types.RerankOptions) ([]*types.ResultItem, error) { if ctx == nil { return nil, fmt.Errorf("context is required for MCP rerank") } // Get MCP client client, err := mcp.Select(p.serverID) if err != nil { return nil, fmt.Errorf("MCP server %s not found: %w", p.serverID, err) } // Build arguments for MCP tool args := map[string]interface{}{ "query": query, "items": items, "top_n": opts.TopN, } // Call MCP tool result, err := client.CallTool(ctx.Context, p.toolName, args) if err != nil { return nil, fmt.Errorf("MCP tool call failed: %w", err) } // Parse result return p.parseResult(result, items, opts) } // parseResult extracts reranked items from MCP response func (p *MCPProvider) parseResult(result *gouMCPTypes.CallToolResponse, originalItems []*types.ResultItem, opts *types.RerankOptions) ([]*types.ResultItem, error) { if result == nil || len(result.Content) == 0 { return originalItems, nil } // Build index map for quick lookup itemMap := make(map[string]*types.ResultItem) for _, item := range originalItems { if item.CitationID != "" { itemMap[item.CitationID] = item } } // Extract text content from MCP response var textContent string for _, content := range result.Content { if content.Type == gouMCPTypes.ToolContentTypeText && content.Text != "" { textContent = content.Text break } } if textContent == "" { return originalItems, nil } // Parse JSON response var response map[string]interface{} if err := json.Unmarshal([]byte(textContent), &response); err != nil { // Try parsing as array of IDs var orderList []string if err := json.Unmarshal([]byte(textContent), &orderList); err == nil { return p.reorderByIDs(orderList, itemMap, originalItems, opts) } return originalItems, nil } // Try "order" field (list of citation IDs) if order, ok := response["order"]; ok { if orderList := toStringSlice(order); len(orderList) > 0 { return p.reorderByIDs(orderList, itemMap, originalItems, opts) } } // Try "items" field if items, ok := response["items"]; ok { if itemsList := toItemsList(items); len(itemsList) > 0 { return p.reorderByItems(itemsList, itemMap, originalItems, opts) } } return originalItems, nil } // reorderByIDs reorders items based on list of citation IDs func (p *MCPProvider) reorderByIDs(order []string, itemMap map[string]*types.ResultItem, originalItems []*types.ResultItem, opts *types.RerankOptions) ([]*types.ResultItem, error) { var result []*types.ResultItem // Add items in specified order for _, id := range order { if item, exists := itemMap[id]; exists { result = append(result, item) delete(itemMap, id) } } // Append remaining items for _, item := range originalItems { if _, exists := itemMap[item.CitationID]; exists { result = append(result, item) } } // Apply top N if opts.TopN > 0 && opts.TopN < len(result) { result = result[:opts.TopN] } return result, nil } // reorderByItems reorders items based on list of item references func (p *MCPProvider) reorderByItems(itemsList []map[string]interface{}, itemMap map[string]*types.ResultItem, originalItems []*types.ResultItem, opts *types.RerankOptions) ([]*types.ResultItem, error) { var result []*types.ResultItem // Add items in specified order for _, respItem := range itemsList { if citationID, ok := respItem["citation_id"].(string); ok { if item, exists := itemMap[citationID]; exists { result = append(result, item) delete(itemMap, citationID) } } } // Append remaining items for _, item := range originalItems { if _, exists := itemMap[item.CitationID]; exists { result = append(result, item) } } // Apply top N if opts.TopN > 0 && opts.TopN < len(result) { result = result[:opts.TopN] } return result, nil } ================================================ FILE: agent/search/rerank/mcp_test.go ================================================ package rerank_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/rerank" "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/testutils" oauthTypes "github.com/yaoapp/yao/openapi/oauth/types" ) func TestMCPProviderWithSearchRerank(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ctx := newMCPTestContext(t) provider, err := rerank.NewMCPProvider("search.rerank") require.NoError(t, err) items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0, Title: "First"}, {CitationID: "ref_002", Score: 0.8, Weight: 1.0, Title: "Second"}, {CitationID: "ref_003", Score: 0.7, Weight: 1.0, Title: "Third"}, } result, err := provider.Rerank(ctx, "test query", items, &types.RerankOptions{TopN: 10}) require.NoError(t, err) assert.NotEmpty(t, result) // The mock MCP reverses the order // So we expect: ref_003, ref_002, ref_001 assert.Len(t, result, 3) assert.Equal(t, "ref_003", result[0].CitationID) assert.Equal(t, "ref_002", result[1].CitationID) assert.Equal(t, "ref_001", result[2].CitationID) } func TestMCPProviderWithTopN(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ctx := newMCPTestContext(t) provider, err := rerank.NewMCPProvider("search.rerank") require.NoError(t, err) items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0}, {CitationID: "ref_002", Score: 0.8, Weight: 1.0}, {CitationID: "ref_003", Score: 0.7, Weight: 1.0}, {CitationID: "ref_004", Score: 0.6, Weight: 1.0}, {CitationID: "ref_005", Score: 0.5, Weight: 1.0}, } // Request top 2 only result, err := provider.Rerank(ctx, "test query", items, &types.RerankOptions{TopN: 2}) require.NoError(t, err) assert.Len(t, result, 2) } func TestMCPProviderInvalidFormat(t *testing.T) { _, err := rerank.NewMCPProvider("invalid-format") assert.Error(t, err) assert.Contains(t, err.Error(), "invalid MCP format") } func TestMCPProviderServerNotFound(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ctx := newMCPTestContext(t) provider, err := rerank.NewMCPProvider("nonexistent.rerank") require.NoError(t, err) items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0}, } _, err = provider.Rerank(ctx, "test query", items, &types.RerankOptions{TopN: 10}) assert.Error(t, err) assert.Contains(t, err.Error(), "not found") } func TestMCPProviderToolNotFound(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ctx := newMCPTestContext(t) provider, err := rerank.NewMCPProvider("search.nonexistent_tool") require.NoError(t, err) items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0}, } _, err = provider.Rerank(ctx, "test query", items, &types.RerankOptions{TopN: 10}) assert.Error(t, err) } func TestMCPProviderWithoutContext(t *testing.T) { provider, err := rerank.NewMCPProvider("search.rerank") require.NoError(t, err) items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0}, } _, err = provider.Rerank(nil, "test query", items, &types.RerankOptions{TopN: 10}) assert.Error(t, err) assert.Contains(t, err.Error(), "context is required") } func TestMCPProviderEmptyItems(t *testing.T) { testutils.Prepare(t) defer testutils.Clean(t) ctx := newMCPTestContext(t) provider, err := rerank.NewMCPProvider("search.rerank") require.NoError(t, err) result, err := provider.Rerank(ctx, "test query", []*types.ResultItem{}, &types.RerankOptions{TopN: 10}) require.NoError(t, err) assert.Empty(t, result) } // newMCPTestContext creates a test context with required fields func newMCPTestContext(t *testing.T) *context.Context { t.Helper() authorized := &oauthTypes.AuthorizedInfo{ UserID: "test-user", } chatID := "test-chat-rerank-mcp" return context.New(t.Context(), authorized, chatID) } ================================================ FILE: agent/search/rerank/reranker.go ================================================ // Package rerank provides result reranking for search module // Supports three modes via uses.rerank configuration: // - "builtin": Simple score-based sorting (no external dependencies) // - "": Delegate to an LLM-powered assistant for semantic reranking // - "mcp:.": Call external MCP tool // // For production use cases requiring high accuracy, use Agent or MCP mode. package rerank import ( "strings" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/types" ) // Reranker reorders search results by relevance // Mode is determined by uses.rerank configuration type Reranker struct { usesRerank string // "builtin", "", "mcp:." config *types.RerankConfig // Rerank options } // NewReranker creates a new reranker // usesRerank: value from uses.rerank config // cfg: rerank options from search config func NewReranker(usesRerank string, cfg *types.RerankConfig) *Reranker { return &Reranker{ usesRerank: usesRerank, config: cfg, } } // Rerank reorders results based on configured mode // Returns reordered items, potentially truncated to top N func (r *Reranker) Rerank(ctx *context.Context, query string, items []*types.ResultItem, opts *types.RerankOptions) ([]*types.ResultItem, error) { if len(items) == 0 { return items, nil } // Merge options with config defaults mergedOpts := r.mergeOptions(opts) switch { case r.usesRerank == "builtin" || r.usesRerank == "": return r.builtinRerank(query, items, mergedOpts) case strings.HasPrefix(r.usesRerank, "mcp:"): return r.mcpRerank(ctx, query, items, mergedOpts) default: // Assume it's an assistant ID for Agent mode return r.agentRerank(ctx, query, items, mergedOpts) } } // mergeOptions merges runtime options with config defaults func (r *Reranker) mergeOptions(opts *types.RerankOptions) *types.RerankOptions { result := &types.RerankOptions{ TopN: 10, // default } // Apply config defaults if r.config != nil { if r.config.TopN > 0 { result.TopN = r.config.TopN } } // Apply runtime options (highest priority) if opts != nil { if opts.TopN > 0 { result.TopN = opts.TopN } } return result } // builtinRerank uses simple score-based sorting func (r *Reranker) builtinRerank(query string, items []*types.ResultItem, opts *types.RerankOptions) ([]*types.ResultItem, error) { reranker := NewBuiltinReranker() return reranker.Rerank(query, items, opts) } // agentRerank delegates to an LLM-powered assistant func (r *Reranker) agentRerank(ctx *context.Context, query string, items []*types.ResultItem, opts *types.RerankOptions) ([]*types.ResultItem, error) { provider := NewAgentProvider(r.usesRerank) return provider.Rerank(ctx, query, items, opts) } // mcpRerank calls an external MCP tool func (r *Reranker) mcpRerank(ctx *context.Context, query string, items []*types.ResultItem, opts *types.RerankOptions) ([]*types.ResultItem, error) { mcpRef := strings.TrimPrefix(r.usesRerank, "mcp:") provider, err := NewMCPProvider(mcpRef) if err != nil { // Fallback to builtin on invalid MCP format return r.builtinRerank(query, items, opts) } return provider.Rerank(ctx, query, items, opts) } ================================================ FILE: agent/search/rerank/reranker_test.go ================================================ package rerank import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/search/types" ) func TestReranker_BuiltinMode(t *testing.T) { reranker := NewReranker("builtin", &types.RerankConfig{TopN: 5}) items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0}, {CitationID: "ref_002", Score: 0.8, Weight: 1.0}, {CitationID: "ref_003", Score: 0.7, Weight: 1.0}, } result, err := reranker.Rerank(nil, "test query", items, nil) assert.NoError(t, err) assert.Len(t, result, 3) assert.Equal(t, "ref_001", result[0].CitationID) } func TestReranker_EmptyUsesRerank(t *testing.T) { // Empty usesRerank should use builtin reranker := NewReranker("", &types.RerankConfig{TopN: 5}) items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0}, } result, err := reranker.Rerank(nil, "test query", items, nil) assert.NoError(t, err) assert.Len(t, result, 1) } func TestReranker_MergeOptions(t *testing.T) { // Config sets TopN = 5 reranker := NewReranker("builtin", &types.RerankConfig{TopN: 5}) items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0}, {CitationID: "ref_002", Score: 0.8, Weight: 1.0}, {CitationID: "ref_003", Score: 0.7, Weight: 1.0}, {CitationID: "ref_004", Score: 0.6, Weight: 1.0}, {CitationID: "ref_005", Score: 0.5, Weight: 1.0}, {CitationID: "ref_006", Score: 0.4, Weight: 1.0}, } // Runtime opts override config result, err := reranker.Rerank(nil, "test query", items, &types.RerankOptions{TopN: 3}) assert.NoError(t, err) assert.Len(t, result, 3) } func TestReranker_ConfigTopN(t *testing.T) { // Config sets TopN = 3 reranker := NewReranker("builtin", &types.RerankConfig{TopN: 3}) items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0}, {CitationID: "ref_002", Score: 0.8, Weight: 1.0}, {CitationID: "ref_003", Score: 0.7, Weight: 1.0}, {CitationID: "ref_004", Score: 0.6, Weight: 1.0}, {CitationID: "ref_005", Score: 0.5, Weight: 1.0}, } // No runtime opts, should use config TopN result, err := reranker.Rerank(nil, "test query", items, nil) assert.NoError(t, err) assert.Len(t, result, 3) } func TestReranker_NilConfig(t *testing.T) { reranker := NewReranker("builtin", nil) items := []*types.ResultItem{ {CitationID: "ref_001", Score: 0.9, Weight: 1.0}, } result, err := reranker.Rerank(nil, "test query", items, nil) assert.NoError(t, err) assert.Len(t, result, 1) } func TestReranker_EmptyItems(t *testing.T) { reranker := NewReranker("builtin", &types.RerankConfig{TopN: 5}) result, err := reranker.Rerank(nil, "test query", []*types.ResultItem{}, nil) assert.NoError(t, err) assert.Empty(t, result) } ================================================ FILE: agent/search/search.go ================================================ package search import ( "sync" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/search/handlers/db" "github.com/yaoapp/yao/agent/search/handlers/kb" "github.com/yaoapp/yao/agent/search/handlers/web" "github.com/yaoapp/yao/agent/search/interfaces" "github.com/yaoapp/yao/agent/search/rerank" "github.com/yaoapp/yao/agent/search/types" ) // Searcher is the main search implementation type Searcher struct { config *types.Config // Merged config (global + assistant) handlers map[types.SearchType]interfaces.Handler reranker *rerank.Reranker citation *CitationGenerator } // Uses contains the search-specific uses configuration // These are extracted from context.Uses and search config type Uses struct { Search string // "builtin", "disabled", "", "mcp:." Web string // "builtin", "", "mcp:." Keyword string // "builtin", "", "mcp:." QueryDSL string // "builtin", "", "mcp:." Rerank string // "builtin", "", "mcp:." } // New creates a new Searcher instance // cfg: merged config from agent/load.go + assistant config // uses: merged uses configuration (global → assistant → hook) func New(cfg *types.Config, uses *Uses) *Searcher { if uses == nil { uses = &Uses{} } if cfg == nil { cfg = &types.Config{} } return &Searcher{ config: cfg, handlers: map[types.SearchType]interfaces.Handler{ types.SearchTypeWeb: web.NewHandler(uses.Web, cfg.Web), types.SearchTypeKB: kb.NewHandler(cfg.KB), types.SearchTypeDB: db.NewHandler(uses.QueryDSL, cfg.DB), }, reranker: rerank.NewReranker(uses.Rerank, cfg.Rerank), citation: NewCitationGenerator(), } } // Search executes a single search request func (s *Searcher) Search(ctx *context.Context, req *types.Request) (*types.Result, error) { handler, ok := s.handlers[req.Type] if !ok { return &types.Result{Error: "unsupported search type"}, nil } // Execute search - use context if handler supports it var result *types.Result var err error if ctxHandler, ok := handler.(interfaces.ContextHandler); ok { result, err = ctxHandler.SearchWithContext(ctx, req) } else { result, err = handler.Search(req) } if err != nil { return &types.Result{Error: err.Error()}, nil } // Assign weights based on source for _, item := range result.Items { item.Weight = s.config.GetWeight(req.Source) } // Rerank if requested if req.Rerank != nil && s.reranker != nil { result.Items, _ = s.reranker.Rerank(ctx, req.Query, result.Items, req.Rerank) } // Generate citation IDs for _, item := range result.Items { item.CitationID = s.citation.Next() } return result, nil } // All executes all searches and waits for all to complete (like Promise.all) func (s *Searcher) All(ctx *context.Context, reqs []*types.Request) ([]*types.Result, error) { if len(reqs) == 0 { return []*types.Result{}, nil } return s.parallelAll(ctx, reqs) } // Any returns as soon as any search succeeds with results (like Promise.any) func (s *Searcher) Any(ctx *context.Context, reqs []*types.Request) ([]*types.Result, error) { if len(reqs) == 0 { return []*types.Result{}, nil } return s.parallelAny(ctx, reqs) } // Race returns as soon as any search completes (like Promise.race) func (s *Searcher) Race(ctx *context.Context, reqs []*types.Request) ([]*types.Result, error) { if len(reqs) == 0 { return []*types.Result{}, nil } return s.parallelRace(ctx, reqs) } // parallelAll executes all searches and waits for all to complete (like Promise.all) func (s *Searcher) parallelAll(ctx *context.Context, reqs []*types.Request) ([]*types.Result, error) { results := make([]*types.Result, len(reqs)) var wg sync.WaitGroup var mu sync.Mutex for i, req := range reqs { wg.Add(1) go func(idx int, r *types.Request) { defer wg.Done() defer func() { if err := recover(); err != nil { mu.Lock() results[idx] = &types.Result{Error: "search panic recovered"} mu.Unlock() } }() result, err := s.Search(ctx, r) mu.Lock() if err != nil { results[idx] = &types.Result{Error: err.Error()} } else if result == nil { results[idx] = &types.Result{Error: "empty result"} } else { results[idx] = result } mu.Unlock() }(i, req) } wg.Wait() return results, nil } // parallelAny returns as soon as any search succeeds (has results) (like Promise.any) // Other searches continue in background but results are discarded func (s *Searcher) parallelAny(ctx *context.Context, reqs []*types.Request) ([]*types.Result, error) { results := make([]*types.Result, len(reqs)) resultChan := make(chan struct { idx int result *types.Result }, len(reqs)) var wg sync.WaitGroup done := make(chan struct{}) for i, req := range reqs { wg.Add(1) go func(idx int, r *types.Request) { defer wg.Done() // Check if done before starting select { case <-done: return default: } result, _ := s.Search(ctx, r) // Try to send result select { case <-done: // Already found a successful result case resultChan <- struct { idx int result *types.Result }{idx, result}: } }(i, req) } // Close channel when all goroutines complete go func() { wg.Wait() close(resultChan) }() // Collect results until we find one with items (success) var foundSuccess bool for res := range resultChan { results[res.idx] = res.result // Check if this result has items (success = has results and no error) if !foundSuccess && res.result != nil && len(res.result.Items) > 0 && res.result.Error == "" { foundSuccess = true close(done) // Signal other goroutines to stop } } // All goroutines have completed (resultChan is closed) return results, nil } // parallelRace returns as soon as any search completes (like Promise.race) // Returns immediately when first result arrives, regardless of success/failure // Note: Still waits for all goroutines to complete before returning to avoid resource leaks func (s *Searcher) parallelRace(ctx *context.Context, reqs []*types.Request) ([]*types.Result, error) { results := make([]*types.Result, len(reqs)) resultChan := make(chan struct { idx int result *types.Result }, len(reqs)) var wg sync.WaitGroup done := make(chan struct{}) for i, req := range reqs { wg.Add(1) go func(idx int, r *types.Request) { defer wg.Done() // Check if done before starting select { case <-done: return default: } result, _ := s.Search(ctx, r) // Try to send result select { case <-done: // Already got first result case resultChan <- struct { idx int result *types.Result }{idx, result}: } }(i, req) } // Close channel when all goroutines complete go func() { wg.Wait() close(resultChan) }() // Get first result and signal others to stop var gotFirst bool for res := range resultChan { results[res.idx] = res.result if !gotFirst { gotFirst = true close(done) // Signal other goroutines to stop } } // All goroutines have completed (resultChan is closed) return results, nil } // BuildReferences converts search results to unified Reference format func (s *Searcher) BuildReferences(results []*types.Result) []*types.Reference { return BuildReferences(results) } ================================================ FILE: agent/search/search_test.go ================================================ package search import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent/search/types" ) func TestNew(t *testing.T) { t.Run("with nil config and uses", func(t *testing.T) { s := New(nil, nil) assert.NotNil(t, s) assert.NotNil(t, s.config) assert.NotNil(t, s.handlers) assert.NotNil(t, s.citation) assert.Equal(t, 3, len(s.handlers)) // web, kb, db }) t.Run("with config", func(t *testing.T) { cfg := &types.Config{ Web: &types.WebConfig{ Provider: "tavily", MaxResults: 10, }, KB: &types.KBConfig{ Collections: []string{"docs"}, Threshold: 0.8, }, DB: &types.DBConfig{ Models: []string{"product"}, MaxResults: 20, }, } s := New(cfg, nil) assert.NotNil(t, s) assert.Equal(t, cfg, s.config) }) t.Run("with uses", func(t *testing.T) { uses := &Uses{ Search: "builtin", Web: "builtin", Keyword: "builtin", QueryDSL: "builtin", Rerank: "builtin", } s := New(nil, uses) assert.NotNil(t, s) }) } func TestSearcher_Search_UnsupportedType(t *testing.T) { s := New(nil, nil) req := &types.Request{ Type: "unsupported", Query: "test", } result, err := s.Search(nil, req) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, "unsupported search type", result.Error) } func TestSearcher_Search_Web(t *testing.T) { // Note: This test uses skeleton handlers that return empty results // Real tests with actual API calls are in handlers/web/*_test.go cfg := &types.Config{ Web: &types.WebConfig{ Provider: "tavily", }, } s := New(cfg, &Uses{Web: "builtin"}) req := &types.Request{ Type: types.SearchTypeWeb, Query: "test query", Source: types.SourceAuto, } result, err := s.Search(nil, req) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, types.SearchTypeWeb, result.Type) assert.Equal(t, "test query", result.Query) // Note: actual result depends on API key availability } func TestSearcher_Search_KB(t *testing.T) { cfg := &types.Config{ KB: &types.KBConfig{ Collections: []string{"docs"}, Threshold: 0.7, }, } s := New(cfg, nil) req := &types.Request{ Type: types.SearchTypeKB, Query: "test query", Source: types.SourceHook, Collections: []string{"docs"}, } result, err := s.Search(nil, req) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, types.SearchTypeKB, result.Type) assert.Equal(t, "test query", result.Query) assert.Equal(t, types.SourceHook, result.Source) // Skeleton returns empty items assert.Equal(t, 0, len(result.Items)) } func TestSearcher_Search_DB(t *testing.T) { cfg := &types.Config{ DB: &types.DBConfig{ Models: []string{"product"}, MaxResults: 20, }, } s := New(cfg, &Uses{QueryDSL: "builtin"}) req := &types.Request{ Type: types.SearchTypeDB, Query: "find products under $100", Source: types.SourceUser, Models: []string{"product"}, } result, err := s.Search(nil, req) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, types.SearchTypeDB, result.Type) assert.Equal(t, "find products under $100", result.Query) assert.Equal(t, types.SourceUser, result.Source) // Skeleton returns empty items assert.Equal(t, 0, len(result.Items)) } func TestSearcher_Search_WeightAssignment(t *testing.T) { cfg := &types.Config{ KB: &types.KBConfig{ Collections: []string{"docs"}, }, Weights: &types.WeightsConfig{ User: 1.0, Hook: 0.8, Auto: 0.6, }, } s := New(cfg, nil) // Test with different sources sources := []struct { source types.SourceType weight float64 }{ {types.SourceUser, 1.0}, {types.SourceHook, 0.8}, {types.SourceAuto, 0.6}, } for _, tc := range sources { req := &types.Request{ Type: types.SearchTypeKB, Query: "test", Source: tc.source, Collections: []string{"docs"}, } result, err := s.Search(nil, req) assert.NoError(t, err) assert.NotNil(t, result) // Items are empty in skeleton, so weight assignment can't be verified here // This test ensures the code path works without error } } func TestSearcher_All(t *testing.T) { cfg := &types.Config{ KB: &types.KBConfig{ Collections: []string{"docs"}, }, DB: &types.DBConfig{ Models: []string{"product"}, }, } s := New(cfg, nil) reqs := []*types.Request{ { Type: types.SearchTypeKB, Query: "KB query", Source: types.SourceAuto, Collections: []string{"docs"}, }, { Type: types.SearchTypeDB, Query: "DB query", Source: types.SourceAuto, Models: []string{"product"}, }, } // Test All() - waits for all searches to complete (like Promise.all) results, err := s.All(nil, reqs) assert.NoError(t, err) assert.Equal(t, 2, len(results)) // Verify each result corresponds to its request assert.Equal(t, types.SearchTypeKB, results[0].Type) assert.Equal(t, "KB query", results[0].Query) assert.Equal(t, types.SearchTypeDB, results[1].Type) assert.Equal(t, "DB query", results[1].Query) } func TestSearcher_Any(t *testing.T) { cfg := &types.Config{ KB: &types.KBConfig{ Collections: []string{"docs"}, }, DB: &types.DBConfig{ Models: []string{"product"}, }, } s := New(cfg, nil) reqs := []*types.Request{ { Type: types.SearchTypeKB, Query: "KB query", Source: types.SourceAuto, Collections: []string{"docs"}, }, { Type: types.SearchTypeDB, Query: "DB query", Source: types.SourceAuto, Models: []string{"product"}, }, } // Test Any() - returns when first search has results (like Promise.any) // Note: With skeleton handlers returning empty results, this will wait for all results, err := s.Any(nil, reqs) assert.NoError(t, err) assert.Equal(t, 2, len(results)) } func TestSearcher_Race(t *testing.T) { cfg := &types.Config{ KB: &types.KBConfig{ Collections: []string{"docs"}, }, DB: &types.DBConfig{ Models: []string{"product"}, }, } s := New(cfg, nil) reqs := []*types.Request{ { Type: types.SearchTypeKB, Query: "KB query", Source: types.SourceAuto, Collections: []string{"docs"}, }, { Type: types.SearchTypeDB, Query: "DB query", Source: types.SourceAuto, Models: []string{"product"}, }, } // Test Race() - returns when first search completes (like Promise.race) results, err := s.Race(nil, reqs) assert.NoError(t, err) // At least one result should be set hasResult := false for _, r := range results { if r != nil { hasResult = true break } } assert.True(t, hasResult) } func TestSearcher_All_Empty(t *testing.T) { s := New(nil, nil) results, err := s.All(nil, []*types.Request{}) assert.NoError(t, err) assert.Equal(t, 0, len(results)) } func TestSearcher_Any_Empty(t *testing.T) { s := New(nil, nil) results, err := s.Any(nil, []*types.Request{}) assert.NoError(t, err) assert.Equal(t, 0, len(results)) } func TestSearcher_Race_Empty(t *testing.T) { s := New(nil, nil) results, err := s.Race(nil, []*types.Request{}) assert.NoError(t, err) assert.Equal(t, 0, len(results)) } func TestSearcher_All_ManyRequests(t *testing.T) { cfg := &types.Config{ KB: &types.KBConfig{ Collections: []string{"docs"}, }, } s := New(cfg, nil) // Create multiple requests to test parallel execution reqs := make([]*types.Request, 10) for i := 0; i < 10; i++ { reqs[i] = &types.Request{ Type: types.SearchTypeKB, Query: "test query", Source: types.SourceAuto, Collections: []string{"docs"}, } } results, err := s.All(nil, reqs) assert.NoError(t, err) assert.Equal(t, 10, len(results)) // All results should be valid for _, result := range results { assert.NotNil(t, result) assert.Equal(t, types.SearchTypeKB, result.Type) } } func TestSearcher_BuildReferences(t *testing.T) { s := New(nil, nil) results := []*types.Result{ { Type: types.SearchTypeWeb, Items: []*types.ResultItem{ { CitationID: "1", Type: types.SearchTypeWeb, Source: types.SourceAuto, Weight: 0.6, Title: "Web Result", Content: "Web content", URL: "https://example.com", }, }, }, { Type: types.SearchTypeKB, Items: []*types.ResultItem{ { CitationID: "2", Type: types.SearchTypeKB, Source: types.SourceHook, Weight: 0.8, Title: "KB Result", Content: "KB content", }, }, }, } refs := s.BuildReferences(results) assert.Equal(t, 2, len(refs)) assert.Equal(t, "1", refs[0].ID) assert.Equal(t, "2", refs[1].ID) } func TestSearcher_CitationGeneration(t *testing.T) { s := New(nil, nil) // Reset citation generator for predictable IDs s.citation.Reset() // Note: This test would need actual results with items to verify citation generation // The skeleton handlers return empty items, so we test the citation generator directly id1 := s.citation.Next() id2 := s.citation.Next() id3 := s.citation.Next() // Citation IDs are now simple integers assert.Equal(t, "1", id1) assert.Equal(t, "2", id2) assert.Equal(t, "3", id3) } ================================================ FILE: agent/search/search_web_test.go ================================================ package search_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/search" "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/testutils" ) // ============================================================================= // Web Search Integration Tests - Single Search // ============================================================================= // TestWebSearch_Tavily tests web search using Tavily provider via assistant config // Skip: requires external API key (Tavily) func TestWebSearch_Tavily(t *testing.T) { t.Skip("Skipping: requires external API key (Tavily)") testutils.Prepare(t) defer testutils.Clean(t) // Load the web-tavily test assistant ast, err := assistant.LoadPath("/assistants/tests/web-tavily") require.NoError(t, err) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Verify assistant config assert.Equal(t, "tavily", ast.Search.Web.Provider) // Create Searcher with assistant's config uses := &search.Uses{Web: "builtin"} s := search.New(ast.Search, uses) // Execute search req := &types.Request{ Type: types.SearchTypeWeb, Query: "Yao App Engine low-code platform", Source: types.SourceAuto, Limit: 5, } result, err := s.Search(nil, req) require.NoError(t, err) require.NotNil(t, result) require.Empty(t, result.Error, "Search should succeed, got error: %s", result.Error) // Verify results assert.NotEmpty(t, result.Items, "Should have search results") for _, item := range result.Items { assert.NotEmpty(t, item.CitationID, "Each item should have citation ID") assert.NotEmpty(t, item.Content, "Each item should have content") t.Logf(" [%s] %s - %s", item.CitationID, item.Title, item.URL) } t.Logf("Tavily search returned %d results", len(result.Items)) } // TestWebSearch_Serper tests web search using Serper provider via assistant config // Skip: requires external API key (Serper) func TestWebSearch_Serper(t *testing.T) { t.Skip("Skipping: requires external API key (Serper)") testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serper test assistant ast, err := assistant.LoadPath("/assistants/tests/web-serper") require.NoError(t, err) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Verify assistant config assert.Equal(t, "serper", ast.Search.Web.Provider) // Create Searcher with assistant's config uses := &search.Uses{Web: "builtin"} s := search.New(ast.Search, uses) // Execute search req := &types.Request{ Type: types.SearchTypeWeb, Query: "Go programming language concurrency", Source: types.SourceAuto, Limit: 5, } result, err := s.Search(nil, req) require.NoError(t, err) require.NotNil(t, result) require.Empty(t, result.Error, "Search should succeed, got error: %s", result.Error) // Verify results assert.NotEmpty(t, result.Items, "Should have search results") for _, item := range result.Items { assert.NotEmpty(t, item.CitationID, "Each item should have citation ID") t.Logf(" [%s] %s - %s", item.CitationID, item.Title, item.URL) } t.Logf("Serper search returned %d results", len(result.Items)) } // TestWebSearch_SerpAPI tests web search using SerpAPI provider via assistant config // Skip: requires external API key (SerpAPI) func TestWebSearch_SerpAPI(t *testing.T) { t.Skip("Skipping: requires external API key (SerpAPI)") testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serpapi test assistant ast, err := assistant.LoadPath("/assistants/tests/web-serpapi") require.NoError(t, err) require.NotNil(t, ast.Search) require.NotNil(t, ast.Search.Web) // Verify assistant config assert.Equal(t, "serpapi", ast.Search.Web.Provider) // Create Searcher with assistant's config uses := &search.Uses{Web: "builtin"} s := search.New(ast.Search, uses) // Execute search req := &types.Request{ Type: types.SearchTypeWeb, Query: "Kubernetes container orchestration", Source: types.SourceAuto, Limit: 5, } result, err := s.Search(nil, req) require.NoError(t, err) require.NotNil(t, result) require.Empty(t, result.Error, "Search should succeed, got error: %s", result.Error) // Verify results assert.NotEmpty(t, result.Items, "Should have search results") for _, item := range result.Items { assert.NotEmpty(t, item.CitationID, "Each item should have citation ID") t.Logf(" [%s] %s - %s", item.CitationID, item.Title, item.URL) } t.Logf("SerpAPI search returned %d results", len(result.Items)) } // ============================================================================= // Web Search Integration Tests - Parallel Search // ============================================================================= // TestWebSearch_All tests parallel web search with All() - like Promise.all // Skip: requires external API key (Serper) func TestWebSearch_All(t *testing.T) { t.Skip("Skipping: requires external API key (Serper)") testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serper test assistant ast, err := assistant.LoadPath("/assistants/tests/web-serper") require.NoError(t, err) require.NotNil(t, ast.Search) // Create Searcher uses := &search.Uses{Web: "builtin"} s := search.New(ast.Search, uses) // Multiple queries reqs := []*types.Request{ {Type: types.SearchTypeWeb, Query: "artificial intelligence", Source: types.SourceAuto, Limit: 3}, {Type: types.SearchTypeWeb, Query: "machine learning", Source: types.SourceAuto, Limit: 3}, {Type: types.SearchTypeWeb, Query: "deep learning", Source: types.SourceAuto, Limit: 3}, } // Execute parallel search with All() - waits for all searches to complete results, err := s.All(nil, reqs) require.NoError(t, err) require.Len(t, results, 3, "Should have 3 results") // Verify all results for i, result := range results { require.NotNil(t, result, "Result %d should not be nil", i) if result.Error == "" { assert.NotEmpty(t, result.Items, "Result %d should have items", i) t.Logf("Query '%s': %d results", reqs[i].Query, len(result.Items)) } else { t.Logf("Query '%s': error - %s", reqs[i].Query, result.Error) } } } // TestWebSearch_Any tests parallel web search with Any() - like Promise.any // Skip: requires external API key (Serper) func TestWebSearch_Any(t *testing.T) { t.Skip("Skipping: requires external API key (Serper)") testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serper test assistant ast, err := assistant.LoadPath("/assistants/tests/web-serper") require.NoError(t, err) require.NotNil(t, ast.Search) // Create Searcher uses := &search.Uses{Web: "builtin"} s := search.New(ast.Search, uses) // Multiple queries reqs := []*types.Request{ {Type: types.SearchTypeWeb, Query: "golang channels", Source: types.SourceAuto, Limit: 3}, {Type: types.SearchTypeWeb, Query: "rust ownership", Source: types.SourceAuto, Limit: 3}, {Type: types.SearchTypeWeb, Query: "python asyncio", Source: types.SourceAuto, Limit: 3}, } // Execute parallel search with Any() - returns when first search succeeds results, err := s.Any(nil, reqs) require.NoError(t, err) // Any() returns as soon as any search succeeds hasSuccess := false for _, result := range results { if result != nil && len(result.Items) > 0 && result.Error == "" { hasSuccess = true t.Logf("First success: '%s' with %d results", result.Query, len(result.Items)) break } } assert.True(t, hasSuccess, "At least one search should succeed") } // TestWebSearch_Race tests parallel web search with Race() - like Promise.race // Skip: requires external API key (Serper) func TestWebSearch_Race(t *testing.T) { t.Skip("Skipping: requires external API key (Serper)") testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serper test assistant ast, err := assistant.LoadPath("/assistants/tests/web-serper") require.NoError(t, err) require.NotNil(t, ast.Search) // Create Searcher uses := &search.Uses{Web: "builtin"} s := search.New(ast.Search, uses) // Multiple queries reqs := []*types.Request{ {Type: types.SearchTypeWeb, Query: "docker containers", Source: types.SourceAuto, Limit: 3}, {Type: types.SearchTypeWeb, Query: "kubernetes pods", Source: types.SourceAuto, Limit: 3}, } // Execute parallel search with Race() - returns when first search completes results, err := s.Race(nil, reqs) require.NoError(t, err) // Race() returns immediately when first result arrives hasResult := false for _, result := range results { if result != nil { hasResult = true t.Logf("First to complete: '%s'", result.Query) break } } assert.True(t, hasResult, "Should have at least one result") } // ============================================================================= // Web Search - Citation and Reference Tests // ============================================================================= // TestWebSearch_BuildReferences tests building references from web search results // Skip: requires external API key (Serper) func TestWebSearch_BuildReferences(t *testing.T) { t.Skip("Skipping: requires external API key (Serper)") testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serper test assistant ast, err := assistant.LoadPath("/assistants/tests/web-serper") require.NoError(t, err) require.NotNil(t, ast.Search) // Create Searcher with weights config uses := &search.Uses{Web: "builtin"} s := search.New(ast.Search, uses) // Execute search req := &types.Request{ Type: types.SearchTypeWeb, Query: "OpenAI GPT-4", Source: types.SourceAuto, Limit: 5, } result, err := s.Search(nil, req) require.NoError(t, err) require.NotNil(t, result) require.Empty(t, result.Error, "Search should succeed") require.NotEmpty(t, result.Items, "Should have results") // Build references refs := s.BuildReferences([]*types.Result{result}) assert.NotEmpty(t, refs, "Should have references") for _, ref := range refs { assert.NotEmpty(t, ref.ID, "Reference should have ID") assert.Equal(t, types.SearchTypeWeb, ref.Type, "Reference type should be web") assert.Equal(t, types.SourceAuto, ref.Source, "Reference source should be auto") t.Logf(" Ref: %s - %s (weight: %.2f)", ref.ID, ref.Title, ref.Weight) } } // ============================================================================= // Web Search - Error Handling Tests // ============================================================================= // TestWebSearch_SiteRestriction tests web search with site restriction // Skip: requires external API key (Serper) func TestWebSearch_SiteRestriction(t *testing.T) { t.Skip("Skipping: requires external API key (Serper)") testutils.Prepare(t) defer testutils.Clean(t) // Load the web-serper test assistant ast, err := assistant.LoadPath("/assistants/tests/web-serper") require.NoError(t, err) require.NotNil(t, ast.Search) // Create Searcher uses := &search.Uses{Web: "builtin"} s := search.New(ast.Search, uses) // Execute search with site restriction req := &types.Request{ Type: types.SearchTypeWeb, Query: "yao-app-engine", Source: types.SourceAuto, Sites: []string{"github.com"}, Limit: 5, } result, err := s.Search(nil, req) require.NoError(t, err) require.NotNil(t, result) if result.Error == "" && len(result.Items) > 0 { // Log results for _, item := range result.Items { t.Logf(" %s - %s", item.Title, item.URL) } } } ================================================ FILE: agent/search/types/config.go ================================================ package types // Config represents the complete search configuration type Config struct { Web *WebConfig `json:"web,omitempty" yaml:"web,omitempty"` KB *KBConfig `json:"kb,omitempty" yaml:"kb,omitempty"` DB *DBConfig `json:"db,omitempty" yaml:"db,omitempty"` Keyword *KeywordConfig `json:"keyword,omitempty" yaml:"keyword,omitempty"` QueryDSL *QueryDSLConfig `json:"querydsl,omitempty" yaml:"querydsl,omitempty"` Rerank *RerankConfig `json:"rerank,omitempty" yaml:"rerank,omitempty"` Citation *CitationConfig `json:"citation,omitempty" yaml:"citation,omitempty"` Weights *WeightsConfig `json:"weights,omitempty" yaml:"weights,omitempty"` Options *OptionsConfig `json:"options,omitempty" yaml:"options,omitempty"` } // WebConfig for web search settings // Note: uses.web determines the mode (builtin/agent/mcp) // Provider is only used when uses.web = "builtin" type WebConfig struct { Provider string `json:"provider,omitempty" yaml:"provider,omitempty"` // "tavily", "serper", or "serpapi" (for builtin mode) APIKeyEnv string `json:"api_key_env,omitempty" yaml:"api_key_env,omitempty"` // Environment variable for API key MaxResults int `json:"max_results,omitempty" yaml:"max_results,omitempty"` // Max results (default: 10) Engine string `json:"engine,omitempty" yaml:"engine,omitempty"` // Search engine for SerpAPI: "google", "bing", "baidu", "yandex", etc. (default: "google") } // KBConfig for knowledge base search settings type KBConfig struct { Collections []string `json:"collections,omitempty" yaml:"collections,omitempty"` // Default collections Threshold float64 `json:"threshold,omitempty" yaml:"threshold,omitempty"` // Similarity threshold (default: 0.7) Graph bool `json:"graph,omitempty" yaml:"graph,omitempty"` // Enable GraphRAG (default: false) } // DBConfig for database search settings type DBConfig struct { Models []string `json:"models,omitempty" yaml:"models,omitempty"` // Default models MaxResults int `json:"max_results,omitempty" yaml:"max_results,omitempty"` // Max results (default: 20) } // KeywordConfig for keyword extraction type KeywordConfig struct { MaxKeywords int `json:"max_keywords,omitempty" yaml:"max_keywords,omitempty"` // Max keywords (default: 10) Language string `json:"language,omitempty" yaml:"language,omitempty"` // "auto", "en", "zh", etc. } // KeywordOptions for keyword extraction (runtime options) type KeywordOptions struct { MaxKeywords int `json:"max_keywords,omitempty"` Language string `json:"language,omitempty"` } // QueryDSLConfig for QueryDSL generation from natural language type QueryDSLConfig struct { Strict bool `json:"strict,omitempty" yaml:"strict,omitempty"` // Fail if generation fails (default: false) } // RerankConfig for reranking type RerankConfig struct { TopN int `json:"top_n,omitempty" yaml:"top_n,omitempty"` // Return top N (default: 10) } // CitationConfig for citation format type CitationConfig struct { Format string `json:"format,omitempty" yaml:"format,omitempty"` // Default: "#ref:{id}" AutoInjectPrompt bool `json:"auto_inject_prompt,omitempty" yaml:"auto_inject_prompt,omitempty"` // Auto-inject prompt (default: true) CustomPrompt string `json:"custom_prompt,omitempty" yaml:"custom_prompt,omitempty"` // Custom prompt template } // WeightsConfig for source weighting type WeightsConfig struct { User float64 `json:"user,omitempty" yaml:"user,omitempty"` // User-provided (default: 1.0) Hook float64 `json:"hook,omitempty" yaml:"hook,omitempty"` // Hook results (default: 0.8) Auto float64 `json:"auto,omitempty" yaml:"auto,omitempty"` // Auto search (default: 0.6) } // OptionsConfig for search behavior type OptionsConfig struct { SkipThreshold int `json:"skip_threshold,omitempty" yaml:"skip_threshold,omitempty"` // Skip auto search if user provides >= N results } // GetWeight returns the weight for a source type func (c *Config) GetWeight(source SourceType) float64 { if c == nil || c.Weights == nil { return getDefaultWeight(source) } switch source { case SourceUser: if c.Weights.User > 0 { return c.Weights.User } return 1.0 case SourceHook: if c.Weights.Hook > 0 { return c.Weights.Hook } return 0.8 case SourceAuto: if c.Weights.Auto > 0 { return c.Weights.Auto } return 0.6 default: return 0.6 } } // getDefaultWeight returns default weight for a source type func getDefaultWeight(source SourceType) float64 { switch source { case SourceUser: return 1.0 case SourceHook: return 0.8 case SourceAuto: return 0.6 default: return 0.6 } } ================================================ FILE: agent/search/types/graph.go ================================================ package types // GraphNode represents a related entity from knowledge graph type GraphNode struct { ID string `json:"id"` Type string `json:"type"` // Entity type Name string `json:"name"` // Entity name Description string `json:"description,omitempty"` // Entity description Relation string `json:"relation,omitempty"` // Relationship to query Score float64 `json:"score,omitempty"` // Relevance score Metadata map[string]interface{} `json:"metadata,omitempty"` } ================================================ FILE: agent/search/types/reference.go ================================================ package types // Reference is the unified structure for all data sources // Used to build LLM context from search results type Reference struct { ID string `json:"id"` // Unique citation ID: "ref_001", "ref_002" Type SearchType `json:"type"` // Data type: "web", "kb", "db" Source SourceType `json:"source"` // Origin: "user", "hook", "auto" Weight float64 `json:"weight"` // Relevance weight (1.0=highest, 0.6=lowest) Score float64 `json:"score"` // Relevance score (0-1) Title string `json:"title"` // Optional title Content string `json:"content"` // Main content URL string `json:"url"` // Optional URL Meta map[string]interface{} `json:"meta"` // Additional metadata } // ReferenceContext holds the formatted references for LLM input type ReferenceContext struct { References []*Reference `json:"references"` // All references XML string `json:"xml"` // Formatted XML Prompt string `json:"prompt"` // Citation instruction prompt } ================================================ FILE: agent/search/types/types.go ================================================ package types import ( "github.com/yaoapp/gou/query/gou" ) // SearchType represents the type of search type SearchType string // SearchType constants const ( SearchTypeWeb SearchType = "web" // Web/Internet search SearchTypeKB SearchType = "kb" // Knowledge base vector search SearchTypeDB SearchType = "db" // Database search (Yao Model/QueryDSL) ) // ScenarioType represents the QueryDSL generation scenario type ScenarioType string // ScenarioType constants for QueryDSL generation const ( ScenarioFilter ScenarioType = "filter" // Simple filtering queries ScenarioAggregation ScenarioType = "aggregation" // Aggregation/grouping queries ScenarioJoin ScenarioType = "join" // Multi-table join queries ScenarioComplex ScenarioType = "complex" // Complex queries combining multiple features ) // SourceType represents where the search result came from type SourceType string // SourceType constants const ( SourceUser SourceType = "user" // User-provided DataContent (highest priority) SourceHook SourceType = "hook" // Hook ctx.search.*() results SourceAuto SourceType = "auto" // Auto search results (lowest priority) ) // Request represents a search request type Request struct { // Common fields Query string `json:"query"` // Search query (natural language) Type SearchType `json:"type"` // Search type: "web", "kb", or "db" Limit int `json:"limit,omitempty"` // Max results (default: 10) Source SourceType `json:"source"` // Source of this request (user/hook/auto) // Web search specific Sites []string `json:"sites,omitempty"` // Restrict to specific sites TimeRange string `json:"time_range,omitempty"` // "day", "week", "month", "year" // Knowledge base specific Collections []string `json:"collections,omitempty"` // KB collection IDs Threshold float64 `json:"threshold,omitempty"` // Similarity threshold (0-1) Graph bool `json:"graph,omitempty"` // Enable graph association Metadata map[string]interface{} `json:"metadata,omitempty"` // Metadata filter for KB search // Database search specific Models []string `json:"models,omitempty"` // Model IDs (e.g., "user", "agents.mybot.product") Scenario ScenarioType `json:"scenario,omitempty"` // QueryDSL scenario: "filter", "aggregation", "join", "complex" Wheres []gou.Where `json:"wheres,omitempty"` // Pre-defined filters (optional), uses GOU QueryDSL Where Orders gou.Orders `json:"orders,omitempty"` // Sort orders (optional), uses GOU QueryDSL Orders Select []string `json:"select,omitempty"` // Fields to return (optional) // Reranking Rerank *RerankOptions `json:"rerank,omitempty"` } // RerankOptions controls result reranking // Reranker type is determined by uses.rerank in agent/agent.yml type RerankOptions struct { TopN int `json:"top_n,omitempty"` // Return top N after reranking } // Result represents the search result type Result struct { Type SearchType `json:"type"` // Search type Query string `json:"query"` // Original query Source SourceType `json:"source"` // Source of this result Items []*ResultItem `json:"items"` // Result items Total int `json:"total"` // Total matches Duration int64 `json:"duration_ms"` // Search duration in ms Error string `json:"error,omitempty"` // Error message if failed // Graph associations (KB only, if enabled) GraphNodes []*GraphNode `json:"graph_nodes,omitempty"` // DB specific DSL map[string]interface{} `json:"dsl,omitempty"` // Generated QueryDSL (DB only) } // ResultItem represents a single search result item type ResultItem struct { // Citation CitationID string `json:"citation_id"` // Unique ID for LLM reference: "ref_001" // Weighting Source SourceType `json:"source"` // Source type: "user", "hook", "auto" Weight float64 `json:"weight"` // Source weight (from config) Score float64 `json:"score,omitempty"` // Relevance score (0-1) // Common fields Type SearchType `json:"type"` // Search type for this item Title string `json:"title,omitempty"` // Title/headline Content string `json:"content"` // Main content/snippet URL string `json:"url,omitempty"` // Source URL // KB specific DocumentID string `json:"document_id,omitempty"` // Source document ID Collection string `json:"collection,omitempty"` // Collection name // DB specific Model string `json:"model,omitempty"` // Model ID RecordID interface{} `json:"record_id,omitempty"` // Record primary key Data map[string]interface{} `json:"data,omitempty"` // Full record data // Metadata Metadata map[string]interface{} `json:"metadata,omitempty"` // Additional metadata } // ProcessedQuery represents a processed query ready for execution type ProcessedQuery struct { Type SearchType `json:"type"` Keywords []string `json:"keywords,omitempty"` // For web search Vector []float32 `json:"vector,omitempty"` // For KB search DSL *gou.QueryDSL `json:"dsl,omitempty"` // For DB search, uses GOU QueryDSL } // Keyword represents an extracted keyword with weight type Keyword struct { K string `json:"k"` // Keyword text W float64 `json:"w"` // Weight (0.1-1.0), higher means more relevant } // Note: For QueryDSL and Model types, use GOU types directly: // - github.com/yaoapp/gou/query/gou.QueryDSL // - github.com/yaoapp/gou/model.Model // - github.com/yaoapp/gou/model.Column ================================================ FILE: agent/store/CHAT_STORAGE_DESIGN.md ================================================ # Chat Storage Design This document describes the design for storing chat conversations, messages, and execution steps in the YAO Agent system. ## Table of Contents - [Overview](#overview) - [Architecture](#architecture) - [Data Models](#data-models) - [Write Strategy](#write-strategy) - [API Interface](#api-interface) - [Usage Examples](#usage-examples) - [Related Documents](#related-documents) ## Overview The chat storage system is designed to: 1. **Store user-visible messages** - All messages sent via `ctx.Send()`, including text, images, loading states, etc. 2. **Support resume/retry** - Track execution steps to enable recovery from interruptions or failures 3. **Efficient writes** - Batch message writes at request end ### Design Goals | Goal | Solution | | ------------------------ | ------------------------------------------------ | | Complete chat history | Store final content of all `ctx.Send()` messages | | Resume from interruption | Track step status and input/output | | Retry failed operations | Store step input for re-execution | | Minimize database writes | Batch writes at request end | ### Non-Goals - **Tracing/debugging** - Handled by separate [Trace module](../../trace/README.md) - **Streaming replay** - Not needed, history shows final content only - **Request tracking/billing** - Handled by [OpenAPI Request module](../../openapi/request/REQUEST_DESIGN.md) ### Relationship with OpenAPI Request The Agent storage focuses on **chat content and execution state**, while request tracking (billing, rate limiting, auditing) is handled globally by the OpenAPI layer: | Concern | Module | Table | | ---------------- | ----------------- | ----------------- | | Request tracking | `openapi/request` | `openapi_request` | | Billing (tokens) | `openapi/request` | `openapi_request` | | Rate limiting | `openapi/request` | - | | Chat sessions | `agent/store` | `agent_chat` | | Chat messages | `agent/store` | `agent_message` | | Resume/Retry | `agent/store` | `agent_resume` | The `request_id` from OpenAPI middleware is passed to Agent and stored in messages/steps for correlation. ## Architecture ``` ┌─────────────────────────────────────────────────────────────┐ │ Chat Storage │ ├─────────────────────────────────────────────────────────────┤ │ │ │ ┌─────────────────┐ │ │ │ Chat │ Metadata: title, assistant, user │ │ └────────┬────────┘ │ │ │ │ │ │ 1:N │ │ ▼ │ │ ┌─────────────────┐ │ │ │ Message │ User-visible: type, props, role │ │ └────────┬────────┘ │ │ │ │ │ │ N:N (via request_id) │ │ ▼ │ │ ┌─────────────────┐ │ │ │ Resume │ Recovery: type, status, input/output │ │ │ (only on fail) │ Only saved when interrupted/failed │ │ └─────────────────┘ │ │ │ └─────────────────────────────────────────────────────────────┘ ``` ## Data Models ### 1. Chat Table Stores chat metadata and session information. **Table Name:** `agent_chat` | Column | Type | Nullable | Index | Description | | ----------------- | ----------- | -------- | ------ | -------------------------------- | | `id` | ID | No | PK | Auto-increment primary key | | `chat_id` | string(64) | No | Unique | Unique chat identifier | | `title` | string(500) | Yes | - | Chat title | | `assistant_id` | string(200) | No | Yes | Associated assistant ID | | `last_connector` | string(200) | Yes | Yes | Last used connector ID | | `last_mode` | string(50) | Yes | - | Last used chat mode (chat/task) | | `status` | enum | No | Yes | Status: `active`, `archived` | | `public` | boolean | No | - | Whether shared across all teams | | `share` | enum | No | Yes | Sharing scope: `private`, `team` | | `sort` | integer | No | - | Sort order for display | | `last_message_at` | timestamp | Yes | Yes | Timestamp of last message | | `metadata` | json | Yes | - | Additional metadata | | `created_at` | timestamp | No | Yes | Creation timestamp | | `updated_at` | timestamp | No | - | Last update timestamp | **Model Options:** ```json { "option": { "soft_deletes": true, "permission": true, "timestamps": true } } ``` **Note:** `permission: true` enables Yao's built-in permission management, which automatically adds the following fields: | Field | Type | Description | | ------------------ | ----------- | ------------------------------ | | `__yao_created_by` | string(200) | User ID who created the record | | `__yao_updated_by` | string(200) | User ID who last updated | | `__yao_team_id` | string(200) | Team ID for team-level access | | `__yao_tenant_id` | string(200) | Tenant ID for multi-tenancy | These fields are automatically managed by the framework and used for access control filtering. **Indexes:** | Name | Columns | Type | | -------------------- | ----------------- | ----- | | `idx_chat_assistant` | `assistant_id` | index | | `idx_chat_last_conn` | `last_connector` | index | | `idx_chat_status` | `status` | index | | `idx_chat_share` | `share` | index | | `idx_chat_last_msg` | `last_message_at` | index | ### 2. Message Table Stores user-visible messages (both user input and assistant responses). **Table Name:** `agent_message` | Column | Type | Nullable | Index | Description | | -------------- | ----------- | -------- | ----- | ------------------------------------------- | | `id` | ID | No | PK | Auto-increment primary key | | `message_id` | string(64) | No | - | Message identifier (unique within request) | | `chat_id` | string(64) | No | Yes | Parent chat ID | | `request_id` | string(64) | Yes | Yes | Request ID for grouping | | `role` | enum | No | Yes | Role: `user`, `assistant` | | `type` | string(50) | No | - | Message type (text, image, loading, etc.) | | `props` | json | No | - | Message properties (content, url, etc.) | | `block_id` | string(64) | Yes | Yes | Block grouping ID | | `thread_id` | string(64) | Yes | Yes | Thread grouping ID | | `assistant_id` | string(200) | Yes | Yes | Assistant ID (join to get name/avatar) | | `connector` | string(200) | Yes | Yes | Connector ID used for this message | | `mode` | string(50) | Yes | - | Chat mode used for this message (chat/task) | | `sequence` | integer | No | - | Message order within chat (in composite) | | `metadata` | json | Yes | - | Additional metadata | | `created_at` | timestamp | No | Yes | Creation timestamp | | `updated_at` | timestamp | No | - | Last update timestamp | **Indexes:** | Name | Columns | Type | | ------------------------- | -------------------------- | ------ | | `idx_msg_chat_seq` | `chat_id`, `sequence` | index | | `idx_msg_request_message` | `request_id`, `message_id` | unique | | `idx_msg_request` | `request_id` | index | | `idx_msg_role` | `role` | index | | `idx_msg_block` | `block_id` | index | | `idx_msg_thread` | `thread_id` | index | | `idx_msg_assistant` | `assistant_id` | index | **Message Ordering:** Messages are ordered by `created_at` first, then by `sequence` within the same timestamp. This ensures correct chronological order when there are multiple requests with overlapping sequence numbers: ```sql ORDER BY created_at ASC, sequence ASC ``` **Why this ordering?** - `sequence` is assigned per-request, so different requests may have the same sequence numbers - `created_at` groups messages by request time, ensuring messages from earlier requests appear first - Within the same request (same `created_at`), `sequence` preserves the internal ordering **Message Types:** All message types are stored, including built-in types and custom types. See `agent/output/BUILTIN_TYPES.md` for built-in Props structures. | Type | Description | Props Example | Stored? | | ------------ | -------------------------------- | ------------------------------------------------------------------------------------------- | ------- | | `user_input` | User input (frontend display) | `{"content": "Hello", "role": "user", "name": "John"}` | ✅ Yes | | `text` | Text/Markdown content | `{"content": "Hello **world**!"}` | ✅ Yes | | `thinking` | Reasoning process (o1, DeepSeek) | `{"content": "Let me analyze..."}` | ✅ Yes | | `loading` | Loading/processing indicator | `{"message": "Searching knowledge base..."}` | ✅ Yes | | `tool_call` | LLM tool/function call | `{"id": "call_abc123", "name": "get_weather", "arguments": "{\"location\":\"SF\"}"}` | ✅ Yes | | `retrieval` | KB/Web search results | `{"query": "...", "sources": [...], "total_results": 10}` | ✅ Yes | | `error` | Error message | `{"message": "Connection timeout", "code": "TIMEOUT", "details": "..."}` | ✅ Yes | | `image` | Image content | `{"url": "...", "alt": "...", "width": 200, "height": 200, "detail": "auto"}` | ✅ Yes | | `audio` | Audio content | `{"url": "...", "format": "mp3", "duration": 120.5, "transcript": "...", "controls": true}` | ✅ Yes | | `video` | Video content | `{"url": "...", "format": "mp4", "thumbnail": "...", "width": 640, "height": 360}` | ✅ Yes | | `action` | System action (CUI only) | `{"name": "open_panel", "payload": {"panel_id": "user_profile"}}` | ✅ Yes | | `event` | Lifecycle event (CUI only) | `{"event": "stream_start", "message": "...", "data": {...}}` | ❌ No | | `*` (custom) | Any custom type | `{"chartType": "bar", "data": [...], "options": {...}}` | ✅ Yes | **Note on `event` type:** Lifecycle events (`stream_start`, `stream_end`, etc.) are transient control signals and are NOT stored. They are only used for real-time streaming coordination. **Note on custom types:** Any type not in the built-in list is stored as-is with its original `type` and `props` structure. **Tool Call Storage:** Tool calls from LLM responses are stored as `tool_call` type messages. The raw tool call data is preserved in `props`: ```json { "message_id": "msg_001", "chat_id": "chat_123", "role": "assistant", "type": "tool_call", "props": { "id": "call_abc123", "name": "get_weather", "arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}" }, "block_id": "B1", "sequence": 5 } ``` **Tool Result Storage:** Tool execution results can be stored as `text` type with metadata indicating it's a tool result: ```json { "message_id": "msg_002", "chat_id": "chat_123", "role": "assistant", "type": "text", "props": { "content": "The weather in San Francisco is 18°C and sunny." }, "metadata": { "tool_call_id": "call_abc123", "tool_name": "get_weather", "is_tool_result": true }, "block_id": "B1", "sequence": 6 } ``` **Custom Types:** Any type not in the built-in list is considered a custom type and stored with its original structure: ```json { "type": "chart", "props": { "chartType": "bar", "data": [...], "options": {...} } } ``` **Multimodal User Input:** User input with multimodal content (text + images + files) is stored as `user_input` type: ```json { "message_id": "msg_000", "chat_id": "chat_123", "role": "user", "type": "user_input", "props": { "content": [ { "type": "text", "text": "What's in this image?" }, { "type": "image_url", "image_url": { "url": "https://example.com/photo.jpg", "detail": "high" } } ], "role": "user", "name": "John" }, "sequence": 1 } ``` ### Knowledge Base & Web Search Results Retrieval results from knowledge bases and web searches need to be stored for: 1. **User Feedback** - Users can rate (👍/👎) individual sources 2. **Quality Analytics** - Track which documents/sources are most useful 3. **Source Attribution** - Display citations in the UI 4. **RAG Optimization** - Improve retrieval based on feedback **Storage Approach:** Store retrieval results as a special message type `retrieval` with structured props. **Retrieval Message Structure:** ```json { "message_id": "msg_retrieval_001", "chat_id": "chat_123", "request_id": "req_abc", "role": "assistant", "type": "retrieval", "props": { "query": "How to configure Yao models?", "sources": [ { "id": "src_001", "type": "kb", "collection_id": "col_docs", "document_id": "doc_123", "chunk_id": "chunk_456", "title": "Model Configuration Guide", "content": "To configure a model in Yao, create a .mod.yao file...", "score": 0.92, "metadata": { "file_path": "/docs/model.md", "page": 3 } }, { "id": "src_002", "type": "kb", "collection_id": "col_docs", "document_id": "doc_124", "chunk_id": "chunk_789", "title": "Advanced Model Options", "content": "Models support various options including soft_deletes...", "score": 0.87, "metadata": { "file_path": "/docs/advanced.md", "page": 12 } }, { "id": "src_003", "type": "web", "url": "https://yaoapps.com/docs/models", "title": "Yao Models Documentation", "content": "Official documentation for Yao model system...", "score": 0.85, "metadata": { "domain": "yaoapps.com", "fetched_at": "2024-01-15T10:30:00Z" } } ], "total_results": 15, "query_time_ms": 120 }, "block_id": "B1", "assistant_id": "docs_assistant", "sequence": 2 } ``` **Source Types:** | Type | Description | Key Fields | | ------ | ----------------------- | ------------------------------------------ | | `kb` | Knowledge base document | `collection_id`, `document_id`, `chunk_id` | | `web` | Web search result | `url`, `domain` | | `file` | Uploaded file | `file_id`, `file_path` | | `api` | External API result | `api_name`, `endpoint` | | `mcp` | MCP tool result | `server`, `tool` | **Source Feedback:** User feedback on retrieval sources is handled by the Knowledge Base module. See [KB Feedback](../../kb/README.md) for details. **Example: KB Search in Create Hook:** ```typescript // In Create hook, search knowledge base and store results const results = await ctx.kb.search("col_docs", query, { limit: 5 }); // Send retrieval message (stored automatically) ctx.Send({ type: "retrieval", props: { query: query, sources: results.documents.map((doc, idx) => ({ id: `src_${idx}`, type: "kb", collection_id: "col_docs", document_id: doc.document.metadata.document_id, chunk_id: doc.document.id, title: doc.document.metadata.title || "Untitled", content: doc.document.content, score: doc.score, metadata: doc.document.metadata, })), total_results: results.total, query_time_ms: results.query_time_ms, }, }); // Also send loading message for user feedback ctx.Send({ type: "loading", props: { message: `Found ${results.total} relevant documents...` }, }); ``` **Example: Web Search Results:** ```json { "type": "retrieval", "props": { "query": "latest AI news 2024", "sources": [ { "id": "src_001", "type": "web", "url": "https://example.com/ai-news", "title": "AI Breakthroughs in 2024", "content": "Summary of the article...", "score": 0.95, "metadata": { "domain": "example.com", "published_at": "2024-01-10", "fetched_at": "2024-01-15T10:30:00Z", "snippet": "The year 2024 has seen remarkable..." } } ], "provider": "tavily", "total_results": 10, "query_time_ms": 850 } } ``` ### 3. Resume Table Stores execution state for resume/retry functionality. **Only written when request is interrupted or failed.** **Table Name:** `agent_resume` | Column | Type | Nullable | Index | Description | | ----------------- | ----------- | -------- | ------ | ---------------------------------------- | | `id` | ID | No | PK | Auto-increment primary key | | `resume_id` | string(64) | No | Unique | Unique resume record identifier | | `chat_id` | string(64) | No | Yes | Parent chat ID | | `request_id` | string(64) | No | Yes | Request ID | | `assistant_id` | string(200) | No | Yes | Assistant executing this step | | `stack_id` | string(64) | No | Yes | Stack node ID for this execution | | `stack_parent_id` | string(64) | Yes | Yes | Parent stack ID (for A2A calls) | | `stack_depth` | integer | No | - | Call depth (0=root, 1+=nested) | | `type` | enum | No | Yes | Step type | | `status` | enum | No | Yes | Status: `interrupted`, `failed` | | `input` | json | Yes | - | Step input data | | `output` | json | Yes | - | Step output data (partial) | | `space_snapshot` | json | Yes | - | Space data snapshot for recovery | | `error` | text | Yes | - | Error message if failed | | `sequence` | integer | No | - | Step order within request (in composite) | | `metadata` | json | Yes | - | Additional metadata | | `created_at` | timestamp | No | Yes | Creation timestamp | | `updated_at` | timestamp | No | - | Last update timestamp | **Space Snapshot:** The `space_snapshot` field stores the shared data space (`ctx.Space`) at each step for recovery purposes. ```typescript // Example: In Next hook, set data to Space before delegate ctx.space.Set("choose_prompt", "query"); return { delegate: { agent_id: "expense", messages: payload.messages }, }; ``` If interrupted during delegate, the `space_snapshot` allows restoring `ctx.Space` state: ```json { "choose_prompt": "query", "user_preferences": { "currency": "USD" } } ``` **Resume Step Types:** | Type | Description | Input | Output | | ------------- | --------------------- | ---------------------- | ------------------------------------- | | `input` | User input received | `{messages: [...]}` | - | | `hook_create` | Create hook execution | `{messages: [...]}` | `{messages: [...], ...}` | | `llm` | LLM completion call | `{messages: [...]}` | `{content: "...", tool_calls: [...]}` | | `tool` | Tool/MCP execution | `{server, tool, args}` | `{result: ...}` | | `hook_next` | Next hook execution | `{completion, tools}` | `{data: ...}` | | `delegate` | A2A delegation | `{agent_id, messages}` | `{response: ...}` | **Resume Status (only two values - table only stores failed/interrupted):** | Status | Description | Action | | ------------- | ----------------- | -------- | | `failed` | Failed with error | Retry | | `interrupted` | User interrupted | Continue | **Indexes:** | Name | Columns | Type | | ---------------------- | ------------------------ | ----- | | `idx_resume_chat` | `chat_id` | index | | `idx_resume_request` | `request_id`, `sequence` | index | | `idx_resume_type` | `type` | index | | `idx_resume_status` | `status` | index | | `idx_resume_stack` | `stack_id` | index | | `idx_resume_parent` | `stack_parent_id` | index | | `idx_resume_assistant` | `assistant_id` | index | ## Write Strategy ### Single-Write Strategy All data is buffered in memory during execution and written to database **only once** when `Stream()` exits: **Note**: Request tracking (status, tokens, duration) is handled by [OpenAPI Request Middleware](../../openapi/request/REQUEST_DESIGN.md). ``` Stream() Entry │ ├── Buffer user input message (role=user) │ ├── Execution (all in memory) │ - ctx.Send() → messageBuffer │ - ctx.Append() → update messageBuffer │ - ctx.Replace() → update messageBuffer │ - Each step → stepBuffer │ └── 【Single Write】Save final state (via defer) │ ├── Always: │ - Batch write all messages (user input + assistant responses) │ - Update token usage in openapi_request (via request_id) │ └── Only on error/interrupt: - Batch write all steps (for resume/retry) ``` ### Write Points | Event | Message Table | Step Table | Token Usage | | ---------------- | -------------------------------------- | ----------------------------------- | ----------- | | Stream entry | Buffer user input | - | - | | During execution | Buffer in memory | Buffer in memory | - | | **Completed** | **Batch write all (user + assistant)** | **❌ Skip (no need to resume)** | ✅ Update | | On interrupt | Batch write all buffered | ✅ Batch write (status=interrupted) | ✅ Update | | On error | Batch write all buffered | ✅ Batch write (status=failed) | ✅ Update | **Why skip Steps on success?** - Steps are only needed for resume/retry operations - If completed successfully, there's nothing to resume - Reduces database writes and keeps Resume table clean ### Why Single Write? | Scenario | What Happens | Data Safe? | | ------------------ | --------------------------------- | ---------- | | Normal completion | `defer` triggers → Write executes | ✅ | | User clicks stop | `defer` triggers → Write executes | ✅ | | LLM timeout | `defer` triggers → Write executes | ✅ | | Tool failure | `defer` triggers → Write executes | ✅ | | Network disconnect | `defer` triggers → Write executes | ✅ | | Process crash | Service is down, user must retry | N/A | **Note**: Process crash is a catastrophic failure handled at infrastructure level, not application level. ### Write Count Comparison For a typical request: user input → hook_create → llm → tool → hook_next → 5 messages | Strategy | Database Writes | Notes | | ------------------------- | --------------- | --------------------- | | Write per operation | 1 + 5 + 5 = 11 | One write per step | | **Single-write strategy** | **1** | Exit only (via defer) | ### Implementation ````go func (ast *Assistant) Stream(ctx, inputMessages, options) { // ========== Memory Buffers ========== messageBuffer := NewMessageBuffer() stepBuffer := NewStepBuffer() // Buffer user input message (not written yet) userMsg := createUserMessage(ctx, inputMessages) messageBuffer.Add(userMsg) // Track current step for error handling var currentStep *Step defer func() { // ========== Single Write: Exit (always executes) ========== // Determine final status for incomplete steps finalStatus := "completed" if ctx.IsInterrupted() { finalStatus = "interrupted" } if r := recover(); r != nil { finalStatus = "failed" } // Update status of any incomplete step if currentStep != nil && currentStep.Status == "running" { currentStep.Status = finalStatus } // Batch write all buffered messages (user input + assistant responses) chatStore.SaveMessages(ctx.ChatID, messageBuffer.GetAll()) // Only save steps on error/interrupt (not on success) if finalStatus != "completed" { chatStore.SaveResume(stepBuffer.GetAll()) } // Update token usage in OpenAPI request record if ctx.RequestID != "" && completionResponse != nil { request.UpdateTokenUsage( ctx.RequestID, completionResponse.Usage.PromptTokens, completionResponse.Usage.CompletionTokens, ) } }() // ========== Execution (all in memory) ========== // Note: request_id = ctx.RequestID (from OpenAPI middleware) // hook_create currentStep = stepBuffer.Add(createStep(ctx, "hook_create", "running", input, nil)) createResponse := ast.HookScript.Create(...) currentStep.Output = createResponse currentStep.Status = "completed" // llm currentStep = stepBuffer.Add(createStep(ctx, "llm", "running", messages, nil)) completionResponse := ast.executeLLMStream(...) currentStep.Output = completionResponse currentStep.Status = "completed" // tool (if any) for _, toolCall := range completionResponse.ToolCalls { currentStep = stepBuffer.Add(createStep(ctx, "tool", "running", toolCall, nil)) result := executeToolCall(toolCall) currentStep.Output = result currentStep.Status = "completed" } // hook_next currentStep = stepBuffer.Add(createStep(ctx, "hook_next", "running", payload, nil)) nextResponse := ast.HookScript.Next(...) currentStep.Output = nextResponse currentStep.Status = "completed" currentStep = nil // All done // Messages are automatically buffered via ctx.Send() } // createResumeRecord creates a resume record with context information // Only called when request fails or is interrupted func createResumeRecord(ctx *Context, stepType, status string, input, output interface{}, err error) *Resume { // Capture Space snapshot for recovery var spaceSnapshot map[string]interface{} if ctx.Space != nil { spaceSnapshot = ctx.Space.Snapshot() // Get all key-value pairs } errorMsg := "" if err != nil { errorMsg = err.Error() } return &Resume{ ResumeID: generateID(), ChatID: ctx.ChatID, // ChatID RequestID: ctx.RequestID, // From OpenAPI middleware AssistantID: ctx.AssistantID, StackID: ctx.Stack.ID, StackParentID: ctx.Stack.ParentID, StackDepth: ctx.Stack.Depth, Type: stepType, Status: status, // "failed" or "interrupted" Input: input, Output: output, SpaceSnapshot: spaceSnapshot, // Shared space data for recovery Error: errorMsg, Sequence: nextSequence(), } } ## API Interface ### ChatStore Interface ```go // ChatStore defines the chat storage interface // Provides operations for chat, message, and resume management type ChatStore interface { // ========================================================================== // Chat Management // ========================================================================== // CreateChat creates a new chat session CreateChat(chat *Chat) error // GetChat retrieves a single chat by ID GetChat(chatID string) (*Chat, error) // UpdateChat updates chat fields UpdateChat(chatID string, updates map[string]interface{}) error // DeleteChat deletes a chat and its associated messages DeleteChat(chatID string) error // ListChats retrieves a paginated list of chats with optional grouping ListChats(filter ChatFilter) (*ChatList, error) // ========================================================================== // Message Management // ========================================================================== // SaveMessages batch saves messages for a chat // This is the primary write method - messages are buffered during execution // and batch-written at the end of a request SaveMessages(chatID string, messages []*Message) error // GetMessages retrieves messages for a chat with filtering GetMessages(chatID string, filter MessageFilter) ([]*Message, error) // UpdateMessage updates a single message UpdateMessage(messageID string, updates map[string]interface{}) error // DeleteMessages deletes specific messages from a chat DeleteMessages(chatID string, messageIDs []string) error // ========================================================================== // Resume Management (only called on failure/interrupt) // ========================================================================== // SaveResume batch saves resume records // Only called when request is interrupted or failed SaveResume(records []*Resume) error // GetResume retrieves all resume records for a chat GetResume(chatID string) ([]*Resume, error) // GetLastResume retrieves the last (most recent) resume record for a chat GetLastResume(chatID string) (*Resume, error) // GetResumeByStackID retrieves resume records for a specific stack GetResumeByStackID(stackID string) ([]*Resume, error) // GetStackPath returns the stack path from root to the given stack // Returns: [root_stack_id, ..., current_stack_id] GetStackPath(stackID string) ([]string, error) // DeleteResume deletes all resume records for a chat // Called after successful resume to clean up DeleteResume(chatID string) error } // AssistantStore defines the assistant storage interface // Separated from ChatStore for clearer responsibility type AssistantStore interface { // SaveAssistant saves assistant information SaveAssistant(assistant *AssistantModel) (string, error) // UpdateAssistant updates assistant fields UpdateAssistant(assistantID string, updates map[string]interface{}) error // DeleteAssistant deletes an assistant DeleteAssistant(assistantID string) error // GetAssistants retrieves a paginated list of assistants with filtering GetAssistants(filter AssistantFilter, locale ...string) (*AssistantList, error) // GetAssistantTags retrieves all unique tags from assistants with filtering GetAssistantTags(filter AssistantFilter, locale ...string) ([]Tag, error) // GetAssistant retrieves a single assistant by ID GetAssistant(assistantID string, fields []string, locale ...string) (*AssistantModel, error) // DeleteAssistants deletes assistants based on filter conditions DeleteAssistants(filter AssistantFilter) (int64, error) } // Store combines ChatStore and AssistantStore interfaces // This is the main interface for the storage layer type Store interface { ChatStore AssistantStore } // SpaceStore defines the interface for Space snapshot operations // Note: Space itself uses plan.Space interface, this is for persistence type SpaceStore interface { // Snapshot returns all key-value pairs in the space Snapshot() map[string]interface{} // Restore sets multiple key-value pairs from a snapshot Restore(data map[string]interface{}) error } ``` ### Data Structures ```go // Chat represents a chat session type Chat struct { ChatID string `json:"chat_id"` Title string `json:"title,omitempty"` AssistantID string `json:"assistant_id"` LastConnector string `json:"last_connector,omitempty"` // Last used connector ID (updated on each message) LastMode string `json:"last_mode,omitempty"` // Last used chat mode (updated on each message) Status string `json:"status"` // "active" or "archived" Public bool `json:"public"` // Whether shared across all teams Share string `json:"share"` // "private" or "team" Sort int `json:"sort"` // Sort order for display LastMessageAt *time.Time `json:"last_message_at,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } // Message represents a chat message type Message struct { MessageID string `json:"message_id"` ChatID string `json:"chat_id"` RequestID string `json:"request_id,omitempty"` Role string `json:"role"` // "user" or "assistant" Type string `json:"type"` // "text", "image", "loading", "tool_call", "retrieval", etc. Props map[string]interface{} `json:"props"` BlockID string `json:"block_id,omitempty"` ThreadID string `json:"thread_id,omitempty"` AssistantID string `json:"assistant_id,omitempty"` Connector string `json:"connector,omitempty"` // Connector ID used for this message Mode string `json:"mode,omitempty"` // Chat mode used for this message (chat or task) Sequence int `json:"sequence"` Metadata map[string]interface{} `json:"metadata,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } // Resume represents an execution state for recovery // Only stored when request is interrupted or failed type Resume struct { ResumeID string `json:"resume_id"` ChatID string `json:"chat_id"` RequestID string `json:"request_id"` AssistantID string `json:"assistant_id"` StackID string `json:"stack_id"` StackParentID string `json:"stack_parent_id,omitempty"` StackDepth int `json:"stack_depth"` Type string `json:"type"` // "input", "hook_create", "llm", "tool", "hook_next", "delegate" Status string `json:"status"` // "failed" or "interrupted" Input map[string]interface{} `json:"input,omitempty"` Output map[string]interface{} `json:"output,omitempty"` SpaceSnapshot map[string]interface{} `json:"space_snapshot,omitempty"` // Shared space data for recovery Error string `json:"error,omitempty"` Sequence int `json:"sequence"` Metadata map[string]interface{} `json:"metadata,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } // ResumeStatus constants const ( ResumeStatusFailed = "failed" ResumeStatusInterrupted = "interrupted" ) // ResumeType constants const ( ResumeTypeInput = "input" ResumeTypeHookCreate = "hook_create" ResumeTypeLLM = "llm" ResumeTypeTool = "tool" ResumeTypeHookNext = "hook_next" ResumeTypeDelegate = "delegate" ) ``` ### Filter Structures ```go // ChatFilter for listing chats type ChatFilter struct { // Permission filters (direct filtering on Yao permission fields) UserID string `json:"user_id,omitempty"` // Filter by __yao_created_by TeamID string `json:"team_id,omitempty"` // Filter by __yao_team_id // Business filters AssistantID string `json:"assistant_id,omitempty"` Status string `json:"status,omitempty"` Keywords string `json:"keywords,omitempty"` // Time range filter StartTime *time.Time `json:"start_time,omitempty"` // Filter chats after this time EndTime *time.Time `json:"end_time,omitempty"` // Filter chats before this time TimeField string `json:"time_field,omitempty"` // Field for time filter: "created_at" or "last_message_at" (default) // Sorting OrderBy string `json:"order_by,omitempty"` // Field to sort by (default: "last_message_at") Order string `json:"order,omitempty"` // Sort order: "desc" (default) or "asc" // Response format GroupBy string `json:"group_by,omitempty"` // "time" for time-based groups, empty for flat list // Pagination Page int `json:"page,omitempty"` PageSize int `json:"pagesize,omitempty"` // Advanced permission filter (not serialized) // Use for complex conditions like: (created_by = user OR team_id = team) QueryFilter func(query.Query) `json:"-"` } // MessageFilter for listing messages type MessageFilter struct { RequestID string `json:"request_id,omitempty"` Role string `json:"role,omitempty"` BlockID string `json:"block_id,omitempty"` ThreadID string `json:"thread_id,omitempty"` Type string `json:"type,omitempty"` Limit int `json:"limit,omitempty"` Offset int `json:"offset,omitempty"` } // ChatList paginated response with time-based grouping type ChatList struct { Data []*Chat `json:"data"` Groups []*ChatGroup `json:"groups,omitempty"` // Time-based groups for UI display Page int `json:"page"` PageSize int `json:"pagesize"` PageCount int `json:"pagecount"` Total int `json:"total"` } // ChatGroup represents a time-based group of chats type ChatGroup struct { Label string `json:"label"` // "Today", "Yesterday", "This Week", "This Month", "Earlier" Key string `json:"key"` // "today", "yesterday", "this_week", "this_month", "earlier" Chats []*Chat `json:"chats"` // Chats in this group Count int `json:"count"` // Number of chats in group } ``` ## Usage Examples ### 1. Complete Message Storage Example A typical conversation with various message types stored in `agent_message`: ``` User: "What's the weather in SF? Also show me a chart." Timeline (user input → hook_create → llm → tool → hook_next): 1. User sends input 2. Create hook shows loading state 3. LLM thinks and calls tool 4. Tool executes and returns result 5. Next hook generates text response and image chart ``` **Stored Messages:** ```json [ // 1. User input (role=user, type=user_input) { "message_id": "msg_001", "chat_id": "chat_123", "request_id": "req_abc", "role": "user", "type": "user_input", "props": { "content": "What's the weather in SF? Also show me a chart.", "role": "user" }, "sequence": 1 }, // 2. Loading state from Create hook (role=assistant, type=loading) { "message_id": "msg_002", "chat_id": "chat_123", "request_id": "req_abc", "role": "assistant", "type": "loading", "props": { "message": "Searching knowledge base..." }, "block_id": "B1", "assistant_id": "weather_assistant", "sequence": 2 }, // 3. LLM thinking process (role=assistant, type=thinking) { "message_id": "msg_003", "chat_id": "chat_123", "request_id": "req_abc", "role": "assistant", "type": "thinking", "props": { "content": "User wants weather info for San Francisco. I should use the get_weather tool..." }, "block_id": "B2", "assistant_id": "weather_assistant", "sequence": 3 }, // 4. LLM tool call (role=assistant, type=tool_call) { "message_id": "msg_004", "chat_id": "chat_123", "request_id": "req_abc", "role": "assistant", "type": "tool_call", "props": { "id": "call_weather_001", "name": "get_weather", "arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}" }, "block_id": "B2", "assistant_id": "weather_assistant", "sequence": 4 }, // 5. Tool result from Next hook (role=assistant, type=text, with tool metadata) { "message_id": "msg_005", "chat_id": "chat_123", "request_id": "req_abc", "role": "assistant", "type": "text", "props": { "content": "The weather in San Francisco is currently **18°C** and sunny with 65% humidity. Perfect weather for outdoor activities!" }, "block_id": "B3", "metadata": { "tool_call_id": "call_weather_001", "tool_name": "get_weather" }, "assistant_id": "weather_assistant", "sequence": 5 }, // 6. Chart image from Next hook (role=assistant, type=image) { "message_id": "msg_006", "chat_id": "chat_123", "request_id": "req_abc", "role": "assistant", "type": "image", "props": { "url": "https://charts.example.com/weather_sf.png", "alt": "San Francisco 7-day weather forecast", "width": 800, "height": 400 }, "block_id": "B3", "assistant_id": "weather_assistant", "sequence": 6 } ] ``` **Streaming IDs (from `STREAMING.md`):** During streaming, messages include additional fields for real-time delivery: | Field | Purpose | Stored? | | ------------ | ------------------------------ | ------- | | `chunk_id` | Deduplication, ordering, debug | ❌ No | | `message_id` | Delta merge target | ✅ Yes | | `block_id` | UI block/section grouping | ✅ Yes | | `thread_id` | Concurrent stream distinction | ✅ Yes | | `delta` | Whether this is a delta chunk | ❌ No | | `delta_path` | Path for delta merge | ❌ No | **Note:** `chunk_id`, `delta`, and `delta_path` are transient streaming control fields and are NOT stored. Only the final merged content is persisted. ### 2. Error Message Storage When errors occur, they are stored as `error` type: ```json { "message_id": "msg_err_001", "chat_id": "chat_123", "request_id": "req_abc", "role": "assistant", "type": "error", "props": { "message": "Failed to connect to weather service", "code": "SERVICE_UNAVAILABLE", "details": "Connection timeout after 30 seconds" }, "block_id": "B2", "assistant_id": "weather_assistant", "sequence": 5 } ``` ### 3. Action Message Storage (CUI clients) System actions are stored but only processed by CUI clients: ```json { "message_id": "msg_action_001", "chat_id": "chat_123", "request_id": "req_abc", "role": "assistant", "type": "action", "props": { "name": "open_panel", "payload": { "panel_id": "weather_details", "location": "San Francisco" } }, "block_id": "B2", "assistant_id": "weather_assistant", "sequence": 6 } ``` ### 4. Audio/Video Message Storage Multimedia content storage: ```json // Audio message { "message_id": "msg_audio_001", "chat_id": "chat_123", "role": "assistant", "type": "audio", "props": { "url": "https://storage.example.com/audio/response.mp3", "format": "mp3", "duration": 45.5, "transcript": "Here's the weather forecast for today...", "controls": true }, "sequence": 7 } // Video message { "message_id": "msg_video_001", "chat_id": "chat_123", "role": "assistant", "type": "video", "props": { "url": "https://storage.example.com/video/weather_report.mp4", "format": "mp4", "thumbnail": "https://storage.example.com/video/weather_report_thumb.jpg", "duration": 120.0, "width": 1280, "height": 720, "controls": true }, "sequence": 8 } ``` ### 5. Load Chat History ```go // Example 1: Filter by user (simple permission check) chats, _ := chatStore.ListChats(ChatFilter{ UserID: "user123", // Filters by __yao_created_by Status: "active", OrderBy: "last_message_at", Order: "desc", Page: 1, PageSize: 20, }) // Response: chats.Data = [...], chats.Groups = nil // Example 2: Filter by team chats, _ := chatStore.ListChats(ChatFilter{ TeamID: "team456", // Filters by __yao_team_id Status: "active", Page: 1, PageSize: 20, }) // Example 3: Filter by user AND team (both must match) chats, _ := chatStore.ListChats(ChatFilter{ UserID: "user123", TeamID: "team456", Page: 1, PageSize: 20, }) // Example 4: Complex permission filter (user OR team) using QueryFilter chats, _ := chatStore.ListChats(ChatFilter{ Page: 1, PageSize: 20, QueryFilter: func(qb query.Query) { qb.Where(func(sub query.Query) { sub.Where("__yao_created_by", "user123"). OrWhere("__yao_team_id", "team456") }) }, }) // Example 5: Grouped by time chats, _ := chatStore.ListChats(ChatFilter{ UserID: "user123", GroupBy: "time", // Enable time-based grouping OrderBy: "last_message_at", Order: "desc", Page: 1, PageSize: 20, }) // Response includes time-based groups: // chats.Groups = [ // { Key: "today", Label: "Today", Chats: [...], Count: 3 }, // { Key: "yesterday", Label: "Yesterday", Chats: [...], Count: 5 }, // { Key: "this_week", Label: "This Week", Chats: [...], Count: 8 }, // { Key: "this_month", Label: "This Month", Chats: [...], Count: 4 }, // { Key: "earlier", Label: "Earlier", Chats: [...], Count: 0 }, // ] // Example 6: Filter by time range startTime := time.Now().AddDate(0, 0, -7) // Last 7 days chats, _ := chatStore.ListChats(ChatFilter{ UserID: "user123", StartTime: &startTime, TimeField: "last_message_at", // Filter by last message time OrderBy: "last_message_at", Order: "desc", }) // Example 7: Filter specific date range start := time.Date(2024, 12, 1, 0, 0, 0, 0, time.Local) end := time.Date(2024, 12, 31, 23, 59, 59, 0, time.Local) chats, _ := chatStore.ListChats(ChatFilter{ UserID: "user123", StartTime: &start, EndTime: &end, TimeField: "created_at", // Filter by creation time }) // Example 8: Combine permission with business filters chats, _ := chatStore.ListChats(ChatFilter{ UserID: "user123", TeamID: "team456", AssistantID: "weather_assistant", Status: "active", Keywords: "weather", Page: 1, PageSize: 20, }) // Get messages for a chat messages, _ := chatStore.GetMessages("chat_123", MessageFilter{ Limit: 100, }) // Return to frontend return map[string]interface{}{ "chat": chat, "messages": messages, } ``` ### 6. Resume from Interruption ```go func (ast *Assistant) Resume(ctx *Context) error { // 1. Find last resume record record, _ := chatStore.GetLastResume(ctx.ChatID) if record == nil { return nil // Nothing to resume } // 2. Restore Space data from snapshot if record.SpaceSnapshot != nil && ctx.Space != nil { for key, value := range record.SpaceSnapshot { ctx.Space.Set(key, value) } } // 3. Check if this is an A2A nested call if record.StackDepth > 0 { // Need to rebuild the call stack return ast.ResumeNestedCall(ctx, record) } // 4. Resume based on step type var err error switch record.Type { case "llm": // Re-execute LLM call with saved input messages := record.Input["messages"].([]Message) err = ast.executeLLMStream(ctx, messages, ...) case "tool": // Retry tool call err = ast.retryToolCall(ctx, record) case "hook_next": // Re-execute hook err = ast.executeHookNext(ctx, record.Input) case "delegate": // Resume delegated agent call agentID := record.Input["agent_id"].(string) messages := record.Input["messages"].([]Message) err = ast.delegateToAgent(ctx, agentID, messages) } // 5. Clean up resume records on success if err == nil { chatStore.DeleteResume(ctx.ChatID) } return err } ``` ### 7. Resume A2A Nested Calls For agent-to-agent (A2A) recursive calls, the stack information is essential for proper recovery. ```go func (ast *Assistant) ResumeNestedCall(ctx *Context, step *Step) error { // 1. Rebuild the call stack from root to interrupted point stackPath, _ := chatStore.GetStackPath(step.StackID) // stackPath: [root_stack_id, parent_stack_id, ..., current_stack_id] // 2. Get all steps for each stack level for _, stackID := range stackPath { steps, _ := chatStore.GetStepsByStackID(stackID) // Restore context for each level } // 3. Resume from the interrupted assistant targetAssistant := assistant.Select(step.AssistantID) return targetAssistant.Stream(ctx, step.Input["messages"], ...) } ``` ### 8. Handle Interruption Interruption is handled automatically by the `defer` block in the two-write strategy. When `ctx.IsInterrupted()` returns true, the status is set to `interrupted` and all buffered data is saved. ```go // Inside the defer block (see Write Strategy - Implementation) if ctx.IsInterrupted() { status = "interrupted" } // Then batch write all buffered messages and steps ``` ## A2A (Agent-to-Agent) Call Example When Assistant A delegates to Assistant B, the step records look like: ``` Request: User asks "analyze this data and visualize it" Step Records: ┌─────┬─────────────┬─────────────┬──────────┬───────┬───────┬─────────────┬─────────────────────────────┐ │ seq │ assistant │ stack_id │ parent │ depth │ type │ status │ space_snapshot │ ├─────┼─────────────┼─────────────┼──────────┼───────┼───────┼─────────────┼─────────────────────────────┤ │ 1 │ analyzer │ stk_001 │ null │ 0 │ input │ completed │ {} │ │ 2 │ analyzer │ stk_001 │ null │ 0 │ llm │ completed │ {} │ │ 3 │ analyzer │ stk_001 │ null │ 0 │ delegate │ running │ {"choose_prompt": "query"} │ ← Space data set before delegate │ 4 │ visualizer │ stk_002 │ stk_001 │ 1 │ input │ completed │ {"choose_prompt": "query"} │ │ 5 │ visualizer │ stk_002 │ stk_001 │ 1 │ llm │ interrupted │ {"choose_prompt": "query"} │ ← interrupted here └─────┴─────────────┴─────────────┴──────────┴───────┴───────┴─────────────┴─────────────────────────────┘ Resume Flow: 1. Find step with status="interrupted" → step 5 2. Restore Space from space_snapshot: {"choose_prompt": "query"} 3. Check stack_depth=1 → nested call 4. Get stack path: [stk_001, stk_002] 5. Resume visualizer assistant with step 5's input 6. When visualizer completes, update step 3 (delegate) to completed ``` **Space Snapshot Use Case (from expense assistant):** ```typescript // In Next hook, before delegating to another agent ctx.space.Set("choose_prompt", "query"); return { delegate: { agent_id: "expense", messages: payload.messages }, }; // If interrupted during delegate, Resume will: // 1. Restore space_snapshot → ctx.space now has "choose_prompt": "query" // 2. The delegated agent's Create hook can read: ctx.space.GetDel("choose_prompt") ``` ## Concurrent Operations Storage When an Agent makes parallel calls (e.g., multiple MCP tools, multiple sub-agents), messages use `block_id` and `thread_id` for grouping: ``` Main Agent concurrently calls 3 tasks: ├── Thread T1: Weather query (MCP) ├── Thread T2: News search (MCP) ├── Thread T3: Stock query (MCP) └── Wait for all to complete, then summarize ``` **Stored Messages:** ```json [ // All concurrent messages share the same block_id, different thread_id // Messages may arrive in any order due to concurrency // Thread T1: Weather result { "message_id": "msg_t1_001", "chat_id": "chat_123", "request_id": "req_abc", "role": "assistant", "type": "text", "props": { "content": "Weather in SF: 18°C, sunny" }, "block_id": "B1", "thread_id": "T1", "assistant_id": "main_assistant", "sequence": 2 }, // Thread T2: News result { "message_id": "msg_t2_001", "chat_id": "chat_123", "request_id": "req_abc", "role": "assistant", "type": "text", "props": { "content": "Top news: AI breakthrough announced..." }, "block_id": "B1", "thread_id": "T2", "assistant_id": "main_assistant", "sequence": 3 }, // Thread T3: Stock result { "message_id": "msg_t3_001", "chat_id": "chat_123", "request_id": "req_abc", "role": "assistant", "type": "text", "props": { "content": "AAPL: $185.50 (+1.2%)" }, "block_id": "B1", "thread_id": "T3", "assistant_id": "main_assistant", "sequence": 4 }, // After all threads complete, main agent summarizes (new block) { "message_id": "msg_summary", "chat_id": "chat_123", "request_id": "req_abc", "role": "assistant", "type": "text", "props": { "content": "Here's your daily briefing: The weather is great at 18°C..." }, "block_id": "B2", "thread_id": null, "assistant_id": "main_assistant", "sequence": 5 } ] ``` **Key Points:** | Field | Concurrent Usage | | ----------- | -------------------------------------------------- | | `block_id` | Same for all parallel operations (B1) | | `thread_id` | Different for each concurrent task (T1, T2, T3) | | `sequence` | Reflects actual arrival order (may be interleaved) | **Frontend Rendering:** - Group messages by `block_id` for visual blocks - Within a block, optionally group by `thread_id` to show parallel results - Use `sequence` for chronological display ## HTTP API The chat storage provides RESTful HTTP APIs for managing chat sessions and messages. **Base Path:** `/v1/chat` ### Chat Sessions | Method | Endpoint | Description | |--------|----------|-------------| | `GET` | `/sessions` | List chat sessions with pagination and filtering | | `GET` | `/sessions/:chat_id` | Get a single chat session | | `PUT` | `/sessions/:chat_id` | Update chat session (title, status, metadata) | | `DELETE` | `/sessions/:chat_id` | Delete chat session | | `GET` | `/sessions/:chat_id/messages` | Get messages for a chat session | ### List Chat Sessions **Request:** ``` GET /v1/chat/sessions?page=1&pagesize=20&assistant_id=xxx&status=active&keywords=search&group_by=time ``` **Query Parameters:** | Parameter | Type | Default | Description | |-----------|------|---------|-------------| | `page` | int | 1 | Page number | | `pagesize` | int | 20 | Items per page (max 100) | | `assistant_id` | string | - | Filter by assistant ID | | `status` | string | - | Filter by status: `active`, `archived` | | `keywords` | string | - | Search in title | | `start_time` | RFC3339 | - | Filter chats after this time | | `end_time` | RFC3339 | - | Filter chats before this time | | `time_field` | string | `last_message_at` | Field for time filter: `created_at` or `last_message_at` | | `order_by` | string | `last_message_at` | Sort field | | `order` | string | `desc` | Sort order: `asc` or `desc` | | `group_by` | string | - | Set to `time` for time-based grouping | **Response:** ```json { "data": [ { "chat_id": "chat_123", "title": "Weather Query", "assistant_id": "weather_assistant", "status": "active", "last_message_at": "2024-01-15T10:30:00Z", "created_at": "2024-01-15T10:00:00Z" } ], "groups": [ { "key": "today", "label": "Today", "chats": [...], "count": 3 }, { "key": "yesterday", "label": "Yesterday", "chats": [...], "count": 5 } ], "page": 1, "pagesize": 20, "pagecount": 5, "total": 100 } ``` ### Get Chat Session **Request:** ``` GET /v1/chat/sessions/chat_123 ``` **Response:** ```json { "chat_id": "chat_123", "title": "Weather Query", "assistant_id": "weather_assistant", "last_connector": "deepseek.v3", "last_mode": "chat", "status": "active", "public": false, "share": "private", "last_message_at": "2024-01-15T10:30:00Z", "metadata": {}, "created_at": "2024-01-15T10:00:00Z", "updated_at": "2024-01-15T10:30:00Z" } ``` ### Update Chat Session **Request:** ``` PUT /v1/chat/sessions/chat_123 Content-Type: application/json { "title": "New Title", "status": "archived", "metadata": {"custom_field": "value"} } ``` **Response:** ```json { "message": "Chat updated successfully", "chat_id": "chat_123" } ``` ### Delete Chat Session **Request:** ``` DELETE /v1/chat/sessions/chat_123 ``` **Response:** ```json { "message": "Chat deleted successfully", "chat_id": "chat_123" } ``` ### Get Chat Messages **Request:** ``` GET /v1/chat/sessions/chat_123/messages?limit=100&offset=0&role=assistant&type=text ``` **Query Parameters:** | Parameter | Type | Default | Description | |-----------|------|---------|-------------| | `request_id` | string | - | Filter by request ID | | `role` | string | - | Filter by role: `user`, `assistant` | | `block_id` | string | - | Filter by block ID | | `thread_id` | string | - | Filter by thread ID | | `type` | string | - | Filter by message type | | `limit` | int | 100 | Max messages to return (max 1000) | | `offset` | int | 0 | Offset for pagination | | `locale` | string | - | Locale for assistant info (e.g., `zh-cn`, `en-us`). Falls back to `Accept-Language` header | **Locale Resolution Priority:** 1. Query parameter `locale` 2. HTTP header `Accept-Language` **Response:** ```json { "chat_id": "chat_123", "messages": [ { "message_id": "msg_001", "chat_id": "chat_123", "request_id": "req_abc", "role": "user", "type": "user_input", "props": { "content": "What's the weather?", "role": "user" }, "sequence": 1, "created_at": "2024-01-15T10:00:00Z" }, { "message_id": "msg_002", "chat_id": "chat_123", "request_id": "req_abc", "role": "assistant", "type": "text", "props": { "content": "The weather in San Francisco is 18°C and sunny." }, "block_id": "B1", "assistant_id": "weather_assistant", "sequence": 2, "created_at": "2024-01-15T10:00:05Z" } ], "count": 2, "assistants": { "weather_assistant": { "assistant_id": "weather_assistant", "name": "Weather Assistant", "avatar": "https://example.com/weather-avatar.png", "description": "Get weather information for any location" } } } ``` **Note:** The `assistants` field contains localized assistant information (name, avatar, description) for all unique `assistant_id` values found in the messages. This allows the frontend to display assistant details without additional API calls. The locale is determined by the `locale` query parameter or `Accept-Language` header. ### Permission Filtering All endpoints respect Yao's permission system: | Constraint | Behavior | |------------|----------| | `OwnerOnly` | User can only access their own chats (`__yao_created_by` matches) | | `TeamOnly` | User can access own chats OR team-shared chats (`share = "team"`) | | No constraints | Full access (for admin users) | **Permission Fields Used:** - `__yao_created_by`: User who created the chat - `__yao_team_id`: Team ID for team-level access - `public`: Whether chat is public to all - `share`: Sharing scope (`private` or `team`) ## Related Documents - [OpenAPI Request Design](../../openapi/request/REQUEST_DESIGN.md) - Global request tracking, billing, rate limiting - [Trace Module](../../trace/README.md) - Detailed execution tracing for debugging - [Agent Context](../context/README.md) - Context and message handling ```` ================================================ FILE: agent/store/README.md ================================================ # YAO Agent Store YAO Agent Store is a comprehensive storage abstraction layer for managing conversations, assistants, attachments, and knowledge collections in the YAO Agent platform. It provides a unified interface that supports multiple storage backends including databases (via Xun), Redis, and MongoDB. ## Table of Contents - [Architecture](#architecture) - [Storage Backends](#storage-backends) - [Configuration](#configuration) - [Initialization](#initialization) - [API Reference](#api-reference) - [Data Models](#data-models) - [Usage Examples](#usage-examples) - [Testing](#testing) ## Architecture The store package provides a unified `Store` interface that abstracts different storage implementations: ``` ┌─────────────────┐ │ Store API │ ← Unified Interface ├─────────────────┤ │ Xun (Database) │ ← Primary Implementation │ Redis │ ← Cache/Memory Store │ MongoDB │ ← Document Store └─────────────────┘ ``` ### Core Entities 1. **Conversations & Chat History** - Manage chat sessions and message history 2. **Assistants** - AI assistant configurations and metadata 3. **Attachments** - File attachments with metadata and access control 4. **Knowledge Collections** - Knowledge bases for AI assistants ## Storage Backends ### 1. Xun (Database) - Primary Backend The main implementation using SQL databases with automatic schema management: - **Supported Databases**: MySQL, PostgreSQL, SQLite, etc. - **Features**: ACID transactions, complex queries, automatic migrations - **Use Case**: Production environments requiring data consistency ### 2. Redis - Cache Backend Redis implementation for high-performance caching: - **Features**: In-memory storage, pub/sub capabilities - **Use Case**: Session management, temporary data, real-time features ### 3. MongoDB - Document Backend MongoDB implementation for document-based storage: - **Features**: Schema flexibility, horizontal scaling - **Use Case**: Large-scale deployments, unstructured data ## Configuration ### Setting Structure ```go type Setting struct { Connector string `json:"connector,omitempty"` // Storage connector name UserField string `json:"user_field,omitempty"` // User ID field name (default: "user_id") Prefix string `json:"prefix,omitempty"` // Database table name prefix MaxSize int `json:"max_size,omitempty" yaml:"max_size,omitempty"` // Maximum history size limit TTL int `json:"ttl,omitempty" yaml:"ttl,omitempty"` // Time To Live in seconds } ``` ### Configuration Examples #### Database Configuration ```yaml # agent.yml agent: store: connector: "mysql" # or "postgresql", "sqlite", "default" prefix: "agent_" # Table prefix max_size: 100 # Maximum chat history size ttl: 7200 # 2 hours TTL for conversations user_field: "user_id" # User identification field ``` #### Redis Configuration ```yaml agent: store: connector: "redis" prefix: "agent:" ttl: 3600 ``` #### MongoDB Configuration ```yaml agent: store: connector: "mongodb" prefix: "agent_" ttl: 7200 ``` ## Initialization ### Automatic Initialization (Recommended) The store is automatically initialized when the Agent system starts: ```go // From yao/agent/load.go func initStore() error { var err error if Agent.StoreSetting.Connector == "default" || Agent.StoreSetting.Connector == "" { Agent.Store, err = store.NewXun(Agent.StoreSetting) return err } // Other connector types conn, err := connector.Select(Agent.StoreSetting.Connector) if err != nil { return err } if conn.Is(connector.DATABASE) { Agent.Store, err = store.NewXun(Agent.StoreSetting) return err } else if conn.Is(connector.REDIS) { Agent.Store = store.NewRedis() return nil } else if conn.Is(connector.MONGO) { Agent.Store = store.NewMongo() return nil } return fmt.Errorf("%s store connector %s not support", Agent.ID, Agent.StoreSetting.Connector) } ``` ### Manual Initialization ```go import "github.com/yaoapp/yao/agent/store" // Database backend setting := store.Setting{ Connector: "mysql", Prefix: "agent_", MaxSize: 100, TTL: 3600, } store, err := store.NewXun(setting) // Redis backend redisStore := store.NewRedis() // MongoDB backend mongoStore := store.NewMongo() ``` ## API Reference ### Store Interface ```go type Store interface { // Chat Management GetChats(sid string, filter ChatFilter, locale ...string) (*ChatGroupResponse, error) GetChat(sid string, cid string, locale ...string) (*ChatInfo, error) GetChatWithFilter(sid string, cid string, filter ChatFilter, locale ...string) (*ChatInfo, error) UpdateChatTitle(sid string, cid string, title string) error DeleteChat(sid string, cid string) error DeleteAllChats(sid string) error // Message History GetHistory(sid string, cid string, locale ...string) ([]map[string]interface{}, error) GetHistoryWithFilter(sid string, cid string, filter ChatFilter, locale ...string) ([]map[string]interface{}, error) SaveHistory(sid string, messages []map[string]interface{}, cid string, context map[string]interface{}) error // Assistant Management SaveAssistant(assistant map[string]interface{}) (interface{}, error) GetAssistants(filter AssistantFilter, locale ...string) (*AssistantResponse, error) GetAssistant(assistantID string, locale ...string) (map[string]interface{}, error) DeleteAssistant(assistantID string) error DeleteAssistants(filter AssistantFilter) (int64, error) GetAssistantTags(locale ...string) ([]Tag, error) // Attachment Management SaveAttachment(attachment map[string]interface{}) (interface{}, error) GetAttachments(filter AttachmentFilter, locale ...string) (*AttachmentResponse, error) GetAttachment(fileID string, locale ...string) (map[string]interface{}, error) DeleteAttachment(fileID string) error DeleteAttachments(filter AttachmentFilter) (int64, error) // Knowledge Management SaveKnowledge(knowledge map[string]interface{}) (interface{}, error) GetKnowledges(filter KnowledgeFilter, locale ...string) (*KnowledgeResponse, error) GetKnowledge(collectionID string, locale ...string) (map[string]interface{}, error) DeleteKnowledge(collectionID string) error DeleteKnowledges(filter KnowledgeFilter) (int64, error) // Resource Management Close() error } ``` ## Data Models ### Database Schema #### 1. History Table (Conversations) ```sql CREATE TABLE agent_history ( id BIGINT PRIMARY KEY AUTO_INCREMENT, sid VARCHAR(255) INDEX, -- Session ID cid VARCHAR(200) INDEX, -- Chat ID uid VARCHAR(255) INDEX, -- User ID role VARCHAR(200) INDEX, -- Message role (user/assistant/system) name VARCHAR(200), -- Message sender name content TEXT, -- Message content context JSON, -- Message context assistant_id VARCHAR(200) INDEX, -- Associated assistant ID assistant_name VARCHAR(200), -- Assistant name assistant_avatar VARCHAR(200), -- Assistant avatar URL mentions JSON, -- Mentions in the message silent BOOLEAN DEFAULT FALSE INDEX, -- Silent message flag created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP INDEX, updated_at TIMESTAMP INDEX, expired_at TIMESTAMP INDEX -- TTL expiration ); ``` #### 2. Chat Table ```sql CREATE TABLE agent_chat ( id BIGINT PRIMARY KEY AUTO_INCREMENT, chat_id VARCHAR(200) UNIQUE INDEX, -- Unique chat identifier title VARCHAR(200), -- Chat title assistant_id VARCHAR(200) INDEX, -- Associated assistant sid VARCHAR(255) INDEX, -- Session ID silent BOOLEAN DEFAULT FALSE INDEX, -- Silent chat flag created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP INDEX, updated_at TIMESTAMP INDEX ); ``` #### 3. Assistant Table ```sql CREATE TABLE agent_assistant ( id BIGINT PRIMARY KEY AUTO_INCREMENT, assistant_id VARCHAR(200) UNIQUE INDEX, -- Unique assistant identifier type VARCHAR(200) DEFAULT 'assistant' INDEX, -- Assistant type name VARCHAR(200), -- Assistant name avatar VARCHAR(200), -- Avatar URL connector VARCHAR(200) NOT NULL, -- LLM connector description VARCHAR(600) INDEX, -- Description (searchable) path VARCHAR(200), -- Storage path sort INTEGER DEFAULT 9999 INDEX, -- Sort order built_in BOOLEAN DEFAULT FALSE INDEX, -- Built-in assistant flag placeholder JSON, -- UI placeholder text options JSON, -- Assistant options prompts JSON, -- System prompts workflow JSON, -- Workflow configuration knowledge JSON, -- Knowledge base references tools JSON, -- Available tools tags JSON, -- Assistant tags readonly BOOLEAN DEFAULT FALSE INDEX, -- Read-only flag permissions JSON, -- Access permissions locales JSON, -- Internationalization data automated BOOLEAN DEFAULT TRUE INDEX, -- Automation enabled mentionable BOOLEAN DEFAULT TRUE INDEX, -- Can be mentioned in chats created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP INDEX, updated_at TIMESTAMP INDEX ); ``` #### 4. Attachment Table ```sql CREATE TABLE agent_attachment ( id BIGINT PRIMARY KEY AUTO_INCREMENT, file_id VARCHAR(255) UNIQUE INDEX, -- Unique file identifier uid VARCHAR(255) INDEX, -- Owner user ID guest BOOLEAN DEFAULT FALSE INDEX, -- Guest upload flag manager VARCHAR(200) INDEX, -- Storage manager content_type VARCHAR(200) INDEX, -- MIME type name VARCHAR(500) INDEX, -- File name (searchable) public BOOLEAN DEFAULT FALSE INDEX, -- Public access flag scope JSON, -- Access scope gzip BOOLEAN DEFAULT FALSE INDEX, -- Compression flag bytes BIGINT INDEX, -- File size collection_id VARCHAR(200) INDEX, -- Associated knowledge collection status ENUM('uploading', 'uploaded', 'indexing', 'indexed', 'upload_failed', 'index_failed') DEFAULT 'uploading' INDEX, -- Processing status progress VARCHAR(200), -- Progress information (nullable) error VARCHAR(600), -- Error message (nullable) created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP INDEX, updated_at TIMESTAMP INDEX ); ``` #### 5. Knowledge Table ```sql CREATE TABLE agent_knowledge ( id BIGINT PRIMARY KEY AUTO_INCREMENT, collection_id VARCHAR(200) UNIQUE INDEX, -- Unique collection identifier name VARCHAR(200) INDEX, -- Collection name (searchable) description VARCHAR(600) INDEX, -- Description (searchable) uid VARCHAR(255) INDEX, -- Owner user ID public BOOLEAN DEFAULT FALSE INDEX, -- Public access flag scope JSON, -- Access scope readonly BOOLEAN DEFAULT FALSE INDEX, -- Read-only flag option JSON, -- Collection options system BOOLEAN DEFAULT FALSE INDEX, -- System collection flag sort INTEGER DEFAULT 9999 INDEX, -- Sort order cover VARCHAR(500), -- Cover image URL created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP INDEX, updated_at TIMESTAMP INDEX ); ``` ### Filter Structures #### ChatFilter ```go type ChatFilter struct { Keywords string `json:"keywords,omitempty"` // Search keywords Page int `json:"page,omitempty"` // Page number (starts from 1) PageSize int `json:"pagesize,omitempty"` // Items per page Order string `json:"order,omitempty"` // Sort order (desc/asc) Silent *bool `json:"silent,omitempty"` // Include silent messages } ``` #### AssistantFilter ```go type AssistantFilter struct { Tags []string `json:"tags,omitempty"` // Filter by tags Type string `json:"type,omitempty"` // Filter by type Keywords string `json:"keywords,omitempty"` // Search keywords Connector string `json:"connector,omitempty"` // Filter by connector AssistantID string `json:"assistant_id,omitempty"` // Specific assistant ID AssistantIDs []string `json:"assistant_ids,omitempty"` // Multiple assistant IDs Mentionable *bool `json:"mentionable,omitempty"` // Mentionable status Automated *bool `json:"automated,omitempty"` // Automation status BuiltIn *bool `json:"built_in,omitempty"` // Built-in status Page int `json:"page,omitempty"` // Page number PageSize int `json:"pagesize,omitempty"` // Items per page Select []string `json:"select,omitempty"` // Fields to return } ``` #### AttachmentFilter ```go type AttachmentFilter struct { UID string `json:"uid,omitempty"` // Filter by user ID Guest *bool `json:"guest,omitempty"` // Filter by guest status Manager string `json:"manager,omitempty"` // Filter by upload manager ContentType string `json:"content_type,omitempty"` // Filter by content type Name string `json:"name,omitempty"` // Filter by filename Public *bool `json:"public,omitempty"` // Filter by public status Gzip *bool `json:"gzip,omitempty"` // Filter by gzip compression CollectionID string `json:"collection_id,omitempty"` // Filter by knowledge collection ID Status string `json:"status,omitempty"` // Filter by processing status Keywords string `json:"keywords,omitempty"` // Search in filename Page int `json:"page,omitempty"` // Page number PageSize int `json:"pagesize,omitempty"` // Items per page Select []string `json:"select,omitempty"` // Fields to return } ``` #### KnowledgeFilter ```go type KnowledgeFilter struct { UID string `json:"uid,omitempty"` // Filter by user ID Name string `json:"name,omitempty"` // Filter by collection name Keywords string `json:"keywords,omitempty"` // Search in name and description Public *bool `json:"public,omitempty"` // Filter by public status Readonly *bool `json:"readonly,omitempty"` // Filter by readonly status System *bool `json:"system,omitempty"` // Filter by system status Page int `json:"page,omitempty"` // Page number PageSize int `json:"pagesize,omitempty"` // Items per page Select []string `json:"select,omitempty"` // Fields to return } ``` ## Usage Examples ### 1. Chat Management ```go // Save chat history messages := []map[string]interface{}{ {"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing well, thank you!"}, } context := map[string]interface{}{ "assistant_id": "gpt-4", "silent": false, } err := store.SaveHistory("user123", messages, "chat456", context) // Get chat history history, err := store.GetHistory("user123", "chat456") // Get chat list with pagination filter := ChatFilter{ Page: 1, PageSize: 20, Order: "desc", } chats, err := store.GetChats("user123", filter) // Update chat title err = store.UpdateChatTitle("user123", "chat456", "New Chat Title") ``` ### 2. Assistant Management ```go // Create an assistant assistant := map[string]interface{}{ "name": "Code Helper", "type": "assistant", "connector": "gpt-4", "description": "A helpful coding assistant", "tags": []string{"coding", "development"}, "sort": 100, "options": map[string]interface{}{ "temperature": 0.7, "max_tokens": 2000, }, "prompts": []string{ "You are a helpful coding assistant.", }, "mentionable": true, "automated": true, } assistantID, err := store.SaveAssistant(assistant) // Get assistants with filtering filter := AssistantFilter{ Tags: []string{"coding"}, Keywords: "helper", Page: 1, PageSize: 10, } assistants, err := store.GetAssistants(filter) // Get specific assistant assistant, err := store.GetAssistant("assistant123") ``` ### 3. Attachment Management ```go // Save attachment metadata attachment := map[string]interface{}{ "file_id": "file123", "uid": "user123", "manager": "local", "content_type": "image/jpeg", "name": "profile.jpg", "public": false, "bytes": 102400, "collection_id": "knowledge456", "scope": []string{"user", "admin"}, "status": "uploaded", // Status: uploading, uploaded, indexing, indexed, upload_failed, index_failed "progress": "Upload completed", // Progress information (optional) "error": nil, // Error message (optional, for failed statuses) } fileID, err := store.SaveAttachment(attachment) // Update attachment status during processing workflow attachment["status"] = "indexing" attachment["progress"] = "Processing file for indexing..." _, err = store.SaveAttachment(attachment) // Handle failed upload attachment["status"] = "upload_failed" attachment["progress"] = nil attachment["error"] = "Network connection timeout" _, err = store.SaveAttachment(attachment) // Complete indexing attachment["status"] = "indexed" attachment["progress"] = "File indexed successfully" attachment["error"] = nil _, err = store.SaveAttachment(attachment) // Get attachments with filtering filter := AttachmentFilter{ UID: "user123", ContentType: "image/jpeg", Status: "indexed", // Filter by status Page: 1, PageSize: 20, } attachments, err := store.GetAttachments(filter) // Get all failed uploads failedFilter := AttachmentFilter{ UID: "user123", Status: "upload_failed", Page: 1, PageSize: 10, } failedUploads, err := store.GetAttachments(failedFilter) ``` #### Attachment Status Workflow The attachment system supports a complete file processing workflow with the following status values: - **`uploading`** (default): File upload is in progress - **`uploaded`**: File upload completed successfully - **`indexing`**: File is being processed for search indexing - **`indexed`**: File has been indexed and is ready for use - **`upload_failed`**: File upload failed (check `error` field for details) - **`index_failed`**: File indexing failed (check `error` field for details) #### Additional Fields - **`progress`**: Human-readable progress information (string, nullable) - **`error`**: Error message for failed operations (string, nullable, max 600 characters) ### 4. Knowledge Collection Management ```go // Create knowledge collection knowledge := map[string]interface{}{ "collection_id": "kb123", "name": "Programming Guide", "description": "Comprehensive programming tutorials and examples", "uid": "user123", "public": true, "readonly": false, "sort": 100, "option": map[string]interface{}{ "embedding": "openai", "chunk_size": 1000, }, "scope": []string{"developers", "students"}, } collectionID, err := store.SaveKnowledge(knowledge) // Get knowledge collections with filtering filter := KnowledgeFilter{ UID: "user123", Keywords: "programming", Public: &[]bool{true}[0], Page: 1, PageSize: 10, } collections, err := store.GetKnowledges(filter) // Get system knowledge collections systemFilter := KnowledgeFilter{ System: &[]bool{true}[0], Page: 1, PageSize: 20, } systemCollections, err := store.GetKnowledges(systemFilter) // Get readonly knowledge collections with specific fields readonlyFilter := KnowledgeFilter{ Readonly: &[]bool{true}[0], Select: []string{"collection_id", "name", "description", "sort"}, Page: 1, PageSize: 15, } readonlyCollections, err := store.GetKnowledges(readonlyFilter) ``` ### 5. Internationalization Support ```go // Get assistants with locale assistants, err := store.GetAssistants(filter, "zh-CN") // Get chat with locale chat, err := store.GetChat("user123", "chat456", "en-US") ``` ### 6. Advanced Filtering and Sorting ```go // Complex assistant filtering filter := AssistantFilter{ Tags: []string{"ai", "assistant"}, Keywords: "helpful", Connector: "gpt-4", Mentionable: &[]bool{true}[0], BuiltIn: &[]bool{false}[0], Select: []string{"assistant_id", "name", "description", "tags"}, Page: 1, PageSize: 50, } assistants, err := store.GetAssistants(filter) // Results are automatically sorted by: // 1. sort field (ASC) - lower numbers appear first // 2. created_at/updated_at (DESC) - newer items appear first ``` ## Testing ### Running Tests ```bash # Run all tests go test -v # Run specific test go test -run TestXunKnowledgeCRUD -v # Run with coverage go test -cover ``` ### Test Structure The test suite includes comprehensive coverage for: - **CRUD Operations**: Create, Read, Update, Delete for all entities - **Filtering**: Various filter combinations and edge cases - **Sorting**: Verify sort order and pagination - **Error Handling**: Invalid inputs and edge cases - **Internationalization**: Locale-specific operations - **Concurrency**: Multiple concurrent operations ### Test Database Setup Tests use isolated table prefixes to avoid conflicts: ```go store, err := NewXun(Setting{ Connector: "default", Prefix: "__unit_test_conversation_", TTL: 3600, }) ``` ## Performance Considerations ### Database Optimization 1. **Indexes**: All frequently queried fields have indexes 2. **TTL**: Automatic cleanup of expired data 3. **Pagination**: All list operations support pagination 4. **Connection Pooling**: Efficient database connection management ### Caching Strategy 1. **Redis Backend**: For high-frequency read operations 2. **Memory Caching**: In-application caching for static data 3. **Query Optimization**: Efficient filtering and sorting ### Scaling 1. **Horizontal Scaling**: MongoDB support for distributed deployments 2. **Read Replicas**: Database read/write splitting 3. **Sharding**: Data partitioning strategies ## Migration and Upgrades ### Schema Evolution The Xun backend automatically handles schema migrations: - New tables are created automatically - New fields are added with default values - Indexes are created during initialization ### Data Migration When switching between backends: 1. Export data from source backend 2. Transform data format if necessary 3. Import to target backend 4. Verify data integrity ## Security ### Access Control 1. **User Isolation**: All operations are user-scoped 2. **Permission System**: Fine-grained access control 3. **Public/Private Flags**: Content visibility management ### Data Protection 1. **Input Validation**: All inputs are validated and sanitized 2. **SQL Injection Prevention**: Parameterized queries 3. **XSS Protection**: Content encoding and sanitization ## Troubleshooting ### Common Issues 1. **Connection Errors**: Check connector configuration 2. **Schema Errors**: Verify database permissions 3. **Performance Issues**: Check indexes and query patterns 4. **Memory Issues**: Monitor TTL and cleanup processes ### Debugging Enable debug logging: ```go import "github.com/yaoapp/kun/log" log.SetLevel(log.DebugLevel) ``` ### Monitoring Key metrics to monitor: - Database connection pool usage - Query performance and slow queries - Memory usage and garbage collection - TTL cleanup effectiveness ## Contributing ### Development Setup 1. Clone the repository 2. Install dependencies: `go mod download` 3. Run tests: `go test -v` 4. Follow Go coding standards ### Adding New Features 1. Update the Store interface 2. Implement in all backends (Xun, Redis, MongoDB) 3. Add comprehensive tests 4. Update documentation ## License This project is part of the Yao App Engine and follows the same license terms. ================================================ FILE: agent/store/mongo/mongo.go ================================================ package mongo import "github.com/yaoapp/yao/agent/store/types" // Mongo represents a MongoDB-based conversation storage type Mongo struct{} // NewMongo create a new mongo store func NewMongo() types.Store { return &Mongo{} } // ============================================================================= // Chat Management // ============================================================================= // CreateChat creates a new chat session func (m *Mongo) CreateChat(chat *types.Chat) error { // TODO: implement return nil } // GetChat retrieves a single chat by ID func (m *Mongo) GetChat(chatID string) (*types.Chat, error) { // TODO: implement return nil, nil } // UpdateChat updates chat fields func (m *Mongo) UpdateChat(chatID string, updates map[string]interface{}) error { // TODO: implement return nil } // DeleteChat deletes a chat and its associated messages func (m *Mongo) DeleteChat(chatID string) error { // TODO: implement return nil } // ListChats retrieves a paginated list of chats with optional grouping func (m *Mongo) ListChats(filter types.ChatFilter) (*types.ChatList, error) { // TODO: implement return nil, nil } // ============================================================================= // Message Management // ============================================================================= // SaveMessages batch saves messages for a chat func (m *Mongo) SaveMessages(chatID string, messages []*types.Message) error { // TODO: implement return nil } // GetMessages retrieves messages for a chat with filtering func (m *Mongo) GetMessages(chatID string, filter types.MessageFilter) ([]*types.Message, error) { // TODO: implement return nil, nil } // UpdateMessage updates a single message func (m *Mongo) UpdateMessage(messageID string, updates map[string]interface{}) error { // TODO: implement return nil } // DeleteMessages deletes specific messages from a chat func (m *Mongo) DeleteMessages(chatID string, messageIDs []string) error { // TODO: implement return nil } // ============================================================================= // Resume Management (only called on failure/interrupt) // ============================================================================= // SaveResume batch saves resume records func (m *Mongo) SaveResume(records []*types.Resume) error { // TODO: implement return nil } // GetResume retrieves all resume records for a chat func (m *Mongo) GetResume(chatID string) ([]*types.Resume, error) { // TODO: implement return nil, nil } // GetLastResume retrieves the last resume record for a chat func (m *Mongo) GetLastResume(chatID string) (*types.Resume, error) { // TODO: implement return nil, nil } // GetResumeByStackID retrieves resume records for a specific stack func (m *Mongo) GetResumeByStackID(stackID string) ([]*types.Resume, error) { // TODO: implement return nil, nil } // GetStackPath returns the stack path from root to the given stack func (m *Mongo) GetStackPath(stackID string) ([]string, error) { // TODO: implement return nil, nil } // DeleteResume deletes all resume records for a chat func (m *Mongo) DeleteResume(chatID string) error { // TODO: implement return nil } // ============================================================================= // Assistant Management // ============================================================================= // SaveAssistant saves assistant information func (m *Mongo) SaveAssistant(assistant *types.AssistantModel) (string, error) { // TODO: implement return assistant.ID, nil } // UpdateAssistant updates specific fields of an assistant func (m *Mongo) UpdateAssistant(assistantID string, updates map[string]interface{}) error { // TODO: implement return nil } // DeleteAssistant deletes an assistant func (m *Mongo) DeleteAssistant(assistantID string) error { // TODO: implement return nil } // GetAssistants retrieves a list of assistants func (m *Mongo) GetAssistants(filter types.AssistantFilter, locale ...string) (*types.AssistantList, error) { // TODO: implement return &types.AssistantList{}, nil } // GetAssistantTags retrieves all unique tags from assistants with filtering func (m *Mongo) GetAssistantTags(filter types.AssistantFilter, locale ...string) ([]types.Tag, error) { // TODO: implement return []types.Tag{}, nil } // GetAssistant retrieves a single assistant by ID func (m *Mongo) GetAssistant(assistantID string, fields []string, locale ...string) (*types.AssistantModel, error) { // TODO: implement return nil, nil } // DeleteAssistants deletes assistants based on filter conditions func (m *Mongo) DeleteAssistants(filter types.AssistantFilter) (int64, error) { // TODO: implement return 0, nil } // ============================================================================= // Search Management // ============================================================================= // SaveSearch saves a search record for a request func (m *Mongo) SaveSearch(search *types.Search) error { // TODO: implement return nil } // GetSearches retrieves all search records for a request func (m *Mongo) GetSearches(requestID string) ([]*types.Search, error) { // TODO: implement return nil, nil } // GetReference retrieves a single reference by request ID and index func (m *Mongo) GetReference(requestID string, index int) (*types.Reference, error) { // TODO: implement return nil, nil } // DeleteSearches deletes all search records for a chat func (m *Mongo) DeleteSearches(chatID string) error { // TODO: implement return nil } ================================================ FILE: agent/store/redis/redis.go ================================================ package redis import "github.com/yaoapp/yao/agent/store/types" // Redis represents a Redis-based conversation storage type Redis struct{} // NewRedis create a new redis store func NewRedis() types.Store { return &Redis{} } // ============================================================================= // Chat Management // ============================================================================= // CreateChat creates a new chat session func (r *Redis) CreateChat(chat *types.Chat) error { // TODO: implement return nil } // GetChat retrieves a single chat by ID func (r *Redis) GetChat(chatID string) (*types.Chat, error) { // TODO: implement return nil, nil } // UpdateChat updates chat fields func (r *Redis) UpdateChat(chatID string, updates map[string]interface{}) error { // TODO: implement return nil } // DeleteChat deletes a chat and its associated messages func (r *Redis) DeleteChat(chatID string) error { // TODO: implement return nil } // ListChats retrieves a paginated list of chats with optional grouping func (r *Redis) ListChats(filter types.ChatFilter) (*types.ChatList, error) { // TODO: implement return nil, nil } // ============================================================================= // Message Management // ============================================================================= // SaveMessages batch saves messages for a chat func (r *Redis) SaveMessages(chatID string, messages []*types.Message) error { // TODO: implement return nil } // GetMessages retrieves messages for a chat with filtering func (r *Redis) GetMessages(chatID string, filter types.MessageFilter) ([]*types.Message, error) { // TODO: implement return nil, nil } // UpdateMessage updates a single message func (r *Redis) UpdateMessage(messageID string, updates map[string]interface{}) error { // TODO: implement return nil } // DeleteMessages deletes specific messages from a chat func (r *Redis) DeleteMessages(chatID string, messageIDs []string) error { // TODO: implement return nil } // ============================================================================= // Resume Management (only called on failure/interrupt) // ============================================================================= // SaveResume batch saves resume records func (r *Redis) SaveResume(records []*types.Resume) error { // TODO: implement return nil } // GetResume retrieves all resume records for a chat func (r *Redis) GetResume(chatID string) ([]*types.Resume, error) { // TODO: implement return nil, nil } // GetLastResume retrieves the last resume record for a chat func (r *Redis) GetLastResume(chatID string) (*types.Resume, error) { // TODO: implement return nil, nil } // GetResumeByStackID retrieves resume records for a specific stack func (r *Redis) GetResumeByStackID(stackID string) ([]*types.Resume, error) { // TODO: implement return nil, nil } // GetStackPath returns the stack path from root to the given stack func (r *Redis) GetStackPath(stackID string) ([]string, error) { // TODO: implement return nil, nil } // DeleteResume deletes all resume records for a chat func (r *Redis) DeleteResume(chatID string) error { // TODO: implement return nil } // ============================================================================= // Assistant Management // ============================================================================= // SaveAssistant saves assistant information func (r *Redis) SaveAssistant(assistant *types.AssistantModel) (string, error) { // TODO: implement return assistant.ID, nil } // UpdateAssistant updates specific fields of an assistant func (r *Redis) UpdateAssistant(assistantID string, updates map[string]interface{}) error { // TODO: implement return nil } // DeleteAssistant deletes an assistant func (r *Redis) DeleteAssistant(assistantID string) error { // TODO: implement return nil } // GetAssistants retrieves a list of assistants func (r *Redis) GetAssistants(filter types.AssistantFilter, locale ...string) (*types.AssistantList, error) { // TODO: implement return &types.AssistantList{}, nil } // GetAssistantTags retrieves all unique tags from assistants with filtering func (r *Redis) GetAssistantTags(filter types.AssistantFilter, locale ...string) ([]types.Tag, error) { // TODO: implement return []types.Tag{}, nil } // GetAssistant retrieves a single assistant by ID func (r *Redis) GetAssistant(assistantID string, fields []string, locale ...string) (*types.AssistantModel, error) { // TODO: implement return nil, nil } // DeleteAssistants deletes assistants based on filter conditions func (r *Redis) DeleteAssistants(filter types.AssistantFilter) (int64, error) { // TODO: implement return 0, nil } // ============================================================================= // Search Management // ============================================================================= // SaveSearch saves a search record for a request func (r *Redis) SaveSearch(search *types.Search) error { // TODO: implement return nil } // GetSearches retrieves all search records for a request func (r *Redis) GetSearches(requestID string) ([]*types.Search, error) { // TODO: implement return nil, nil } // GetReference retrieves a single reference by request ID and index func (r *Redis) GetReference(requestID string, index int) (*types.Reference, error) { // TODO: implement return nil, nil } // DeleteSearches deletes all search records for a chat func (r *Redis) DeleteSearches(chatID string) error { // TODO: implement return nil } ================================================ FILE: agent/store/types/convert.go ================================================ package types import ( "fmt" "strings" "time" jsoniter "github.com/json-iterator/go" "github.com/spf13/cast" "github.com/yaoapp/gou/connector" "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" searchTypes "github.com/yaoapp/yao/agent/search/types" ) // ToKnowledgeBase converts various types to KnowledgeBase func ToKnowledgeBase(v interface{}) (*KnowledgeBase, error) { if v == nil { return nil, nil } switch kb := v.(type) { case *KnowledgeBase: return kb, nil case KnowledgeBase: return &kb, nil case []string: return &KnowledgeBase{Collections: kb}, nil case []interface{}: var collections []string for _, item := range kb { collections = append(collections, cast.ToString(item)) } return &KnowledgeBase{Collections: collections}, nil default: raw, err := jsoniter.Marshal(kb) if err != nil { return nil, fmt.Errorf("kb format error: %s", err.Error()) } var knowledgeBase KnowledgeBase err = jsoniter.Unmarshal(raw, &knowledgeBase) if err != nil { return nil, fmt.Errorf("kb format error: %s", err.Error()) } return &knowledgeBase, nil } } // ToDatabase converts various types to Database func ToDatabase(v interface{}) (*Database, error) { if v == nil { return nil, nil } switch db := v.(type) { case *Database: return db, nil case Database: return &db, nil case []string: return &Database{Models: db}, nil case []interface{}: var models []string for _, item := range db { models = append(models, cast.ToString(item)) } return &Database{Models: models}, nil default: raw, err := jsoniter.Marshal(db) if err != nil { return nil, fmt.Errorf("db format error: %s", err.Error()) } var database Database err = jsoniter.Unmarshal(raw, &database) if err != nil { return nil, fmt.Errorf("db format error: %s", err.Error()) } return &database, nil } } // ToMCPServers converts various types to MCPServers func ToMCPServers(v interface{}) (*MCPServers, error) { if v == nil { return nil, nil } switch mcp := v.(type) { case *MCPServers: return mcp, nil case MCPServers: return &mcp, nil default: // For any type (including []string, []interface{}, map[string]interface{}), // marshal and unmarshal to MCPServers using custom UnmarshalJSON raw, err := jsoniter.Marshal(mcp) if err != nil { return nil, fmt.Errorf("mcp format error: %s", err.Error()) } var mcpServers MCPServers err = jsoniter.Unmarshal(raw, &mcpServers) if err != nil { return nil, fmt.Errorf("mcp format error: %s", err.Error()) } return &mcpServers, nil } } // ToWorkflow converts various types to Workflow func ToWorkflow(v interface{}) (*Workflow, error) { if v == nil { return nil, nil } switch workflow := v.(type) { case *Workflow: return workflow, nil case Workflow: return &workflow, nil case []string: return &Workflow{Workflows: workflow}, nil case []interface{}: var workflows []string for _, item := range workflow { workflows = append(workflows, cast.ToString(item)) } return &Workflow{Workflows: workflows}, nil default: raw, err := jsoniter.Marshal(workflow) if err != nil { return nil, fmt.Errorf("workflow format error: %s", err.Error()) } var wf Workflow err = jsoniter.Unmarshal(raw, &wf) if err != nil { return nil, fmt.Errorf("workflow format error: %s", err.Error()) } return &wf, nil } } // ToSandbox converts various types to Sandbox func ToSandbox(v interface{}) (*Sandbox, error) { if v == nil { return nil, nil } switch sandbox := v.(type) { case *Sandbox: return sandbox, nil case Sandbox: return &sandbox, nil default: raw, err := jsoniter.Marshal(sandbox) if err != nil { return nil, fmt.Errorf("sandbox format error: %s", err.Error()) } var sb Sandbox err = jsoniter.Unmarshal(raw, &sb) if err != nil { return nil, fmt.Errorf("sandbox format error: %s", err.Error()) } return &sb, nil } } // ToMySQLTime converts various types to MySQL datetime format func ToMySQLTime(v interface{}) string { switch val := v.(type) { case int64: if val == 0 { return "0000-00-00 00:00:00" } return time.Unix(val/1e9, val%1e9).Format("2006-01-02 15:04:05") case int: if val == 0 { return "0000-00-00 00:00:00" } return time.Unix(int64(val)/1e9, int64(val)%1e9).Format("2006-01-02 15:04:05") case string: // If already in MySQL format, return as-is if _, err := time.Parse("2006-01-02 15:04:05", val); err == nil { return val } // Try RFC3339 format if ts, err := time.Parse(time.RFC3339, val); err == nil { return ts.Format("2006-01-02 15:04:05") } // Try parsing as Unix timestamp if ts, err := cast.ToInt64E(val); err == nil { if ts == 0 { return "0000-00-00 00:00:00" } return time.Unix(ts/1e9, ts%1e9).Format("2006-01-02 15:04:05") } return val case time.Time: if val.IsZero() { return "0000-00-00 00:00:00" } return val.Format("2006-01-02 15:04:05") case nil: return "0000-00-00 00:00:00" default: return "0000-00-00 00:00:00" } } // ToAssistantModel converts various types to AssistantModel func ToAssistantModel(v interface{}) (*AssistantModel, error) { if v == nil { return nil, nil } // If already an AssistantModel, return it switch model := v.(type) { case *AssistantModel: return model, nil case AssistantModel: return &model, nil } // Convert to map first if needed var data map[string]interface{} switch v := v.(type) { case map[string]interface{}: data = v default: // Try to marshal and unmarshal raw, err := jsoniter.Marshal(v) if err != nil { return nil, fmt.Errorf("failed to marshal to AssistantModel: %w", err) } err = jsoniter.Unmarshal(raw, &data) if err != nil { return nil, fmt.Errorf("failed to unmarshal to map: %w", err) } } model := &AssistantModel{} // Basic string fields if id, ok := data["assistant_id"].(string); ok { model.ID = id } if typ, ok := data["type"].(string); ok { model.Type = typ } if name, ok := data["name"].(string); ok { model.Name = name } if avatar, ok := data["avatar"].(string); ok { model.Avatar = avatar } if connector, ok := data["connector"].(string); ok { model.Connector = connector } if path, ok := data["path"].(string); ok { model.Path = path } if source, ok := data["source"].(string); ok { model.Source = source } if description, ok := data["description"].(string); ok { model.Description = description } if capabilities, ok := data["capabilities"].(string); ok { model.Capabilities = capabilities } if share, ok := data["share"].(string); ok { model.Share = share } // Boolean fields (handle both bool and int types from database) model.BuiltIn = getBoolValue(data, "built_in") model.Readonly = getBoolValue(data, "readonly") model.Public = getBoolValue(data, "public") model.Mentionable = getBoolValue(data, "mentionable") model.Automated = getBoolValue(data, "automated") // Integer fields if sort, ok := data["sort"].(int); ok { model.Sort = sort } else if sort, ok := data["sort"].(float64); ok { model.Sort = int(sort) } if createdAt, ok := data["created_at"].(int64); ok { model.CreatedAt = createdAt } else if createdAt, ok := data["created_at"].(float64); ok { model.CreatedAt = int64(createdAt) } if updatedAt, ok := data["updated_at"].(int64); ok { model.UpdatedAt = updatedAt } else if updatedAt, ok := data["updated_at"].(float64); ok { model.UpdatedAt = int64(updatedAt) } // Tags (string array) if tags, ok := data["tags"]; ok && tags != nil { raw, err := jsoniter.Marshal(tags) if err == nil { var t []string if err := jsoniter.Unmarshal(raw, &t); err == nil { model.Tags = t } } } // Modes (string array) if modes, ok := data["modes"]; ok && modes != nil { raw, err := jsoniter.Marshal(modes) if err == nil { var m []string if err := jsoniter.Unmarshal(raw, &m); err == nil { model.Modes = m } } } // DefaultMode (string) if defaultMode, ok := data["default_mode"].(string); ok { model.DefaultMode = defaultMode } // Options (map) if options, ok := data["options"].(map[string]interface{}); ok { model.Options = options } // Prompts if prompts, ok := data["prompts"]; ok && prompts != nil { raw, err := jsoniter.Marshal(prompts) if err == nil { var p []Prompt if err := jsoniter.Unmarshal(raw, &p); err == nil { model.Prompts = p } } } // PromptPresets if promptPresets, ok := data["prompt_presets"]; ok && promptPresets != nil { raw, err := jsoniter.Marshal(promptPresets) if err == nil { var pp map[string][]Prompt if err := jsoniter.Unmarshal(raw, &pp); err == nil { model.PromptPresets = pp } } } // DisableGlobalPrompts model.DisableGlobalPrompts = getBoolValue(data, "disable_global_prompts") // ConnectorOptions if connectorOptions, ok := data["connector_options"]; ok && connectorOptions != nil { raw, err := jsoniter.Marshal(connectorOptions) if err == nil { var co ConnectorOptions if err := jsoniter.Unmarshal(raw, &co); err == nil { model.ConnectorOptions = &co } } } // KB if kb, ok := data["kb"]; ok && kb != nil { kbConverted, err := ToKnowledgeBase(kb) if err == nil { model.KB = kbConverted } } // DB if db, ok := data["db"]; ok && db != nil { dbConverted, err := ToDatabase(db) if err == nil { model.DB = dbConverted } } // MCP if mcp, ok := data["mcp"]; ok && mcp != nil { mcpConverted, err := ToMCPServers(mcp) if err == nil { model.MCP = mcpConverted } } // Workflow if workflow, ok := data["workflow"]; ok && workflow != nil { wf, err := ToWorkflow(workflow) if err == nil { model.Workflow = wf } } // Sandbox if sandbox, ok := data["sandbox"]; ok && sandbox != nil { sb, err := ToSandbox(sandbox) if err == nil { model.Sandbox = sb } } // Placeholder if placeholder, ok := data["placeholder"]; ok && placeholder != nil { raw, err := jsoniter.Marshal(placeholder) if err == nil { var ph Placeholder if err := jsoniter.Unmarshal(raw, &ph); err == nil { model.Placeholder = &ph } } } // Locales if locales, ok := data["locales"]; ok && locales != nil { raw, err := jsoniter.Marshal(locales) if err == nil { var loc i18n.Map if err := jsoniter.Unmarshal(raw, &loc); err == nil { model.Locales = loc } } } // Dependencies if deps, ok := data["dependencies"]; ok && deps != nil { raw, err := jsoniter.Marshal(deps) if err == nil { var d map[string]string if err := jsoniter.Unmarshal(raw, &d); err == nil { model.Dependencies = d } } } // Permission fields if createdBy, ok := data["__yao_created_by"].(string); ok { model.YaoCreatedBy = createdBy } if updatedBy, ok := data["__yao_updated_by"].(string); ok { model.YaoUpdatedBy = updatedBy } if teamID, ok := data["__yao_team_id"].(string); ok { model.YaoTeamID = teamID } if tenantID, ok := data["__yao_tenant_id"].(string); ok { model.YaoTenantID = tenantID } return model, nil } // getBoolValue extracts a boolean value from a map, handling both bool and numeric types func getBoolValue(data map[string]interface{}, key string) bool { if v, ok := data[key]; ok && v != nil { switch val := v.(type) { case bool: return val case int: return val != 0 case int64: return val != 0 case float64: return val != 0 case string: return val == "true" || val == "1" } } return false } // ModelID generates an OpenAI-compatible model ID from assistant // Format: [prefix-]assistantName-model-yao_assistantID // prefix is optional, if provided, it will be prepended to the model ID func (assistant AssistantModel) ModelID(prefix ...string) string { // Clean assistant name (remove spaces and special characters) assistantName := strings.ReplaceAll(assistant.Name, " ", "-") assistantName = strings.ToLower(assistantName) // Get connector name from assistant connectorName := assistant.Connector if connectorName == "" { log.Error("Assistant %s has no connector configured", assistant.ID) modelID := assistantName + "-unknown-yao_" + assistant.ID if len(prefix) > 0 && prefix[0] != "" { return prefix[0] + modelID } return modelID } // Get model name modelName := "" // First, try to get custom model from Options if assistant.Options != nil { if m, ok := assistant.Options["model"].(string); ok && m != "" { modelName = m } } // If no custom model in options, try to get from connector configuration if modelName == "" { conn, err := connector.Select(connectorName) if err != nil { log.Error("Failed to select connector %s for assistant %s: %v", connectorName, assistant.ID, err) modelID := assistantName + "-unknown-yao_" + assistant.ID if len(prefix) > 0 && prefix[0] != "" { return prefix[0] + modelID } return modelID } // Get model from connector settings settings := conn.Setting() if settings != nil { if m, ok := settings["model"].(string); ok && m != "" { modelName = m } } if modelName == "" { log.Error("Connector %s has no model configured for assistant %s", connectorName, assistant.ID) modelID := assistantName + "-unknown-yao_" + assistant.ID if len(prefix) > 0 && prefix[0] != "" { return prefix[0] + modelID } return modelID } } // Format: [prefix-]assistantName-model-yao_assistantID modelID := assistantName + "-" + modelName + "-yao_" + assistant.ID if len(prefix) > 0 && prefix[0] != "" { return prefix[0] + modelID } return modelID } // ParseModelID extracts assistant ID from model ID // Expected format: [prefix-]assistantName-model-yao_assistantID // The function handles optional prefixes (e.g., "yao-agents-") func ParseModelID(modelID string) string { // Find the last occurrence of "yao_" parts := strings.Split(modelID, "-yao_") if len(parts) < 2 { return "" } return parts[len(parts)-1] } // ToConnectorOptions converts various types to ConnectorOptions func ToConnectorOptions(v interface{}) (*ConnectorOptions, error) { if v == nil { return nil, nil } switch opts := v.(type) { case *ConnectorOptions: return opts, nil case ConnectorOptions: return &opts, nil default: raw, err := jsoniter.Marshal(opts) if err != nil { return nil, fmt.Errorf("connector_options format error: %s", err.Error()) } var connOpts ConnectorOptions err = jsoniter.Unmarshal(raw, &connOpts) if err != nil { return nil, fmt.Errorf("connector_options format error: %s", err.Error()) } return &connOpts, nil } } // ToModes converts various types to []string for modes func ToModes(v interface{}) ([]string, error) { if v == nil { return nil, nil } switch modes := v.(type) { case []string: return modes, nil case []interface{}: var result []string for _, item := range modes { result = append(result, cast.ToString(item)) } return result, nil case string: // Single string becomes a slice with one element return []string{modes}, nil default: raw, err := jsoniter.Marshal(modes) if err != nil { return nil, fmt.Errorf("modes format error: %s", err.Error()) } var result []string err = jsoniter.Unmarshal(raw, &result) if err != nil { return nil, fmt.Errorf("modes format error: %s", err.Error()) } return result, nil } } // ToPromptPresets converts various types to map[string][]Prompt func ToPromptPresets(v interface{}) (map[string][]Prompt, error) { if v == nil { return nil, nil } switch presets := v.(type) { case map[string][]Prompt: return presets, nil default: raw, err := jsoniter.Marshal(presets) if err != nil { return nil, fmt.Errorf("prompt_presets format error: %s", err.Error()) } var result map[string][]Prompt err = jsoniter.Unmarshal(raw, &result) if err != nil { return nil, fmt.Errorf("prompt_presets format error: %s", err.Error()) } return result, nil } } // ToUses converts various types to context.Uses func ToUses(v interface{}) (*context.Uses, error) { if v == nil { return nil, nil } switch uses := v.(type) { case *context.Uses: return uses, nil case context.Uses: return &uses, nil default: raw, err := jsoniter.Marshal(uses) if err != nil { return nil, fmt.Errorf("uses format error: %s", err.Error()) } var result context.Uses err = jsoniter.Unmarshal(raw, &result) if err != nil { return nil, fmt.Errorf("uses format error: %s", err.Error()) } return &result, nil } } // ToSearchConfig converts various types to searchTypes.Config func ToSearchConfig(v interface{}) (*searchTypes.Config, error) { if v == nil { return nil, nil } switch cfg := v.(type) { case *searchTypes.Config: return cfg, nil case searchTypes.Config: return &cfg, nil default: raw, err := jsoniter.Marshal(cfg) if err != nil { return nil, fmt.Errorf("search config format error: %s", err.Error()) } var result searchTypes.Config err = jsoniter.Unmarshal(raw, &result) if err != nil { return nil, fmt.Errorf("search config format error: %s", err.Error()) } return &result, nil } } ================================================ FILE: agent/store/types/convert_test.go ================================================ package types import ( "testing" "time" ) // TestToDatabase tests the ToDatabase conversion function func TestToDatabase(t *testing.T) { t.Run("NilInput", func(t *testing.T) { result, err := ToDatabase(nil) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result != nil { t.Errorf("Expected nil result, got: %v", result) } }) t.Run("DatabasePointer", func(t *testing.T) { db := &Database{Models: []string{"model1", "model2"}} result, err := ToDatabase(db) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result != db { t.Errorf("Expected same pointer") } }) t.Run("DatabaseValue", func(t *testing.T) { db := Database{Models: []string{"model1", "model2"}} result, err := ToDatabase(db) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result.Models) != 2 { t.Errorf("Expected 2 models, got %d", len(result.Models)) } }) t.Run("StringSlice", func(t *testing.T) { models := []string{"model1", "model2", "model3"} result, err := ToDatabase(models) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result.Models) != 3 { t.Errorf("Expected 3 models, got %d", len(result.Models)) } if result.Models[0] != "model1" { t.Errorf("Expected 'model1', got '%s'", result.Models[0]) } }) t.Run("InterfaceSlice", func(t *testing.T) { models := []interface{}{"model1", "model2", 123} result, err := ToDatabase(models) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result.Models) != 3 { t.Errorf("Expected 3 models, got %d", len(result.Models)) } if result.Models[2] != "123" { t.Errorf("Expected '123', got '%s'", result.Models[2]) } }) t.Run("MapInput", func(t *testing.T) { data := map[string]interface{}{ "models": []string{"model1", "model2"}, } result, err := ToDatabase(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result.Models) != 2 { t.Errorf("Expected 2 models, got %d", len(result.Models)) } }) t.Run("InvalidInput", func(t *testing.T) { // Test with data that can't be marshaled invalidData := make(chan int) _, err := ToDatabase(invalidData) if err == nil { t.Error("Expected error for invalid input") } }) t.Run("InvalidJSONUnmarshal", func(t *testing.T) { // Test with data that marshals but can't unmarshal to Database data := map[string]interface{}{ "invalid_field": "should cause unmarshal to fail gracefully", } result, err := ToDatabase(data) // Should not error, just return empty Database if err != nil { t.Errorf("Expected no error, got: %v", err) } if result == nil { t.Error("Expected non-nil result") } }) } // TestToKnowledgeBase tests the ToKnowledgeBase conversion function func TestToKnowledgeBase(t *testing.T) { t.Run("NilInput", func(t *testing.T) { result, err := ToKnowledgeBase(nil) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result != nil { t.Errorf("Expected nil result, got: %v", result) } }) t.Run("KnowledgeBasePointer", func(t *testing.T) { kb := &KnowledgeBase{Collections: []string{"col1", "col2"}} result, err := ToKnowledgeBase(kb) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result != kb { t.Errorf("Expected same pointer") } }) t.Run("KnowledgeBaseValue", func(t *testing.T) { kb := KnowledgeBase{Collections: []string{"col1", "col2"}} result, err := ToKnowledgeBase(kb) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result.Collections) != 2 { t.Errorf("Expected 2 collections, got %d", len(result.Collections)) } }) t.Run("StringSlice", func(t *testing.T) { collections := []string{"col1", "col2", "col3"} result, err := ToKnowledgeBase(collections) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result.Collections) != 3 { t.Errorf("Expected 3 collections, got %d", len(result.Collections)) } if result.Collections[0] != "col1" { t.Errorf("Expected 'col1', got '%s'", result.Collections[0]) } }) t.Run("InterfaceSlice", func(t *testing.T) { collections := []interface{}{"col1", "col2", 123} result, err := ToKnowledgeBase(collections) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result.Collections) != 3 { t.Errorf("Expected 3 collections, got %d", len(result.Collections)) } if result.Collections[2] != "123" { t.Errorf("Expected '123', got '%s'", result.Collections[2]) } }) t.Run("MapInput", func(t *testing.T) { data := map[string]interface{}{ "collections": []string{"col1", "col2"}, } result, err := ToKnowledgeBase(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result.Collections) != 2 { t.Errorf("Expected 2 collections, got %d", len(result.Collections)) } }) t.Run("InvalidInput", func(t *testing.T) { // Test with data that can't be marshaled invalidData := make(chan int) _, err := ToKnowledgeBase(invalidData) if err == nil { t.Error("Expected error for invalid input") } }) t.Run("InvalidJSONUnmarshal", func(t *testing.T) { // Test with data that marshals but can't unmarshal to KnowledgeBase data := map[string]interface{}{ "invalid_field": "should cause unmarshal to fail gracefully", } result, err := ToKnowledgeBase(data) // Should not error, just return empty KnowledgeBase if err != nil { t.Errorf("Expected no error, got: %v", err) } if result == nil { t.Error("Expected non-nil result") } }) } // TestToMCPServers tests the ToMCPServers conversion function func TestToMCPServers(t *testing.T) { t.Run("NilInput", func(t *testing.T) { result, err := ToMCPServers(nil) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result != nil { t.Errorf("Expected nil result, got: %v", result) } }) t.Run("MCPServersPointer", func(t *testing.T) { mcp := &MCPServers{Servers: []MCPServerConfig{{ServerID: "server1"}, {ServerID: "server2"}}} result, err := ToMCPServers(mcp) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result != mcp { t.Errorf("Expected same pointer") } }) t.Run("MCPServersValue", func(t *testing.T) { mcp := MCPServers{Servers: []MCPServerConfig{{ServerID: "server1"}, {ServerID: "server2"}}} result, err := ToMCPServers(mcp) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result.Servers) != 2 { t.Errorf("Expected 2 servers, got %d", len(result.Servers)) } }) t.Run("MapInput", func(t *testing.T) { data := map[string]interface{}{ "servers": []interface{}{"server1", "server2"}, } result, err := ToMCPServers(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result.Servers) != 2 { t.Errorf("Expected 2 servers, got %d", len(result.Servers)) } if result.Servers[0].ServerID != "server1" { t.Errorf("Expected 'server1', got '%s'", result.Servers[0].ServerID) } }) t.Run("InvalidInput", func(t *testing.T) { // Test with data that can't be marshaled invalidData := make(chan int) _, err := ToMCPServers(invalidData) if err == nil { t.Error("Expected error for invalid input") } }) t.Run("InvalidJSONUnmarshal", func(t *testing.T) { // Test with data that marshals but can't unmarshal to MCPServers data := map[string]interface{}{ "invalid_field": "should cause unmarshal to fail gracefully", } result, err := ToMCPServers(data) // Should not error, just return empty MCPServers if err != nil { t.Errorf("Expected no error, got: %v", err) } if result == nil { t.Error("Expected non-nil result") } }) } // TestToWorkflow tests the ToWorkflow conversion function func TestToWorkflow(t *testing.T) { t.Run("NilInput", func(t *testing.T) { result, err := ToWorkflow(nil) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result != nil { t.Errorf("Expected nil result, got: %v", result) } }) t.Run("WorkflowPointer", func(t *testing.T) { wf := &Workflow{Workflows: []string{"wf1", "wf2"}} result, err := ToWorkflow(wf) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result != wf { t.Errorf("Expected same pointer") } }) t.Run("WorkflowValue", func(t *testing.T) { wf := Workflow{Workflows: []string{"wf1", "wf2"}} result, err := ToWorkflow(wf) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result.Workflows) != 2 { t.Errorf("Expected 2 workflows, got %d", len(result.Workflows)) } }) t.Run("StringSlice", func(t *testing.T) { workflows := []string{"wf1", "wf2", "wf3"} result, err := ToWorkflow(workflows) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result.Workflows) != 3 { t.Errorf("Expected 3 workflows, got %d", len(result.Workflows)) } if result.Workflows[0] != "wf1" { t.Errorf("Expected 'wf1', got '%s'", result.Workflows[0]) } }) t.Run("InterfaceSlice", func(t *testing.T) { workflows := []interface{}{"wf1", "wf2", 789} result, err := ToWorkflow(workflows) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result.Workflows) != 3 { t.Errorf("Expected 3 workflows, got %d", len(result.Workflows)) } if result.Workflows[2] != "789" { t.Errorf("Expected '789', got '%s'", result.Workflows[2]) } }) t.Run("MapInput", func(t *testing.T) { data := map[string]interface{}{ "workflows": []string{"wf1", "wf2"}, } result, err := ToWorkflow(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result.Workflows) != 2 { t.Errorf("Expected 2 workflows, got %d", len(result.Workflows)) } }) t.Run("InvalidInput", func(t *testing.T) { // Test with data that can't be marshaled invalidData := make(chan int) _, err := ToWorkflow(invalidData) if err == nil { t.Error("Expected error for invalid input") } }) t.Run("InvalidJSONUnmarshal", func(t *testing.T) { // Test with data that marshals but can't unmarshal to Workflow data := map[string]interface{}{ "invalid_field": "should cause unmarshal to fail gracefully", } result, err := ToWorkflow(data) // Should not error, just return empty Workflow if err != nil { t.Errorf("Expected no error, got: %v", err) } if result == nil { t.Error("Expected non-nil result") } }) } // TestToMySQLTime tests the ToMySQLTime conversion function func TestToMySQLTime(t *testing.T) { t.Run("Int64Zero", func(t *testing.T) { result := ToMySQLTime(int64(0)) if result != "0000-00-00 00:00:00" { t.Errorf("Expected '0000-00-00 00:00:00', got '%s'", result) } }) t.Run("Int64Timestamp", func(t *testing.T) { // Unix timestamp in nanoseconds: 1609459200000000000 = 2021-01-01 00:00:00 UTC timestamp := int64(1609459200000000000) result := ToMySQLTime(timestamp) // Should be in format "2021-01-01 00:00:00" or similar depending on timezone if len(result) != 19 { t.Errorf("Expected 19 character timestamp, got %d: '%s'", len(result), result) } }) t.Run("IntZero", func(t *testing.T) { result := ToMySQLTime(int(0)) if result != "0000-00-00 00:00:00" { t.Errorf("Expected '0000-00-00 00:00:00', got '%s'", result) } }) t.Run("IntTimestamp", func(t *testing.T) { timestamp := int(1609459200000000000) result := ToMySQLTime(timestamp) if len(result) != 19 { t.Errorf("Expected 19 character timestamp, got %d: '%s'", len(result), result) } }) t.Run("StringMySQLFormat", func(t *testing.T) { mysqlTime := "2021-01-01 12:30:45" result := ToMySQLTime(mysqlTime) if result != mysqlTime { t.Errorf("Expected '%s', got '%s'", mysqlTime, result) } }) t.Run("StringRFC3339", func(t *testing.T) { rfc3339Time := "2021-01-01T12:30:45Z" result := ToMySQLTime(rfc3339Time) expected := "2021-01-01 12:30:45" if result != expected { t.Errorf("Expected '%s', got '%s'", expected, result) } }) t.Run("StringUnixTimestamp", func(t *testing.T) { // Unix timestamp in seconds as string result := ToMySQLTime("1609459200000000000") if len(result) != 19 { t.Errorf("Expected 19 character timestamp, got %d: '%s'", len(result), result) } }) t.Run("StringInvalidFormat", func(t *testing.T) { invalidTime := "not-a-valid-time" result := ToMySQLTime(invalidTime) // Should return the original string when it can't be parsed if result != invalidTime { t.Errorf("Expected '%s', got '%s'", invalidTime, result) } }) t.Run("TimeZero", func(t *testing.T) { zeroTime := time.Time{} result := ToMySQLTime(zeroTime) if result != "0000-00-00 00:00:00" { t.Errorf("Expected '0000-00-00 00:00:00', got '%s'", result) } }) t.Run("TimeNormal", func(t *testing.T) { normalTime := time.Date(2021, 1, 1, 12, 30, 45, 0, time.UTC) result := ToMySQLTime(normalTime) expected := "2021-01-01 12:30:45" if result != expected { t.Errorf("Expected '%s', got '%s'", expected, result) } }) t.Run("NilInput", func(t *testing.T) { result := ToMySQLTime(nil) if result != "0000-00-00 00:00:00" { t.Errorf("Expected '0000-00-00 00:00:00', got '%s'", result) } }) t.Run("UnknownType", func(t *testing.T) { // Test with unsupported type result := ToMySQLTime(struct{}{}) if result != "0000-00-00 00:00:00" { t.Errorf("Expected '0000-00-00 00:00:00', got '%s'", result) } }) } // TestToAssistantModel tests the ToAssistantModel conversion function func TestToAssistantModel(t *testing.T) { t.Run("NilInput", func(t *testing.T) { result, err := ToAssistantModel(nil) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result != nil { t.Errorf("Expected nil result, got: %v", result) } }) t.Run("AssistantModelPointer", func(t *testing.T) { model := &AssistantModel{ ID: "test-id", Name: "Test Assistant", } result, err := ToAssistantModel(model) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result != model { t.Errorf("Expected same pointer") } }) t.Run("AssistantModelValue", func(t *testing.T) { model := AssistantModel{ ID: "test-id", Name: "Test Assistant", } result, err := ToAssistantModel(model) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.ID != "test-id" { t.Errorf("Expected 'test-id', got '%s'", result.ID) } }) t.Run("MapWithAllFields", func(t *testing.T) { data := map[string]interface{}{ "assistant_id": "test-id", "type": "assistant", "name": "Test Assistant", "avatar": "https://example.com/avatar.png", "connector": "openai", "connector_options": map[string]interface{}{ "optional": true, "connectors": []string{"openai", "anthropic"}, "filters": []string{"vision", "tool_calls"}, }, "path": "/path/to/assistant", "description": "Test description", "share": "team", "built_in": true, "readonly": false, "public": true, "mentionable": true, "automated": false, "sort": 100, "created_at": int64(1609459200), "updated_at": int64(1609459300), "tags": []string{"tag1", "tag2"}, "modes": []string{"chat", "task"}, "default_mode": "chat", "options": map[string]interface{}{ "temperature": 0.7, }, "prompts": []map[string]interface{}{ {"role": "system", "content": "You are helpful"}, }, "prompt_presets": map[string]interface{}{ "chat": []map[string]interface{}{ {"role": "system", "content": "You are a chat assistant"}, }, "task": []map[string]interface{}{ {"role": "system", "content": "You are a task assistant"}, }, }, "disable_global_prompts": true, "source": "function hook() { return 'test'; }", "kb": map[string]interface{}{ "collections": []string{"col1"}, }, "db": map[string]interface{}{ "models": []string{"model1"}, }, "mcp": map[string]interface{}{ "servers": []string{"server1"}, }, "workflow": map[string]interface{}{ "workflows": []string{"wf1"}, }, "placeholder": map[string]interface{}{ "title": "Enter message", }, "locales": map[string]interface{}{ "en": map[string]interface{}{ "name": "English Name", }, }, "dependencies": map[string]interface{}{ "echo": "^1.0.0", "customer": ">=2.0.0", }, } result, err := ToAssistantModel(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } // Verify all fields if result.ID != "test-id" { t.Errorf("Expected ID 'test-id', got '%s'", result.ID) } if result.Type != "assistant" { t.Errorf("Expected Type 'assistant', got '%s'", result.Type) } if result.Name != "Test Assistant" { t.Errorf("Expected Name 'Test Assistant', got '%s'", result.Name) } if result.Avatar != "https://example.com/avatar.png" { t.Errorf("Expected Avatar URL, got '%s'", result.Avatar) } if result.Connector != "openai" { t.Errorf("Expected Connector 'openai', got '%s'", result.Connector) } if result.ConnectorOptions == nil { t.Error("Expected ConnectorOptions to be set") } else { if result.ConnectorOptions.Optional == nil || !*result.ConnectorOptions.Optional { t.Error("Expected ConnectorOptions.Optional to be true") } if len(result.ConnectorOptions.Connectors) != 2 { t.Errorf("Expected 2 connectors in options, got %d", len(result.ConnectorOptions.Connectors)) } if len(result.ConnectorOptions.Filters) != 2 { t.Errorf("Expected 2 filters, got %d", len(result.ConnectorOptions.Filters)) } } if result.Path != "/path/to/assistant" { t.Errorf("Expected Path, got '%s'", result.Path) } if result.Source != "function hook() { return 'test'; }" { t.Errorf("Expected Source, got '%s'", result.Source) } if result.Description != "Test description" { t.Errorf("Expected Description, got '%s'", result.Description) } if result.Share != "team" { t.Errorf("Expected Share 'team', got '%s'", result.Share) } if !result.BuiltIn { t.Error("Expected BuiltIn to be true") } if result.Readonly { t.Error("Expected Readonly to be false") } if !result.Public { t.Error("Expected Public to be true") } if !result.Mentionable { t.Error("Expected Mentionable to be true") } if result.Automated { t.Error("Expected Automated to be false") } if result.Sort != 100 { t.Errorf("Expected Sort 100, got %d", result.Sort) } if result.CreatedAt != 1609459200 { t.Errorf("Expected CreatedAt 1609459200, got %d", result.CreatedAt) } if result.UpdatedAt != 1609459300 { t.Errorf("Expected UpdatedAt 1609459300, got %d", result.UpdatedAt) } if len(result.Tags) != 2 { t.Errorf("Expected 2 tags, got %d", len(result.Tags)) } if len(result.Modes) != 2 { t.Errorf("Expected 2 modes, got %d", len(result.Modes)) } if result.Modes[0] != "chat" { t.Errorf("Expected first mode 'chat', got '%s'", result.Modes[0]) } if result.DefaultMode != "chat" { t.Errorf("Expected default_mode 'chat', got '%s'", result.DefaultMode) } if result.Options == nil { t.Error("Expected Options to be set") } if len(result.Prompts) != 1 { t.Errorf("Expected 1 prompt, got %d", len(result.Prompts)) } if result.PromptPresets == nil { t.Error("Expected PromptPresets to be set") } else { if len(result.PromptPresets) != 2 { t.Errorf("Expected 2 prompt presets, got %d", len(result.PromptPresets)) } if chatPrompts, ok := result.PromptPresets["chat"]; !ok { t.Error("Expected 'chat' prompt preset") } else if len(chatPrompts) != 1 { t.Errorf("Expected 1 chat prompt, got %d", len(chatPrompts)) } if taskPrompts, ok := result.PromptPresets["task"]; !ok { t.Error("Expected 'task' prompt preset") } else if len(taskPrompts) != 1 { t.Errorf("Expected 1 task prompt, got %d", len(taskPrompts)) } } if !result.DisableGlobalPrompts { t.Error("Expected DisableGlobalPrompts to be true") } if result.KB == nil { t.Error("Expected KB to be set") } if result.DB == nil { t.Error("Expected DB to be set") } if result.MCP == nil { t.Error("Expected MCP to be set") } if result.Workflow == nil { t.Error("Expected Workflow to be set") } if result.Placeholder == nil { t.Error("Expected Placeholder to be set") } if result.Locales == nil { t.Error("Expected Locales to be set") } if result.Dependencies == nil { t.Error("Expected Dependencies to be set") } else { if len(result.Dependencies) != 2 { t.Errorf("Expected 2 dependencies, got %d", len(result.Dependencies)) } if result.Dependencies["echo"] != "^1.0.0" { t.Errorf("Expected echo dependency '^1.0.0', got '%s'", result.Dependencies["echo"]) } if result.Dependencies["customer"] != ">=2.0.0" { t.Errorf("Expected customer dependency '>=2.0.0', got '%s'", result.Dependencies["customer"]) } } }) t.Run("MapWithFloatNumbers", func(t *testing.T) { data := map[string]interface{}{ "sort": float64(150), "created_at": float64(1609459200), "updated_at": float64(1609459300), } result, err := ToAssistantModel(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.Sort != 150 { t.Errorf("Expected Sort 150, got %d", result.Sort) } if result.CreatedAt != 1609459200 { t.Errorf("Expected CreatedAt 1609459200, got %d", result.CreatedAt) } if result.UpdatedAt != 1609459300 { t.Errorf("Expected UpdatedAt 1609459300, got %d", result.UpdatedAt) } }) t.Run("MapWithNilFields", func(t *testing.T) { data := map[string]interface{}{ "assistant_id": "test-id", "tags": nil, "modes": nil, "default_mode": "", "options": nil, "prompts": nil, "kb": nil, "db": nil, "mcp": nil, "workflow": nil, "placeholder": nil, "locales": nil, "dependencies": nil, } result, err := ToAssistantModel(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.ID != "test-id" { t.Errorf("Expected ID 'test-id', got '%s'", result.ID) } // All nil fields should remain nil if result.Tags != nil { t.Error("Expected Tags to be nil") } if result.Dependencies != nil { t.Error("Expected Dependencies to be nil") } if result.Modes != nil { t.Error("Expected Modes to be nil") } if result.DefaultMode != "" { t.Error("Expected DefaultMode to be empty") } if result.Options != nil { t.Error("Expected Options to be nil") } }) t.Run("StructInput", func(t *testing.T) { type CustomStruct struct { AssistantID string `json:"assistant_id"` Name string `json:"name"` Type string `json:"type"` } input := CustomStruct{ AssistantID: "custom-id", Name: "Custom Assistant", Type: "bot", } result, err := ToAssistantModel(input) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.ID != "custom-id" { t.Errorf("Expected ID 'custom-id', got '%s'", result.ID) } if result.Name != "Custom Assistant" { t.Errorf("Expected Name 'Custom Assistant', got '%s'", result.Name) } if result.Type != "bot" { t.Errorf("Expected Type 'bot', got '%s'", result.Type) } }) t.Run("InvalidInput", func(t *testing.T) { // Test with data that can't be marshaled invalidData := make(chan int) _, err := ToAssistantModel(invalidData) if err == nil { t.Error("Expected error for invalid input") } }) t.Run("EmptyMap", func(t *testing.T) { data := map[string]interface{}{} result, err := ToAssistantModel(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result == nil { t.Error("Expected non-nil result") } // All fields should have default values if result.ID != "" { t.Errorf("Expected empty ID, got '%s'", result.ID) } }) } // TestToAssistantModelNewFields tests the newly added fields func TestToAssistantModelNewFields(t *testing.T) { t.Run("ConnectorOptions", func(t *testing.T) { data := map[string]interface{}{ "connector_options": map[string]interface{}{ "optional": true, "connectors": []string{"openai", "anthropic", "azure"}, "filters": []string{"vision", "tool_calls", "audio"}, }, } result, err := ToAssistantModel(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.ConnectorOptions == nil { t.Fatal("Expected ConnectorOptions to be set") } if result.ConnectorOptions.Optional == nil || !*result.ConnectorOptions.Optional { t.Error("Expected Optional to be true") } if len(result.ConnectorOptions.Connectors) != 3 { t.Errorf("Expected 3 connectors, got %d", len(result.ConnectorOptions.Connectors)) } if len(result.ConnectorOptions.Filters) != 3 { t.Errorf("Expected 3 filters, got %d", len(result.ConnectorOptions.Filters)) } }) t.Run("PromptPresets", func(t *testing.T) { data := map[string]interface{}{ "prompt_presets": map[string]interface{}{ "chat": []map[string]interface{}{ {"role": "system", "content": "You are a helpful chat assistant"}, {"role": "user", "content": "Example question"}, }, "task": []map[string]interface{}{ {"role": "system", "content": "You are a task completion assistant"}, }, "analyze": []map[string]interface{}{ {"role": "system", "content": "You are a data analysis assistant"}, }, }, } result, err := ToAssistantModel(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.PromptPresets == nil { t.Fatal("Expected PromptPresets to be set") } if len(result.PromptPresets) != 3 { t.Errorf("Expected 3 prompt preset modes, got %d", len(result.PromptPresets)) } if chatPrompts, ok := result.PromptPresets["chat"]; !ok { t.Error("Expected 'chat' mode in prompt presets") } else if len(chatPrompts) != 2 { t.Errorf("Expected 2 prompts in chat mode, got %d", len(chatPrompts)) } if taskPrompts, ok := result.PromptPresets["task"]; !ok { t.Error("Expected 'task' mode in prompt presets") } else if len(taskPrompts) != 1 { t.Errorf("Expected 1 prompt in task mode, got %d", len(taskPrompts)) } if analyzePrompts, ok := result.PromptPresets["analyze"]; !ok { t.Error("Expected 'analyze' mode in prompt presets") } else if len(analyzePrompts) != 1 { t.Errorf("Expected 1 prompt in analyze mode, got %d", len(analyzePrompts)) } }) t.Run("Source", func(t *testing.T) { hookScript := ` function beforeChat(context) { console.log('Hook called'); return context; } ` data := map[string]interface{}{ "source": hookScript, } result, err := ToAssistantModel(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.Source != hookScript { t.Errorf("Expected Source to match, got '%s'", result.Source) } }) t.Run("AllNewFields", func(t *testing.T) { data := map[string]interface{}{ "connector_options": map[string]interface{}{ "optional": true, "connectors": []string{"openai"}, "filters": []string{"vision"}, }, "prompt_presets": map[string]interface{}{ "chat": []map[string]interface{}{ {"role": "system", "content": "Chat mode"}, }, }, "disable_global_prompts": true, "source": "function test() {}", } result, err := ToAssistantModel(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.ConnectorOptions == nil { t.Error("Expected ConnectorOptions to be set") } if result.PromptPresets == nil { t.Error("Expected PromptPresets to be set") } if !result.DisableGlobalPrompts { t.Error("Expected DisableGlobalPrompts to be true") } if result.Source == "" { t.Error("Expected Source to be set") } }) t.Run("NilNewFields", func(t *testing.T) { data := map[string]interface{}{ "connector_options": nil, "prompt_presets": nil, "disable_global_prompts": nil, "source": nil, } result, err := ToAssistantModel(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.ConnectorOptions != nil { t.Error("Expected ConnectorOptions to be nil") } if result.PromptPresets != nil { t.Error("Expected PromptPresets to be nil") } if result.DisableGlobalPrompts { t.Error("Expected DisableGlobalPrompts to be false") } if result.Source != "" { t.Error("Expected Source to be empty") } }) t.Run("DisableGlobalPrompts", func(t *testing.T) { // Test with true data := map[string]interface{}{ "disable_global_prompts": true, } result, err := ToAssistantModel(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if !result.DisableGlobalPrompts { t.Error("Expected DisableGlobalPrompts to be true") } // Test with false data = map[string]interface{}{ "disable_global_prompts": false, } result, err = ToAssistantModel(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.DisableGlobalPrompts { t.Error("Expected DisableGlobalPrompts to be false") } // Test with int 1 data = map[string]interface{}{ "disable_global_prompts": 1, } result, err = ToAssistantModel(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if !result.DisableGlobalPrompts { t.Error("Expected DisableGlobalPrompts to be true for int 1") } // Test with string "true" data = map[string]interface{}{ "disable_global_prompts": "true", } result, err = ToAssistantModel(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if !result.DisableGlobalPrompts { t.Error("Expected DisableGlobalPrompts to be true for string 'true'") } }) } // TestToAssistantModelComplexTypes tests complex type conversions in ToAssistantModel func TestToAssistantModelComplexTypes(t *testing.T) { t.Run("CompleteLocales", func(t *testing.T) { data := map[string]interface{}{ "locales": map[string]interface{}{ "en": map[string]interface{}{ "locale": "en", "messages": map[string]interface{}{ "name": "English Name", "description": "English Description", }, }, "zh": map[string]interface{}{ "locale": "zh", "messages": map[string]interface{}{ "name": "中文名称", "description": "中文描述", }, }, }, } result, err := ToAssistantModel(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.Locales == nil { t.Fatal("Expected Locales to be set") } if len(result.Locales) != 2 { t.Errorf("Expected 2 locales, got %d", len(result.Locales)) } }) t.Run("ComplexPrompts", func(t *testing.T) { data := map[string]interface{}{ "prompts": []interface{}{ map[string]interface{}{ "role": "system", "content": "You are a helpful assistant", }, map[string]interface{}{ "role": "user", "content": "Hello", }, }, } result, err := ToAssistantModel(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result.Prompts) != 2 { t.Errorf("Expected 2 prompts, got %d", len(result.Prompts)) } }) } // TestGetBoolValue tests the getBoolValue helper function func TestGetBoolValue(t *testing.T) { t.Run("BoolTrue", func(t *testing.T) { data := map[string]interface{}{"key": true} result := getBoolValue(data, "key") if !result { t.Error("Expected true") } }) t.Run("BoolFalse", func(t *testing.T) { data := map[string]interface{}{"key": false} result := getBoolValue(data, "key") if result { t.Error("Expected false") } }) t.Run("IntNonZero", func(t *testing.T) { data := map[string]interface{}{"key": 1} result := getBoolValue(data, "key") if !result { t.Error("Expected true for non-zero int") } }) t.Run("IntZero", func(t *testing.T) { data := map[string]interface{}{"key": 0} result := getBoolValue(data, "key") if result { t.Error("Expected false for zero int") } }) t.Run("Int64NonZero", func(t *testing.T) { data := map[string]interface{}{"key": int64(1)} result := getBoolValue(data, "key") if !result { t.Error("Expected true for non-zero int64") } }) t.Run("Int64Zero", func(t *testing.T) { data := map[string]interface{}{"key": int64(0)} result := getBoolValue(data, "key") if result { t.Error("Expected false for zero int64") } }) t.Run("Float64NonZero", func(t *testing.T) { data := map[string]interface{}{"key": float64(1.5)} result := getBoolValue(data, "key") if !result { t.Error("Expected true for non-zero float64") } }) t.Run("Float64Zero", func(t *testing.T) { data := map[string]interface{}{"key": float64(0)} result := getBoolValue(data, "key") if result { t.Error("Expected false for zero float64") } }) t.Run("StringTrue", func(t *testing.T) { data := map[string]interface{}{"key": "true"} result := getBoolValue(data, "key") if !result { t.Error("Expected true for string 'true'") } }) t.Run("StringOne", func(t *testing.T) { data := map[string]interface{}{"key": "1"} result := getBoolValue(data, "key") if !result { t.Error("Expected true for string '1'") } }) t.Run("StringFalse", func(t *testing.T) { data := map[string]interface{}{"key": "false"} result := getBoolValue(data, "key") if result { t.Error("Expected false for string 'false'") } }) t.Run("StringOther", func(t *testing.T) { data := map[string]interface{}{"key": "other"} result := getBoolValue(data, "key") if result { t.Error("Expected false for other string values") } }) t.Run("NilValue", func(t *testing.T) { data := map[string]interface{}{"key": nil} result := getBoolValue(data, "key") if result { t.Error("Expected false for nil value") } }) t.Run("MissingKey", func(t *testing.T) { data := map[string]interface{}{} result := getBoolValue(data, "missing") if result { t.Error("Expected false for missing key") } }) t.Run("UnsupportedType", func(t *testing.T) { data := map[string]interface{}{"key": struct{}{}} result := getBoolValue(data, "key") if result { t.Error("Expected false for unsupported type") } }) } // TestModelID tests the AssistantModel.ModelID method func TestModelID(t *testing.T) { t.Run("WithCustomModel", func(t *testing.T) { assistant := AssistantModel{ ID: "test123", Name: "Test Assistant", Connector: "openai", Options: map[string]interface{}{ "model": "gpt-4o", }, } result := assistant.ModelID() expected := "test-assistant-gpt-4o-yao_test123" if result != expected { t.Errorf("Expected '%s', got '%s'", expected, result) } }) t.Run("WithModelInOptions", func(t *testing.T) { assistant := AssistantModel{ ID: "abc456", Name: "My Bot", Connector: "openai", Options: map[string]interface{}{ "model": "gpt-3.5-turbo", }, } result := assistant.ModelID() expected := "my-bot-gpt-3.5-turbo-yao_abc456" if result != expected { t.Errorf("Expected '%s', got '%s'", expected, result) } }) t.Run("WithoutCustomModel", func(t *testing.T) { assistant := AssistantModel{ ID: "xyz789", Name: "Default Assistant", Connector: "openai", } result := assistant.ModelID() // When connector is not loaded in test, it should return unknown expected := "default-assistant-unknown-yao_xyz789" if result != expected { t.Errorf("Expected '%s', got '%s'", expected, result) } }) t.Run("WithoutConnector", func(t *testing.T) { assistant := AssistantModel{ ID: "noconn", Name: "No Connector", } result := assistant.ModelID() expected := "no-connector-unknown-yao_noconn" if result != expected { t.Errorf("Expected '%s', got '%s'", expected, result) } }) t.Run("WithSpacesInName", func(t *testing.T) { assistant := AssistantModel{ ID: "spaces", Name: "Test Bot With Spaces", Connector: "anthropic", Options: map[string]interface{}{ "model": "claude-3", }, } result := assistant.ModelID() expected := "test-bot-with-spaces-claude-3-yao_spaces" if result != expected { t.Errorf("Expected '%s', got '%s'", expected, result) } }) t.Run("WithUpperCaseName", func(t *testing.T) { assistant := AssistantModel{ ID: "upper", Name: "UPPERCASE-NAME", Connector: "openai", Options: map[string]interface{}{ "model": "GPT-4", }, } result := assistant.ModelID() expected := "uppercase-name-GPT-4-yao_upper" if result != expected { t.Errorf("Expected '%s', got '%s'", expected, result) } }) t.Run("WithEmptyOptions", func(t *testing.T) { assistant := AssistantModel{ ID: "empty", Name: "Empty Options", Connector: "openai", Options: map[string]interface{}{}, } result := assistant.ModelID() // When connector is not loaded in test, it should return unknown expected := "empty-options-unknown-yao_empty" if result != expected { t.Errorf("Expected '%s', got '%s'", expected, result) } }) } // TestToConnectorOptions tests the ToConnectorOptions conversion function func TestToConnectorOptions(t *testing.T) { t.Run("NilInput", func(t *testing.T) { result, err := ToConnectorOptions(nil) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result != nil { t.Errorf("Expected nil result, got: %v", result) } }) t.Run("ConnectorOptionsPointer", func(t *testing.T) { optionalTrue := true opts := &ConnectorOptions{ Optional: &optionalTrue, Connectors: []string{"openai", "anthropic"}, Filters: []ModelCapability{CapVision, CapToolCalls}, } result, err := ToConnectorOptions(opts) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result != opts { t.Errorf("Expected same pointer") } }) t.Run("ConnectorOptionsValue", func(t *testing.T) { optionalTrue := true opts := ConnectorOptions{ Optional: &optionalTrue, Connectors: []string{"openai", "anthropic"}, Filters: []ModelCapability{CapVision, CapToolCalls}, } result, err := ToConnectorOptions(opts) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.Optional == nil || !*result.Optional { t.Error("Expected Optional to be true") } if len(result.Connectors) != 2 { t.Errorf("Expected 2 connectors, got %d", len(result.Connectors)) } if len(result.Filters) != 2 { t.Errorf("Expected 2 filters, got %d", len(result.Filters)) } }) t.Run("MapInput", func(t *testing.T) { data := map[string]interface{}{ "optional": true, "connectors": []string{"openai", "anthropic", "azure"}, "filters": []string{"vision", "tool_calls", "audio"}, } result, err := ToConnectorOptions(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.Optional == nil || !*result.Optional { t.Error("Expected Optional to be true") } if len(result.Connectors) != 3 { t.Errorf("Expected 3 connectors, got %d", len(result.Connectors)) } if len(result.Filters) != 3 { t.Errorf("Expected 3 filters, got %d", len(result.Filters)) } }) t.Run("MapInputOptionalOnly", func(t *testing.T) { data := map[string]interface{}{ "optional": true, } result, err := ToConnectorOptions(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.Optional == nil || !*result.Optional { t.Error("Expected Optional to be true") } if result.Connectors != nil { t.Error("Expected Connectors to be nil") } if result.Filters != nil { t.Error("Expected Filters to be nil") } }) t.Run("MapInputOptionalFalse", func(t *testing.T) { data := map[string]interface{}{ "optional": false, "connectors": []string{"openai"}, "filters": []string{"vision"}, } result, err := ToConnectorOptions(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.Optional == nil { t.Error("Expected Optional to be set") } else if *result.Optional { t.Error("Expected Optional to be false") } if len(result.Connectors) != 1 { t.Errorf("Expected 1 connector, got %d", len(result.Connectors)) } }) t.Run("MapInputOptionalNil", func(t *testing.T) { data := map[string]interface{}{ "connectors": []string{"openai"}, } result, err := ToConnectorOptions(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result.Optional != nil { t.Errorf("Expected Optional to be nil (not set), got: %v", *result.Optional) } if len(result.Connectors) != 1 { t.Errorf("Expected 1 connector, got %d", len(result.Connectors)) } }) t.Run("InvalidInput", func(t *testing.T) { // Test with data that can't be marshaled invalidData := make(chan int) _, err := ToConnectorOptions(invalidData) if err == nil { t.Error("Expected error for invalid input") } }) t.Run("InvalidJSONUnmarshal", func(t *testing.T) { // Test with data that marshals but can't unmarshal to ConnectorOptions data := map[string]interface{}{ "invalid_field": "should cause unmarshal to fail gracefully", } result, err := ToConnectorOptions(data) // Should not error, just return empty ConnectorOptions if err != nil { t.Errorf("Expected no error, got: %v", err) } if result == nil { t.Error("Expected non-nil result") } }) } // TestToModes tests the ToModes conversion function func TestToModes(t *testing.T) { t.Run("NilInput", func(t *testing.T) { result, err := ToModes(nil) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result != nil { t.Errorf("Expected nil result, got: %v", result) } }) t.Run("StringSlice", func(t *testing.T) { modes := []string{"chat", "task", "analyze"} result, err := ToModes(modes) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result) != 3 { t.Errorf("Expected 3 modes, got %d", len(result)) } if result[0] != "chat" { t.Errorf("Expected 'chat', got '%s'", result[0]) } if result[1] != "task" { t.Errorf("Expected 'task', got '%s'", result[1]) } if result[2] != "analyze" { t.Errorf("Expected 'analyze', got '%s'", result[2]) } }) t.Run("InterfaceSlice", func(t *testing.T) { modes := []interface{}{"chat", "task", 123} result, err := ToModes(modes) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result) != 3 { t.Errorf("Expected 3 modes, got %d", len(result)) } if result[0] != "chat" { t.Errorf("Expected 'chat', got '%s'", result[0]) } if result[2] != "123" { t.Errorf("Expected '123', got '%s'", result[2]) } }) t.Run("SingleString", func(t *testing.T) { mode := "chat" result, err := ToModes(mode) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result) != 1 { t.Errorf("Expected 1 mode, got %d", len(result)) } if result[0] != "chat" { t.Errorf("Expected 'chat', got '%s'", result[0]) } }) t.Run("EmptySlice", func(t *testing.T) { modes := []string{} result, err := ToModes(modes) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result) != 0 { t.Errorf("Expected 0 modes, got %d", len(result)) } }) t.Run("InvalidInput", func(t *testing.T) { // Test with data that can't be marshaled invalidData := make(chan int) _, err := ToModes(invalidData) if err == nil { t.Error("Expected error for invalid input") } }) t.Run("InvalidJSONUnmarshal", func(t *testing.T) { // Test with data that marshals but can't unmarshal to []string data := map[string]interface{}{ "invalid": "structure", } _, err := ToModes(data) if err == nil { t.Error("Expected error for invalid unmarshal") } }) t.Run("MixedTypes", func(t *testing.T) { modes := []interface{}{"chat", 456, "task", true} result, err := ToModes(modes) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result) != 4 { t.Errorf("Expected 4 modes, got %d", len(result)) } // cast.ToString should convert all to strings if result[0] != "chat" { t.Errorf("Expected 'chat', got '%s'", result[0]) } if result[1] != "456" { t.Errorf("Expected '456', got '%s'", result[1]) } if result[2] != "task" { t.Errorf("Expected 'task', got '%s'", result[2]) } if result[3] != "true" { t.Errorf("Expected 'true', got '%s'", result[3]) } }) } // TestToPromptPresets tests the ToPromptPresets conversion function func TestToPromptPresets(t *testing.T) { t.Run("NilInput", func(t *testing.T) { result, err := ToPromptPresets(nil) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result != nil { t.Errorf("Expected nil result, got: %v", result) } }) t.Run("MapStringPromptSlice", func(t *testing.T) { presets := map[string][]Prompt{ "chat": { {Role: "system", Content: "You are a chat assistant"}, }, "task": { {Role: "system", Content: "You are a task assistant"}, }, } result, err := ToPromptPresets(presets) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result) != 2 { t.Errorf("Expected 2 presets, got %d", len(result)) } if len(result["chat"]) != 1 { t.Errorf("Expected 1 chat prompt, got %d", len(result["chat"])) } if len(result["task"]) != 1 { t.Errorf("Expected 1 task prompt, got %d", len(result["task"])) } }) t.Run("MapInput", func(t *testing.T) { data := map[string]interface{}{ "chat": []interface{}{ map[string]interface{}{"role": "system", "content": "Chat mode system prompt"}, map[string]interface{}{"role": "user", "content": "Example user message"}, }, "task": []interface{}{ map[string]interface{}{"role": "system", "content": "Task mode system prompt"}, }, "analyze": []interface{}{ map[string]interface{}{"role": "system", "content": "Analyze mode system prompt"}, }, } result, err := ToPromptPresets(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result) != 3 { t.Errorf("Expected 3 presets, got %d", len(result)) } if len(result["chat"]) != 2 { t.Errorf("Expected 2 chat prompts, got %d", len(result["chat"])) } if len(result["task"]) != 1 { t.Errorf("Expected 1 task prompt, got %d", len(result["task"])) } if len(result["analyze"]) != 1 { t.Errorf("Expected 1 analyze prompt, got %d", len(result["analyze"])) } }) t.Run("EmptyMap", func(t *testing.T) { data := map[string]interface{}{} result, err := ToPromptPresets(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if result == nil { t.Error("Expected non-nil result") } if len(result) != 0 { t.Errorf("Expected empty map, got %d entries", len(result)) } }) t.Run("SinglePreset", func(t *testing.T) { data := map[string]interface{}{ "default": []interface{}{ map[string]interface{}{"role": "system", "content": "Default prompt"}, }, } result, err := ToPromptPresets(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result) != 1 { t.Errorf("Expected 1 preset, got %d", len(result)) } if _, ok := result["default"]; !ok { t.Error("Expected 'default' key in result") } }) t.Run("InvalidInput", func(t *testing.T) { // Test with data that can't be marshaled invalidData := make(chan int) _, err := ToPromptPresets(invalidData) if err == nil { t.Error("Expected error for invalid input") } }) t.Run("InvalidJSONUnmarshal", func(t *testing.T) { // Test with data that marshals but can't unmarshal to map[string][]Prompt // This is a string that can be marshaled but won't unmarshal to the expected type data := "not a map" _, err := ToPromptPresets(data) if err == nil { t.Error("Expected error for invalid JSON unmarshal") } }) t.Run("PromptWithAllFields", func(t *testing.T) { data := map[string]interface{}{ "advanced": []interface{}{ map[string]interface{}{ "role": "system", "content": "Advanced system prompt", "name": "system-prompt", }, map[string]interface{}{ "role": "user", "content": "User example", "name": "user-example", }, map[string]interface{}{ "role": "assistant", "content": "Assistant response", "name": "assistant-response", }, }, } result, err := ToPromptPresets(data) if err != nil { t.Errorf("Expected no error, got: %v", err) } if len(result["advanced"]) != 3 { t.Errorf("Expected 3 prompts in advanced, got %d", len(result["advanced"])) } if result["advanced"][0].Role != "system" { t.Errorf("Expected role 'system', got '%s'", result["advanced"][0].Role) } if result["advanced"][0].Content != "Advanced system prompt" { t.Errorf("Expected content 'Advanced system prompt', got '%s'", result["advanced"][0].Content) } }) } // TestParseModelID tests the ParseModelID function func TestParseModelID(t *testing.T) { t.Run("ValidModelID", func(t *testing.T) { modelID := "test-assistant-gpt-4o-yao_test123" result := ParseModelID(modelID) expected := "test123" if result != expected { t.Errorf("Expected '%s', got '%s'", expected, result) } }) t.Run("ValidModelIDWithMultipleDashes", func(t *testing.T) { modelID := "my-test-bot-gpt-3.5-turbo-yao_abc456" result := ParseModelID(modelID) expected := "abc456" if result != expected { t.Errorf("Expected '%s', got '%s'", expected, result) } }) t.Run("ValidModelIDWithHyphenInID", func(t *testing.T) { modelID := "assistant-name-model-yao_id-with-dash" result := ParseModelID(modelID) expected := "id-with-dash" if result != expected { t.Errorf("Expected '%s', got '%s'", expected, result) } }) t.Run("InvalidModelIDNoYaoPrefix", func(t *testing.T) { modelID := "test-assistant-gpt-4o-test123" result := ParseModelID(modelID) if result != "" { t.Errorf("Expected empty string, got '%s'", result) } }) t.Run("InvalidModelIDEmpty", func(t *testing.T) { modelID := "" result := ParseModelID(modelID) if result != "" { t.Errorf("Expected empty string, got '%s'", result) } }) t.Run("InvalidModelIDOnlyYaoPrefix", func(t *testing.T) { modelID := "yao_" result := ParseModelID(modelID) if result != "" { t.Errorf("Expected empty string, got '%s'", result) } }) t.Run("RoundTrip", func(t *testing.T) { assistant := AssistantModel{ ID: "roundtrip123", Name: "Round Trip Test", Connector: "openai", Options: map[string]interface{}{ "model": "gpt-4", }, } modelID := assistant.ModelID() extractedID := ParseModelID(modelID) if extractedID != assistant.ID { t.Errorf("Round trip failed: expected '%s', got '%s'", assistant.ID, extractedID) } }) } ================================================ FILE: agent/store/types/fields.go ================================================ package types import "github.com/yaoapp/kun/log" // AssistantAllowedFields defines the whitelist of fields that can be selected for assistants var AssistantAllowedFields = map[string]bool{ "id": true, "assistant_id": true, "type": true, "name": true, "avatar": true, "connector": true, "connector_options": true, "description": true, "path": true, "sort": true, "built_in": true, "placeholder": true, "options": true, "prompts": true, "prompt_presets": true, "disable_global_prompts": true, "capabilities": true, "workflow": true, "kb": true, "db": true, "mcp": true, "sandbox": true, "source": true, "tags": true, "modes": true, "default_mode": true, "readonly": true, "public": true, "share": true, "locales": true, "uses": true, "search": true, "dependencies": true, "automated": true, "mentionable": true, "created_at": true, "updated_at": true, "__yao_created_by": true, "__yao_updated_by": true, "__yao_team_id": true, "__yao_tenant_id": true, } // AssistantDefaultFields defines the default fields to select for assistants when no specific fields are requested // These are lightweight fields suitable for list views and basic information display var AssistantDefaultFields = []string{ "assistant_id", "type", "name", "avatar", "connector", "description", "capabilities", // Capabilities description for Robot orchestration (lightweight) "tags", // Tags for categorization (lightweight) "modes", // Supported modes (lightweight) "default_mode", // Default mode (lightweight) "sort", "built_in", "readonly", "public", "share", "automated", "mentionable", "sandbox", // Sandbox configuration presence (lightweight) "kb", // Knowledge base configuration (lightweight) "db", // Database configuration (lightweight) "mcp", // MCP servers configuration (lightweight) "dependencies", // Dependencies on other MCP Clients (lightweight) "created_at", "updated_at", "__yao_created_by", // Permission: creator user ID "__yao_updated_by", // Permission: updater user ID "__yao_team_id", // Permission: team ID "__yao_tenant_id", // Permission: tenant ID } // AssistantFullFields defines all available fields including complex/large fields // Use this when you need complete assistant data for backend processing var AssistantFullFields = []string{ "assistant_id", "type", "name", "avatar", "connector", "connector_options", "description", "capabilities", "path", "sort", "built_in", "placeholder", "options", "prompts", "prompt_presets", "disable_global_prompts", "workflow", "kb", "db", "mcp", "sandbox", "source", "tags", "modes", "default_mode", "readonly", "public", "share", "locales", "uses", "search", "dependencies", "automated", "mentionable", "created_at", "updated_at", "__yao_created_by", "__yao_updated_by", "__yao_team_id", "__yao_tenant_id", } // ValidateAssistantFields validates and filters assistant select fields against the whitelist // Returns the filtered fields. If input is empty, returns empty slice (meaning no restriction). // If all fields are invalid, returns default fields as fallback. func ValidateAssistantFields(fields []string) []string { // If no fields specified, return empty slice (no restriction) if len(fields) == 0 { return []string{} } // Filter out any fields not in the whitelist sanitized := make([]string, 0, len(fields)) for _, field := range fields { if AssistantAllowedFields[field] { sanitized = append(sanitized, field) } else { log.Warn("Ignoring invalid assistant select field: %s", field) } } // If all fields were filtered out, return default fields as fallback if len(sanitized) == 0 { log.Warn("All assistant select fields were invalid, using default fields") return AssistantDefaultFields } return sanitized } ================================================ FILE: agent/store/types/fields_test.go ================================================ package types import ( "reflect" "testing" ) func TestValidateAssistantFields(t *testing.T) { t.Run("EmptyInput_ReturnsEmptySlice", func(t *testing.T) { result := ValidateAssistantFields([]string{}) if len(result) != 0 { t.Errorf("Expected empty slice, got %v", result) } }) t.Run("NilInput_ReturnsEmptySlice", func(t *testing.T) { result := ValidateAssistantFields(nil) if len(result) != 0 { t.Errorf("Expected empty slice, got %v", result) } }) t.Run("ValidFields_ReturnsFiltered", func(t *testing.T) { input := []string{"assistant_id", "name", "type"} result := ValidateAssistantFields(input) expected := []string{"assistant_id", "name", "type"} if !reflect.DeepEqual(result, expected) { t.Errorf("Expected %v, got %v", expected, result) } }) t.Run("MixedValidInvalidFields_ReturnsOnlyValid", func(t *testing.T) { input := []string{"assistant_id", "invalid_field", "name", "malicious_column"} result := ValidateAssistantFields(input) expected := []string{"assistant_id", "name"} if !reflect.DeepEqual(result, expected) { t.Errorf("Expected %v, got %v", expected, result) } }) t.Run("AllInvalidFields_ReturnsDefaultFields", func(t *testing.T) { input := []string{"invalid1", "invalid2", "malicious"} result := ValidateAssistantFields(input) if !reflect.DeepEqual(result, AssistantDefaultFields) { t.Errorf("Expected default fields when all invalid, got %v", result) } }) t.Run("PermissionFields_AreAllowed", func(t *testing.T) { input := []string{"__yao_created_by", "__yao_team_id", "assistant_id"} result := ValidateAssistantFields(input) expected := []string{"__yao_created_by", "__yao_team_id", "assistant_id"} if !reflect.DeepEqual(result, expected) { t.Errorf("Expected %v, got %v", expected, result) } }) t.Run("AllAllowedFields_AreInWhitelist", func(t *testing.T) { // Verify all default fields are in the allowed list for _, field := range AssistantDefaultFields { if !AssistantAllowedFields[field] { t.Errorf("Default field %s is not in allowed fields", field) } } }) t.Run("SQLInjectionAttempt_IsFiltered", func(t *testing.T) { input := []string{"assistant_id", "name; DROP TABLE assistants;--", "type"} result := ValidateAssistantFields(input) expected := []string{"assistant_id", "type"} if !reflect.DeepEqual(result, expected) { t.Errorf("Expected SQL injection attempt to be filtered, got %v", result) } }) t.Run("DuplicateFields_AreKept", func(t *testing.T) { input := []string{"assistant_id", "name", "assistant_id", "name"} result := ValidateAssistantFields(input) // Duplicates should be kept as-is (validation doesn't deduplicate) expected := []string{"assistant_id", "name", "assistant_id", "name"} if !reflect.DeepEqual(result, expected) { t.Errorf("Expected %v, got %v", expected, result) } }) } func TestAssistantAllowedFields(t *testing.T) { t.Run("ContainsBasicFields", func(t *testing.T) { requiredFields := []string{ "assistant_id", "name", "type", "connector", "description", } for _, field := range requiredFields { if !AssistantAllowedFields[field] { t.Errorf("Required field %s is missing from allowed fields", field) } } }) t.Run("ContainsPermissionFields", func(t *testing.T) { permissionFields := []string{ "__yao_created_by", "__yao_updated_by", "__yao_team_id", "__yao_tenant_id", } for _, field := range permissionFields { if !AssistantAllowedFields[field] { t.Errorf("Permission field %s is missing from allowed fields", field) } } }) t.Run("ContainsComplexFields", func(t *testing.T) { complexFields := []string{ "options", "prompts", "prompt_presets", "disable_global_prompts", "workflow", "kb", "db", "mcp", "placeholder", "locales", "uses", "connector_options", "source", "modes", "default_mode", } for _, field := range complexFields { if !AssistantAllowedFields[field] { t.Errorf("Complex field %s is missing from allowed fields", field) } } }) } func TestAssistantDefaultFields(t *testing.T) { t.Run("ContainsEssentialFields", func(t *testing.T) { essentialFields := []string{ "assistant_id", "name", "type", "kb", // Knowledge base is essential for assistant functionality "db", // Database is essential for assistant functionality "mcp", // MCP servers are essential for assistant functionality "modes", // Supported modes are essential for mode filtering "default_mode", // Default mode is essential for mode selection "__yao_created_by", // Permission fields are essential for access control "__yao_updated_by", "__yao_team_id", "__yao_tenant_id", } defaultFieldsMap := make(map[string]bool) for _, field := range AssistantDefaultFields { defaultFieldsMap[field] = true } for _, field := range essentialFields { if !defaultFieldsMap[field] { t.Errorf("Essential field %s is missing from default fields", field) } } }) t.Run("DoesNotContainSensitiveFields", func(t *testing.T) { // Default fields should not include complex/large fields by default // Note: kb, db, mcp, tags, modes, and default_mode are lightweight and included in defaults sensitiveFields := []string{ "options", "prompts", "prompt_presets", "workflow", "placeholder", "locales", "uses", "connector_options", "source", } defaultFieldsMap := make(map[string]bool) for _, field := range AssistantDefaultFields { defaultFieldsMap[field] = true } for _, field := range sensitiveFields { if defaultFieldsMap[field] { t.Errorf("Large/complex field %s should not be in default fields", field) } } }) } func TestAssistantFullFields(t *testing.T) { t.Run("ContainsAllAllowedFields", func(t *testing.T) { // Full fields should contain all fields from allowed fields fullFieldsMap := make(map[string]bool) for _, field := range AssistantFullFields { fullFieldsMap[field] = true } for field := range AssistantAllowedFields { if field == "id" { // "id" is an alias for "assistant_id", skip continue } if !fullFieldsMap[field] { t.Errorf("Allowed field %s is missing from full fields", field) } } }) t.Run("AllFieldsAreAllowed", func(t *testing.T) { // All fields in full list should be in allowed fields for _, field := range AssistantFullFields { if !AssistantAllowedFields[field] { t.Errorf("Full field %s is not in allowed fields", field) } } }) t.Run("ContainsComplexFields", func(t *testing.T) { // Full fields should include all complex/large fields complexFields := []string{ "options", "prompts", "prompt_presets", "disable_global_prompts", "workflow", "kb", "db", "mcp", "placeholder", "locales", "uses", "connector_options", "source", "modes", "default_mode", } fullFieldsMap := make(map[string]bool) for _, field := range AssistantFullFields { fullFieldsMap[field] = true } for _, field := range complexFields { if !fullFieldsMap[field] { t.Errorf("Complex field %s is missing from full fields", field) } } }) t.Run("ContainsPermissionFields", func(t *testing.T) { // Full fields should include permission fields permissionFields := []string{ "__yao_created_by", "__yao_updated_by", "__yao_team_id", "__yao_tenant_id", } fullFieldsMap := make(map[string]bool) for _, field := range AssistantFullFields { fullFieldsMap[field] = true } for _, field := range permissionFields { if !fullFieldsMap[field] { t.Errorf("Permission field %s is missing from full fields", field) } } }) } ================================================ FILE: agent/store/types/mcp_test.go ================================================ package types import ( "encoding/json" "testing" ) func TestMCPServerConfig_UnmarshalJSON(t *testing.T) { tests := []struct { name string input string want MCPServerConfig wantErr bool }{ { name: "Simple string", input: `"server1"`, want: MCPServerConfig{ ServerID: "server1", Resources: nil, Tools: nil, }, wantErr: false, }, { name: "Tools array only", input: `{"server1": ["tool1", "tool2"]}`, want: MCPServerConfig{ ServerID: "server1", Resources: nil, Tools: []string{"tool1", "tool2"}, }, wantErr: false, }, { name: "Full config with resources and tools", input: `{"server1": {"resources": ["res1", "res2"], "tools": ["tool1", "tool2"]}}`, want: MCPServerConfig{ ServerID: "server1", Resources: []string{"res1", "res2"}, Tools: []string{"tool1", "tool2"}, }, wantErr: false, }, { name: "Only resources", input: `{"server1": {"resources": ["res1"]}}`, want: MCPServerConfig{ ServerID: "server1", Resources: []string{"res1"}, Tools: nil, }, wantErr: false, }, { name: "Only tools", input: `{"server1": {"tools": ["tool1"]}}`, want: MCPServerConfig{ ServerID: "server1", Resources: nil, Tools: []string{"tool1"}, }, wantErr: false, }, { name: "Standard object format", input: `{"server_id": "server1", "resources": ["res1"], "tools": ["tool1"]}`, want: MCPServerConfig{ ServerID: "server1", Resources: []string{"res1"}, Tools: []string{"tool1"}, }, wantErr: false, }, { name: "Standard object format - no resources/tools", input: `{"server_id": "server1"}`, want: MCPServerConfig{ ServerID: "server1", Resources: nil, Tools: nil, }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got MCPServerConfig err := json.Unmarshal([]byte(tt.input), &got) if (err != nil) != tt.wantErr { t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !tt.wantErr { if got.ServerID != tt.want.ServerID { t.Errorf("ServerID = %v, want %v", got.ServerID, tt.want.ServerID) } if !stringSlicesEqual(got.Resources, tt.want.Resources) { t.Errorf("Resources = %v, want %v", got.Resources, tt.want.Resources) } if !stringSlicesEqual(got.Tools, tt.want.Tools) { t.Errorf("Tools = %v, want %v", got.Tools, tt.want.Tools) } } }) } } func TestMCPServers_UnmarshalJSON(t *testing.T) { tests := []struct { name string input string want []MCPServerConfig wantErr bool }{ { name: "Simple string array", input: `{"servers": ["server1", "server2", "server3"]}`, want: []MCPServerConfig{ {ServerID: "server1"}, {ServerID: "server2"}, {ServerID: "server3"}, }, wantErr: false, }, { name: "Mixed formats", input: `{"servers": ["server1", {"server2": ["tool1", "tool2"]}, {"server3": {"resources": ["res1"], "tools": ["tool3"]}}]}`, want: []MCPServerConfig{ {ServerID: "server1"}, {ServerID: "server2", Tools: []string{"tool1", "tool2"}}, {ServerID: "server3", Resources: []string{"res1"}, Tools: []string{"tool3"}}, }, wantErr: false, }, { name: "Empty servers", input: `{"servers": []}`, want: []MCPServerConfig{}, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got MCPServers err := json.Unmarshal([]byte(tt.input), &got) if (err != nil) != tt.wantErr { t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !tt.wantErr { if len(got.Servers) != len(tt.want) { t.Errorf("got %d servers, want %d", len(got.Servers), len(tt.want)) return } for i := range got.Servers { if got.Servers[i].ServerID != tt.want[i].ServerID { t.Errorf("Server[%d].ServerID = %v, want %v", i, got.Servers[i].ServerID, tt.want[i].ServerID) } if !stringSlicesEqual(got.Servers[i].Resources, tt.want[i].Resources) { t.Errorf("Server[%d].Resources = %v, want %v", i, got.Servers[i].Resources, tt.want[i].Resources) } if !stringSlicesEqual(got.Servers[i].Tools, tt.want[i].Tools) { t.Errorf("Server[%d].Tools = %v, want %v", i, got.Servers[i].Tools, tt.want[i].Tools) } } } }) } } func TestMCPServerConfig_MarshalJSON(t *testing.T) { tests := []struct { name string config MCPServerConfig want string }{ { name: "Only ServerID - should be simple string", config: MCPServerConfig{ ServerID: "server1", }, want: `"server1"`, }, { name: "With Tools - should be object", config: MCPServerConfig{ ServerID: "server1", Tools: []string{"tool1", "tool2"}, }, want: `{"server_id":"server1","tools":["tool1","tool2"]}`, }, { name: "With Resources - should be object", config: MCPServerConfig{ ServerID: "server1", Resources: []string{"res1"}, }, want: `{"server_id":"server1","resources":["res1"]}`, }, { name: "With Both - should be object", config: MCPServerConfig{ ServerID: "server1", Resources: []string{"res1"}, Tools: []string{"tool1"}, }, want: `{"server_id":"server1","resources":["res1"],"tools":["tool1"]}`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := json.Marshal(tt.config) if err != nil { t.Errorf("MarshalJSON() error = %v", err) return } if string(got) != tt.want { t.Errorf("MarshalJSON() = %s, want %s", string(got), tt.want) } }) } } func TestMCPServerConfig_RoundTrip(t *testing.T) { tests := []struct { name string config MCPServerConfig }{ { name: "Simple ServerID", config: MCPServerConfig{ ServerID: "server1", }, }, { name: "With Tools", config: MCPServerConfig{ ServerID: "server2", Tools: []string{"tool1", "tool2"}, }, }, { name: "With Resources and Tools", config: MCPServerConfig{ ServerID: "server3", Resources: []string{"res1", "res2"}, Tools: []string{"tool3", "tool4"}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Marshal data, err := json.Marshal(tt.config) if err != nil { t.Fatalf("Marshal error = %v", err) } // Unmarshal var got MCPServerConfig err = json.Unmarshal(data, &got) if err != nil { t.Fatalf("Unmarshal error = %v", err) } // Compare if got.ServerID != tt.config.ServerID { t.Errorf("ServerID = %v, want %v", got.ServerID, tt.config.ServerID) } if !stringSlicesEqual(got.Resources, tt.config.Resources) { t.Errorf("Resources = %v, want %v", got.Resources, tt.config.Resources) } if !stringSlicesEqual(got.Tools, tt.config.Tools) { t.Errorf("Tools = %v, want %v", got.Tools, tt.config.Tools) } }) } } // Helper function to compare string slices (nil-safe) func stringSlicesEqual(a, b []string) bool { if len(a) == 0 && len(b) == 0 { return true } if len(a) != len(b) { return false } for i := range a { if a[i] != b[i] { return false } } return true } ================================================ FILE: agent/store/types/prompt.go ================================================ package types import ( "os" "path/filepath" "regexp" "strings" "time" "github.com/yaoapp/gou/application" "github.com/yaoapp/gou/fs" "gopkg.in/yaml.v3" ) // Prompts is a slice of Prompt with helper methods type Prompts []Prompt // SystemVariables defines the available system variables // These are computed at parse time var SystemVariables = map[string]func() string{ "TIME": func() string { return time.Now().Format("15:04:05") }, "DATE": func() string { return time.Now().Format("2006-01-02") }, "DATETIME": func() string { return time.Now().Format("2006-01-02 15:04:05") }, "TIMEZONE": func() string { return time.Now().Location().String() }, "WEEKDAY": func() string { return time.Now().Weekday().String() }, "YEAR": func() string { return time.Now().Format("2006") }, "MONTH": func() string { return time.Now().Format("01") }, "DAY": func() string { return time.Now().Format("02") }, "HOUR": func() string { return time.Now().Format("15") }, "MINUTE": func() string { return time.Now().Format("04") }, "SECOND": func() string { return time.Now().Format("05") }, "UNIX": func() string { return time.Now().Format("1136239445") }, // Unix timestamp } // Regular expressions for variable replacement var ( reSysVar = regexp.MustCompile(`\$SYS\.([A-Z_]+)`) reEnvVar = regexp.MustCompile(`\$ENV\.([A-Za-z_][A-Za-z0-9_]*)`) reCtxVar = regexp.MustCompile(`\$CTX\.([A-Za-z_][A-Za-z0-9_]*)`) reAssetRef = regexp.MustCompile(`@assets/([^\s]+\.(md|yml|yaml|json|txt))`) ) // LoadPrompts loads prompts from a YAML file // Handles @assets/* replacement at load time // file: prompt file path relative to app root (e.g., "assistants/test/prompts.yml") // root: resource root directory for assets (e.g., "assistants/test") // Returns: prompts slice, modification timestamp, error func LoadPrompts(file string, root string) ([]Prompt, int64, error) { app, err := fs.Get("app") if err != nil { return nil, 0, err } ts, err := app.ModTime(file) if err != nil { return nil, 0, err } content, err := app.ReadFile(file) if err != nil { return nil, 0, err } // Replace @assets/xxx references with file content content = replaceAssets(content, root, app) // Parse prompts var prompts []Prompt err = yaml.Unmarshal(content, &prompts) if err != nil { return nil, 0, err } return prompts, ts.UnixNano(), nil } // LoadPromptsRaw loads raw prompt content from a YAML file // Handles @assets/* replacement at load time // Returns raw YAML string for further processing func LoadPromptsRaw(file string, root string) (string, int64, error) { app, err := fs.Get("app") if err != nil { return "", 0, err } ts, err := app.ModTime(file) if err != nil { return "", 0, err } content, err := app.ReadFile(file) if err != nil { return "", 0, err } // Replace @assets/xxx references with file content content = replaceAssets(content, root, app) return string(content), ts.UnixNano(), nil } // LoadGlobalPrompts loads global prompts from agent/prompts.yml // Returns: prompts slice, modification timestamp, error func LoadGlobalPrompts() ([]Prompt, int64, error) { file := filepath.Join("agent", "prompts.yml") // Check if file exists exists, _ := application.App.Exists(file) if !exists { return nil, 0, nil } return LoadPrompts(file, "agent") } // LoadPromptPresets loads prompt presets from a directory // Supports multi-level directories, key is path with "/" replaced by "." // e.g., prompts/chat/friendly.yml -> "chat.friendly" func LoadPromptPresets(dir string, root string) (map[string][]Prompt, int64, error) { app, err := fs.Get("app") if err != nil { return nil, 0, err } // Check if directory exists exists, _ := app.Exists(dir) if !exists { return nil, 0, nil } // Read directory recursively files, err := app.ReadDir(dir, true) if err != nil { return nil, 0, err } presets := make(map[string][]Prompt) var latestTs int64 for _, file := range files { // Only process .yml/.yaml files if !strings.HasSuffix(file, ".yml") && !strings.HasSuffix(file, ".yaml") { continue } ts, err := app.ModTime(file) if err != nil { return nil, 0, err } if ts.UnixNano() > latestTs { latestTs = ts.UnixNano() } // Read file content content, err := app.ReadFile(file) if err != nil { return nil, 0, err } // Replace @assets/xxx references with file content content = replaceAssets(content, root, app) // Parse prompts var prompts []Prompt err = yaml.Unmarshal(content, &prompts) if err != nil { return nil, 0, err } // Build key: get relative path from dir, remove extension and replace "/" with "." relPath := strings.TrimPrefix(file, dir+"/") key := strings.TrimSuffix(relPath, filepath.Ext(relPath)) key = strings.ReplaceAll(key, "/", ".") presets[key] = prompts } return presets, latestTs, nil } // replaceAssets replaces @assets/xxx references with file content func replaceAssets(content []byte, root string, app fs.FileSystem) []byte { return reAssetRef.ReplaceAllFunc(content, func(s []byte) []byte { matches := reAssetRef.FindStringSubmatch(string(s)) if len(matches) < 2 { return s } asset := matches[1] assetFile := filepath.Join(root, "assets", asset) assetContent, err := app.ReadFile(assetFile) if err != nil { return []byte("") } // Add proper YAML formatting for content (multiline string) lines := strings.Split(string(assetContent), "\n") formattedContent := "|\n" for _, line := range lines { formattedContent += " " + line + "\n" } return []byte(formattedContent) }) } // Parse parses a single prompt, replacing variables // ctx: context variables map, key corresponds to $CTX.{key} // Returns a new Prompt with variables replaced func (p Prompt) Parse(ctx map[string]string) Prompt { result := Prompt{ Role: p.Role, Content: parseVariables(p.Content, ctx), Name: p.Name, } return result } // Parse parses all prompts in the slice, replacing variables // ctx: context variables map, key corresponds to $CTX.{key} // Returns a new Prompts slice with variables replaced func (ps Prompts) Parse(ctx map[string]string) Prompts { result := make(Prompts, len(ps)) for i, p := range ps { result[i] = p.Parse(ctx) } return result } // parseVariables replaces all variable types in content func parseVariables(content string, ctx map[string]string) string { // Replace $SYS.* variables content = reSysVar.ReplaceAllStringFunc(content, func(s string) string { matches := reSysVar.FindStringSubmatch(s) if len(matches) < 2 { return s } varName := matches[1] if fn, ok := SystemVariables[varName]; ok { return fn() } return s // Keep original if not found }) // Replace $ENV.* variables content = reEnvVar.ReplaceAllStringFunc(content, func(s string) string { matches := reEnvVar.FindStringSubmatch(s) if len(matches) < 2 { return s } varName := matches[1] return os.Getenv(varName) }) // Replace $CTX.* variables if ctx != nil { content = reCtxVar.ReplaceAllStringFunc(content, func(s string) string { matches := reCtxVar.FindStringSubmatch(s) if len(matches) < 2 { return s } varName := matches[1] if val, ok := ctx[varName]; ok { return val } return "" // Empty string if not found in ctx }) } return content } // Merge merges two prompt slices // globalPrompts are prepended to assistantPrompts func Merge(globalPrompts, assistantPrompts []Prompt) []Prompt { if len(globalPrompts) == 0 { return assistantPrompts } if len(assistantPrompts) == 0 { return globalPrompts } result := make([]Prompt, 0, len(globalPrompts)+len(assistantPrompts)) result = append(result, globalPrompts...) result = append(result, assistantPrompts...) return result } ================================================ FILE: agent/store/types/prompt_test.go ================================================ package types import ( "os" "testing" "time" "github.com/stretchr/testify/assert" ) func TestPromptParse(t *testing.T) { tests := []struct { name string prompt Prompt ctx map[string]string validate func(t *testing.T, result Prompt) }{ { name: "ParseSysTimeVariables", prompt: Prompt{ Role: "system", Content: "Current time: $SYS.TIME, Date: $SYS.DATE", }, ctx: nil, validate: func(t *testing.T, result Prompt) { assert.Equal(t, "system", result.Role) // Check that variables are replaced (not exact match due to time) assert.NotContains(t, result.Content, "$SYS.TIME") assert.NotContains(t, result.Content, "$SYS.DATE") assert.Contains(t, result.Content, "Current time:") assert.Contains(t, result.Content, "Date:") }, }, { name: "ParseSysDatetimeVariable", prompt: Prompt{ Role: "system", Content: "Now: $SYS.DATETIME, Timezone: $SYS.TIMEZONE", }, ctx: nil, validate: func(t *testing.T, result Prompt) { assert.NotContains(t, result.Content, "$SYS.DATETIME") assert.NotContains(t, result.Content, "$SYS.TIMEZONE") }, }, { name: "ParseSysWeekdayVariable", prompt: Prompt{ Role: "system", Content: "Today is $SYS.WEEKDAY", }, ctx: nil, validate: func(t *testing.T, result Prompt) { weekday := time.Now().Weekday().String() assert.Contains(t, result.Content, weekday) }, }, { name: "ParseSysYearMonthDay", prompt: Prompt{ Role: "system", Content: "Year: $SYS.YEAR, Month: $SYS.MONTH, Day: $SYS.DAY", }, ctx: nil, validate: func(t *testing.T, result Prompt) { now := time.Now() assert.Contains(t, result.Content, now.Format("2006")) assert.Contains(t, result.Content, now.Format("01")) assert.Contains(t, result.Content, now.Format("02")) }, }, { name: "ParseSysHourMinuteSecond", prompt: Prompt{ Role: "system", Content: "Hour: $SYS.HOUR, Minute: $SYS.MINUTE, Second: $SYS.SECOND", }, ctx: nil, validate: func(t *testing.T, result Prompt) { assert.NotContains(t, result.Content, "$SYS.HOUR") assert.NotContains(t, result.Content, "$SYS.MINUTE") assert.NotContains(t, result.Content, "$SYS.SECOND") }, }, { name: "ParseEnvVariable", prompt: Prompt{ Role: "system", Content: "App: $ENV.TEST_APP_NAME", }, ctx: nil, validate: func(t *testing.T, result Prompt) { assert.Contains(t, result.Content, "App: TestApp") }, }, { name: "ParseEnvVariableNotFound", prompt: Prompt{ Role: "system", Content: "Value: $ENV.NOT_EXIST_VAR_12345", }, ctx: nil, validate: func(t *testing.T, result Prompt) { // Should be replaced with empty string assert.Equal(t, "Value: ", result.Content) }, }, { name: "ParseCtxVariables", prompt: Prompt{ Role: "system", Content: "User: $CTX.USER_ID, Locale: $CTX.LOCALE", }, ctx: map[string]string{ "USER_ID": "user-123", "LOCALE": "zh-CN", }, validate: func(t *testing.T, result Prompt) { assert.Equal(t, "User: user-123, Locale: zh-CN", result.Content) }, }, { name: "ParseCtxVariableNotFound", prompt: Prompt{ Role: "system", Content: "Value: $CTX.NOT_EXIST", }, ctx: map[string]string{ "OTHER": "value", }, validate: func(t *testing.T, result Prompt) { // Should be replaced with empty string assert.Equal(t, "Value: ", result.Content) }, }, { name: "ParseCtxWithNilMap", prompt: Prompt{ Role: "system", Content: "Value: $CTX.SOMETHING", }, ctx: nil, validate: func(t *testing.T, result Prompt) { // Should keep original when ctx is nil assert.Equal(t, "Value: $CTX.SOMETHING", result.Content) }, }, { name: "ParseMixedVariables", prompt: Prompt{ Role: "system", Content: "Time: $SYS.TIME, App: $ENV.TEST_APP_NAME, User: $CTX.USER_ID", }, ctx: map[string]string{ "USER_ID": "user-456", }, validate: func(t *testing.T, result Prompt) { assert.NotContains(t, result.Content, "$SYS.TIME") assert.Contains(t, result.Content, "App: TestApp") assert.Contains(t, result.Content, "User: user-456") }, }, { name: "ParseUnknownSysVariable", prompt: Prompt{ Role: "system", Content: "Value: $SYS.UNKNOWN_VAR", }, ctx: nil, validate: func(t *testing.T, result Prompt) { // Should keep original if not found assert.Equal(t, "Value: $SYS.UNKNOWN_VAR", result.Content) }, }, { name: "ParsePreservesRoleAndName", prompt: Prompt{ Role: "user", Content: "Hello $CTX.NAME", Name: "test_user", }, ctx: map[string]string{ "NAME": "World", }, validate: func(t *testing.T, result Prompt) { assert.Equal(t, "user", result.Role) assert.Equal(t, "Hello World", result.Content) assert.Equal(t, "test_user", result.Name) }, }, { name: "ParseCustomCtxVariables", prompt: Prompt{ Role: "system", Content: "Custom: $CTX.MY_CUSTOM_VAR, Another: $CTX.ANOTHER_VAR", }, ctx: map[string]string{ "MY_CUSTOM_VAR": "custom-value", "ANOTHER_VAR": "another-value", }, validate: func(t *testing.T, result Prompt) { assert.Equal(t, "Custom: custom-value, Another: another-value", result.Content) }, }, { name: "ParseMultilineContent", prompt: Prompt{ Role: "system", Content: `# System Context Current Time: $SYS.TIME User: $CTX.USER_ID App: $ENV.TEST_APP_NAME`, }, ctx: map[string]string{ "USER_ID": "user-789", }, validate: func(t *testing.T, result Prompt) { assert.Contains(t, result.Content, "# System Context") assert.NotContains(t, result.Content, "$SYS.TIME") assert.Contains(t, result.Content, "User: user-789") assert.Contains(t, result.Content, "App: TestApp") }, }, } // Set test environment variable os.Setenv("TEST_APP_NAME", "TestApp") defer os.Unsetenv("TEST_APP_NAME") for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.prompt.Parse(tt.ctx) tt.validate(t, result) }) } } func TestPromptsParse(t *testing.T) { os.Setenv("TEST_APP_NAME", "TestApp") defer os.Unsetenv("TEST_APP_NAME") prompts := Prompts{ {Role: "system", Content: "Time: $SYS.TIME"}, {Role: "system", Content: "User: $CTX.USER_ID"}, {Role: "user", Content: "App: $ENV.TEST_APP_NAME"}, } ctx := map[string]string{ "USER_ID": "user-123", } result := prompts.Parse(ctx) assert.Len(t, result, 3) assert.NotContains(t, result[0].Content, "$SYS.TIME") assert.Equal(t, "User: user-123", result[1].Content) assert.Equal(t, "App: TestApp", result[2].Content) } func TestMergePrompts(t *testing.T) { tests := []struct { name string globalPrompts []Prompt assistantPrompts []Prompt expectedLen int validate func(t *testing.T, result []Prompt) }{ { name: "MergeBothNonEmpty", globalPrompts: []Prompt{ {Role: "system", Content: "Global prompt 1"}, {Role: "system", Content: "Global prompt 2"}, }, assistantPrompts: []Prompt{ {Role: "system", Content: "Assistant prompt 1"}, }, expectedLen: 3, validate: func(t *testing.T, result []Prompt) { assert.Equal(t, "Global prompt 1", result[0].Content) assert.Equal(t, "Global prompt 2", result[1].Content) assert.Equal(t, "Assistant prompt 1", result[2].Content) }, }, { name: "MergeGlobalEmpty", globalPrompts: []Prompt{}, assistantPrompts: []Prompt{ {Role: "system", Content: "Assistant prompt"}, }, expectedLen: 1, validate: func(t *testing.T, result []Prompt) { assert.Equal(t, "Assistant prompt", result[0].Content) }, }, { name: "MergeAssistantEmpty", globalPrompts: []Prompt{ {Role: "system", Content: "Global prompt"}, }, assistantPrompts: []Prompt{}, expectedLen: 1, validate: func(t *testing.T, result []Prompt) { assert.Equal(t, "Global prompt", result[0].Content) }, }, { name: "MergeBothEmpty", globalPrompts: []Prompt{}, assistantPrompts: []Prompt{}, expectedLen: 0, validate: func(t *testing.T, result []Prompt) { assert.Empty(t, result) }, }, { name: "MergeGlobalNil", globalPrompts: nil, assistantPrompts: []Prompt{ {Role: "system", Content: "Assistant prompt"}, }, expectedLen: 1, validate: func(t *testing.T, result []Prompt) { assert.Equal(t, "Assistant prompt", result[0].Content) }, }, { name: "MergeAssistantNil", globalPrompts: []Prompt{ {Role: "system", Content: "Global prompt"}, }, assistantPrompts: nil, expectedLen: 1, validate: func(t *testing.T, result []Prompt) { assert.Equal(t, "Global prompt", result[0].Content) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := Merge(tt.globalPrompts, tt.assistantPrompts) assert.Len(t, result, tt.expectedLen) if tt.validate != nil { tt.validate(t, result) } }) } } func TestSystemVariables(t *testing.T) { // Test that all system variables are defined and return non-empty values expectedVars := []string{ "TIME", "DATE", "DATETIME", "TIMEZONE", "WEEKDAY", "YEAR", "MONTH", "DAY", "HOUR", "MINUTE", "SECOND", "UNIX", } for _, varName := range expectedVars { t.Run(varName, func(t *testing.T) { fn, ok := SystemVariables[varName] assert.True(t, ok, "SystemVariables should contain %s", varName) value := fn() assert.NotEmpty(t, value, "SystemVariables[%s]() should return non-empty value", varName) }) } } func TestParseVariablesEdgeCases(t *testing.T) { os.Setenv("TEST_VAR", "test-value") defer os.Unsetenv("TEST_VAR") tests := []struct { name string content string ctx map[string]string expected string }{ { name: "EmptyContent", content: "", ctx: nil, expected: "", }, { name: "NoVariables", content: "Hello, World!", ctx: nil, expected: "Hello, World!", }, { name: "PartialVariableSyntax", content: "Value: $SYS Value: $ENV Value: $CTX", ctx: nil, expected: "Value: $SYS Value: $ENV Value: $CTX", }, { name: "VariableInMiddleOfWord", content: "prefix$SYS.TIMEsuffix", ctx: nil, expected: "prefix$SYS.TIMEsuffix", // Should not match - variable must be followed by valid char }, { name: "MultipleOccurrences", content: "$CTX.VAR and $CTX.VAR again", ctx: map[string]string{"VAR": "value"}, expected: "value and value again", }, { name: "SpecialCharactersInValue", content: "User: $CTX.USER", ctx: map[string]string{"USER": "user@example.com"}, expected: "User: user@example.com", }, { name: "UnicodeInValue", content: "Name: $CTX.NAME", ctx: map[string]string{"NAME": "用户名"}, expected: "Name: 用户名", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := parseVariables(tt.content, tt.ctx) if tt.name == "VariableInMiddleOfWord" { // This case depends on regex behavior - just check it doesn't crash assert.NotEmpty(t, result) } else { assert.Equal(t, tt.expected, result) } }) } } ================================================ FILE: agent/store/types/sandbox_v2.go ================================================ package types import ( "crypto/sha256" "encoding/json" "fmt" "os" "path/filepath" "sort" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/application" sandboxTypes "github.com/yaoapp/yao/agent/sandbox/v2/types" ) // LoadSandboxConfig reads a sandbox.yao file (JSON or YAML) and returns // the V2 SandboxConfig. Called during Assistant.Load(). func LoadSandboxConfig(filePath string) (*sandboxTypes.SandboxConfig, error) { data, err := os.ReadFile(filePath) if err != nil { return nil, fmt.Errorf("read sandbox config %s: %w", filePath, err) } var cfg sandboxTypes.SandboxConfig if err := application.Parse(filepath.Base(filePath), data, &cfg); err != nil { return nil, fmt.Errorf("parse sandbox config: %w", err) } if cfg.Version != sandboxTypes.SandboxVersionV2 { return nil, fmt.Errorf("sandbox.yao version must be %q, got %q", sandboxTypes.SandboxVersionV2, cfg.Version) } return &cfg, nil } // ToSandboxV2 converts a generic value (typically map[string]any from DSL // parsing) into a V2 SandboxConfig. func ToSandboxV2(v any) (*sandboxTypes.SandboxConfig, error) { if v == nil { return nil, nil } switch sb := v.(type) { case *sandboxTypes.SandboxConfig: return sb, nil case sandboxTypes.SandboxConfig: return &sb, nil default: raw, err := jsoniter.Marshal(v) if err != nil { return nil, fmt.Errorf("sandbox v2 format error: %w", err) } var cfg sandboxTypes.SandboxConfig if err := jsoniter.Unmarshal(raw, &cfg); err != nil { return nil, fmt.Errorf("sandbox v2 format error: %w", err) } return &cfg, nil } } // ComputeConfigHash computes a SHA-256 fingerprint of the sandbox configuration, // MCP servers, and skills directory. Used for hot-reload detection in prepare // step "once" logic. func ComputeConfigHash(cfg *sandboxTypes.SandboxConfig, mcpServers []MCPServerConfig, skillsDir string) string { h := sha256.New() raw, _ := json.Marshal(cfg) h.Write(raw) if len(mcpServers) > 0 { mcpRaw, _ := json.Marshal(mcpServers) h.Write(mcpRaw) } if skillsDir != "" { h.Write([]byte(skillsDir)) entries, err := os.ReadDir(skillsDir) if err == nil { names := make([]string, 0, len(entries)) for _, e := range entries { names = append(names, e.Name()) } sort.Strings(names) for _, n := range names { h.Write([]byte(n)) } } } return fmt.Sprintf("%x", h.Sum(nil)) } ================================================ FILE: agent/store/types/store.go ================================================ package types // ChatStore defines the chat storage interface // Provides operations for chat, message, and resume management type ChatStore interface { // ========================================================================== // Chat Management // ========================================================================== // CreateChat creates a new chat session // chat: Chat session to create // Returns: Potential error CreateChat(chat *Chat) error // GetChat retrieves a single chat by ID // chatID: Chat ID // Returns: Chat information and potential error GetChat(chatID string) (*Chat, error) // UpdateChat updates chat fields // chatID: Chat ID // updates: Map of fields to update // Returns: Potential error UpdateChat(chatID string, updates map[string]interface{}) error // DeleteChat deletes a chat and its associated messages // chatID: Chat ID // Returns: Potential error DeleteChat(chatID string) error // ListChats retrieves a paginated list of chats with optional grouping // filter: Filter conditions including time range, sorting, and grouping // Returns: Paginated chat list (flat or grouped) and potential error ListChats(filter ChatFilter) (*ChatList, error) // ========================================================================== // Message Management // ========================================================================== // SaveMessages batch saves messages for a chat // This is the primary write method - messages are buffered during execution // and batch-written at the end of a request // chatID: Parent chat ID // messages: Messages to save (includes user input and assistant responses) // Returns: Potential error SaveMessages(chatID string, messages []*Message) error // GetMessages retrieves messages for a chat with filtering // chatID: Chat ID // filter: Filter conditions (role, type, block, thread, etc.) // Returns: Message list and potential error GetMessages(chatID string, filter MessageFilter) ([]*Message, error) // UpdateMessage updates a single message // messageID: Message ID // updates: Map of fields to update // Returns: Potential error UpdateMessage(messageID string, updates map[string]interface{}) error // DeleteMessages deletes specific messages from a chat // chatID: Chat ID // messageIDs: List of message IDs to delete // Returns: Potential error DeleteMessages(chatID string, messageIDs []string) error // ========================================================================== // Resume Management (only called on failure/interrupt) // ========================================================================== // SaveResume batch saves resume records // Only called when request is interrupted or failed // records: Resume records to save // Returns: Potential error SaveResume(records []*Resume) error // GetResume retrieves all resume records for a chat // chatID: Chat ID // Returns: Resume records and potential error GetResume(chatID string) ([]*Resume, error) // GetLastResume retrieves the last (most recent) resume record for a chat // chatID: Chat ID // Returns: Last resume record and potential error GetLastResume(chatID string) (*Resume, error) // GetResumeByStackID retrieves resume records for a specific stack // stackID: Stack ID // Returns: Resume records and potential error GetResumeByStackID(stackID string) ([]*Resume, error) // GetStackPath returns the stack path from root to the given stack // stackID: Current stack ID // Returns: Stack path [root_stack_id, ..., current_stack_id] and potential error GetStackPath(stackID string) ([]string, error) // DeleteResume deletes all resume records for a chat // Called after successful resume to clean up // chatID: Chat ID // Returns: Potential error DeleteResume(chatID string) error // ========================================================================== // Search Management // ========================================================================== // SaveSearch saves a search record for a request // Used for citation support, debugging, and replay // search: Search record to save // Returns: Potential error SaveSearch(search *Search) error // GetSearches retrieves all search records for a request // requestID: Request ID // Returns: Search records and potential error GetSearches(requestID string) ([]*Search, error) // GetReference retrieves a single reference by request ID and index // Used for citation click handling // requestID: Request ID // index: Reference index (1-based) // Returns: Reference and potential error GetReference(requestID string, index int) (*Reference, error) // DeleteSearches deletes all search records for a chat // Called when deleting a chat // chatID: Chat ID // Returns: Potential error DeleteSearches(chatID string) error } // AssistantStore defines the assistant storage interface // Separated from ChatStore for clearer responsibility type AssistantStore interface { // SaveAssistant saves assistant information // assistant: Assistant information // Returns: Assistant ID and potential error SaveAssistant(assistant *AssistantModel) (string, error) // UpdateAssistant updates assistant fields // assistantID: Assistant ID // updates: Map of fields to update // Returns: Potential error UpdateAssistant(assistantID string, updates map[string]interface{}) error // DeleteAssistant deletes an assistant // assistantID: Assistant ID // Returns: Potential error DeleteAssistant(assistantID string) error // GetAssistants retrieves a paginated list of assistants with filtering // filter: Filter conditions for querying assistants // locale: Optional locale for i18n translations // Returns: Paginated assistant list and potential error GetAssistants(filter AssistantFilter, locale ...string) (*AssistantList, error) // GetAssistantTags retrieves all unique tags from assistants with filtering // filter: Filter conditions including QueryFilter for permission filtering // locale: Optional locale for i18n translations // Returns: List of tags and potential error GetAssistantTags(filter AssistantFilter, locale ...string) ([]Tag, error) // GetAssistant retrieves a single assistant by ID // assistantID: Assistant ID // fields: List of fields to select, empty/nil means default fields // locale: Optional locale for i18n translations // Returns: Assistant information and potential error GetAssistant(assistantID string, fields []string, locale ...string) (*AssistantModel, error) // DeleteAssistants deletes assistants based on filter conditions // filter: Filter conditions // Returns: Number of deleted records and potential error DeleteAssistants(filter AssistantFilter) (int64, error) } // Store combines ChatStore and AssistantStore interfaces // This is the main interface for the storage layer type Store interface { ChatStore AssistantStore } // SpaceStore defines the interface for Space snapshot operations // Note: Space itself uses plan.Space interface, this is for persistence type SpaceStore interface { // Snapshot returns all key-value pairs in the space Snapshot() map[string]interface{} // Restore sets multiple key-value pairs from a snapshot Restore(data map[string]interface{}) error } ================================================ FILE: agent/store/types/types.go ================================================ package types import ( "encoding/json" "fmt" "time" graphragtypes "github.com/yaoapp/gou/graphrag/types" "github.com/yaoapp/xun/dbal/query" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" sandboxTypes "github.com/yaoapp/yao/agent/sandbox/v2/types" searchTypes "github.com/yaoapp/yao/agent/search/types" ) // Setting represents the conversation configuration structure // Used to configure basic conversation parameters including connector, user field, table name, etc. type Setting struct { Connector string `json:"connector,omitempty" yaml:"connector,omitempty"` // Connector name, default is "default" MaxSize int `json:"max_size,omitempty" yaml:"max_size,omitempty"` // Maximum storage size limit, default is 20 TTL int `json:"ttl,omitempty" yaml:"ttl,omitempty"` // Time To Live in seconds, default is 90 * 24 * 60 * 60 (90 days) Options map[string]interface{} `json:"optional,omitempty" yaml:"optional,omitempty"` // The options for the store } // ============================================================================= // Chat Types // ============================================================================= // Chat represents a chat session type Chat struct { ChatID string `json:"chat_id"` Title string `json:"title,omitempty"` AssistantID string `json:"assistant_id"` LastConnector string `json:"last_connector,omitempty"` // Last used connector ID (updated on each message) LastMode string `json:"last_mode,omitempty"` // Last used chat mode (updated on each message) Status string `json:"status"` // "active" or "archived" Public bool `json:"public"` // Whether shared across all teams Share string `json:"share"` // "private" or "team" Sort int `json:"sort"` // Sort order for display LastMessageAt *time.Time `json:"last_message_at,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` // Permission fields (managed by Yao framework when permission: true) CreatedBy string `json:"__yao_created_by,omitempty"` // User ID who created the record UpdatedBy string `json:"__yao_updated_by,omitempty"` // User ID who last updated TeamID string `json:"__yao_team_id,omitempty"` // Team ID for team-level access TenantID string `json:"__yao_tenant_id,omitempty"` // Tenant ID for multi-tenancy } // ChatFilter for listing chats type ChatFilter struct { UserID string `json:"user_id,omitempty"` TeamID string `json:"team_id,omitempty"` AssistantID string `json:"assistant_id,omitempty"` Status string `json:"status,omitempty"` Keywords string `json:"keywords,omitempty"` ChatIDPrefix string `json:"chat_id_prefix,omitempty"` // Time range filter StartTime *time.Time `json:"start_time,omitempty"` // Filter chats after this time EndTime *time.Time `json:"end_time,omitempty"` // Filter chats before this time TimeField string `json:"time_field,omitempty"` // Field for time filter: "created_at" or "last_message_at" (default) // Sorting OrderBy string `json:"order_by,omitempty"` // Field to sort by (default: "last_message_at") Order string `json:"order,omitempty"` // Sort order: "desc" (default) or "asc" // Response format GroupBy string `json:"group_by,omitempty"` // "time" for time-based groups, empty for flat list // Pagination Page int `json:"page,omitempty"` PageSize int `json:"pagesize,omitempty"` // Permission filter (not serialized) QueryFilter func(query.Query) `json:"-"` // Custom query function for permission filtering } // ChatList paginated response with time-based grouping type ChatList struct { Data []*Chat `json:"data"` Groups []*ChatGroup `json:"groups,omitempty"` // Time-based groups for UI display Page int `json:"page"` PageSize int `json:"pagesize"` PageCount int `json:"pagecount"` Total int `json:"total"` } // ChatGroup represents a time-based group of chats type ChatGroup struct { Label string `json:"label"` // "Today", "Yesterday", "This Week", "This Month", "Earlier" Key string `json:"key"` // "today", "yesterday", "this_week", "this_month", "earlier" Chats []*Chat `json:"chats"` // Chats in this group Count int `json:"count"` // Number of chats in group } // ============================================================================= // Message Types // ============================================================================= // Message represents a chat message type Message struct { MessageID string `json:"message_id"` ChatID string `json:"chat_id"` RequestID string `json:"request_id,omitempty"` Role string `json:"role"` // "user" or "assistant" Type string `json:"type"` // "text", "image", "loading", "tool_call", "retrieval", etc. Props map[string]interface{} `json:"props"` BlockID string `json:"block_id,omitempty"` ThreadID string `json:"thread_id,omitempty"` AssistantID string `json:"assistant_id,omitempty"` Connector string `json:"connector,omitempty"` // Connector ID used for this message Mode string `json:"mode,omitempty"` // Chat mode used for this message (chat or task) Sequence int `json:"sequence"` Metadata map[string]interface{} `json:"metadata,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } // MessageFilter for listing messages type MessageFilter struct { RequestID string `json:"request_id,omitempty"` Role string `json:"role,omitempty"` BlockID string `json:"block_id,omitempty"` ThreadID string `json:"thread_id,omitempty"` Type string `json:"type,omitempty"` Limit int `json:"limit,omitempty"` Offset int `json:"offset,omitempty"` } // ============================================================================= // Resume Types (for recovery from interruption/failure) // ============================================================================= // Resume represents an execution state for recovery // Only stored when request is interrupted or failed type Resume struct { ResumeID string `json:"resume_id"` ChatID string `json:"chat_id"` RequestID string `json:"request_id"` AssistantID string `json:"assistant_id"` StackID string `json:"stack_id"` StackParentID string `json:"stack_parent_id,omitempty"` StackDepth int `json:"stack_depth"` Type string `json:"type"` // "input", "hook_create", "llm", "tool", "hook_next", "delegate" Status string `json:"status"` // "failed" or "interrupted" Input map[string]interface{} `json:"input,omitempty"` Output map[string]interface{} `json:"output,omitempty"` SpaceSnapshot map[string]interface{} `json:"space_snapshot,omitempty"` // Shared space data for recovery Error string `json:"error,omitempty"` Sequence int `json:"sequence"` Metadata map[string]interface{} `json:"metadata,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } // ResumeStatus constants const ( ResumeStatusFailed = "failed" ResumeStatusInterrupted = "interrupted" ) // ResumeType constants const ( ResumeTypeInput = "input" ResumeTypeHookCreate = "hook_create" ResumeTypeLLM = "llm" ResumeTypeTool = "tool" ResumeTypeHookNext = "hook_next" ResumeTypeDelegate = "delegate" ) // AssistantFilter represents the assistant filter structure // Used for filtering and pagination when retrieving assistant lists type AssistantFilter struct { Tags []string `json:"tags,omitempty"` // Filter by tags Type string `json:"type,omitempty"` // Filter by type (single value) Types []string `json:"types,omitempty"` // Filter by types (multiple values, IN query) Keywords string `json:"keywords,omitempty"` // Search in name and description Connector string `json:"connector,omitempty"` // Filter by connector AssistantID string `json:"assistant_id,omitempty"` // Filter by assistant ID AssistantIDs []string `json:"assistant_ids,omitempty"` // Filter by assistant IDs Mentionable *bool `json:"mentionable,omitempty"` // Filter by mentionable status Automated *bool `json:"automated,omitempty"` // Filter by automation status BuiltIn *bool `json:"built_in,omitempty"` // Filter by built-in status Sandbox *bool `json:"sandbox,omitempty"` // Filter by sandbox configuration (true=has sandbox, false=no sandbox) Page int `json:"page,omitempty"` // Page number, starting from 1 PageSize int `json:"pagesize,omitempty"` // Items per page Select []string `json:"select,omitempty"` // Fields to return, returns all fields if empty QueryFilter func(query.Query) `json:"-"` // Custom query function for permission filtering (not serialized) } // AssistantList represents the paginated assistant list response structure // Used for returning paginated assistant lists with metadata type AssistantList struct { Data []*AssistantModel `json:"data"` // List of assistants Page int `json:"page"` // Current page number (1-based) PageSize int `json:"pagesize"` // Number of items per page PageCount int `json:"pagecount"` // Total number of pages Next int `json:"next"` // Next page number (0 if no next page) Prev int `json:"prev"` // Previous page number (0 if no previous page) Total int `json:"total"` // Total number of items across all pages } // AssistantInfo contains basic assistant information for display // Used in chat history to show assistant details with i18n support type AssistantInfo struct { AssistantID string `json:"assistant_id"` Name string `json:"name"` Avatar string `json:"avatar,omitempty"` Description string `json:"description,omitempty"` Connector string `json:"connector,omitempty"` ConnectorOptions *ConnectorOptions `json:"connector_options,omitempty"` Modes []string `json:"modes,omitempty"` DefaultMode string `json:"default_mode,omitempty"` Sandbox bool `json:"sandbox,omitempty"` ComputerFilter *sandboxTypes.ComputerFilter `json:"computer_filter,omitempty"` } // Tag represents a tag type Tag struct { Value string `json:"value"` Label string `json:"label"` } // Prompt a prompt type Prompt struct { Role string `json:"role"` Content string `json:"content"` Name string `json:"name,omitempty"` } // KBSetting Knowledge Base configuration for agent (from agent/kb.yml) type KBSetting struct { Chat *ChatKBSetting `json:"chat,omitempty" yaml:"chat,omitempty"` // Chat session KB settings } // ChatKBSetting represents KB settings for chat sessions type ChatKBSetting struct { EmbeddingProviderID string `json:"embedding_provider_id" yaml:"embedding_provider_id"` // Embedding provider ID EmbeddingOptionID string `json:"embedding_option_id" yaml:"embedding_option_id"` // Embedding option ID Locale string `json:"locale,omitempty" yaml:"locale,omitempty"` // Locale for content processing Config *graphragtypes.CreateCollectionOptions `json:"config,omitempty" yaml:"config,omitempty"` // Vector index configuration Metadata map[string]interface{} `json:"metadata,omitempty" yaml:"metadata,omitempty"` // Collection metadata defaults DocumentDefaults *DocumentDefaults `json:"document_defaults,omitempty" yaml:"document_defaults,omitempty"` // Document processing defaults } // DocumentDefaults represents default settings for document processing type DocumentDefaults struct { Chunking *ProviderOption `json:"chunking,omitempty" yaml:"chunking,omitempty"` // Chunking provider configuration Extraction *ProviderOption `json:"extraction,omitempty" yaml:"extraction,omitempty"` // Extraction provider configuration Converter *ProviderOption `json:"converter,omitempty" yaml:"converter,omitempty"` // Converter provider configuration } // ProviderOption represents a provider and option ID pair type ProviderOption struct { ProviderID string `json:"provider_id" yaml:"provider_id"` // Provider ID OptionID string `json:"option_id" yaml:"option_id"` // Option ID within the provider } // KnowledgeBase the knowledge base configuration type KnowledgeBase struct { Collections []string `json:"collections,omitempty"` // Knowledge base collection IDs Options map[string]interface{} `json:"options,omitempty"` // Additional options for knowledge base } // Database the database configuration type Database struct { Models []string `json:"models,omitempty"` // Database models Options map[string]interface{} `json:"options,omitempty"` // Additional options for database } // MCPServers the MCP servers configuration // Supports multiple formats in the servers array: // - Simple string: "server_id" // - With tools: {"server_id": ["tool1", "tool2"]} // - With resources and tools: {"server_id": {"resources": [...], "tools": [...]}} type MCPServers struct { Servers []MCPServerConfig `json:"servers,omitempty"` // MCP server configurations Options map[string]interface{} `json:"options,omitempty"` // Additional options for MCP servers } // MCPServerConfig represents a single MCP server configuration type MCPServerConfig struct { ServerID string `json:"server_id,omitempty"` // MCP server ID Resources []string `json:"resources,omitempty"` // Resources to use (optional) Tools []string `json:"tools,omitempty"` // Tools to use (optional) } // UnmarshalJSON implements custom JSON unmarshaling for MCPServerConfig // Supports multiple input formats: // 1. Simple string: "server_id" // 2. Standard object: {"server_id": "server1", "resources": [...], "tools": [...]} // 3. Tools array: {"server_id": ["tool1", "tool2"]} // 4. Full config: {"server_id": {"resources": [...], "tools": [...]}} func (m *MCPServerConfig) UnmarshalJSON(data []byte) error { // Try to unmarshal as string first var str string if err := json.Unmarshal(data, &str); err == nil { m.ServerID = str return nil } // Try to unmarshal as standard object with server_id field type Alias MCPServerConfig var stdObj Alias if err := json.Unmarshal(data, &stdObj); err == nil && stdObj.ServerID != "" { *m = MCPServerConfig(stdObj) return nil } // Try to unmarshal as object with single key (alternative formats) var obj map[string]json.RawMessage if err := json.Unmarshal(data, &obj); err != nil { return err } // Should have exactly one key (the server ID) if len(obj) != 1 { return fmt.Errorf("MCPServerConfig object must have exactly one key or server_id field") } // Get the server ID (the only key) for serverID, value := range obj { m.ServerID = serverID // Try to unmarshal value as array of strings (format c: tools only) var tools []string if err := json.Unmarshal(value, &tools); err == nil { m.Tools = tools return nil } // Try to unmarshal as object with resources and tools (format b) var detail struct { Resources []string `json:"resources,omitempty"` Tools []string `json:"tools,omitempty"` } if err := json.Unmarshal(value, &detail); err == nil { m.Resources = detail.Resources m.Tools = detail.Tools return nil } return fmt.Errorf("invalid format for server '%s'", serverID) } return nil } // MarshalJSON implements custom JSON marshaling for MCPServerConfig // Serializes to different formats based on content: // 1. If only ServerID: "server_id" // 2. If has Resources or Tools: {"server_id": "...", "resources": [...], "tools": [...]} func (m MCPServerConfig) MarshalJSON() ([]byte, error) { // If only ServerID, serialize as simple string if len(m.Resources) == 0 && len(m.Tools) == 0 { return json.Marshal(m.ServerID) } // Otherwise, use standard object format type Alias MCPServerConfig return json.Marshal(Alias(m)) } // Workflow the workflow configuration type Workflow struct { Workflows []string `json:"workflows,omitempty"` // Workflow IDs Options map[string]interface{} `json:"options,omitempty"` // Additional workflow options } // Sandbox the sandbox configuration for coding agents (Claude CLI, Cursor CLI) type Sandbox struct { Command string `json:"command"` // Command type: "claude" or "cursor" Image string `json:"image,omitempty"` // Docker image (optional, auto-selected by command) MaxMemory string `json:"max_memory,omitempty"` // Memory limit (e.g., "4g") MaxCPU float64 `json:"max_cpu,omitempty"` // CPU limit (e.g., 2.0) Timeout string `json:"timeout,omitempty"` // Execution timeout (e.g., "10m") Arguments map[string]interface{} `json:"arguments,omitempty"` // Command-specific arguments Secrets map[string]string `json:"secrets,omitempty"` // Secrets to pass to container (e.g., GITHUB_TOKEN: "$ENV.GITHUB_TOKEN") } // Tool represents a tool configuration for storage type Tool struct { Type string `json:"type,omitempty"` Name string `json:"name"` Description string `json:"description,omitempty"` Parameters map[string]interface{} `json:"parameters,omitempty"` } // ToolCalls the tool calls type ToolCalls struct { Tools []Tool `json:"tools,omitempty"` Prompts []Prompt `json:"prompts,omitempty"` } // Placeholder the assistant placeholder type Placeholder struct { Title string `json:"title,omitempty"` Description string `json:"description,omitempty"` Prompts []string `json:"prompts,omitempty"` } // ModelCapability defines the available model capability filters type ModelCapability string // Model capability constants for filtering connectors const ( CapVision ModelCapability = "vision" CapAudio ModelCapability = "audio" CapToolCalls ModelCapability = "tool_calls" CapReasoning ModelCapability = "reasoning" CapStreaming ModelCapability = "streaming" CapJSON ModelCapability = "json" CapMultimodal ModelCapability = "multimodal" CapTemperatureAdjustable ModelCapability = "temperature_adjustable" ) // ConnectorOptions the connector selection options // Allows defining optional connector selection with filtering capabilities type ConnectorOptions struct { Optional *bool `json:"optional"` // Whether connector is optional for user selection (nil=default, false=hidden, true=shown) Connectors []string `json:"connectors,omitempty"` // List of available connectors, empty means all connectors are available Filters []ModelCapability `json:"filters,omitempty"` // Filter by model capabilities, conditions can be stacked } // AssistantModel the assistant database model type AssistantModel struct { ID string `json:"assistant_id"` // Assistant ID Type string `json:"type,omitempty"` // Assistant Type, default is assistant Name string `json:"name,omitempty"` // Assistant Name Avatar string `json:"avatar,omitempty"` // Assistant Avatar Connector string `json:"connector"` // AI Connector (default connector) ConnectorOptions *ConnectorOptions `json:"connector_options,omitempty"` // Connector selection options for user to choose from Path string `json:"path,omitempty"` // Assistant Path BuiltIn bool `json:"built_in,omitempty"` // Whether this is a built-in assistant Sort int `json:"sort,omitempty"` // Assistant Sort Description string `json:"description,omitempty"` // Assistant Description Capabilities string `json:"capabilities,omitempty"` // Assistant capabilities description (useful for Robot orchestration) Tags []string `json:"tags,omitempty"` // Assistant Tags Modes []string `json:"modes,omitempty"` // Supported modes (e.g., ["task", "chat"]), null means all modes are supported DefaultMode string `json:"default_mode,omitempty"` // Default mode, can be empty Readonly bool `json:"readonly,omitempty"` // Whether this assistant is readonly Public bool `json:"public,omitempty"` // Whether this assistant is shared across all teams in the platform Share string `json:"share,omitempty"` // Assistant sharing scope (private/team) Mentionable bool `json:"mentionable,omitempty"` // Whether this assistant is mentionable Automated bool `json:"automated,omitempty"` // Whether this assistant is automated Options map[string]interface{} `json:"options,omitempty"` // AI Options Prompts []Prompt `json:"prompts,omitempty"` // AI Prompts (default prompts) PromptPresets map[string][]Prompt `json:"prompt_presets,omitempty"` // Prompt presets organized by mode (e.g., "chat", "task", etc.) DisableGlobalPrompts bool `json:"disable_global_prompts,omitempty"` // Whether to disable global prompts, default is false KB *KnowledgeBase `json:"kb,omitempty"` // Knowledge base configuration DB *Database `json:"db,omitempty"` // Database configuration MCP *MCPServers `json:"mcp,omitempty"` // MCP servers configuration Workflow *Workflow `json:"workflow,omitempty"` // Workflow configuration Sandbox *Sandbox `json:"sandbox,omitempty"` // Sandbox configuration for coding agents (V1) SandboxV2 *sandboxTypes.SandboxConfig `json:"-"` // V2 sandbox configuration (runtime only, not persisted in DB) IsSandbox bool `json:"-"` // Whether this is a Sandbox assistant (derived from SandboxV2 presence) ComputerFilter *sandboxTypes.ComputerFilter `json:"-"` // Computer filter from DSL sandbox.filter (runtime only) ConfigHash string `json:"-"` // V2 sandbox config fingerprint for hot-reload Placeholder *Placeholder `json:"placeholder,omitempty"` // Assistant Placeholder Source string `json:"source,omitempty"` // Hook script source code Locales i18n.Map `json:"locales,omitempty"` // Assistant Locales Uses *context.Uses `json:"uses,omitempty"` // Assistant-specific wrapper configurations for vision, audio, etc. If not set, use global settings Search *searchTypes.Config `json:"search,omitempty"` // Search configuration (web, kb, db, citation, weights, etc.) Dependencies map[string]string `json:"dependencies,omitempty"` // Dependencies on other MCP Clients (name -> version constraint) CreatedAt int64 `json:"created_at"` // Creation timestamp UpdatedAt int64 `json:"updated_at"` // Last update timestamp // Permission management fields (not exposed in JSON API responses) YaoCreatedBy string `json:"-"` // User who created the assistant (not exposed in JSON) YaoUpdatedBy string `json:"-"` // User who last updated the assistant (not exposed in JSON) YaoTeamID string `json:"-"` // Team ID for team-based access control (not exposed in JSON) YaoTenantID string `json:"-"` // Tenant ID for multi-tenancy support (not exposed in JSON) } // ============================================================================= // Search Types (for search result storage) // ============================================================================= // Search represents stored search results for a request // Stores all intermediate processing results for debugging, replay, and citation type Search struct { ID int64 `json:"id"` RequestID string `json:"request_id"` ChatID string `json:"chat_id"` Query string `json:"query"` // Original query Config map[string]any `json:"config,omitempty"` // Search config used (for tuning) Keywords []string `json:"keywords,omitempty"` // Extracted keywords (Web/NLP) Entities []Entity `json:"entities,omitempty"` // Extracted entities (Graph) Relations []Relation `json:"relations,omitempty"` // Extracted relations (Graph) DSL map[string]any `json:"dsl,omitempty"` // Generated QueryDSL (DB) Source string `json:"source"` // web/kb/db/auto References []Reference `json:"references"` // References with global index Graph []GraphNode `json:"graph,omitempty"` // Graph nodes from KB XML string `json:"xml,omitempty"` // Formatted XML for LLM Prompt string `json:"prompt,omitempty"` // Citation prompt Duration int64 `json:"duration_ms"` // Search duration in ms Error string `json:"error,omitempty"` // Error if failed CreatedAt time.Time `json:"created_at"` } // Reference represents a single reference with global index (for storage) type Reference struct { Index int `json:"index"` // Global index (1-based, unique within request) Type string `json:"type"` // web/kb/db Title string `json:"title"` // Reference title URL string `json:"url,omitempty"` // URL (for web) Snippet string `json:"snippet,omitempty"` // Short snippet Content string `json:"content,omitempty"` // Full content Metadata map[string]any `json:"metadata,omitempty"` } // SearchFilter for listing searches type SearchFilter struct { RequestID string `json:"request_id,omitempty"` ChatID string `json:"chat_id,omitempty"` Source string `json:"source,omitempty"` } // Entity represents an extracted entity (for Graph RAG) type Entity struct { Name string `json:"name"` Type string `json:"type,omitempty"` Source string `json:"source,omitempty"` } // Relation represents an extracted relation (for Graph RAG) type Relation struct { Subject string `json:"subject"` Predicate string `json:"predicate"` Object string `json:"object"` Source string `json:"source,omitempty"` } // GraphNode represents a node from knowledge graph type GraphNode struct { ID string `json:"id"` Type string `json:"type"` Label string `json:"label,omitempty"` Properties map[string]any `json:"properties,omitempty"` Score float64 `json:"score,omitempty"` } ================================================ FILE: agent/store/xun/assistant.go ================================================ package xun import ( "fmt" "math" "time" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/kun/log" "github.com/yaoapp/xun/dbal/query" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" searchTypes "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/store/types" ) // SaveAssistant saves assistant information func (store *Xun) SaveAssistant(assistant *types.AssistantModel) (string, error) { if assistant == nil { return "", fmt.Errorf("assistant cannot be nil") } // Validate required fields if assistant.Name == "" { return "", fmt.Errorf("field name is required") } if assistant.Type == "" { return "", fmt.Errorf("field type is required") } if assistant.Connector == "" { return "", fmt.Errorf("field connector is required") } // Generate assistant_id if not provided if assistant.ID == "" { var err error assistant.ID, err = store.GenerateAssistantID() if err != nil { return "", err } } // Check if assistant exists exists, err := store.query.New(). Table(store.getAssistantTable()). Where("assistant_id", assistant.ID). Exists() if err != nil { return "", err } // Convert model to map for database storage data := make(map[string]interface{}) data["assistant_id"] = assistant.ID data["type"] = assistant.Type data["connector"] = assistant.Connector data["built_in"] = assistant.BuiltIn data["sort"] = assistant.Sort data["readonly"] = assistant.Readonly data["public"] = assistant.Public data["mentionable"] = assistant.Mentionable data["automated"] = assistant.Automated data["disable_global_prompts"] = assistant.DisableGlobalPrompts // Set timestamps now := time.Now().UnixNano() if exists { // Update: set updated_at, keep created_at unchanged if assistant.UpdatedAt == 0 { data["updated_at"] = now } else { data["updated_at"] = assistant.UpdatedAt } // Don't modify created_at on update } else { // Create: set created_at, updated_at is null if assistant.CreatedAt == 0 { data["created_at"] = now } else { data["created_at"] = assistant.CreatedAt } data["updated_at"] = nil } // Handle nullable string fields from assistant.mod.yao // Store as nil if empty string (this matches database nullable: true fields) if assistant.Name != "" { data["name"] = assistant.Name } else { data["name"] = nil } if assistant.Avatar != "" { data["avatar"] = assistant.Avatar } else { data["avatar"] = nil } if assistant.Description != "" { data["description"] = assistant.Description } else { data["description"] = nil } if assistant.Capabilities != "" { data["capabilities"] = assistant.Capabilities } else { data["capabilities"] = nil } if assistant.Path != "" { data["path"] = assistant.Path } else { data["path"] = nil } if assistant.Source != "" { data["source"] = assistant.Source } else { data["source"] = nil } // Share field: nullable: false with default "private" // Apply default if empty if assistant.Share != "" { data["share"] = assistant.Share } else { data["share"] = "private" // Apply default value } // Permission management fields - store as nil if empty if assistant.YaoCreatedBy != "" { data["__yao_created_by"] = assistant.YaoCreatedBy } else { data["__yao_created_by"] = nil } if assistant.YaoUpdatedBy != "" { data["__yao_updated_by"] = assistant.YaoUpdatedBy } else { data["__yao_updated_by"] = nil } if assistant.YaoTeamID != "" { data["__yao_team_id"] = assistant.YaoTeamID } else { data["__yao_team_id"] = nil } if assistant.YaoTenantID != "" { data["__yao_tenant_id"] = assistant.YaoTenantID } else { data["__yao_tenant_id"] = nil } // DefaultMode is a simple string field if assistant.DefaultMode != "" { data["default_mode"] = assistant.DefaultMode } else { data["default_mode"] = nil } // Handle all JSON fields uniformly via marshalJSONFields. // Uses isNil() to correctly skip typed nils stored in interface{}. jsonFields := map[string]interface{}{ "options": assistant.Options, "tags": assistant.Tags, "modes": assistant.Modes, "prompts": assistant.Prompts, "prompt_presets": assistant.PromptPresets, "connector_options": assistant.ConnectorOptions, "kb": assistant.KB, "db": assistant.DB, "mcp": assistant.MCP, "workflow": assistant.Workflow, "sandbox": assistant.Sandbox, "placeholder": assistant.Placeholder, "locales": assistant.Locales, "uses": assistant.Uses, "search": assistant.Search, "dependencies": assistant.Dependencies, } if err := marshalJSONFields(data, jsonFields); err != nil { return "", err } // Update or insert if exists { _, err := store.query.New(). Table(store.getAssistantTable()). Where("assistant_id", assistant.ID). Update(data) if err != nil { return "", err } return assistant.ID, nil } err = store.query.New(). Table(store.getAssistantTable()). Insert(data) if err != nil { return "", err } return assistant.ID, nil } // UpdateAssistant updates specific fields of an assistant func (store *Xun) UpdateAssistant(assistantID string, updates map[string]interface{}) error { if assistantID == "" { return fmt.Errorf("assistant_id is required") } if len(updates) == 0 { return fmt.Errorf("no fields to update") } // Check if assistant exists exists, err := store.query.New(). Table(store.getAssistantTable()). Where("assistant_id", assistantID). Exists() if err != nil { return err } if !exists { return fmt.Errorf("assistant %s not found", assistantID) } // Prepare update data data := make(map[string]interface{}) // List of fields that need JSON marshaling jsonFields := []string{"options", "tags", "modes", "prompts", "prompt_presets", "connector_options", "kb", "db", "mcp", "workflow", "sandbox", "placeholder", "locales", "uses", "search", "dependencies"} jsonFieldSet := make(map[string]bool) for _, field := range jsonFields { jsonFieldSet[field] = true } // List of nullable string fields nullableStringFields := []string{"name", "avatar", "description", "capabilities", "path", "source", "default_mode", "__yao_created_by", "__yao_updated_by", "__yao_team_id", "__yao_tenant_id"} nullableFieldSet := make(map[string]bool) for _, field := range nullableStringFields { nullableFieldSet[field] = true } // Process each update field for key, value := range updates { // Skip system fields that shouldn't be updated directly if key == "assistant_id" || key == "created_at" { continue } // Handle JSON fields if jsonFieldSet[key] { if isNil(value) { data[key] = nil } else { jsonStr, err := jsoniter.MarshalToString(value) if err != nil { return fmt.Errorf("failed to marshal %s: %w", key, err) } data[key] = jsonStr } } else { // Handle regular fields // Convert empty strings to nil for nullable fields if strVal, ok := value.(string); ok && strVal == "" && nullableFieldSet[key] { data[key] = nil continue } data[key] = value } } // Always update updated_at timestamp data["updated_at"] = types.ToMySQLTime(time.Now().UnixNano()) if len(data) == 0 { return fmt.Errorf("no valid fields to update") } // Perform update _, err = store.query.New(). Table(store.getAssistantTable()). Where("assistant_id", assistantID). Update(data) return err } // DeleteAssistant deletes an assistant by assistant_id func (store *Xun) DeleteAssistant(assistantID string) error { // Check if assistant exists exists, err := store.query.New(). Table(store.getAssistantTable()). Where("assistant_id", assistantID). Exists() if err != nil { return err } if !exists { return fmt.Errorf("assistant %s not found", assistantID) } _, err = store.query.New(). Table(store.getAssistantTable()). Where("assistant_id", assistantID). Delete() return err } // GetAssistants retrieves assistants with pagination and filtering func (store *Xun) GetAssistants(filter types.AssistantFilter, locale ...string) (*types.AssistantList, error) { qb := store.query.New(). Table(store.getAssistantTable()) // Apply tag filter if provided if len(filter.Tags) > 0 { qb.Where(func(qb query.Query) { for i, tag := range filter.Tags { // For each tag, we need to match it as part of a JSON array // This will match both single tag arrays ["tag1"] and multi-tag arrays ["tag1","tag2"] pattern := fmt.Sprintf("%%\"%s\"%%", tag) if i == 0 { qb.Where("tags", "like", pattern) } else { qb.OrWhere("tags", "like", pattern) } } }) } // Apply keyword filter if provided if filter.Keywords != "" { qb.Where(func(qb query.Query) { qb.Where("name", "like", fmt.Sprintf("%%%s%%", filter.Keywords)). OrWhere("description", "like", fmt.Sprintf("%%%s%%", filter.Keywords)). OrWhere("capabilities", "like", fmt.Sprintf("%%%s%%", filter.Keywords)). OrWhere("locales", "like", fmt.Sprintf("%%%s%%", filter.Keywords)) }) } // Apply type filter if provided (single value) if filter.Type != "" { qb.Where("type", filter.Type) } // Apply types filter if provided (multiple values, IN query) if len(filter.Types) > 0 { qb.WhereIn("type", filter.Types) } // Apply connector filter if provided if filter.Connector != "" { qb.Where("connector", filter.Connector) } // Apply assistant_id filter if provided if filter.AssistantID != "" { qb.Where("assistant_id", filter.AssistantID) } // Apply assistantIDs filter if provided if len(filter.AssistantIDs) > 0 { qb.WhereIn("assistant_id", filter.AssistantIDs) } // Apply mentionable filter if provided if filter.Mentionable != nil { qb.Where("mentionable", *filter.Mentionable) } // Apply automated filter if provided if filter.Automated != nil { qb.Where("automated", *filter.Automated) } // Apply built_in filter if provided if filter.BuiltIn != nil { qb.Where("built_in", *filter.BuiltIn) } // Apply sandbox filter (true = has sandbox config, false = no sandbox config) // MySQL JSON columns distinguish between SQL NULL and JSON literal null. // CAST(sandbox AS CHAR) returns 'null' for JSON null and NULL for SQL NULL. if filter.Sandbox != nil { if *filter.Sandbox { qb.WhereNotNull("sandbox"). WhereRaw("CAST(`sandbox` AS CHAR) <> 'null'") } else { qb.Where(func(qb query.Query) { qb.WhereNull("sandbox"). OrWhereRaw("CAST(`sandbox` AS CHAR) = 'null'") }) } } // Apply custom query filter function (for permission filtering) if filter.QueryFilter != nil { qb.Where(filter.QueryFilter) } // Set defaults for pagination if filter.PageSize <= 0 { filter.PageSize = 20 } if filter.Page <= 0 { filter.Page = 1 } // Get total count total, err := qb.Clone().Count() if err != nil { return nil, err } // Calculate pagination offset := (filter.Page - 1) * filter.PageSize totalPages := int(math.Ceil(float64(total) / float64(filter.PageSize))) nextPage := filter.Page + 1 if nextPage > totalPages { nextPage = 0 } prevPage := filter.Page - 1 if prevPage < 1 { prevPage = 0 } // Apply select fields with security validation (only if fields are explicitly specified) if len(filter.Select) > 0 { // ValidateAssistantFields will validate fields against whitelist sanitized := types.ValidateAssistantFields(filter.Select) selectFields := make([]interface{}, len(sanitized)) for i, field := range sanitized { selectFields[i] = field } qb.Select(selectFields...) } // If no select fields specified, query will return all fields (SELECT *) // Get paginated results rows, err := qb.OrderBy("sort", "asc"). OrderBy("updated_at", "desc"). Offset(offset). Limit(filter.PageSize). Get() if err != nil { return nil, err } // Convert rows to types.AssistantModel slice assistants := make([]*types.AssistantModel, 0, len(rows)) jsonFields := []string{"tags", "options", "prompts", "prompt_presets", "connector_options", "workflow", "sandbox", "kb", "mcp", "placeholder", "locales", "uses", "search", "dependencies"} for _, row := range rows { data := row.ToMap() if data == nil { continue } // Parse JSON fields store.parseJSONFields(data, jsonFields) // Convert map to types.AssistantModel using existing helper function model, err := types.ToAssistantModel(data) if err != nil { log.Error("Failed to convert row to types.AssistantModel: %s", err.Error()) continue } // Apply i18n translations if locale is provided if len(locale) > 0 && locale[0] != "" && model != nil { store.translate(model, model.ID, locale[0]) } assistants = append(assistants, model) } return &types.AssistantList{ Data: assistants, Page: filter.Page, PageSize: filter.PageSize, PageCount: totalPages, Next: nextPage, Prev: prevPage, Total: int(total), }, nil } // GetAssistant retrieves a single assistant by ID func (store *Xun) GetAssistant(assistantID string, fields []string, locale ...string) (*types.AssistantModel, error) { qb := store.query.New(). Table(store.getAssistantTable()). Where("assistant_id", assistantID) // Apply select fields with security validation // If no fields specified, use default fields fieldsToSelect := fields if len(fieldsToSelect) == 0 { fieldsToSelect = types.AssistantDefaultFields } // ValidateAssistantFields will validate fields against whitelist sanitized := types.ValidateAssistantFields(fieldsToSelect) selectFields := make([]interface{}, len(sanitized)) for i, field := range sanitized { selectFields[i] = field } qb.Select(selectFields...) row, err := qb.First() if err != nil { return nil, err } if row == nil { return nil, fmt.Errorf("assistant %s not found", assistantID) } data := row.ToMap() if len(data) == 0 { return nil, fmt.Errorf("the assistant %s is empty", assistantID) } // Parse JSON fields jsonFields := []string{"tags", "modes", "options", "prompts", "prompt_presets", "connector_options", "workflow", "sandbox", "kb", "db", "mcp", "placeholder", "locales", "uses", "search", "dependencies"} store.parseJSONFields(data, jsonFields) // Convert map to types.AssistantModel model := &types.AssistantModel{ ID: getString(data, "assistant_id"), Type: getString(data, "type"), Name: getString(data, "name"), Avatar: getString(data, "avatar"), Connector: getString(data, "connector"), Path: getString(data, "path"), Source: getString(data, "source"), BuiltIn: getBool(data, "built_in"), Sort: getInt(data, "sort"), Description: getString(data, "description"), Capabilities: getString(data, "capabilities"), DefaultMode: getString(data, "default_mode"), Readonly: getBool(data, "readonly"), Public: getBool(data, "public"), Share: getString(data, "share"), Mentionable: getBool(data, "mentionable"), Automated: getBool(data, "automated"), DisableGlobalPrompts: getBool(data, "disable_global_prompts"), CreatedAt: getInt64(data, "created_at"), UpdatedAt: getInt64(data, "updated_at"), YaoCreatedBy: getString(data, "__yao_created_by"), YaoUpdatedBy: getString(data, "__yao_updated_by"), YaoTeamID: getString(data, "__yao_team_id"), YaoTenantID: getString(data, "__yao_tenant_id"), } // Handle Tags if tags, ok := data["tags"].([]interface{}); ok { model.Tags = make([]string, len(tags)) for i, tag := range tags { if s, ok := tag.(string); ok { model.Tags[i] = s } } } // Handle Modes if modes, ok := data["modes"].([]interface{}); ok { model.Modes = make([]string, len(modes)) for i, mode := range modes { if s, ok := mode.(string); ok { model.Modes[i] = s } } } // Handle Options if options, ok := data["options"].(map[string]interface{}); ok { model.Options = options } // Handle typed fields with conversion if prompts, has := data["prompts"]; has && prompts != nil { // Try to unmarshal to []Prompt raw, err := jsoniter.Marshal(prompts) if err == nil { var p []types.Prompt if err := jsoniter.Unmarshal(raw, &p); err == nil { model.Prompts = p } } } if promptPresets, has := data["prompt_presets"]; has && promptPresets != nil { raw, err := jsoniter.Marshal(promptPresets) if err == nil { var pp map[string][]types.Prompt if err := jsoniter.Unmarshal(raw, &pp); err == nil { model.PromptPresets = pp } } } if connectorOptions, has := data["connector_options"]; has && connectorOptions != nil { raw, err := jsoniter.Marshal(connectorOptions) if err == nil { var co types.ConnectorOptions if err := jsoniter.Unmarshal(raw, &co); err == nil { model.ConnectorOptions = &co } } } if kb, has := data["kb"]; has && kb != nil { kbConverted, err := types.ToKnowledgeBase(kb) if err == nil { model.KB = kbConverted } } if db, has := data["db"]; has && db != nil { dbConverted, err := types.ToDatabase(db) if err == nil { model.DB = dbConverted } } if mcp, has := data["mcp"]; has && mcp != nil { mcpConverted, err := types.ToMCPServers(mcp) if err == nil { model.MCP = mcpConverted } } if workflow, has := data["workflow"]; has && workflow != nil { wf, err := types.ToWorkflow(workflow) if err == nil { model.Workflow = wf } } if sandbox, has := data["sandbox"]; has && sandbox != nil { sb, err := types.ToSandbox(sandbox) if err == nil { model.Sandbox = sb } } if placeholder, has := data["placeholder"]; has && placeholder != nil { raw, err := jsoniter.Marshal(placeholder) if err == nil { var ph types.Placeholder if err := jsoniter.Unmarshal(raw, &ph); err == nil { model.Placeholder = &ph } } } if locales, has := data["locales"]; has && locales != nil { raw, err := jsoniter.Marshal(locales) if err == nil { var loc i18n.Map if err := jsoniter.Unmarshal(raw, &loc); err == nil { model.Locales = loc } } } if uses, has := data["uses"]; has && uses != nil { raw, err := jsoniter.Marshal(uses) if err == nil { var u context.Uses if err := jsoniter.Unmarshal(raw, &u); err == nil { model.Uses = &u } } } if search, has := data["search"]; has && search != nil { raw, err := jsoniter.Marshal(search) if err == nil { var s searchTypes.Config if err := jsoniter.Unmarshal(raw, &s); err == nil { model.Search = &s } } } if deps, has := data["dependencies"]; has && deps != nil { raw, err := jsoniter.Marshal(deps) if err == nil { var d map[string]string if err := jsoniter.Unmarshal(raw, &d); err == nil { model.Dependencies = d } } } // Apply i18n translation if locale is provided if len(locale) > 0 && locale[0] != "" { store.translate(model, assistantID, locale[0]) } return model, nil } // DeleteAssistants deletes assistants based on filter conditions func (store *Xun) DeleteAssistants(filter types.AssistantFilter) (int64, error) { qb := store.query.New(). Table(store.getAssistantTable()) // Apply tag filter if provided if len(filter.Tags) > 0 { qb.Where(func(qb query.Query) { for i, tag := range filter.Tags { pattern := fmt.Sprintf("%%\"%s\"%%", tag) if i == 0 { qb.Where("tags", "like", pattern) } else { qb.OrWhere("tags", "like", pattern) } } }) } // Apply keyword filter if provided if filter.Keywords != "" { qb.Where(func(qb query.Query) { qb.Where("name", "like", fmt.Sprintf("%%%s%%", filter.Keywords)). OrWhere("description", "like", fmt.Sprintf("%%%s%%", filter.Keywords)) }) } // Apply connector filter if provided if filter.Connector != "" { qb.Where("connector", filter.Connector) } // Apply assistant_id filter if provided if filter.AssistantID != "" { qb.Where("assistant_id", filter.AssistantID) } // Apply assistantIDs filter if provided if len(filter.AssistantIDs) > 0 { qb.WhereIn("assistant_id", filter.AssistantIDs) } // Apply mentionable filter if provided if filter.Mentionable != nil { qb.Where("mentionable", *filter.Mentionable) } // Apply automated filter if provided if filter.Automated != nil { qb.Where("automated", *filter.Automated) } // Apply built_in filter if provided if filter.BuiltIn != nil { qb.Where("built_in", *filter.BuiltIn) } // Execute delete and return number of deleted records return qb.Delete() } // GetAssistantTags retrieves all unique tags from assistants with filtering func (store *Xun) GetAssistantTags(filter types.AssistantFilter, locale ...string) ([]types.Tag, error) { qb := store.query.New().Table(store.getAssistantTable()) // Apply type filter (default to "assistant") typeFilter := "assistant" if filter.Type != "" { typeFilter = filter.Type } qb.Where("type", typeFilter) // Apply custom query filter function (for permission filtering) if filter.QueryFilter != nil { qb.Where(filter.QueryFilter) } // Apply other filters if provided if filter.Connector != "" { qb.Where("connector", filter.Connector) } if filter.BuiltIn != nil { qb.Where("built_in", *filter.BuiltIn) } if filter.Mentionable != nil { qb.Where("mentionable", *filter.Mentionable) } if filter.Automated != nil { qb.Where("automated", *filter.Automated) } // Apply keyword filter if provided if filter.Keywords != "" { qb.Where(func(qb query.Query) { qb.Where("name", "like", fmt.Sprintf("%%%s%%", filter.Keywords)). OrWhere("description", "like", fmt.Sprintf("%%%s%%", filter.Keywords)) }) } rows, err := qb.Select("tags").GroupBy("tags").Get() if err != nil { return nil, err } tagSet := map[string]bool{} for _, row := range rows { if tags, ok := row["tags"].(string); ok && tags != "" { var tagList []string if err := jsoniter.UnmarshalFromString(tags, &tagList); err == nil { for _, tag := range tagList { tagSet[tag] = true } } } } lang := "en" if len(locale) > 0 { lang = locale[0] } // Convert map keys to slice tags := make([]types.Tag, 0, len(tagSet)) for tag := range tagSet { tags = append(tags, types.Tag{ Value: tag, Label: i18n.TranslateGlobal(lang, tag).(string), }) } return tags, nil } // translate applies i18n translation to assistant model fields func (store *Xun) translate(model *types.AssistantModel, assistantID string, locale string) { if model == nil { return } // Translate name if translated := i18n.Translate(assistantID, locale, model.Name); translated != nil { if s, ok := translated.(string); ok { model.Name = s } } // Translate description if translated := i18n.Translate(assistantID, locale, model.Description); translated != nil { if s, ok := translated.(string); ok { model.Description = s } } // Translate capabilities if translated := i18n.Translate(assistantID, locale, model.Capabilities); translated != nil { if s, ok := translated.(string); ok { model.Capabilities = s } } // Translate prompts if model.Prompts != nil { for i := range model.Prompts { if translated := i18n.Translate(assistantID, locale, model.Prompts[i].Name); translated != nil { if s, ok := translated.(string); ok { model.Prompts[i].Name = s } } if translated := i18n.Translate(assistantID, locale, model.Prompts[i].Content); translated != nil { if s, ok := translated.(string); ok { model.Prompts[i].Content = s } } } } // Translate placeholder if model.Placeholder != nil { if translated := i18n.Translate(assistantID, locale, model.Placeholder.Title); translated != nil { if s, ok := translated.(string); ok { model.Placeholder.Title = s } } if translated := i18n.Translate(assistantID, locale, model.Placeholder.Description); translated != nil { if s, ok := translated.(string); ok { model.Placeholder.Description = s } } if translated := i18n.Translate(assistantID, locale, model.Placeholder.Prompts); translated != nil { if prompts, ok := translated.([]string); ok { model.Placeholder.Prompts = prompts } } } // Translate tags if translated := i18n.Translate(assistantID, locale, model.Tags); translated != nil { if tags, ok := translated.([]string); ok { model.Tags = tags } } } ================================================ FILE: agent/store/xun/assistant_test.go ================================================ package xun_test import ( "fmt" "os" "strings" "testing" "time" "github.com/yaoapp/xun/dbal/query" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/i18n" searchTypes "github.com/yaoapp/yao/agent/search/types" "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/agent/store/xun" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func TestMain(m *testing.M) { // Setup will be done in each test via test.Prepare test.Prepare(nil, config.Conf) defer test.Clean() // Run tests and exit with appropriate exit code code := m.Run() os.Exit(code) } // TestSaveAssistant tests creating and updating assistants func TestSaveAssistant(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create a new xun store store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("CreateNewAssistant", func(t *testing.T) { assistant := &types.AssistantModel{ Name: "Test Assistant", Type: "assistant", Connector: "openai", Description: "A test assistant for unit testing", Avatar: "https://example.com/avatar.png", Tags: []string{"test", "automation"}, Options: map[string]interface{}{"temperature": 0.7}, Sort: 100, BuiltIn: false, Readonly: false, Public: false, Share: "private", Mentionable: true, Automated: true, } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant: %v", err) } if id == "" { t.Error("Expected non-empty assistant ID") } if assistant.ID == "" { t.Error("Expected assistant.ID to be set") } t.Logf("Created assistant with ID: %s", id) }) t.Run("UpdateExistingAssistant", func(t *testing.T) { // Create initial assistant assistant := &types.AssistantModel{ Name: "Update Test Assistant", Type: "assistant", Connector: "openai", Description: "Original description", Tags: []string{"original"}, Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update the assistant assistant.Description = "Updated description" assistant.Tags = []string{"updated", "modified"} assistant.Sort = 200 updatedID, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to update assistant: %v", err) } if updatedID != id { t.Errorf("Expected ID %s, got %s", id, updatedID) } // Verify update - request all fields to see the update retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve updated assistant: %v", err) } if retrieved.Description != "Updated description" { t.Errorf("Expected description 'Updated description', got '%s'", retrieved.Description) } if len(retrieved.Tags) != 2 || retrieved.Tags[0] != "updated" { t.Errorf("Expected tags [updated, modified], got %v", retrieved.Tags) } }) t.Run("ValidationErrors", func(t *testing.T) { // Test nil assistant _, err := store.SaveAssistant(nil) if err == nil { t.Error("Expected error for nil assistant") } // Test missing name assistant := &types.AssistantModel{ Type: "assistant", Connector: "openai", } _, err = store.SaveAssistant(assistant) if err == nil { t.Error("Expected error for missing name") } // Test missing type assistant = &types.AssistantModel{ Name: "Test", Connector: "openai", } _, err = store.SaveAssistant(assistant) if err == nil { t.Error("Expected error for missing type") } // Test missing connector assistant = &types.AssistantModel{ Name: "Test", Type: "assistant", } _, err = store.SaveAssistant(assistant) if err == nil { t.Error("Expected error for missing connector") } }) t.Run("ComplexDataTypes", func(t *testing.T) { assistant := &types.AssistantModel{ Name: "Complex Assistant", Type: "assistant", Connector: "openai", Share: "private", Prompts: []types.Prompt{ {Role: "system", Content: "You are a helpful assistant"}, {Role: "user", Content: "Hello"}, }, Options: map[string]interface{}{ "temperature": 0.8, "max_tokens": 2000, }, Tags: []string{"complex", "testing", "data"}, Placeholder: &types.Placeholder{ Title: "Type your message", Description: "Enter your message here...", Prompts: []string{"What can I help you with?"}, }, } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save complex assistant: %v", err) } // Retrieve and verify - request all fields for complex data retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve complex assistant: %v", err) } if len(retrieved.Prompts) != 2 { t.Errorf("Expected 2 prompts, got %d", len(retrieved.Prompts)) } if retrieved.Placeholder == nil { t.Error("Expected placeholder to be set") } if len(retrieved.Tags) != 3 { t.Errorf("Expected 3 tags, got %d", len(retrieved.Tags)) } }) t.Run("SaveWithMCPServers", func(t *testing.T) { // Test creating assistant with MCP servers directly // This will test that: // - server1 (no tools/resources) serializes as "server1" // - server2 (with tools) serializes as {"server_id":"server2","tools":[...]} // - server3 (with both) serializes as {"server_id":"server3","resources":[...],"tools":[...]} assistant := &types.AssistantModel{ Name: "MCP Save Test", Type: "assistant", Connector: "openai", Share: "private", MCP: &types.MCPServers{ Servers: []types.MCPServerConfig{ {ServerID: "server1"}, { ServerID: "server2", Tools: []string{"tool1", "tool2"}, }, { ServerID: "server3", Resources: []string{"res1", "res2"}, Tools: []string{"tool3", "tool4"}, }, }, Options: map[string]interface{}{ "timeout": 30, }, }, } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant with MCP: %v", err) } // Retrieve and verify MCP configuration - mcp is in default fields retrieved, err := store.GetAssistant(id, []string{}) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.MCP == nil { t.Fatal("Expected MCP to be set") } if len(retrieved.MCP.Servers) != 3 { t.Errorf("Expected 3 MCP servers, got %d", len(retrieved.MCP.Servers)) } // Verify server1 (simple format) if retrieved.MCP.Servers[0].ServerID != "server1" { t.Errorf("Expected server1, got '%s'", retrieved.MCP.Servers[0].ServerID) } // Verify server2 (with tools) if retrieved.MCP.Servers[1].ServerID != "server2" { t.Errorf("Expected server2, got '%s'", retrieved.MCP.Servers[1].ServerID) } if len(retrieved.MCP.Servers[1].Tools) != 2 { t.Errorf("Expected 2 tools for server2, got %d", len(retrieved.MCP.Servers[1].Tools)) } // Verify server3 (with resources and tools) if retrieved.MCP.Servers[2].ServerID != "server3" { t.Errorf("Expected server3, got '%s'", retrieved.MCP.Servers[2].ServerID) } if len(retrieved.MCP.Servers[2].Resources) != 2 { t.Errorf("Expected 2 resources for server3, got %d", len(retrieved.MCP.Servers[2].Resources)) } if len(retrieved.MCP.Servers[2].Tools) != 2 { t.Errorf("Expected 2 tools for server3, got %d", len(retrieved.MCP.Servers[2].Tools)) } // Verify options if retrieved.MCP.Options == nil { t.Error("Expected MCP options to be set") } if timeout, ok := retrieved.MCP.Options["timeout"].(float64); !ok || timeout != 30 { t.Errorf("Expected timeout 30, got %v", retrieved.MCP.Options["timeout"]) } t.Logf("Successfully verified MCP configuration for assistant %s", id) }) t.Run("UpdateWithMCPServers", func(t *testing.T) { // Create assistant without MCP assistant := &types.AssistantModel{ Name: "MCP Update Test", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update assistant with MCP assistant.MCP = &types.MCPServers{ Servers: []types.MCPServerConfig{ {ServerID: "new-server1"}, { ServerID: "new-server2", Tools: []string{"newtool1"}, }, }, } _, err = store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to update assistant with MCP: %v", err) } // Retrieve and verify - mcp is in default fields retrieved, err := store.GetAssistant(id, []string{}) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.MCP == nil || len(retrieved.MCP.Servers) != 2 { t.Errorf("Expected 2 MCP servers, got %v", retrieved.MCP) } if retrieved.MCP.Servers[0].ServerID != "new-server1" { t.Errorf("Expected new-server1, got '%s'", retrieved.MCP.Servers[0].ServerID) } t.Logf("Successfully updated and verified MCP for assistant %s", id) }) t.Run("UsesConfiguration", func(t *testing.T) { // Test assistant with Uses configuration assistant := &types.AssistantModel{ Name: "Uses Test Assistant", Type: "assistant", Connector: "openai", Share: "private", Uses: &context.Uses{ Vision: "mcp:vision-server", Audio: "agent", Search: "mcp:search-server", Fetch: "agent", }, } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant with uses: %v", err) } // Retrieve and verify uses configuration - uses is NOT in default fields, need to request all retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Uses == nil { t.Fatal("Expected uses to be set") } if retrieved.Uses.Vision != "mcp:vision-server" { t.Errorf("Expected vision 'mcp:vision-server', got '%s'", retrieved.Uses.Vision) } if retrieved.Uses.Audio != "agent" { t.Errorf("Expected audio 'agent', got '%s'", retrieved.Uses.Audio) } if retrieved.Uses.Search != "mcp:search-server" { t.Errorf("Expected search 'mcp:search-server', got '%s'", retrieved.Uses.Search) } if retrieved.Uses.Fetch != "agent" { t.Errorf("Expected fetch 'agent', got '%s'", retrieved.Uses.Fetch) } t.Logf("Successfully saved and retrieved assistant with uses configuration") }) t.Run("NilUses", func(t *testing.T) { // Test assistant without Uses configuration assistant := &types.AssistantModel{ Name: "No Uses Assistant", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant without uses: %v", err) } // Retrieve and verify uses is nil - request all fields to check uses retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Uses != nil { t.Errorf("Expected uses to be nil, got %+v", retrieved.Uses) } }) t.Run("PartialUsesConfiguration", func(t *testing.T) { // Test assistant with partial Uses configuration assistant := &types.AssistantModel{ Name: "Partial Uses Assistant", Type: "assistant", Connector: "openai", Share: "private", Uses: &context.Uses{ Vision: "mcp:vision-only", // Audio, Search, Fetch not set }, } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant with partial uses: %v", err) } // Retrieve and verify - request all fields for uses retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Uses == nil { t.Fatal("Expected uses to be set") } if retrieved.Uses.Vision != "mcp:vision-only" { t.Errorf("Expected vision 'mcp:vision-only', got '%s'", retrieved.Uses.Vision) } if retrieved.Uses.Audio != "" { t.Errorf("Expected audio to be empty, got '%s'", retrieved.Uses.Audio) } if retrieved.Uses.Search != "" { t.Errorf("Expected search to be empty, got '%s'", retrieved.Uses.Search) } if retrieved.Uses.Fetch != "" { t.Errorf("Expected fetch to be empty, got '%s'", retrieved.Uses.Fetch) } }) t.Run("SearchConfiguration", func(t *testing.T) { // Test assistant with Search configuration assistant := &types.AssistantModel{ Name: "Search Config Test Assistant", Type: "assistant", Connector: "openai", Share: "private", Search: &searchTypes.Config{ Web: &searchTypes.WebConfig{ Provider: "tavily", MaxResults: 15, }, KB: &searchTypes.KBConfig{ Collections: []string{"docs", "faq"}, Threshold: 0.8, Graph: true, }, DB: &searchTypes.DBConfig{ Models: []string{"user", "product"}, MaxResults: 50, }, Citation: &searchTypes.CitationConfig{ Format: "#ref:{id}", AutoInjectPrompt: true, }, Weights: &searchTypes.WeightsConfig{ User: 1.0, Hook: 0.9, Auto: 0.7, }, }, } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant with search config: %v", err) } // Retrieve and verify search configuration - search is NOT in default fields retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Search == nil { t.Fatal("Expected search to be set") } // Verify Web config if retrieved.Search.Web == nil { t.Fatal("Expected search.web to be set") } if retrieved.Search.Web.Provider != "tavily" { t.Errorf("Expected web provider 'tavily', got '%s'", retrieved.Search.Web.Provider) } if retrieved.Search.Web.MaxResults != 15 { t.Errorf("Expected web max_results 15, got %d", retrieved.Search.Web.MaxResults) } // Verify KB config if retrieved.Search.KB == nil { t.Fatal("Expected search.kb to be set") } if len(retrieved.Search.KB.Collections) != 2 { t.Errorf("Expected 2 KB collections, got %d", len(retrieved.Search.KB.Collections)) } if retrieved.Search.KB.Collections[0] != "docs" { t.Errorf("Expected first collection 'docs', got '%s'", retrieved.Search.KB.Collections[0]) } if retrieved.Search.KB.Threshold != 0.8 { t.Errorf("Expected KB threshold 0.8, got %f", retrieved.Search.KB.Threshold) } if !retrieved.Search.KB.Graph { t.Error("Expected KB graph to be true") } // Verify DB config if retrieved.Search.DB == nil { t.Fatal("Expected search.db to be set") } if len(retrieved.Search.DB.Models) != 2 { t.Errorf("Expected 2 DB models, got %d", len(retrieved.Search.DB.Models)) } if retrieved.Search.DB.MaxResults != 50 { t.Errorf("Expected DB max_results 50, got %d", retrieved.Search.DB.MaxResults) } // Verify Citation config if retrieved.Search.Citation == nil { t.Fatal("Expected search.citation to be set") } if retrieved.Search.Citation.Format != "#ref:{id}" { t.Errorf("Expected citation format '#ref:{id}', got '%s'", retrieved.Search.Citation.Format) } if !retrieved.Search.Citation.AutoInjectPrompt { t.Error("Expected citation auto_inject_prompt to be true") } // Verify Weights config if retrieved.Search.Weights == nil { t.Fatal("Expected search.weights to be set") } if retrieved.Search.Weights.User != 1.0 { t.Errorf("Expected weights.user 1.0, got %f", retrieved.Search.Weights.User) } if retrieved.Search.Weights.Hook != 0.9 { t.Errorf("Expected weights.hook 0.9, got %f", retrieved.Search.Weights.Hook) } if retrieved.Search.Weights.Auto != 0.7 { t.Errorf("Expected weights.auto 0.7, got %f", retrieved.Search.Weights.Auto) } t.Logf("Successfully saved and retrieved assistant with search configuration") }) t.Run("NilSearchConfiguration", func(t *testing.T) { // Test assistant without Search configuration assistant := &types.AssistantModel{ Name: "No Search Config Assistant", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant without search: %v", err) } // Retrieve and verify search is nil - request all fields to check search retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Search != nil { t.Errorf("Expected search to be nil, got %+v", retrieved.Search) } }) t.Run("PartialSearchConfiguration", func(t *testing.T) { // Test assistant with partial Search configuration assistant := &types.AssistantModel{ Name: "Partial Search Config Assistant", Type: "assistant", Connector: "openai", Share: "private", Search: &searchTypes.Config{ Web: &searchTypes.WebConfig{ Provider: "serper", }, // KB, DB, Citation, Weights not set }, } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant with partial search: %v", err) } // Retrieve and verify - request all fields for search retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Search == nil { t.Fatal("Expected search to be set") } if retrieved.Search.Web == nil { t.Fatal("Expected search.web to be set") } if retrieved.Search.Web.Provider != "serper" { t.Errorf("Expected web provider 'serper', got '%s'", retrieved.Search.Web.Provider) } // Other fields should be nil if retrieved.Search.KB != nil { t.Errorf("Expected search.kb to be nil, got %+v", retrieved.Search.KB) } if retrieved.Search.DB != nil { t.Errorf("Expected search.db to be nil, got %+v", retrieved.Search.DB) } if retrieved.Search.Citation != nil { t.Errorf("Expected search.citation to be nil, got %+v", retrieved.Search.Citation) } if retrieved.Search.Weights != nil { t.Errorf("Expected search.weights to be nil, got %+v", retrieved.Search.Weights) } }) t.Run("ConnectorOptions", func(t *testing.T) { // Test assistant with connector options optionalTrue := true assistant := &types.AssistantModel{ Name: "Connector Options Test", Type: "assistant", Connector: "openai", Share: "private", ConnectorOptions: &types.ConnectorOptions{ Optional: &optionalTrue, Connectors: []string{"openai", "anthropic"}, Filters: []types.ModelCapability{types.CapVision, types.CapToolCalls}, }, } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant with connector options: %v", err) } // Retrieve and verify - connector_options is NOT in default fields retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.ConnectorOptions == nil { t.Fatal("Expected connector options to be set") } if retrieved.ConnectorOptions.Optional == nil || !*retrieved.ConnectorOptions.Optional { t.Error("Expected optional to be true") } if len(retrieved.ConnectorOptions.Connectors) != 2 { t.Errorf("Expected 2 connectors, got %d", len(retrieved.ConnectorOptions.Connectors)) } if len(retrieved.ConnectorOptions.Filters) != 2 { t.Errorf("Expected 2 filters, got %d", len(retrieved.ConnectorOptions.Filters)) } if retrieved.ConnectorOptions.Filters[0] != types.CapVision { t.Errorf("Expected first filter to be vision, got '%s'", retrieved.ConnectorOptions.Filters[0]) } t.Logf("Successfully saved and retrieved connector options for assistant %s", id) }) t.Run("PromptPresets", func(t *testing.T) { // Test assistant with prompt presets assistant := &types.AssistantModel{ Name: "Prompt Presets Test", Type: "assistant", Connector: "openai", Share: "private", PromptPresets: map[string][]types.Prompt{ "chat": { {Role: "system", Content: "You are a friendly chatbot"}, {Role: "user", Content: "Hello!"}, }, "task": { {Role: "system", Content: "You are a task executor"}, }, }, } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant with prompt presets: %v", err) } // Retrieve and verify - prompt_presets is NOT in default fields retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.PromptPresets == nil { t.Fatal("Expected prompt presets to be set") } if len(retrieved.PromptPresets) != 2 { t.Errorf("Expected 2 preset groups, got %d", len(retrieved.PromptPresets)) } chatPrompts, ok := retrieved.PromptPresets["chat"] if !ok { t.Fatal("Expected 'chat' preset to exist") } if len(chatPrompts) != 2 { t.Errorf("Expected 2 chat prompts, got %d", len(chatPrompts)) } if chatPrompts[0].Role != "system" { t.Errorf("Expected system role, got '%s'", chatPrompts[0].Role) } taskPrompts, ok := retrieved.PromptPresets["task"] if !ok { t.Fatal("Expected 'task' preset to exist") } if len(taskPrompts) != 1 { t.Errorf("Expected 1 task prompt, got %d", len(taskPrompts)) } t.Logf("Successfully saved and retrieved prompt presets for assistant %s", id) }) t.Run("SourceField", func(t *testing.T) { // Test assistant with source code sourceCode := `function onMessage(msg) { console.log("Received:", msg); return { status: "ok" }; }` assistant := &types.AssistantModel{ Name: "Source Field Test", Type: "assistant", Connector: "openai", Share: "private", Source: sourceCode, } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant with source: %v", err) } // Retrieve and verify - source is NOT in default fields retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Source != sourceCode { t.Errorf("Expected source code to match, got '%s'", retrieved.Source) } t.Logf("Successfully saved and retrieved source code for assistant %s", id) }) t.Run("SandboxConfiguration", func(t *testing.T) { assistant := &types.AssistantModel{ Name: "Sandbox Test Assistant", Type: "assistant", Connector: "openai", Share: "private", Sandbox: &types.Sandbox{ Command: "claude", Timeout: "5m", Arguments: map[string]interface{}{ "max_turns": 10, "permission_mode": "bypassPermissions", }, }, } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant with sandbox: %v", err) } retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Sandbox == nil { t.Fatal("Expected sandbox to be set") } if retrieved.Sandbox.Command != "claude" { t.Errorf("Expected command 'claude', got '%s'", retrieved.Sandbox.Command) } if retrieved.Sandbox.Timeout != "5m" { t.Errorf("Expected timeout '5m', got '%s'", retrieved.Sandbox.Timeout) } if retrieved.Sandbox.Arguments == nil { t.Fatal("Expected sandbox arguments to be set") } if maxTurns, ok := retrieved.Sandbox.Arguments["max_turns"].(float64); !ok || maxTurns != 10 { t.Errorf("Expected max_turns 10, got %v", retrieved.Sandbox.Arguments["max_turns"]) } t.Logf("Successfully saved and retrieved sandbox configuration for assistant %s", id) }) t.Run("SandboxWithImage", func(t *testing.T) { assistant := &types.AssistantModel{ Name: "Sandbox Image Test", Type: "assistant", Connector: "openai", Share: "private", Sandbox: &types.Sandbox{ Command: "claude", Image: "yaoapp/sandbox-claude-desktop:latest", Timeout: "20m", MaxMemory: "4g", MaxCPU: 2.0, Arguments: map[string]interface{}{ "max_turns": 500, "permission_mode": "bypassPermissions", "disallowed_tools": "WebSearch", }, Secrets: map[string]string{ "GITHUB_TOKEN": "$ENV.GITHUB_TOKEN", }, }, } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant with sandbox image: %v", err) } retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Sandbox == nil { t.Fatal("Expected sandbox to be set") } if retrieved.Sandbox.Image != "yaoapp/sandbox-claude-desktop:latest" { t.Errorf("Expected image 'yaoapp/sandbox-claude-desktop:latest', got '%s'", retrieved.Sandbox.Image) } if retrieved.Sandbox.MaxMemory != "4g" { t.Errorf("Expected max_memory '4g', got '%s'", retrieved.Sandbox.MaxMemory) } if retrieved.Sandbox.MaxCPU != 2.0 { t.Errorf("Expected max_cpu 2.0, got %f", retrieved.Sandbox.MaxCPU) } if retrieved.Sandbox.Secrets == nil || retrieved.Sandbox.Secrets["GITHUB_TOKEN"] != "$ENV.GITHUB_TOKEN" { t.Errorf("Expected secrets to contain GITHUB_TOKEN, got %v", retrieved.Sandbox.Secrets) } t.Logf("Successfully saved and retrieved sandbox with image for assistant %s", id) }) t.Run("NilSandbox", func(t *testing.T) { assistant := &types.AssistantModel{ Name: "No Sandbox Assistant", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant without sandbox: %v", err) } retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Sandbox != nil { t.Errorf("Expected sandbox to be nil, got %+v", retrieved.Sandbox) } }) t.Run("CapabilitiesField", func(t *testing.T) { assistant := &types.AssistantModel{ Name: "Capabilities Test", Type: "assistant", Connector: "openai", Share: "private", Description: "A test assistant", Capabilities: "Can search the web, analyze data, write code, and summarize documents.", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant with capabilities: %v", err) } retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Capabilities != "Can search the web, analyze data, write code, and summarize documents." { t.Errorf("Expected capabilities to match, got '%s'", retrieved.Capabilities) } if retrieved.Description != "A test assistant" { t.Errorf("Expected description 'A test assistant', got '%s'", retrieved.Description) } t.Logf("Successfully saved and retrieved capabilities for assistant %s", id) }) t.Run("EmptyCapabilities", func(t *testing.T) { assistant := &types.AssistantModel{ Name: "No Capabilities Assistant", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant without capabilities: %v", err) } retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Capabilities != "" { t.Errorf("Expected empty capabilities, got '%s'", retrieved.Capabilities) } }) t.Run("CapabilitiesWithI18n", func(t *testing.T) { assistant := &types.AssistantModel{ Name: "{{name}}", Type: "assistant", Connector: "openai", Share: "private", Description: "{{description}}", Capabilities: "{{capabilities}}", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant with i18n capabilities: %v", err) } // Setup i18n i18n.Locales[id] = map[string]i18n.I18n{ "en": { Locale: "en", Messages: map[string]any{ "name": "i18n Test", "description": "Description in English", "capabilities": "Can do X, Y, and Z", }, }, "zh-cn": { Locale: "zh-cn", Messages: map[string]any{ "name": "国际化测试", "description": "中文描述", "capabilities": "可以做X、Y和Z", }, }, } retrievedEN, err := store.GetAssistant(id, types.AssistantFullFields, "en") if err != nil { t.Fatalf("Failed to get assistant with EN locale: %v", err) } if retrievedEN.Capabilities != "Can do X, Y, and Z" { t.Errorf("Expected capabilities 'Can do X, Y, and Z', got '%s'", retrievedEN.Capabilities) } retrievedZH, err := store.GetAssistant(id, types.AssistantFullFields, "zh-cn") if err != nil { t.Fatalf("Failed to get assistant with ZH locale: %v", err) } if retrievedZH.Capabilities != "可以做X、Y和Z" { t.Errorf("Expected capabilities '可以做X、Y和Z', got '%s'", retrievedZH.Capabilities) } // Cleanup delete(i18n.Locales, id) t.Logf("Successfully tested capabilities i18n for assistant %s", id) }) t.Run("CapabilitiesInKeywordSearch", func(t *testing.T) { uniqueCapability := fmt.Sprintf("unique-cap-%d", time.Now().UnixNano()) assistant := &types.AssistantModel{ Name: "Capabilities Search Test", Type: "assistant", Connector: "openai", Share: "private", Capabilities: uniqueCapability, } _, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant: %v", err) } response, err := store.GetAssistants(types.AssistantFilter{ Keywords: uniqueCapability, Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to search by capabilities keyword: %v", err) } if len(response.Data) < 1 { t.Error("Expected to find assistant by capabilities keyword search") } found := false for _, a := range response.Data { if a.Capabilities == uniqueCapability { found = true break } } if !found { t.Error("Expected to find assistant with matching capabilities") } }) t.Run("UpdateSandbox", func(t *testing.T) { assistant := &types.AssistantModel{ Name: "Update Sandbox Test", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update with sandbox updates := map[string]interface{}{ "sandbox": &types.Sandbox{ Command: "claude", Timeout: "10m", }, } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update sandbox: %v", err) } retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Sandbox == nil { t.Fatal("Expected sandbox to be set") } if retrieved.Sandbox.Command != "claude" { t.Errorf("Expected command 'claude', got '%s'", retrieved.Sandbox.Command) } // Update to remove sandbox updates2 := map[string]interface{}{ "sandbox": nil, } err = store.UpdateAssistant(id, updates2) if err != nil { t.Fatalf("Failed to remove sandbox: %v", err) } retrieved2, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved2.Sandbox != nil { t.Errorf("Expected sandbox to be nil, got %+v", retrieved2.Sandbox) } }) t.Run("UpdateCapabilities", func(t *testing.T) { assistant := &types.AssistantModel{ Name: "Update Capabilities Test", Type: "assistant", Connector: "openai", Share: "private", Capabilities: "Original capabilities", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update capabilities updates := map[string]interface{}{ "capabilities": "Updated capabilities: can search, analyze, and write code", } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update capabilities: %v", err) } retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Capabilities != "Updated capabilities: can search, analyze, and write code" { t.Errorf("Expected updated capabilities, got '%s'", retrieved.Capabilities) } // Update to clear capabilities updates2 := map[string]interface{}{ "capabilities": "", } err = store.UpdateAssistant(id, updates2) if err != nil { t.Fatalf("Failed to clear capabilities: %v", err) } retrieved2, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved2.Capabilities != "" { t.Errorf("Expected empty capabilities, got '%s'", retrieved2.Capabilities) } }) t.Run("FilterBySandbox", func(t *testing.T) { // Create one assistant with sandbox withSandbox := &types.AssistantModel{ Name: "Filter Sandbox Yes", Type: "assistant", Connector: "openai", Share: "private", Sandbox: &types.Sandbox{ Command: "claude", Timeout: "5m", }, } idWith, err := store.SaveAssistant(withSandbox) if err != nil { t.Fatalf("Failed to save assistant with sandbox: %v", err) } // Create one assistant without sandbox withoutSandbox := &types.AssistantModel{ Name: "Filter Sandbox No", Type: "assistant", Connector: "openai", Share: "private", } idWithout, err := store.SaveAssistant(withoutSandbox) if err != nil { t.Fatalf("Failed to save assistant without sandbox: %v", err) } testIDs := []string{idWith, idWithout} // Filter: sandbox=true, scoped to test IDs trueVal := true result, err := store.GetAssistants(types.AssistantFilter{ Page: 1, PageSize: 100, Sandbox: &trueVal, AssistantIDs: testIDs, }) if err != nil { t.Fatalf("Failed to filter with sandbox=true: %v", err) } if len(result.Data) != 1 { t.Errorf("Expected 1 result for sandbox=true, got %d", len(result.Data)) } else if result.Data[0].ID != idWith { t.Errorf("Expected assistant %s, got %s", idWith, result.Data[0].ID) } // Filter: sandbox=false, scoped to test IDs falseVal := false result2, err := store.GetAssistants(types.AssistantFilter{ Page: 1, PageSize: 100, Sandbox: &falseVal, AssistantIDs: testIDs, }) if err != nil { t.Fatalf("Failed to filter with sandbox=false: %v", err) } if len(result2.Data) != 1 { t.Errorf("Expected 1 result for sandbox=false, got %d", len(result2.Data)) } else if result2.Data[0].ID != idWithout { t.Errorf("Expected assistant %s, got %s", idWithout, result2.Data[0].ID) } t.Logf("Sandbox filter test passed: sandbox=true returned %d, sandbox=false returned %d", len(result.Data), len(result2.Data)) }) t.Run("AllNewFieldsTogether", func(t *testing.T) { // Test assistant with all new fields together optionalFalse := false assistant := &types.AssistantModel{ Name: "All New Fields Test", Type: "assistant", Connector: "openai", Share: "private", ConnectorOptions: &types.ConnectorOptions{ Optional: &optionalFalse, Connectors: []string{"openai"}, Filters: []types.ModelCapability{types.CapVision}, }, PromptPresets: map[string][]types.Prompt{ "default": { {Role: "system", Content: "Default system prompt"}, }, }, Source: "// Hook code here", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant with all new fields: %v", err) } // Retrieve and verify all new fields retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.ConnectorOptions == nil { t.Error("Expected connector options to be set") } if retrieved.PromptPresets == nil { t.Error("Expected prompt presets to be set") } if retrieved.Source == "" { t.Error("Expected source to be set") } t.Logf("Successfully saved and retrieved all new fields for assistant %s", id) }) } // TestDeleteAssistant tests deleting a single assistant func TestDeleteAssistant(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("DeleteExistingAssistant", func(t *testing.T) { // Create assistant assistant := &types.AssistantModel{ Name: "Delete Test", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Delete it err = store.DeleteAssistant(id) if err != nil { t.Fatalf("Failed to delete assistant: %v", err) } // Verify deletion _, err = store.GetAssistant(id, nil) if err == nil { t.Error("Expected error when getting deleted assistant") } }) t.Run("DeleteNonExistentAssistant", func(t *testing.T) { err := store.DeleteAssistant("nonexistent-id") if err == nil { t.Error("Expected error when deleting non-existent assistant") } }) } // TestGetAssistant tests retrieving a single assistant func TestGetAssistant(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("GetExistingAssistant", func(t *testing.T) { // Create assistant assistant := &types.AssistantModel{ Name: "Get Test", Type: "assistant", Connector: "openai", Description: "Test description", Avatar: "https://example.com/avatar.png", Tags: []string{"tag1", "tag2"}, Sort: 150, BuiltIn: false, Share: "private", Mentionable: true, } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Retrieve it with default fields (tags are now in default fields) retrieved, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to get assistant: %v", err) } if retrieved.ID != id { t.Errorf("Expected ID %s, got %s", id, retrieved.ID) } if retrieved.Name != "Get Test" { t.Errorf("Expected name 'Get Test', got '%s'", retrieved.Name) } if retrieved.Description != "Test description" { t.Errorf("Expected description 'Test description', got '%s'", retrieved.Description) } if len(retrieved.Tags) != 2 { t.Errorf("Expected 2 tags, got %d", len(retrieved.Tags)) } if retrieved.Sort != 150 { t.Errorf("Expected sort 150, got %d", retrieved.Sort) } }) t.Run("GetNonExistentAssistant", func(t *testing.T) { _, err := store.GetAssistant("nonexistent-id", nil) if err == nil { t.Error("Expected error when getting non-existent assistant") } }) } // TestGetAssistants tests retrieving multiple assistants with filtering and pagination func TestGetAssistants(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } // Clean up existing data before creating test assistants deleted, err := store.DeleteAssistants(types.AssistantFilter{}) if err != nil { t.Logf("Warning: Failed to clean up existing assistants: %v", err) } else if deleted > 0 { t.Logf("Cleaned up %d existing assistants", deleted) } // Create test assistants assistants := []types.AssistantModel{ { Name: "Assistant 1", Type: "assistant", Connector: "openai", Description: "First test assistant", Tags: []string{"test", "automation"}, Sort: 100, Share: "private", Mentionable: true, Automated: true, }, { Name: "Assistant 2", Type: "assistant", Connector: "anthropic", Description: "Second test assistant", Tags: []string{"test", "manual"}, Sort: 200, Share: "private", Mentionable: false, Automated: false, }, { Name: "Assistant 3", Type: "bot", Connector: "openai", Description: "Third test bot", Tags: []string{"bot", "automation"}, Sort: 50, Share: "private", Mentionable: true, Automated: true, }, } createdIDs := []string{} for _, asst := range assistants { id, err := store.SaveAssistant(&asst) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } createdIDs = append(createdIDs, id) } t.Run("GetAllAssistants", func(t *testing.T) { response, err := store.GetAssistants(types.AssistantFilter{ Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to get assistants: %v", err) } if len(response.Data) < 3 { t.Errorf("Expected at least 3 assistants, got %d", len(response.Data)) } if response.Total < 3 { t.Errorf("Expected total >= 3, got %d", response.Total) } }) t.Run("FilterByType", func(t *testing.T) { response, err := store.GetAssistants(types.AssistantFilter{ Type: "assistant", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to get assistants by type: %v", err) } for _, assistant := range response.Data { if assistant.Type != "assistant" { t.Errorf("Expected type 'assistant', got '%s'", assistant.Type) } } }) t.Run("FilterByConnector", func(t *testing.T) { response, err := store.GetAssistants(types.AssistantFilter{ Connector: "openai", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to get assistants by connector: %v", err) } for _, assistant := range response.Data { if assistant.Connector != "openai" { t.Errorf("Expected connector 'openai', got '%s'", assistant.Connector) } } }) t.Run("FilterByTags", func(t *testing.T) { response, err := store.GetAssistants(types.AssistantFilter{ Tags: []string{"automation"}, Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to get assistants by tags: %v", err) } // Should find assistants with "automation" tag found := false for _, assistant := range response.Data { for _, tag := range assistant.Tags { if tag == "automation" { found = true break } } if found { break } } if !found && len(response.Data) > 0 { t.Error("Expected to find assistants with 'automation' tag") } }) t.Run("FilterByKeywords", func(t *testing.T) { response, err := store.GetAssistants(types.AssistantFilter{ Keywords: "Second", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to get assistants by keywords: %v", err) } // Should find "Assistant 2" found := false for _, assistant := range response.Data { if assistant.Name == "Assistant 2" { found = true break } } if !found { t.Error("Expected to find assistant with keyword 'Second'") } }) t.Run("FilterByMentionable", func(t *testing.T) { mentionableTrue := true response, err := store.GetAssistants(types.AssistantFilter{ Mentionable: &mentionableTrue, Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to get mentionable assistants: %v", err) } if len(response.Data) != 2 { t.Errorf("Expected 2 mentionable assistants, got %d", len(response.Data)) } for _, assistant := range response.Data { if !assistant.Mentionable { t.Errorf("Expected assistant %s (%s) to be mentionable, but it's not", assistant.ID, assistant.Name) } } }) t.Run("FilterByAutomated", func(t *testing.T) { automatedFalse := false response, err := store.GetAssistants(types.AssistantFilter{ Automated: &automatedFalse, Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to get non-automated assistants: %v", err) } for _, assistant := range response.Data { if assistant.Automated { t.Error("Expected all assistants to be non-automated") } } }) t.Run("Pagination", func(t *testing.T) { // Test first page response1, err := store.GetAssistants(types.AssistantFilter{ Page: 1, PageSize: 2, }) if err != nil { t.Fatalf("Failed to get first page: %v", err) } if len(response1.Data) > 2 { t.Errorf("Expected max 2 results, got %d", len(response1.Data)) } if response1.Page != 1 { t.Errorf("Expected page 1, got %d", response1.Page) } if response1.PageSize != 2 { t.Errorf("Expected page size 2, got %d", response1.PageSize) } // Test second page if there are enough records if response1.Total > 2 { response2, err := store.GetAssistants(types.AssistantFilter{ Page: 2, PageSize: 2, }) if err != nil { t.Fatalf("Failed to get second page: %v", err) } if response2.Page != 2 { t.Errorf("Expected page 2, got %d", response2.Page) } } }) t.Run("FieldSelection", func(t *testing.T) { response, err := store.GetAssistants(types.AssistantFilter{ Select: []string{"assistant_id", "name", "type"}, Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to get assistants with field selection: %v", err) } if len(response.Data) > 0 { assistant := response.Data[0] if assistant.ID == "" { t.Error("Expected assistant_id field") } if assistant.Name == "" { t.Error("Expected name field") } if assistant.Type == "" { t.Error("Expected type field") } } }) t.Run("FilterByAssistantID", func(t *testing.T) { if len(createdIDs) > 0 { response, err := store.GetAssistants(types.AssistantFilter{ AssistantID: createdIDs[0], Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to get assistant by ID: %v", err) } if len(response.Data) != 1 { t.Errorf("Expected 1 result, got %d", len(response.Data)) } if response.Data[0].ID != createdIDs[0] { t.Errorf("Expected assistant_id %s, got %s", createdIDs[0], response.Data[0].ID) } } }) t.Run("FilterByAssistantIDs", func(t *testing.T) { if len(createdIDs) >= 2 { filterIDs := []string{createdIDs[0], createdIDs[1]} response, err := store.GetAssistants(types.AssistantFilter{ AssistantIDs: filterIDs, Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to get assistants by IDs: %v", err) } if len(response.Data) < 2 { t.Errorf("Expected at least 2 results, got %d", len(response.Data)) } } }) } // TestDeleteAssistants tests bulk deletion with filters func TestDeleteAssistants(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("DeleteByTag", func(t *testing.T) { // Create assistants with specific tag tag := fmt.Sprintf("delete-test-%d", time.Now().UnixNano()) for i := 0; i < 3; i++ { assistant := &types.AssistantModel{ Name: fmt.Sprintf("Delete Test %d", i), Type: "assistant", Connector: "openai", Tags: []string{tag}, Share: "private", } _, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } } // Delete by tag count, err := store.DeleteAssistants(types.AssistantFilter{ Tags: []string{tag}, }) if err != nil { t.Fatalf("Failed to delete assistants: %v", err) } if count < 3 { t.Errorf("Expected at least 3 deletions, got %d", count) } }) t.Run("DeleteByConnector", func(t *testing.T) { // Create assistants with specific connector connector := fmt.Sprintf("test-connector-%d", time.Now().UnixNano()) for i := 0; i < 2; i++ { assistant := &types.AssistantModel{ Name: fmt.Sprintf("Connector Test %d", i), Type: "assistant", Connector: connector, Share: "private", } _, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } } // Delete by connector count, err := store.DeleteAssistants(types.AssistantFilter{ Connector: connector, }) if err != nil { t.Fatalf("Failed to delete assistants: %v", err) } if count < 2 { t.Errorf("Expected at least 2 deletions, got %d", count) } }) t.Run("DeleteByKeywords", func(t *testing.T) { // Create assistants with specific keyword keyword := fmt.Sprintf("unique-keyword-%d", time.Now().UnixNano()) assistant := &types.AssistantModel{ Name: fmt.Sprintf("Assistant with %s", keyword), Type: "assistant", Connector: "openai", Description: "Test description", Share: "private", } _, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Delete by keyword count, err := store.DeleteAssistants(types.AssistantFilter{ Keywords: keyword, }) if err != nil { t.Fatalf("Failed to delete assistants: %v", err) } if count < 1 { t.Errorf("Expected at least 1 deletion, got %d", count) } }) t.Run("DeleteByAssistantID", func(t *testing.T) { // Create an assistant assistant := &types.AssistantModel{ Name: "Single Delete Test", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Delete by ID count, err := store.DeleteAssistants(types.AssistantFilter{ AssistantID: id, }) if err != nil { t.Fatalf("Failed to delete assistant: %v", err) } if count != 1 { t.Errorf("Expected 1 deletion, got %d", count) } }) } // TestGetAssistantTags tests retrieving unique tags func TestGetAssistantTags(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("GetUniqueTags", func(t *testing.T) { // Create assistants with various tags uniqueTag := fmt.Sprintf("tag-test-%d", time.Now().UnixNano()) assistants := []types.AssistantModel{ { Name: "Tags Test 1", Type: "assistant", Connector: "openai", Tags: []string{uniqueTag, "common"}, Share: "private", }, { Name: "Tags Test 2", Type: "assistant", Connector: "openai", Tags: []string{uniqueTag, "different"}, Share: "private", }, { Name: "Tags Test 3", Type: "assistant", Connector: "openai", Tags: []string{"common", "another"}, Share: "private", }, } for _, asst := range assistants { _, err := store.SaveAssistant(&asst) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } } // Get all tags tags, err := store.GetAssistantTags(types.AssistantFilter{}) if err != nil { t.Fatalf("Failed to get tags: %v", err) } // Verify we have some tags if len(tags) == 0 { t.Error("Expected at least some tags") } // Verify tag structure for _, tag := range tags { if tag.Value == "" { t.Error("Expected tag to have non-empty value") } if tag.Label == "" { t.Error("Expected tag to have non-empty label") } } t.Logf("Found %d unique tags", len(tags)) }) t.Run("GetTagsWithFilter", func(t *testing.T) { // Create test assistants with specific tags and attributes uniqueTag := fmt.Sprintf("filter-tag-%d", time.Now().UnixNano()) assistants := []types.AssistantModel{ { Name: "Filtered Tags Test 1", Type: "assistant", Connector: "openai", Tags: []string{uniqueTag, "ai"}, Share: "private", BuiltIn: false, Mentionable: true, }, { Name: "Filtered Tags Test 2", Type: "assistant", Connector: "anthropic", Tags: []string{uniqueTag, "coding"}, Share: "private", BuiltIn: true, Mentionable: false, }, { Name: "Filtered Tags Test 3", Type: "assistant", Connector: "openai", Tags: []string{uniqueTag, "search"}, Share: "private", BuiltIn: false, Automated: true, }, } for _, asst := range assistants { _, err := store.SaveAssistant(&asst) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } } // Test: Get tags filtered by connector tagsOpenAI, err := store.GetAssistantTags(types.AssistantFilter{ Connector: "openai", }) if err != nil { t.Fatalf("Failed to get tags with connector filter: %v", err) } t.Logf("Found %d tags for openai connector", len(tagsOpenAI)) // Test: Get tags filtered by built_in builtInFalse := false tagsNonBuiltIn, err := store.GetAssistantTags(types.AssistantFilter{ BuiltIn: &builtInFalse, }) if err != nil { t.Fatalf("Failed to get tags with built_in filter: %v", err) } t.Logf("Found %d tags for non-built-in assistants", len(tagsNonBuiltIn)) // Test: Get tags filtered by mentionable mentionableTrue := true tagsMentionable, err := store.GetAssistantTags(types.AssistantFilter{ Mentionable: &mentionableTrue, }) if err != nil { t.Fatalf("Failed to get tags with mentionable filter: %v", err) } t.Logf("Found %d tags for mentionable assistants", len(tagsMentionable)) // Test: Get tags filtered by keywords tagsWithKeywords, err := store.GetAssistantTags(types.AssistantFilter{ Keywords: "Filtered Tags Test", }) if err != nil { t.Fatalf("Failed to get tags with keywords filter: %v", err) } t.Logf("Found %d tags with keywords filter", len(tagsWithKeywords)) }) t.Run("GetTagsWithQueryFilter", func(t *testing.T) { // Create test assistants with permission fields permTag := fmt.Sprintf("perm-tag-%d", time.Now().UnixNano()) assistants := []types.AssistantModel{ { Name: "Permission Tags Test 1", Type: "assistant", Connector: "openai", Tags: []string{permTag, "public-tag"}, Share: "private", Public: true, YaoCreatedBy: "user-1", YaoTeamID: "team-1", }, { Name: "Permission Tags Test 2", Type: "assistant", Connector: "openai", Tags: []string{permTag, "team-tag"}, Share: "team", Public: false, YaoCreatedBy: "user-2", YaoTeamID: "team-1", }, { Name: "Permission Tags Test 3", Type: "assistant", Connector: "openai", Tags: []string{permTag, "private-tag"}, Share: "private", Public: false, YaoCreatedBy: "user-3", YaoTeamID: "team-2", }, } for _, asst := range assistants { _, err := store.SaveAssistant(&asst) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } } // Test: Get tags for public assistants only tagsPublic, err := store.GetAssistantTags(types.AssistantFilter{ QueryFilter: func(qb query.Query) { qb.Where("public", true) }, }) if err != nil { t.Fatalf("Failed to get tags for public assistants: %v", err) } t.Logf("Found %d tags for public assistants", len(tagsPublic)) // Test: Get tags for team-1 assistants tagsTeam1, err := store.GetAssistantTags(types.AssistantFilter{ QueryFilter: func(qb query.Query) { qb.Where("__yao_team_id", "team-1") }, }) if err != nil { t.Fatalf("Failed to get tags for team-1: %v", err) } t.Logf("Found %d tags for team-1 assistants", len(tagsTeam1)) // Test: Complex permission filter (public OR team-1 with share=team) tagsComplex, err := store.GetAssistantTags(types.AssistantFilter{ QueryFilter: func(qb query.Query) { qb.Where(func(qb query.Query) { qb.Where("public", true) }).OrWhere(func(qb query.Query) { qb.Where("__yao_team_id", "team-1"). Where("share", "team") }) }, }) if err != nil { t.Fatalf("Failed to get tags with complex filter: %v", err) } t.Logf("Found %d tags with complex permission filter", len(tagsComplex)) }) } // TestAssistantPermissionFields tests permission management fields func TestAssistantPermissionFields(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("SaveWithPermissionFields", func(t *testing.T) { assistant := &types.AssistantModel{ Name: "Permission Test Assistant", Type: "assistant", Connector: "openai", Description: "Testing permission fields", Share: "private", YaoCreatedBy: "user-123", YaoUpdatedBy: "user-123", YaoTeamID: "team-456", YaoTenantID: "tenant-789", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant with permission fields: %v", err) } // Retrieve and verify - default fields include permission fields retrieved, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to get assistant: %v", err) } if retrieved.YaoCreatedBy != "user-123" { t.Errorf("Expected YaoCreatedBy 'user-123', got '%s'", retrieved.YaoCreatedBy) } if retrieved.YaoUpdatedBy != "user-123" { t.Errorf("Expected YaoUpdatedBy 'user-123', got '%s'", retrieved.YaoUpdatedBy) } if retrieved.YaoTeamID != "team-456" { t.Errorf("Expected YaoTeamID 'team-456', got '%s'", retrieved.YaoTeamID) } if retrieved.YaoTenantID != "tenant-789" { t.Errorf("Expected YaoTenantID 'tenant-789', got '%s'", retrieved.YaoTenantID) } t.Logf("Permission fields saved and retrieved successfully for assistant %s", id) }) t.Run("UpdatePermissionFields", func(t *testing.T) { // Create assistant assistant := &types.AssistantModel{ Name: "Update Permission Test", Type: "assistant", Connector: "openai", Share: "private", YaoCreatedBy: "user-original", YaoTeamID: "team-original", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update with new permission fields assistant.ID = id assistant.YaoUpdatedBy = "user-updater" assistant.YaoTenantID = "tenant-new" _, err = store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to update assistant: %v", err) } // Verify update - default fields include permission fields retrieved, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to get updated assistant: %v", err) } if retrieved.YaoCreatedBy != "user-original" { t.Errorf("Expected YaoCreatedBy to remain 'user-original', got '%s'", retrieved.YaoCreatedBy) } if retrieved.YaoUpdatedBy != "user-updater" { t.Errorf("Expected YaoUpdatedBy 'user-updater', got '%s'", retrieved.YaoUpdatedBy) } if retrieved.YaoTenantID != "tenant-new" { t.Errorf("Expected YaoTenantID 'tenant-new', got '%s'", retrieved.YaoTenantID) } }) t.Run("EmptyPermissionFields", func(t *testing.T) { // Create assistant without permission fields assistant := &types.AssistantModel{ Name: "No Permission Fields", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant: %v", err) } // Retrieve and verify fields are empty retrieved, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to get assistant: %v", err) } if retrieved.YaoCreatedBy != "" { t.Errorf("Expected empty YaoCreatedBy, got '%s'", retrieved.YaoCreatedBy) } if retrieved.YaoUpdatedBy != "" { t.Errorf("Expected empty YaoUpdatedBy, got '%s'", retrieved.YaoUpdatedBy) } if retrieved.YaoTeamID != "" { t.Errorf("Expected empty YaoTeamID, got '%s'", retrieved.YaoTeamID) } if retrieved.YaoTenantID != "" { t.Errorf("Expected empty YaoTenantID, got '%s'", retrieved.YaoTenantID) } }) } // TestEmptyStringAsNull tests that empty strings are stored as NULL in database func TestEmptyStringAsNull(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("EmptyStringsStoredAsNull", func(t *testing.T) { // Create assistant with empty strings for nullable fields // According to assistant.mod.yao, nullable string fields are: // - name (nullable: true, but required by validation) // - avatar, description, path (nullable: true) // - share (nullable: false, but empty should trigger default) assistant := &types.AssistantModel{ Name: "Test Null Fields", // Required by validation Type: "assistant", Connector: "openai", Avatar: "", // Empty string should become NULL (nullable: true) Path: "", // Empty string should become NULL (nullable: true) Description: "", // Empty string should become NULL (nullable: true) Share: "", // Empty string should become NULL, then default "private" applied } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant: %v", err) } // Retrieve and verify empty strings are returned (not stored as empty strings) retrieved, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to get assistant: %v", err) } // Name should be preserved (required field) if retrieved.Name != "Test Null Fields" { t.Errorf("Expected Name 'Test Null Fields', got '%s'", retrieved.Name) } // These nullable fields should be empty strings in Go (converted from NULL) if retrieved.Avatar != "" { t.Errorf("Expected empty Avatar, got '%s'", retrieved.Avatar) } if retrieved.Path != "" { t.Errorf("Expected empty Path, got '%s'", retrieved.Path) } if retrieved.Description != "" { t.Errorf("Expected empty Description, got '%s'", retrieved.Description) } // Share should have default value "private" applied if retrieved.Share != "private" { t.Errorf("Expected Share to be 'private', got '%s'", retrieved.Share) } t.Logf("Successfully verified empty strings are stored as NULL for assistant %s", id) }) t.Run("NonEmptyStringsPreserved", func(t *testing.T) { // Create assistant with non-empty values assistant := &types.AssistantModel{ Name: "Test Non-Empty Fields", Type: "assistant", Connector: "openai", Avatar: "https://example.com/avatar.png", Path: "/path/to/assistant", Description: "This is a description", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to save assistant: %v", err) } // Retrieve and verify values are preserved - path is sensitive, need full fields retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to get assistant: %v", err) } if retrieved.Avatar != "https://example.com/avatar.png" { t.Errorf("Expected Avatar 'https://example.com/avatar.png', got '%s'", retrieved.Avatar) } if retrieved.Path != "/path/to/assistant" { t.Errorf("Expected Path '/path/to/assistant', got '%s'", retrieved.Path) } if retrieved.Description != "This is a description" { t.Errorf("Expected Description 'This is a description', got '%s'", retrieved.Description) } if retrieved.Share != "private" { t.Errorf("Expected Share 'private', got '%s'", retrieved.Share) } t.Logf("Successfully verified non-empty strings are preserved for assistant %s", id) }) } // TestGetAssistantWithLocale tests retrieving assistant with locale translation func TestGetAssistantWithLocale(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("GetAssistantWithLocaleTranslation", func(t *testing.T) { // Create assistant with i18n locales assistant := &types.AssistantModel{ Name: "{{name}}", Type: "assistant", Connector: "openai", Description: "{{description}}", Tags: []string{"test"}, Share: "private", Placeholder: &types.Placeholder{ Title: "{{chat.title}}", Description: "{{chat.description}}", Prompts: []string{"{{chat.prompts.0}}", "{{chat.prompts.1}}"}, }, } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Setup i18n for testing i18n.Locales[id] = map[string]i18n.I18n{ "en": { Locale: "en", Messages: map[string]any{ "name": "Test Assistant", "description": "This is a test assistant", "chat.title": "Chat with me", "chat.description": "Start a conversation", "chat.prompts.0": "How can I help you?", "chat.prompts.1": "What would you like to know?", }, }, "zh-cn": { Locale: "zh-cn", Messages: map[string]any{ "name": "测试助手", "description": "这是一个测试助手", "chat.title": "与我聊天", "chat.description": "开始对话", "chat.prompts.0": "我能帮你什么?", "chat.prompts.1": "你想了解什么?", }, }, } // Test English locale - request all fields for placeholder retrievedEN, err := store.GetAssistant(id, types.AssistantFullFields, "en") if err != nil { t.Fatalf("Failed to get assistant with EN locale: %v", err) } if retrievedEN.Name != "Test Assistant" { t.Errorf("Expected name 'Test Assistant', got '%s'", retrievedEN.Name) } if retrievedEN.Description != "This is a test assistant" { t.Errorf("Expected description 'This is a test assistant', got '%s'", retrievedEN.Description) } if retrievedEN.Placeholder == nil { t.Fatal("Expected placeholder to be set") } if retrievedEN.Placeholder.Title != "Chat with me" { t.Errorf("Expected placeholder title 'Chat with me', got '%s'", retrievedEN.Placeholder.Title) } if retrievedEN.Placeholder.Description != "Start a conversation" { t.Errorf("Expected placeholder description 'Start a conversation', got '%s'", retrievedEN.Placeholder.Description) } if len(retrievedEN.Placeholder.Prompts) != 2 { t.Errorf("Expected 2 placeholder prompts, got %d", len(retrievedEN.Placeholder.Prompts)) } if retrievedEN.Placeholder.Prompts[0] != "How can I help you?" { t.Errorf("Expected first prompt 'How can I help you?', got '%s'", retrievedEN.Placeholder.Prompts[0]) } // Test Chinese locale - request all fields for placeholder retrievedZH, err := store.GetAssistant(id, types.AssistantFullFields, "zh-cn") if err != nil { t.Fatalf("Failed to get assistant with ZH locale: %v", err) } if retrievedZH.Name != "测试助手" { t.Errorf("Expected name '测试助手', got '%s'", retrievedZH.Name) } if retrievedZH.Description != "这是一个测试助手" { t.Errorf("Expected description '这是一个测试助手', got '%s'", retrievedZH.Description) } if retrievedZH.Placeholder == nil { t.Fatal("Expected placeholder to be set") } if retrievedZH.Placeholder.Title != "与我聊天" { t.Errorf("Expected placeholder title '与我聊天', got '%s'", retrievedZH.Placeholder.Title) } // Test without locale (should return original {{...}} values) - request all fields for placeholder retrievedNoLocale, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to get assistant without locale: %v", err) } if retrievedNoLocale.Name != "{{name}}" { t.Errorf("Expected original name '{{name}}', got '%s'", retrievedNoLocale.Name) } if retrievedNoLocale.Description != "{{description}}" { t.Errorf("Expected original description '{{description}}', got '%s'", retrievedNoLocale.Description) } // Cleanup delete(i18n.Locales, id) t.Logf("Successfully tested locale translation for assistant %s", id) }) } // TestGetAssistantsWithLocale tests retrieving multiple assistants with locale translation func TestGetAssistantsWithLocale(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("GetAssistantsWithLocaleTranslation", func(t *testing.T) { // Create assistant with i18n locales assistant := &types.AssistantModel{ Name: "{{name}}", Type: "assistant", Connector: "openai", Description: "{{description}}", Tags: []string{"locale-test"}, Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Setup i18n for testing i18n.Locales[id] = map[string]i18n.I18n{ "en": { Locale: "en", Messages: map[string]any{ "name": "List Test Assistant", "description": "This appears in the list", }, }, "zh-cn": { Locale: "zh-cn", Messages: map[string]any{ "name": "列表测试助手", "description": "这出现在列表中", }, }, } // Test GetAssistants with English locale responseEN, err := store.GetAssistants(types.AssistantFilter{ Tags: []string{"locale-test"}, Page: 1, PageSize: 20, }, "en") if err != nil { t.Fatalf("Failed to get assistants with EN locale: %v", err) } if len(responseEN.Data) < 1 { t.Fatal("Expected at least 1 assistant in response") } found := false for _, asst := range responseEN.Data { if asst.ID == id { found = true if asst.Name != "List Test Assistant" { t.Errorf("Expected name 'List Test Assistant', got '%s'", asst.Name) } if asst.Description != "This appears in the list" { t.Errorf("Expected description 'This appears in the list', got '%s'", asst.Description) } break } } if !found { t.Error("Expected to find the test assistant in the list") } // Test GetAssistants with Chinese locale responseZH, err := store.GetAssistants(types.AssistantFilter{ Tags: []string{"locale-test"}, Page: 1, PageSize: 20, }, "zh-cn") if err != nil { t.Fatalf("Failed to get assistants with ZH locale: %v", err) } found = false for _, asst := range responseZH.Data { if asst.ID == id { found = true if asst.Name != "列表测试助手" { t.Errorf("Expected name '列表测试助手', got '%s'", asst.Name) } if asst.Description != "这出现在列表中" { t.Errorf("Expected description '这出现在列表中', got '%s'", asst.Description) } break } } if !found { t.Error("Expected to find the test assistant in the list") } // Cleanup delete(i18n.Locales, id) t.Logf("Successfully tested locale translation for assistants list") }) } // TestGetAssistantsWithQueryFilter tests using QueryFilter for permission filtering func TestGetAssistantsWithQueryFilter(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } // Create test assistants with different permission settings assistants := []types.AssistantModel{ { Name: "Public Assistant", Type: "assistant", Connector: "openai", Description: "Public assistant visible to all", Tags: []string{"query-filter-test"}, Public: true, Share: "private", YaoCreatedBy: "user-1", YaoTeamID: "team-1", }, { Name: "Team Shared Assistant", Type: "assistant", Connector: "openai", Description: "Team shared assistant", Tags: []string{"query-filter-test"}, Public: false, Share: "team", YaoCreatedBy: "user-2", YaoTeamID: "team-1", }, { Name: "Private Assistant Owner", Type: "assistant", Connector: "openai", Description: "Private assistant owned by user-1", Tags: []string{"query-filter-test"}, Public: false, Share: "private", YaoCreatedBy: "user-1", YaoTeamID: "team-1", }, { Name: "Private Assistant Other", Type: "assistant", Connector: "openai", Description: "Private assistant owned by user-3", Tags: []string{"query-filter-test"}, Public: false, Share: "private", YaoCreatedBy: "user-3", YaoTeamID: "team-2", }, } createdIDs := []string{} for _, asst := range assistants { id, err := store.SaveAssistant(&asst) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } createdIDs = append(createdIDs, id) } t.Run("FilterByPublic", func(t *testing.T) { // QueryFilter: only public assistants response, err := store.GetAssistants(types.AssistantFilter{ Tags: []string{"query-filter-test"}, Page: 1, PageSize: 20, QueryFilter: func(qb query.Query) { qb.Where("public", true) }, }) if err != nil { t.Fatalf("Failed to get public assistants: %v", err) } if len(response.Data) != 1 { t.Errorf("Expected 1 public assistant, got %d", len(response.Data)) } if len(response.Data) > 0 && response.Data[0].Name != "Public Assistant" { t.Errorf("Expected 'Public Assistant', got '%s'", response.Data[0].Name) } }) t.Run("FilterByTeamAndShare", func(t *testing.T) { // QueryFilter: team-1 assistants that are shared with team response, err := store.GetAssistants(types.AssistantFilter{ Tags: []string{"query-filter-test"}, Page: 1, PageSize: 20, QueryFilter: func(qb query.Query) { qb.Where("__yao_team_id", "team-1"). Where("share", "team") }, }) if err != nil { t.Fatalf("Failed to get team shared assistants: %v", err) } if len(response.Data) != 1 { t.Errorf("Expected 1 team shared assistant, got %d", len(response.Data)) } if len(response.Data) > 0 && response.Data[0].Name != "Team Shared Assistant" { t.Errorf("Expected 'Team Shared Assistant', got '%s'", response.Data[0].Name) } }) t.Run("FilterByOwner", func(t *testing.T) { // QueryFilter: assistants created by user-1 response, err := store.GetAssistants(types.AssistantFilter{ Tags: []string{"query-filter-test"}, Page: 1, PageSize: 20, QueryFilter: func(qb query.Query) { qb.Where("__yao_created_by", "user-1") }, }) if err != nil { t.Fatalf("Failed to get user-1 assistants: %v", err) } if len(response.Data) != 2 { t.Errorf("Expected 2 assistants for user-1, got %d", len(response.Data)) } for _, asst := range response.Data { if asst.YaoCreatedBy != "user-1" { t.Errorf("Expected creator 'user-1', got '%s'", asst.YaoCreatedBy) } } }) t.Run("ComplexQueryFilter", func(t *testing.T) { // Complex QueryFilter: (public = true) OR (team_id = team-1 AND (created_by = user-1 OR share = team)) response, err := store.GetAssistants(types.AssistantFilter{ Tags: []string{"query-filter-test"}, Page: 1, PageSize: 20, QueryFilter: func(qb query.Query) { qb.Where(func(qb query.Query) { // Public assistants qb.Where("public", true) }).OrWhere(func(qb query.Query) { // Team assistants where user is creator or shared with team qb.Where("__yao_team_id", "team-1").Where(func(qb query.Query) { qb.Where("__yao_created_by", "user-1"). OrWhere("share", "team") }) }) }, }) if err != nil { t.Fatalf("Failed to get filtered assistants: %v", err) } // Should find: Public Assistant, Team Shared Assistant, Private Assistant Owner if len(response.Data) != 3 { t.Errorf("Expected 3 assistants, got %d", len(response.Data)) } // Verify we got the right assistants names := make(map[string]bool) for _, asst := range response.Data { names[asst.Name] = true } expectedNames := []string{"Public Assistant", "Team Shared Assistant", "Private Assistant Owner"} for _, name := range expectedNames { if !names[name] { t.Errorf("Expected to find '%s' in results", name) } } // Should NOT find Private Assistant Other if names["Private Assistant Other"] { t.Error("Should not find 'Private Assistant Other' in results") } }) t.Run("QueryFilterWithNullCheck", func(t *testing.T) { // QueryFilter: assistants where team_id is null response, err := store.GetAssistants(types.AssistantFilter{ Tags: []string{"query-filter-test"}, Page: 1, PageSize: 20, QueryFilter: func(qb query.Query) { qb.WhereNull("__yao_team_id") }, }) if err != nil { t.Fatalf("Failed to get assistants with null team_id: %v", err) } // All test assistants have team_id, so should find 0 if len(response.Data) != 0 { t.Errorf("Expected 0 assistants with null team_id, got %d", len(response.Data)) } }) t.Run("QueryFilterCombinedWithOtherFilters", func(t *testing.T) { // Combine QueryFilter with other filters response, err := store.GetAssistants(types.AssistantFilter{ Tags: []string{"query-filter-test"}, Connector: "openai", Page: 1, PageSize: 20, QueryFilter: func(qb query.Query) { qb.Where("public", true) }, }) if err != nil { t.Fatalf("Failed to get combined filtered assistants: %v", err) } // Should only find public openai assistants if len(response.Data) != 1 { t.Errorf("Expected 1 assistant, got %d", len(response.Data)) } if len(response.Data) > 0 { if response.Data[0].Connector != "openai" { t.Errorf("Expected connector 'openai', got '%s'", response.Data[0].Connector) } if !response.Data[0].Public { t.Error("Expected public assistant") } } }) // Cleanup for _, id := range createdIDs { _ = store.DeleteAssistant(id) } } // TestUpdateAssistant tests the UpdateAssistant method for incremental updates func TestUpdateAssistant(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("UpdateSingleField", func(t *testing.T) { // Create assistant assistant := &types.AssistantModel{ Name: "Original Name", Type: "assistant", Connector: "openai", Description: "Original description", Tags: []string{"original"}, Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update only description updates := map[string]interface{}{ "description": "Updated description", } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update assistant: %v", err) } // Verify update - need full fields to see tags retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Description != "Updated description" { t.Errorf("Expected description 'Updated description', got '%s'", retrieved.Description) } // Other fields should remain unchanged if retrieved.Name != "Original Name" { t.Errorf("Expected name 'Original Name', got '%s'", retrieved.Name) } if len(retrieved.Tags) != 1 || retrieved.Tags[0] != "original" { t.Errorf("Expected tags [original], got %v", retrieved.Tags) } }) t.Run("UpdateMultipleFields", func(t *testing.T) { // Create assistant assistant := &types.AssistantModel{ Name: "Test Assistant", Type: "assistant", Connector: "openai", Description: "Test description", Sort: 100, Mentionable: false, Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update multiple fields updates := map[string]interface{}{ "name": "Updated Name", "description": "Updated description", "sort": 200, "mentionable": true, } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update assistant: %v", err) } // Verify all updates - use default fields (includes name, description, sort, mentionable) retrieved, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Name != "Updated Name" { t.Errorf("Expected name 'Updated Name', got '%s'", retrieved.Name) } if retrieved.Description != "Updated description" { t.Errorf("Expected description 'Updated description', got '%s'", retrieved.Description) } if retrieved.Sort != 200 { t.Errorf("Expected sort 200, got %d", retrieved.Sort) } if !retrieved.Mentionable { t.Error("Expected mentionable to be true") } }) t.Run("UpdateJSONFields", func(t *testing.T) { // Create assistant with complex fields assistant := &types.AssistantModel{ Name: "JSON Test", Type: "assistant", Connector: "openai", Tags: []string{"tag1", "tag2"}, Options: map[string]interface{}{"temperature": 0.7}, Prompts: []types.Prompt{ {Role: "system", Content: "Original system prompt"}, }, Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update JSON fields updates := map[string]interface{}{ "tags": []string{"updated", "new-tags"}, "options": map[string]interface{}{ "temperature": 0.9, "max_tokens": 2000, }, "prompts": []types.Prompt{ {Role: "system", Content: "Updated system prompt"}, {Role: "user", Content: "New user prompt"}, }, } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update JSON fields: %v", err) } // Verify updates - need full fields for tags, options, prompts retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if len(retrieved.Tags) != 2 || retrieved.Tags[0] != "updated" { t.Errorf("Expected tags [updated, new-tags], got %v", retrieved.Tags) } if temp, ok := retrieved.Options["temperature"].(float64); !ok || temp != 0.9 { t.Errorf("Expected temperature 0.9, got %v", retrieved.Options["temperature"]) } if len(retrieved.Prompts) != 2 { t.Errorf("Expected 2 prompts, got %d", len(retrieved.Prompts)) } if retrieved.Prompts[0].Content != "Updated system prompt" { t.Errorf("Expected updated system prompt, got '%s'", retrieved.Prompts[0].Content) } }) t.Run("UpdateKBAndMCP", func(t *testing.T) { // Create assistant assistant := &types.AssistantModel{ Name: "KB MCP Test", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update KB and MCP updates := map[string]interface{}{ "kb": map[string]interface{}{ "collections": []string{"collection1", "collection2"}, }, "mcp": map[string]interface{}{ "servers": []string{"server1", "server2"}, }, } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update KB and MCP: %v", err) } // Verify updates - KB and MCP are in default fields retrieved, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.KB == nil || len(retrieved.KB.Collections) != 2 { t.Errorf("Expected 2 KB collections, got %v", retrieved.KB) } if retrieved.MCP == nil || len(retrieved.MCP.Servers) != 2 { t.Errorf("Expected 2 MCP servers, got %v", retrieved.MCP) } if retrieved.MCP.Servers[0].ServerID != "server1" { t.Errorf("Expected first server 'server1', got '%s'", retrieved.MCP.Servers[0].ServerID) } }) t.Run("UpdateKBDBAndMCP", func(t *testing.T) { // Create assistant assistant := &types.AssistantModel{ Name: "KB DB MCP Test", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update KB, DB and MCP updates := map[string]interface{}{ "kb": map[string]interface{}{ "collections": []string{"collection1", "collection2"}, }, "db": map[string]interface{}{ "models": []string{"model1", "model2"}, }, "mcp": map[string]interface{}{ "servers": []string{"server1", "server2"}, }, } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update KB, DB and MCP: %v", err) } // Verify updates - KB, DB and MCP are in default fields retrieved, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.KB == nil || len(retrieved.KB.Collections) != 2 { t.Errorf("Expected 2 KB collections, got %v", retrieved.KB) } if retrieved.DB == nil || len(retrieved.DB.Models) != 2 { t.Errorf("Expected 2 DB models, got %v", retrieved.DB) } if retrieved.DB.Models[0] != "model1" { t.Errorf("Expected first model 'model1', got '%s'", retrieved.DB.Models[0]) } if retrieved.MCP == nil || len(retrieved.MCP.Servers) != 2 { t.Errorf("Expected 2 MCP servers, got %v", retrieved.MCP) } if retrieved.MCP.Servers[0].ServerID != "server1" { t.Errorf("Expected first server 'server1', got '%s'", retrieved.MCP.Servers[0].ServerID) } }) t.Run("UpdateDBWithOptions", func(t *testing.T) { // Create assistant assistant := &types.AssistantModel{ Name: "DB Advanced Test", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update with DB using advanced configuration updates := map[string]interface{}{ "db": map[string]interface{}{ "models": []string{"user", "product", "order"}, "options": map[string]interface{}{ "limit": 100, "offset": 0, }, }, } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update DB: %v", err) } // Verify updates - DB is in default fields retrieved, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.DB == nil { t.Fatal("Expected DB to be set") } if len(retrieved.DB.Models) != 3 { t.Errorf("Expected 3 DB models, got %d", len(retrieved.DB.Models)) } if retrieved.DB.Models[0] != "user" { t.Errorf("Expected first model 'user', got '%s'", retrieved.DB.Models[0]) } if retrieved.DB.Options == nil { t.Error("Expected DB options to be set") } else { if limit, ok := retrieved.DB.Options["limit"].(float64); !ok || limit != 100 { t.Errorf("Expected DB limit 100, got %v", retrieved.DB.Options["limit"]) } } }) t.Run("UpdateModesAndDefaultMode", func(t *testing.T) { // Create assistant assistant := &types.AssistantModel{ Name: "Modes Test", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update with modes and default_mode updates := map[string]interface{}{ "modes": []string{"chat", "task", "analyze"}, "default_mode": "chat", } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update modes: %v", err) } // Verify updates - modes and default_mode are in default fields retrieved, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Modes == nil || len(retrieved.Modes) != 3 { t.Errorf("Expected 3 modes, got %v", retrieved.Modes) } if retrieved.Modes[0] != "chat" { t.Errorf("Expected first mode 'chat', got '%s'", retrieved.Modes[0]) } if retrieved.DefaultMode != "chat" { t.Errorf("Expected default_mode 'chat', got '%s'", retrieved.DefaultMode) } }) t.Run("UpdateModesOnly", func(t *testing.T) { // Create assistant with default_mode assistant := &types.AssistantModel{ Name: "Modes Only Test", Type: "assistant", Connector: "openai", Share: "private", DefaultMode: "task", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update only modes updates := map[string]interface{}{ "modes": []string{"chat", "task"}, } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update modes: %v", err) } // Verify updates - default_mode should remain unchanged retrieved, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if len(retrieved.Modes) != 2 { t.Errorf("Expected 2 modes, got %d", len(retrieved.Modes)) } if retrieved.DefaultMode != "task" { t.Errorf("Expected default_mode to remain 'task', got '%s'", retrieved.DefaultMode) } }) t.Run("UpdateMCPWithToolsAndResources", func(t *testing.T) { // Create assistant assistant := &types.AssistantModel{ Name: "MCP Advanced Test", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update with MCP servers using advanced configuration updates := map[string]interface{}{ "mcp": map[string]interface{}{ "servers": []interface{}{ "server1", // Simple format map[string]interface{}{ "server2": []string{"tool1", "tool2"}, // Tools only }, map[string]interface{}{ "server3": map[string]interface{}{ "resources": []string{"res1", "res2"}, "tools": []string{"tool3", "tool4"}, }, }, }, }, } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update MCP: %v", err) } // Verify updates - MCP is in default fields retrieved, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.MCP == nil || len(retrieved.MCP.Servers) != 3 { t.Fatalf("Expected 3 MCP servers, got %d", len(retrieved.MCP.Servers)) } // Verify server1 (simple format) if retrieved.MCP.Servers[0].ServerID != "server1" { t.Errorf("Expected server1, got '%s'", retrieved.MCP.Servers[0].ServerID) } if len(retrieved.MCP.Servers[0].Tools) != 0 { t.Errorf("Expected no tools for server1, got %v", retrieved.MCP.Servers[0].Tools) } // Verify server2 (tools only) if retrieved.MCP.Servers[1].ServerID != "server2" { t.Errorf("Expected server2, got '%s'", retrieved.MCP.Servers[1].ServerID) } if len(retrieved.MCP.Servers[1].Tools) != 2 { t.Errorf("Expected 2 tools for server2, got %d", len(retrieved.MCP.Servers[1].Tools)) } if retrieved.MCP.Servers[1].Tools[0] != "tool1" { t.Errorf("Expected tool1, got '%s'", retrieved.MCP.Servers[1].Tools[0]) } // Verify server3 (full config) if retrieved.MCP.Servers[2].ServerID != "server3" { t.Errorf("Expected server3, got '%s'", retrieved.MCP.Servers[2].ServerID) } if len(retrieved.MCP.Servers[2].Resources) != 2 { t.Errorf("Expected 2 resources for server3, got %d", len(retrieved.MCP.Servers[2].Resources)) } if len(retrieved.MCP.Servers[2].Tools) != 2 { t.Errorf("Expected 2 tools for server3, got %d", len(retrieved.MCP.Servers[2].Tools)) } if retrieved.MCP.Servers[2].Resources[0] != "res1" { t.Errorf("Expected res1, got '%s'", retrieved.MCP.Servers[2].Resources[0]) } if retrieved.MCP.Servers[2].Tools[0] != "tool3" { t.Errorf("Expected tool3, got '%s'", retrieved.MCP.Servers[2].Tools[0]) } t.Logf("Successfully verified MCP advanced configuration for assistant %s", id) }) t.Run("UpdateUses", func(t *testing.T) { // Create assistant without uses assistant := &types.AssistantModel{ Name: "Uses Update Test", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update with uses configuration updates := map[string]interface{}{ "uses": &context.Uses{ Vision: "mcp:new-vision", Audio: "mcp:new-audio", Search: "agent", Fetch: "mcp:fetch-server", }, } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update uses: %v", err) } // Verify updates - uses is NOT in default fields retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Uses == nil { t.Fatal("Expected uses to be set") } if retrieved.Uses.Vision != "mcp:new-vision" { t.Errorf("Expected vision 'mcp:new-vision', got '%s'", retrieved.Uses.Vision) } if retrieved.Uses.Audio != "mcp:new-audio" { t.Errorf("Expected audio 'mcp:new-audio', got '%s'", retrieved.Uses.Audio) } if retrieved.Uses.Search != "agent" { t.Errorf("Expected search 'agent', got '%s'", retrieved.Uses.Search) } if retrieved.Uses.Fetch != "mcp:fetch-server" { t.Errorf("Expected fetch 'mcp:fetch-server', got '%s'", retrieved.Uses.Fetch) } // Update to change uses updates2 := map[string]interface{}{ "uses": &context.Uses{ Vision: "agent", Audio: "agent", }, } err = store.UpdateAssistant(id, updates2) if err != nil { t.Fatalf("Failed to update uses again: %v", err) } // Verify second update - uses is NOT in default fields retrieved2, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved2.Uses.Vision != "agent" { t.Errorf("Expected vision 'agent', got '%s'", retrieved2.Uses.Vision) } if retrieved2.Uses.Audio != "agent" { t.Errorf("Expected audio 'agent', got '%s'", retrieved2.Uses.Audio) } // Update to remove uses (set to nil) updates3 := map[string]interface{}{ "uses": nil, } err = store.UpdateAssistant(id, updates3) if err != nil { t.Fatalf("Failed to set uses to nil: %v", err) } // Verify uses is nil - uses is NOT in default fields retrieved3, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved3.Uses != nil { t.Errorf("Expected uses to be nil, got %+v", retrieved3.Uses) } }) t.Run("UpdateSearch", func(t *testing.T) { // Create assistant without search assistant := &types.AssistantModel{ Name: "Search Update Test", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update with search configuration updates := map[string]interface{}{ "search": &searchTypes.Config{ Web: &searchTypes.WebConfig{ Provider: "tavily", MaxResults: 20, }, KB: &searchTypes.KBConfig{ Collections: []string{"knowledge"}, Threshold: 0.75, }, }, } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update search: %v", err) } // Verify updates - search is NOT in default fields retrieved, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Search == nil { t.Fatal("Expected search to be set") } if retrieved.Search.Web == nil { t.Fatal("Expected search.web to be set") } if retrieved.Search.Web.Provider != "tavily" { t.Errorf("Expected web provider 'tavily', got '%s'", retrieved.Search.Web.Provider) } if retrieved.Search.Web.MaxResults != 20 { t.Errorf("Expected web max_results 20, got %d", retrieved.Search.Web.MaxResults) } if retrieved.Search.KB == nil { t.Fatal("Expected search.kb to be set") } if len(retrieved.Search.KB.Collections) != 1 { t.Errorf("Expected 1 KB collection, got %d", len(retrieved.Search.KB.Collections)) } if retrieved.Search.KB.Threshold != 0.75 { t.Errorf("Expected KB threshold 0.75, got %f", retrieved.Search.KB.Threshold) } // Update to change search configuration updates2 := map[string]interface{}{ "search": &searchTypes.Config{ Web: &searchTypes.WebConfig{ Provider: "serper", MaxResults: 30, }, Citation: &searchTypes.CitationConfig{ Format: "#cite:{id}", AutoInjectPrompt: false, }, }, } err = store.UpdateAssistant(id, updates2) if err != nil { t.Fatalf("Failed to update search again: %v", err) } // Verify second update - search is NOT in default fields retrieved2, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved2.Search.Web.Provider != "serper" { t.Errorf("Expected web provider 'serper', got '%s'", retrieved2.Search.Web.Provider) } if retrieved2.Search.Web.MaxResults != 30 { t.Errorf("Expected web max_results 30, got %d", retrieved2.Search.Web.MaxResults) } if retrieved2.Search.Citation == nil { t.Fatal("Expected search.citation to be set") } if retrieved2.Search.Citation.Format != "#cite:{id}" { t.Errorf("Expected citation format '#cite:{id}', got '%s'", retrieved2.Search.Citation.Format) } // Update to remove search (set to nil) updates3 := map[string]interface{}{ "search": nil, } err = store.UpdateAssistant(id, updates3) if err != nil { t.Fatalf("Failed to set search to nil: %v", err) } // Verify search is nil - search is NOT in default fields retrieved3, err := store.GetAssistant(id, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved3.Search != nil { t.Errorf("Expected search to be nil, got %+v", retrieved3.Search) } }) t.Run("UpdatePermissionFields", func(t *testing.T) { // Create assistant with permission fields assistant := &types.AssistantModel{ Name: "Permission Test", Type: "assistant", Connector: "openai", Share: "private", YaoCreatedBy: "user-1", YaoTeamID: "team-1", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update permission fields updates := map[string]interface{}{ "__yao_updated_by": "user-2", "__yao_tenant_id": "tenant-1", } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update permission fields: %v", err) } // Verify updates - permission fields are in default fields retrieved, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.YaoUpdatedBy != "user-2" { t.Errorf("Expected YaoUpdatedBy 'user-2', got '%s'", retrieved.YaoUpdatedBy) } if retrieved.YaoTenantID != "tenant-1" { t.Errorf("Expected YaoTenantID 'tenant-1', got '%s'", retrieved.YaoTenantID) } // Created by should remain unchanged if retrieved.YaoCreatedBy != "user-1" { t.Errorf("Expected YaoCreatedBy 'user-1', got '%s'", retrieved.YaoCreatedBy) } }) t.Run("UpdateWithEmptyStrings", func(t *testing.T) { // Create assistant with values assistant := &types.AssistantModel{ Name: "Empty String Test", Type: "assistant", Connector: "openai", Avatar: "https://example.com/avatar.png", Description: "Some description", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Update with empty strings (should become NULL) updates := map[string]interface{}{ "avatar": "", "description": "", } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update with empty strings: %v", err) } // Verify empty strings are stored as NULL - default fields include avatar, description retrieved, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.Avatar != "" { t.Errorf("Expected empty avatar, got '%s'", retrieved.Avatar) } if retrieved.Description != "" { t.Errorf("Expected empty description, got '%s'", retrieved.Description) } // Name should remain unchanged if retrieved.Name != "Empty String Test" { t.Errorf("Expected name 'Empty String Test', got '%s'", retrieved.Name) } }) t.Run("UpdateNonExistentAssistant", func(t *testing.T) { updates := map[string]interface{}{ "name": "Updated Name", } err := store.UpdateAssistant("nonexistent-id", updates) if err == nil { t.Error("Expected error when updating non-existent assistant") } if !strings.Contains(err.Error(), "not found") { t.Errorf("Expected 'not found' error, got: %v", err) } }) t.Run("UpdateWithEmptyID", func(t *testing.T) { updates := map[string]interface{}{ "name": "Updated Name", } err := store.UpdateAssistant("", updates) if err == nil { t.Error("Expected error when updating with empty ID") } if !strings.Contains(err.Error(), "required") { t.Errorf("Expected 'required' error, got: %v", err) } }) t.Run("UpdateWithEmptyUpdates", func(t *testing.T) { // Create assistant assistant := &types.AssistantModel{ Name: "Empty Updates Test", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Try to update with empty map updates := map[string]interface{}{} err = store.UpdateAssistant(id, updates) if err == nil { t.Error("Expected error when updating with no fields") } if !strings.Contains(err.Error(), "no fields to update") { t.Errorf("Expected 'no fields to update' error, got: %v", err) } }) t.Run("UpdateTimestampAutomaticallySet", func(t *testing.T) { // Create assistant assistant := &types.AssistantModel{ Name: "Timestamp Test", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Get original updated_at - default fields include updated_at original, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } // Wait a bit to ensure timestamp difference time.Sleep(100 * time.Millisecond) // Update assistant updates := map[string]interface{}{ "description": "Updated to test timestamp", } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update assistant: %v", err) } // Get updated assistant - default fields include description, updated_at updated, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to retrieve updated assistant: %v", err) } // Verify description was updated (main test objective) if updated.Description != "Updated to test timestamp" { t.Errorf("Expected description 'Updated to test timestamp', got '%s'", updated.Description) } // Only check timestamp if both are set (some stores may not return timestamps) if original.UpdatedAt > 0 && updated.UpdatedAt > 0 { if updated.UpdatedAt <= original.UpdatedAt { t.Errorf("Expected updated_at to increase, original=%d, updated=%d", original.UpdatedAt, updated.UpdatedAt) } } else { t.Logf("Skipping timestamp comparison (original=%d, updated=%d)", original.UpdatedAt, updated.UpdatedAt) } }) t.Run("UpdateSkipsSystemFields", func(t *testing.T) { // Create assistant assistant := &types.AssistantModel{ Name: "System Fields Test", Type: "assistant", Connector: "openai", Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant: %v", err) } // Get original - default fields original, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } // Try to update system fields (should be ignored) updates := map[string]interface{}{ "assistant_id": "new-id-123", // Should be ignored "created_at": int64(123456789), // Should be ignored "name": "Valid Update", // Should be applied } err = store.UpdateAssistant(id, updates) if err != nil { t.Fatalf("Failed to update assistant: %v", err) } // Verify system fields unchanged, but name updated - default fields retrieved, err := store.GetAssistant(id, nil) if err != nil { t.Fatalf("Failed to retrieve assistant: %v", err) } if retrieved.ID != id { t.Errorf("Expected ID to remain %s, got %s", id, retrieved.ID) } if retrieved.CreatedAt != original.CreatedAt { t.Errorf("Expected created_at to remain unchanged") } if retrieved.Name != "Valid Update" { t.Errorf("Expected name 'Valid Update', got '%s'", retrieved.Name) } }) } // TestAssistantCompleteWorkflow tests a complete workflow func TestAssistantCompleteWorkflow(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("CompleteWorkflow", func(t *testing.T) { // Step 1: Create multiple assistants assistantIDs := []string{} for i := 0; i < 3; i++ { assistant := &types.AssistantModel{ Name: fmt.Sprintf("Workflow Assistant %d", i), Type: "assistant", Connector: "openai", Description: fmt.Sprintf("Workflow test assistant %d", i), Tags: []string{"workflow", fmt.Sprintf("test-%d", i)}, Sort: i * 100, Share: "private", } id, err := store.SaveAssistant(assistant) if err != nil { t.Fatalf("Failed to create assistant %d: %v", i, err) } assistantIDs = append(assistantIDs, id) } t.Logf("Created %d assistants", len(assistantIDs)) // Step 2: Retrieve all assistants response, err := store.GetAssistants(types.AssistantFilter{ Tags: []string{"workflow"}, Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to get assistants: %v", err) } if len(response.Data) < 3 { t.Errorf("Expected at least 3 assistants, got %d", len(response.Data)) } // Step 3: Update one assistant - need full fields for tags updatedID := assistantIDs[1] updatedAssistant, err := store.GetAssistant(updatedID, types.AssistantFullFields) if err != nil { t.Fatalf("Failed to get assistant for update: %v", err) } updatedAssistant.Description = "Updated workflow description" updatedAssistant.Tags = append(updatedAssistant.Tags, "updated") _, err = store.SaveAssistant(updatedAssistant) if err != nil { t.Fatalf("Failed to update assistant: %v", err) } // Verify update - default fields include description verifyAssistant, err := store.GetAssistant(updatedID, nil) if err != nil { t.Fatalf("Failed to verify update: %v", err) } if verifyAssistant.Description != "Updated workflow description" { t.Errorf("Update not applied correctly") } // Step 4: Delete one assistant err = store.DeleteAssistant(assistantIDs[0]) if err != nil { t.Fatalf("Failed to delete assistant: %v", err) } // Verify deletion _, err = store.GetAssistant(assistantIDs[0], nil) if err == nil { t.Error("Expected error when getting deleted assistant") } // Step 5: Get tags tags, err := store.GetAssistantTags(types.AssistantFilter{}) if err != nil { t.Fatalf("Failed to get tags: %v", err) } // Should find "workflow" tag found := false for _, tag := range tags { if tag.Value == "workflow" { found = true break } } if !found { t.Error("Expected to find 'workflow' tag") } // Step 6: Bulk delete remaining assistants count, err := store.DeleteAssistants(types.AssistantFilter{ Tags: []string{"workflow"}, }) if err != nil { t.Fatalf("Failed to bulk delete: %v", err) } t.Logf("Bulk deleted %d assistants", count) // Verify bulk deletion finalResponse, err := store.GetAssistants(types.AssistantFilter{ Tags: []string{"workflow"}, Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to verify bulk deletion: %v", err) } if len(finalResponse.Data) > 0 { t.Logf("Warning: Still found %d assistants after bulk delete", len(finalResponse.Data)) } }) } ================================================ FILE: agent/store/xun/chat.go ================================================ package xun import ( "fmt" "math" "time" "github.com/google/uuid" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/yao/agent/store/types" ) // ============================================================================= // Chat Management // ============================================================================= // CreateChat creates a new chat session func (store *Xun) CreateChat(chat *types.Chat) error { if chat == nil { return fmt.Errorf("chat cannot be nil") } // Validate required fields if chat.AssistantID == "" { return fmt.Errorf("assistant_id is required") } // Generate chat_id if not provided if chat.ChatID == "" { chat.ChatID = uuid.New().String() } // Check if chat already exists exists, err := store.newQueryChat(). Where("chat_id", chat.ChatID). Exists() if err != nil { return err } if exists { return fmt.Errorf("chat %s already exists", chat.ChatID) } // Set defaults if chat.Status == "" { chat.Status = "active" } if chat.Share == "" { chat.Share = "private" } // Prepare data data := map[string]interface{}{ "chat_id": chat.ChatID, "assistant_id": chat.AssistantID, "status": chat.Status, "public": chat.Public, "share": chat.Share, "sort": chat.Sort, "created_at": time.Now(), "updated_at": time.Now(), } // Handle last_mode (nullable) if chat.LastMode != "" { data["last_mode"] = chat.LastMode } // Handle nullable fields if chat.Title != "" { data["title"] = chat.Title } if chat.LastConnector != "" { data["last_connector"] = chat.LastConnector } if chat.LastMode != "" { data["last_mode"] = chat.LastMode } if chat.LastMessageAt != nil { data["last_message_at"] = *chat.LastMessageAt } if chat.Metadata != nil { metadataJSON, err := jsoniter.MarshalToString(chat.Metadata) if err != nil { return fmt.Errorf("failed to marshal metadata: %w", err) } data["metadata"] = metadataJSON } // Handle permission fields (Yao framework permission: true) if chat.CreatedBy != "" { data["__yao_created_by"] = chat.CreatedBy } if chat.UpdatedBy != "" { data["__yao_updated_by"] = chat.UpdatedBy } if chat.TeamID != "" { data["__yao_team_id"] = chat.TeamID } if chat.TenantID != "" { data["__yao_tenant_id"] = chat.TenantID } // Insert return store.newQueryChat().Insert(data) } // GetChat retrieves a single chat by ID func (store *Xun) GetChat(chatID string) (*types.Chat, error) { if chatID == "" { return nil, fmt.Errorf("chat_id is required") } row, err := store.newQueryChat(). Where("chat_id", chatID). WhereNull("deleted_at"). First() if err != nil { return nil, err } if row == nil { return nil, fmt.Errorf("chat %s not found", chatID) } data := row.ToMap() if len(data) == 0 || data["chat_id"] == nil { return nil, fmt.Errorf("chat %s not found", chatID) } return store.rowToChat(data) } // UpdateChat updates chat fields func (store *Xun) UpdateChat(chatID string, updates map[string]interface{}) error { if chatID == "" { return fmt.Errorf("chat_id is required") } if len(updates) == 0 { return fmt.Errorf("no fields to update") } // Check if chat exists exists, err := store.newQueryChat(). Where("chat_id", chatID). WhereNull("deleted_at"). Exists() if err != nil { return err } if !exists { return fmt.Errorf("chat %s not found", chatID) } // Prepare update data data := make(map[string]interface{}) // Process each update field for key, value := range updates { // Skip system fields if key == "chat_id" || key == "created_at" { continue } // Handle metadata specially if key == "metadata" { if value != nil { metadataJSON, err := jsoniter.MarshalToString(value) if err != nil { return fmt.Errorf("failed to marshal metadata: %w", err) } data["metadata"] = metadataJSON } else { data["metadata"] = nil } continue } data[key] = value } // Always update updated_at data["updated_at"] = time.Now() if len(data) == 0 { return fmt.Errorf("no valid fields to update") } _, err = store.newQueryChat(). Where("chat_id", chatID). Update(data) return err } // DeleteChat deletes a chat and its associated messages (soft delete) func (store *Xun) DeleteChat(chatID string) error { if chatID == "" { return fmt.Errorf("chat_id is required") } // Check if chat exists exists, err := store.newQueryChat(). Where("chat_id", chatID). WhereNull("deleted_at"). Exists() if err != nil { return err } if !exists { return fmt.Errorf("chat %s not found", chatID) } // Soft delete the chat _, err = store.newQueryChat(). Where("chat_id", chatID). Update(map[string]interface{}{ "deleted_at": time.Now(), "updated_at": time.Now(), }) return err } // ListChats retrieves a paginated list of chats with optional grouping func (store *Xun) ListChats(filter types.ChatFilter) (*types.ChatList, error) { // Set defaults if filter.Page <= 0 { filter.Page = 1 } if filter.PageSize <= 0 { filter.PageSize = 20 } if filter.OrderBy == "" { filter.OrderBy = "last_message_at" } if filter.Order == "" { filter.Order = "desc" } if filter.TimeField == "" { filter.TimeField = "last_message_at" } // Build base query qb := store.newQueryChat().WhereNull("deleted_at") // Apply permission filters (UserID and TeamID) if filter.UserID != "" { qb.Where("__yao_created_by", filter.UserID) } if filter.TeamID != "" { qb.Where("__yao_team_id", filter.TeamID) } // Apply business filters if filter.AssistantID != "" { qb.Where("assistant_id", filter.AssistantID) } if filter.Status != "" { qb.Where("status", filter.Status) } if filter.Keywords != "" { qb.Where("title", "like", fmt.Sprintf("%%%s%%", filter.Keywords)) } if filter.ChatIDPrefix != "" { qb.Where("chat_id", "like", filter.ChatIDPrefix+"%") } // Apply time range filter if filter.StartTime != nil { qb.Where(filter.TimeField, ">=", *filter.StartTime) } if filter.EndTime != nil { qb.Where(filter.TimeField, "<=", *filter.EndTime) } // Apply custom query filter (for advanced permission filtering) // This allows flexible combinations like: (created_by = user OR team_id = team) if filter.QueryFilter != nil { qb.Where(filter.QueryFilter) } // Get total count total, err := qb.Clone().Count() if err != nil { return nil, err } // Calculate pagination pageCount := int(math.Ceil(float64(total) / float64(filter.PageSize))) if pageCount < 1 { pageCount = 1 } offset := (filter.Page - 1) * filter.PageSize // Get paginated results rows, err := qb.OrderBy(filter.OrderBy, filter.Order). Offset(offset). Limit(filter.PageSize). Get() if err != nil { return nil, err } // Convert rows to Chat objects chats := make([]*types.Chat, 0, len(rows)) for _, row := range rows { data := row.ToMap() if data == nil || data["chat_id"] == nil { continue } chat, err := store.rowToChat(data) if err != nil { continue } chats = append(chats, chat) } result := &types.ChatList{ Data: chats, Page: filter.Page, PageSize: filter.PageSize, PageCount: pageCount, Total: int(total), } // Apply time-based grouping if requested if filter.GroupBy == "time" { result.Groups = store.groupChatsByTime(chats) } return result, nil } // ============================================================================= // Helper Functions // ============================================================================= // rowToChat converts a database row to a Chat struct func (store *Xun) rowToChat(data map[string]interface{}) (*types.Chat, error) { chat := &types.Chat{ ChatID: getString(data, "chat_id"), Title: getString(data, "title"), AssistantID: getString(data, "assistant_id"), LastConnector: getString(data, "last_connector"), LastMode: getString(data, "last_mode"), Status: getString(data, "status"), Public: getBool(data, "public"), Share: getString(data, "share"), Sort: getInt(data, "sort"), } // Handle timestamps if createdAt := getTime(data, "created_at"); createdAt != nil { chat.CreatedAt = *createdAt } if updatedAt := getTime(data, "updated_at"); updatedAt != nil { chat.UpdatedAt = *updatedAt } if lastMsgAt := getTime(data, "last_message_at"); lastMsgAt != nil { chat.LastMessageAt = lastMsgAt } // Handle metadata if metadata := data["metadata"]; metadata != nil { if metaStr, ok := metadata.(string); ok && metaStr != "" { var meta map[string]interface{} if err := jsoniter.UnmarshalFromString(metaStr, &meta); err == nil { chat.Metadata = meta } } else if metaMap, ok := metadata.(map[string]interface{}); ok { chat.Metadata = metaMap } } // Handle permission fields chat.CreatedBy = getString(data, "__yao_created_by") chat.UpdatedBy = getString(data, "__yao_updated_by") chat.TeamID = getString(data, "__yao_team_id") chat.TenantID = getString(data, "__yao_tenant_id") return chat, nil } // groupChatsByTime groups chats by time periods func (store *Xun) groupChatsByTime(chats []*types.Chat) []*types.ChatGroup { now := time.Now() today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) yesterday := today.AddDate(0, 0, -1) thisWeekStart := today.AddDate(0, 0, -int(today.Weekday())) thisMonthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()) groups := map[string]*types.ChatGroup{ "today": {Key: "today", Label: "Today", Chats: []*types.Chat{}}, "yesterday": {Key: "yesterday", Label: "Yesterday", Chats: []*types.Chat{}}, "this_week": {Key: "this_week", Label: "This Week", Chats: []*types.Chat{}}, "this_month": {Key: "this_month", Label: "This Month", Chats: []*types.Chat{}}, "earlier": {Key: "earlier", Label: "Earlier", Chats: []*types.Chat{}}, } for _, chat := range chats { // Use last_message_at if available, otherwise created_at var chatTime time.Time if chat.LastMessageAt != nil { chatTime = *chat.LastMessageAt } else { chatTime = chat.CreatedAt } chatDate := time.Date(chatTime.Year(), chatTime.Month(), chatTime.Day(), 0, 0, 0, 0, chatTime.Location()) switch { case chatDate.Equal(today) || chatDate.After(today): groups["today"].Chats = append(groups["today"].Chats, chat) case chatDate.Equal(yesterday): groups["yesterday"].Chats = append(groups["yesterday"].Chats, chat) case chatDate.After(thisWeekStart) || chatDate.Equal(thisWeekStart): groups["this_week"].Chats = append(groups["this_week"].Chats, chat) case chatDate.After(thisMonthStart) || chatDate.Equal(thisMonthStart): groups["this_month"].Chats = append(groups["this_month"].Chats, chat) default: groups["earlier"].Chats = append(groups["earlier"].Chats, chat) } } // Update counts and filter empty groups result := make([]*types.ChatGroup, 0) for _, key := range []string{"today", "yesterday", "this_week", "this_month", "earlier"} { group := groups[key] group.Count = len(group.Chats) if group.Count > 0 { result = append(result, group) } } return result } // getTime helper function to convert database value to time.Time pointer func getTime(data map[string]interface{}, key string) *time.Time { if v := data[key]; v != nil { switch t := v.(type) { case time.Time: return &t case *time.Time: return t case string: // Try parsing various formats formats := []string{ time.RFC3339, "2006-01-02 15:04:05", "2006-01-02 15:04:05.999999-07:00", "2006-01-02T15:04:05Z", } for _, format := range formats { if parsed, err := time.Parse(format, t); err == nil { return &parsed } } } } return nil } // UpdateChatLastMessageAt updates the last_message_at timestamp for a chat func (store *Xun) UpdateChatLastMessageAt(chatID string, timestamp time.Time) error { if chatID == "" { return fmt.Errorf("chat_id is required") } _, err := store.newQueryChat(). Where("chat_id", chatID). Update(map[string]interface{}{ "last_message_at": timestamp, "updated_at": time.Now(), }) return err } ================================================ FILE: agent/store/xun/chat_test.go ================================================ package xun_test import ( "fmt" "testing" "time" goumodel "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/dbal/query" "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/agent/store/xun" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // TestCreateChat tests creating chat sessions func TestCreateChat(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("CreateNewChat", func(t *testing.T) { chat := &types.Chat{ AssistantID: "test_assistant", Title: "Test Chat", Status: "active", Share: "private", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } if chat.ChatID == "" { t.Error("Expected chat_id to be generated") } t.Logf("Created chat with ID: %s", chat.ChatID) // Clean up _ = store.DeleteChat(chat.ChatID) }) t.Run("CreateChatWithAllFields", func(t *testing.T) { now := time.Now() chat := &types.Chat{ AssistantID: "test_assistant", LastConnector: "openai", Title: "Full Chat", LastMode: "task", Status: "active", Public: true, Share: "team", Sort: 100, LastMessageAt: &now, Metadata: map[string]interface{}{ "source": "test", "tags": []string{"test", "chat"}, }, } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } // Retrieve and verify retrieved, err := store.GetChat(chat.ChatID) if err != nil { t.Fatalf("Failed to retrieve chat: %v", err) } if retrieved.Title != "Full Chat" { t.Errorf("Expected title 'Full Chat', got '%s'", retrieved.Title) } if retrieved.LastConnector != "openai" { t.Errorf("Expected last_connector 'openai', got '%s'", retrieved.LastConnector) } if retrieved.LastMode != "task" { t.Errorf("Expected last_mode 'task', got '%s'", retrieved.LastMode) } if !retrieved.Public { t.Error("Expected public to be true") } if retrieved.Share != "team" { t.Errorf("Expected share 'team', got '%s'", retrieved.Share) } if retrieved.Sort != 100 { t.Errorf("Expected sort 100, got %d", retrieved.Sort) } if retrieved.Metadata == nil { t.Error("Expected metadata to be set") } // Clean up _ = store.DeleteChat(chat.ChatID) }) t.Run("CreateChatWithCustomID", func(t *testing.T) { customID := fmt.Sprintf("custom_chat_%d", time.Now().UnixNano()) chat := &types.Chat{ ChatID: customID, AssistantID: "test_assistant", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } if chat.ChatID != customID { t.Errorf("Expected chat_id '%s', got '%s'", customID, chat.ChatID) } // Clean up _ = store.DeleteChat(chat.ChatID) }) t.Run("CreateDuplicateChatFails", func(t *testing.T) { chat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create first chat: %v", err) } // Try to create with same ID duplicateChat := &types.Chat{ ChatID: chat.ChatID, AssistantID: "test_assistant", } err = store.CreateChat(duplicateChat) if err == nil { t.Error("Expected error when creating duplicate chat") } // Clean up _ = store.DeleteChat(chat.ChatID) }) t.Run("CreateChatWithoutAssistantIDFails", func(t *testing.T) { chat := &types.Chat{ Title: "No Assistant", } err := store.CreateChat(chat) if err == nil { t.Error("Expected error when creating chat without assistant_id") } }) t.Run("CreateNilChatFails", func(t *testing.T) { err := store.CreateChat(nil) if err == nil { t.Error("Expected error when creating nil chat") } }) t.Run("CreateChatWithDefaults", func(t *testing.T) { chat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } // Retrieve and verify defaults retrieved, err := store.GetChat(chat.ChatID) if err != nil { t.Fatalf("Failed to retrieve chat: %v", err) } // last_mode is nullable, so it should be empty by default if retrieved.LastMode != "" { t.Errorf("Expected default last_mode to be empty, got '%s'", retrieved.LastMode) } if retrieved.Status != "active" { t.Errorf("Expected default status 'active', got '%s'", retrieved.Status) } if retrieved.Share != "private" { t.Errorf("Expected default share 'private', got '%s'", retrieved.Share) } // Clean up _ = store.DeleteChat(chat.ChatID) }) } // TestGetChat tests retrieving chat sessions func TestGetChat(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("GetExistingChat", func(t *testing.T) { // Create chat first chat := &types.Chat{ AssistantID: "test_assistant", Title: "Get Test Chat", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } // Get it retrieved, err := store.GetChat(chat.ChatID) if err != nil { t.Fatalf("Failed to get chat: %v", err) } if retrieved.ChatID != chat.ChatID { t.Errorf("Expected chat_id '%s', got '%s'", chat.ChatID, retrieved.ChatID) } if retrieved.Title != "Get Test Chat" { t.Errorf("Expected title 'Get Test Chat', got '%s'", retrieved.Title) } // Clean up _ = store.DeleteChat(chat.ChatID) }) t.Run("GetNonExistentChat", func(t *testing.T) { _, err := store.GetChat("nonexistent_chat_id") if err == nil { t.Error("Expected error when getting non-existent chat") } }) t.Run("GetChatWithEmptyID", func(t *testing.T) { _, err := store.GetChat("") if err == nil { t.Error("Expected error when getting chat with empty ID") } }) t.Run("GetDeletedChatFails", func(t *testing.T) { // Create and delete chat chat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } err = store.DeleteChat(chat.ChatID) if err != nil { t.Fatalf("Failed to delete chat: %v", err) } // Try to get deleted chat _, err = store.GetChat(chat.ChatID) if err == nil { t.Error("Expected error when getting deleted chat") } }) } // TestUpdateChat tests updating chat sessions func TestUpdateChat(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("UpdateTitle", func(t *testing.T) { chat := &types.Chat{ AssistantID: "test_assistant", Title: "Original Title", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } err = store.UpdateChat(chat.ChatID, map[string]interface{}{ "title": "Updated Title", }) if err != nil { t.Fatalf("Failed to update chat: %v", err) } retrieved, err := store.GetChat(chat.ChatID) if err != nil { t.Fatalf("Failed to retrieve chat: %v", err) } if retrieved.Title != "Updated Title" { t.Errorf("Expected title 'Updated Title', got '%s'", retrieved.Title) } // Clean up _ = store.DeleteChat(chat.ChatID) }) t.Run("UpdateLastConnector", func(t *testing.T) { chat := &types.Chat{ AssistantID: "test_assistant", LastConnector: "openai", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } // Verify initial connector retrieved, err := store.GetChat(chat.ChatID) if err != nil { t.Fatalf("Failed to retrieve chat: %v", err) } if retrieved.LastConnector != "openai" { t.Errorf("Expected last_connector 'openai', got '%s'", retrieved.LastConnector) } // Update to different connector (simulating user switching connector) err = store.UpdateChat(chat.ChatID, map[string]interface{}{ "last_connector": "anthropic", }) if err != nil { t.Fatalf("Failed to update chat: %v", err) } // Verify updated connector retrieved, err = store.GetChat(chat.ChatID) if err != nil { t.Fatalf("Failed to retrieve chat: %v", err) } if retrieved.LastConnector != "anthropic" { t.Errorf("Expected last_connector 'anthropic', got '%s'", retrieved.LastConnector) } // Clean up _ = store.DeleteChat(chat.ChatID) }) t.Run("UpdateLastConnectorAndLastMessageAt", func(t *testing.T) { // This simulates what FlushBuffer does chat := &types.Chat{ AssistantID: "test_assistant", LastConnector: "openai", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } // Update both fields together (like FlushBuffer does) now := time.Now() err = store.UpdateChat(chat.ChatID, map[string]interface{}{ "last_message_at": now, "last_connector": "claude", }) if err != nil { t.Fatalf("Failed to update chat: %v", err) } retrieved, err := store.GetChat(chat.ChatID) if err != nil { t.Fatalf("Failed to retrieve chat: %v", err) } if retrieved.LastConnector != "claude" { t.Errorf("Expected last_connector 'claude', got '%s'", retrieved.LastConnector) } if retrieved.LastMessageAt == nil { t.Error("Expected last_message_at to be set") } // Clean up _ = store.DeleteChat(chat.ChatID) }) t.Run("UpdateMultipleFields", func(t *testing.T) { chat := &types.Chat{ AssistantID: "test_assistant", Title: "Original", Status: "active", Share: "private", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } err = store.UpdateChat(chat.ChatID, map[string]interface{}{ "title": "Updated", "status": "archived", "share": "team", "public": true, "sort": 50, }) if err != nil { t.Fatalf("Failed to update chat: %v", err) } retrieved, err := store.GetChat(chat.ChatID) if err != nil { t.Fatalf("Failed to retrieve chat: %v", err) } if retrieved.Title != "Updated" { t.Errorf("Expected title 'Updated', got '%s'", retrieved.Title) } if retrieved.Status != "archived" { t.Errorf("Expected status 'archived', got '%s'", retrieved.Status) } if retrieved.Share != "team" { t.Errorf("Expected share 'team', got '%s'", retrieved.Share) } if !retrieved.Public { t.Error("Expected public to be true") } if retrieved.Sort != 50 { t.Errorf("Expected sort 50, got %d", retrieved.Sort) } // Clean up _ = store.DeleteChat(chat.ChatID) }) t.Run("UpdateMetadata", func(t *testing.T) { chat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } err = store.UpdateChat(chat.ChatID, map[string]interface{}{ "metadata": map[string]interface{}{ "key1": "value1", "key2": 123, }, }) if err != nil { t.Fatalf("Failed to update metadata: %v", err) } retrieved, err := store.GetChat(chat.ChatID) if err != nil { t.Fatalf("Failed to retrieve chat: %v", err) } if retrieved.Metadata == nil { t.Fatal("Expected metadata to be set") } if retrieved.Metadata["key1"] != "value1" { t.Errorf("Expected metadata key1 'value1', got '%v'", retrieved.Metadata["key1"]) } // Clean up _ = store.DeleteChat(chat.ChatID) }) t.Run("UpdateNonExistentChatFails", func(t *testing.T) { err := store.UpdateChat("nonexistent_chat", map[string]interface{}{ "title": "Test", }) if err == nil { t.Error("Expected error when updating non-existent chat") } }) t.Run("UpdateWithEmptyIDFails", func(t *testing.T) { err := store.UpdateChat("", map[string]interface{}{ "title": "Test", }) if err == nil { t.Error("Expected error when updating with empty ID") } }) t.Run("UpdateWithEmptyFieldsFails", func(t *testing.T) { chat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } err = store.UpdateChat(chat.ChatID, map[string]interface{}{}) if err == nil { t.Error("Expected error when updating with empty fields") } // Clean up _ = store.DeleteChat(chat.ChatID) }) t.Run("UpdateSkipsSystemFields", func(t *testing.T) { chat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } originalID := chat.ChatID // Try to update system fields err = store.UpdateChat(chat.ChatID, map[string]interface{}{ "chat_id": "new_id", "title": "Valid Update", }) if err != nil { t.Fatalf("Failed to update chat: %v", err) } // Verify chat_id unchanged retrieved, err := store.GetChat(originalID) if err != nil { t.Fatalf("Failed to retrieve chat: %v", err) } if retrieved.ChatID != originalID { t.Errorf("Expected chat_id to remain '%s', got '%s'", originalID, retrieved.ChatID) } if retrieved.Title != "Valid Update" { t.Errorf("Expected title 'Valid Update', got '%s'", retrieved.Title) } // Clean up _ = store.DeleteChat(chat.ChatID) }) } // TestDeleteChat tests deleting chat sessions func TestDeleteChat(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("DeleteExistingChat", func(t *testing.T) { chat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } err = store.DeleteChat(chat.ChatID) if err != nil { t.Fatalf("Failed to delete chat: %v", err) } // Verify deleted _, err = store.GetChat(chat.ChatID) if err == nil { t.Error("Expected error when getting deleted chat") } }) t.Run("DeleteNonExistentChatFails", func(t *testing.T) { err := store.DeleteChat("nonexistent_chat") if err == nil { t.Error("Expected error when deleting non-existent chat") } }) t.Run("DeleteWithEmptyIDFails", func(t *testing.T) { err := store.DeleteChat("") if err == nil { t.Error("Expected error when deleting with empty ID") } }) t.Run("DeleteAlreadyDeletedChatFails", func(t *testing.T) { chat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } // Delete first time err = store.DeleteChat(chat.ChatID) if err != nil { t.Fatalf("Failed to delete chat: %v", err) } // Try to delete again err = store.DeleteChat(chat.ChatID) if err == nil { t.Error("Expected error when deleting already deleted chat") } }) } // TestListChats tests listing chat sessions func TestListChats(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } // Create test chats chatIDs := []string{} for i := 0; i < 5; i++ { chat := &types.Chat{ AssistantID: "test_assistant", Title: fmt.Sprintf("Chat %d", i), Status: "active", } if i >= 3 { chat.Status = "archived" } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } chatIDs = append(chatIDs, chat.ChatID) // Add small delay to ensure different timestamps time.Sleep(10 * time.Millisecond) } // Clean up at the end defer func() { for _, id := range chatIDs { _ = store.DeleteChat(id) } }() t.Run("ListAllChats", func(t *testing.T) { result, err := store.ListChats(types.ChatFilter{ Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to list chats: %v", err) } if len(result.Data) < 5 { t.Errorf("Expected at least 5 chats, got %d", len(result.Data)) } }) t.Run("ListChatsByStatus", func(t *testing.T) { result, err := store.ListChats(types.ChatFilter{ Status: "active", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to list chats: %v", err) } for _, chat := range result.Data { if chat.Status != "active" { t.Errorf("Expected status 'active', got '%s'", chat.Status) } } }) t.Run("ListChatsByAssistant", func(t *testing.T) { result, err := store.ListChats(types.ChatFilter{ AssistantID: "test_assistant", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to list chats: %v", err) } for _, chat := range result.Data { if chat.AssistantID != "test_assistant" { t.Errorf("Expected assistant_id 'test_assistant', got '%s'", chat.AssistantID) } } }) t.Run("ListChatsByKeywords", func(t *testing.T) { result, err := store.ListChats(types.ChatFilter{ Keywords: "Chat 1", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to list chats: %v", err) } found := false for _, chat := range result.Data { if chat.Title == "Chat 1" { found = true break } } if !found { t.Error("Expected to find chat with title 'Chat 1'") } }) t.Run("ListChatsPagination", func(t *testing.T) { // First page result1, err := store.ListChats(types.ChatFilter{ Page: 1, PageSize: 2, }) if err != nil { t.Fatalf("Failed to list first page: %v", err) } if len(result1.Data) > 2 { t.Errorf("Expected max 2 chats, got %d", len(result1.Data)) } if result1.Page != 1 { t.Errorf("Expected page 1, got %d", result1.Page) } if result1.PageSize != 2 { t.Errorf("Expected pagesize 2, got %d", result1.PageSize) } // Second page if result1.Total > 2 { result2, err := store.ListChats(types.ChatFilter{ Page: 2, PageSize: 2, }) if err != nil { t.Fatalf("Failed to list second page: %v", err) } if result2.Page != 2 { t.Errorf("Expected page 2, got %d", result2.Page) } } }) t.Run("ListChatsWithGrouping", func(t *testing.T) { result, err := store.ListChats(types.ChatFilter{ GroupBy: "time", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to list chats with grouping: %v", err) } // Should have groups when GroupBy is "time" if result.Groups == nil { t.Error("Expected groups to be set when GroupBy='time'") } // Verify group structure for _, group := range result.Groups { if group.Key == "" { t.Error("Expected group key to be set") } if group.Label == "" { t.Error("Expected group label to be set") } if group.Count != len(group.Chats) { t.Errorf("Expected count %d to match chats length %d", group.Count, len(group.Chats)) } } }) t.Run("ListChatsWithTimeRange", func(t *testing.T) { now := time.Now() yesterday := now.AddDate(0, 0, -1) result, err := store.ListChats(types.ChatFilter{ StartTime: &yesterday, EndTime: &now, TimeField: "created_at", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to list chats with time range: %v", err) } // Should return chats created within the time range t.Logf("Found %d chats in time range", len(result.Data)) }) t.Run("ListChatsWithSorting", func(t *testing.T) { // Ascending order resultAsc, err := store.ListChats(types.ChatFilter{ OrderBy: "created_at", Order: "asc", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to list chats ascending: %v", err) } // Descending order resultDesc, err := store.ListChats(types.ChatFilter{ OrderBy: "created_at", Order: "desc", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to list chats descending: %v", err) } // Verify different order if len(resultAsc.Data) > 1 && len(resultDesc.Data) > 1 { if resultAsc.Data[0].ChatID == resultDesc.Data[0].ChatID { // This is fine if there's only one chat, but otherwise order should differ if len(resultAsc.Data) > 1 { t.Logf("First chat in asc: %s, first in desc: %s", resultAsc.Data[0].ChatID, resultDesc.Data[0].ChatID) } } } }) t.Run("ListChatsWithQueryFilter", func(t *testing.T) { result, err := store.ListChats(types.ChatFilter{ Page: 1, PageSize: 20, QueryFilter: func(qb query.Query) { qb.Where("status", "active") }, }) if err != nil { t.Fatalf("Failed to list chats with query filter: %v", err) } for _, chat := range result.Data { if chat.Status != "active" { t.Errorf("Expected status 'active', got '%s'", chat.Status) } } }) } // TestListChatsByUserAndTeam tests filtering chats by UserID and TeamID func TestListChatsByUserAndTeam(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } // Create chats with different user/team combinations // Note: __yao_created_by and __yao_team_id are managed by Yao's permission system // For testing, we'll create chats and then update these fields directly via raw query chat1 := &types.Chat{AssistantID: "test_assistant", Title: "User1 Team1 Chat"} chat2 := &types.Chat{AssistantID: "test_assistant", Title: "User1 Team2 Chat"} chat3 := &types.Chat{AssistantID: "test_assistant", Title: "User2 Team1 Chat"} chat4 := &types.Chat{AssistantID: "test_assistant", Title: "User2 Team2 Chat"} for _, chat := range []*types.Chat{chat1, chat2, chat3, chat4} { err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } } defer func() { store.DeleteChat(chat1.ChatID) store.DeleteChat(chat2.ChatID) store.DeleteChat(chat3.ChatID) store.DeleteChat(chat4.ChatID) }() // Update permission fields directly for testing // In production, these would be set by Yao's permission middleware updatePermissionFields := func(chatID, userID, teamID string) error { // Use Yao model to update permission fields m := goumodel.Select("__yao.agent.chat") if m == nil { return fmt.Errorf("model __yao.agent.chat not found") } _, err := m.UpdateWhere( goumodel.QueryParam{Wheres: []goumodel.QueryWhere{{Column: "chat_id", Value: chatID}}}, map[string]interface{}{ "__yao_created_by": userID, "__yao_team_id": teamID, }, ) return err } // Set up permission fields updatePermissionFields(chat1.ChatID, "user1", "team1") updatePermissionFields(chat2.ChatID, "user1", "team2") updatePermissionFields(chat3.ChatID, "user2", "team1") updatePermissionFields(chat4.ChatID, "user2", "team2") t.Run("FilterByUserID", func(t *testing.T) { result, err := store.ListChats(types.ChatFilter{ UserID: "user1", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to list chats by user: %v", err) } if len(result.Data) != 2 { t.Errorf("Expected 2 chats for user1, got %d", len(result.Data)) } // Verify all returned chats belong to user1 for _, chat := range result.Data { if chat.Title != "User1 Team1 Chat" && chat.Title != "User1 Team2 Chat" { t.Errorf("Unexpected chat title: %s", chat.Title) } } }) t.Run("FilterByTeamID", func(t *testing.T) { result, err := store.ListChats(types.ChatFilter{ TeamID: "team1", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to list chats by team: %v", err) } if len(result.Data) != 2 { t.Errorf("Expected 2 chats for team1, got %d", len(result.Data)) } // Verify all returned chats belong to team1 for _, chat := range result.Data { if chat.Title != "User1 Team1 Chat" && chat.Title != "User2 Team1 Chat" { t.Errorf("Unexpected chat title: %s", chat.Title) } } }) t.Run("FilterByUserIDAndTeamID", func(t *testing.T) { result, err := store.ListChats(types.ChatFilter{ UserID: "user1", TeamID: "team1", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to list chats by user and team: %v", err) } if len(result.Data) != 1 { t.Errorf("Expected 1 chat for user1+team1, got %d", len(result.Data)) } if len(result.Data) > 0 && result.Data[0].Title != "User1 Team1 Chat" { t.Errorf("Expected 'User1 Team1 Chat', got '%s'", result.Data[0].Title) } }) t.Run("FilterByUserIDWithOtherFilters", func(t *testing.T) { // Combine UserID with Status filter result, err := store.ListChats(types.ChatFilter{ UserID: "user1", Status: "active", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to list chats: %v", err) } // All user1's chats should be active (default status) if len(result.Data) != 2 { t.Errorf("Expected 2 active chats for user1, got %d", len(result.Data)) } }) t.Run("FilterByTeamIDWithQueryFilter", func(t *testing.T) { // Combine TeamID with custom QueryFilter result, err := store.ListChats(types.ChatFilter{ TeamID: "team2", Page: 1, PageSize: 20, QueryFilter: func(qb query.Query) { // Additional filter: only chats with "User1" in title qb.Where("title", "like", "%User1%") }, }) if err != nil { t.Fatalf("Failed to list chats: %v", err) } if len(result.Data) != 1 { t.Errorf("Expected 1 chat (User1 in team2), got %d", len(result.Data)) } if len(result.Data) > 0 && result.Data[0].Title != "User1 Team2 Chat" { t.Errorf("Expected 'User1 Team2 Chat', got '%s'", result.Data[0].Title) } }) t.Run("FilterByNonExistentUser", func(t *testing.T) { result, err := store.ListChats(types.ChatFilter{ UserID: "nonexistent_user", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to list chats: %v", err) } if len(result.Data) != 0 { t.Errorf("Expected 0 chats for nonexistent user, got %d", len(result.Data)) } }) t.Run("FilterByNonExistentTeam", func(t *testing.T) { result, err := store.ListChats(types.ChatFilter{ TeamID: "nonexistent_team", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to list chats: %v", err) } if len(result.Data) != 0 { t.Errorf("Expected 0 chats for nonexistent team, got %d", len(result.Data)) } }) t.Run("QueryFilterForOrCondition", func(t *testing.T) { // Use QueryFilter for complex OR condition: // Get chats where user is user1 OR team is team2 result, err := store.ListChats(types.ChatFilter{ Page: 1, PageSize: 20, QueryFilter: func(qb query.Query) { qb.Where(func(sub query.Query) { sub.Where("__yao_created_by", "user1"). OrWhere("__yao_team_id", "team2") }) }, }) if err != nil { t.Fatalf("Failed to list chats with OR condition: %v", err) } // Should return: user1+team1, user1+team2, user2+team2 = 3 chats if len(result.Data) != 3 { t.Errorf("Expected 3 chats (user1 OR team2), got %d", len(result.Data)) } }) } // TestChatCompleteWorkflow tests a complete chat workflow func TestChatCompleteWorkflow(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("CompleteWorkflow", func(t *testing.T) { // 1. Create chat chat := &types.Chat{ AssistantID: "workflow_assistant", Title: "Workflow Test Chat", Status: "active", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } t.Logf("Created chat: %s", chat.ChatID) // 2. Get chat retrieved, err := store.GetChat(chat.ChatID) if err != nil { t.Fatalf("Failed to get chat: %v", err) } if retrieved.Title != "Workflow Test Chat" { t.Errorf("Expected title 'Workflow Test Chat', got '%s'", retrieved.Title) } // 3. Update chat err = store.UpdateChat(chat.ChatID, map[string]interface{}{ "title": "Updated Workflow Chat", "status": "archived", }) if err != nil { t.Fatalf("Failed to update chat: %v", err) } // 4. Verify update updated, err := store.GetChat(chat.ChatID) if err != nil { t.Fatalf("Failed to get updated chat: %v", err) } if updated.Title != "Updated Workflow Chat" { t.Errorf("Expected title 'Updated Workflow Chat', got '%s'", updated.Title) } if updated.Status != "archived" { t.Errorf("Expected status 'archived', got '%s'", updated.Status) } // 5. List chats result, err := store.ListChats(types.ChatFilter{ AssistantID: "workflow_assistant", Page: 1, PageSize: 20, }) if err != nil { t.Fatalf("Failed to list chats: %v", err) } found := false for _, c := range result.Data { if c.ChatID == chat.ChatID { found = true break } } if !found { t.Error("Expected to find chat in list") } // 6. Delete chat err = store.DeleteChat(chat.ChatID) if err != nil { t.Fatalf("Failed to delete chat: %v", err) } // 7. Verify deletion _, err = store.GetChat(chat.ChatID) if err == nil { t.Error("Expected error when getting deleted chat") } t.Log("Complete workflow passed!") }) } ================================================ FILE: agent/store/xun/message.go ================================================ package xun import ( "fmt" "time" "github.com/google/uuid" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/yao/agent/store/types" ) // ============================================================================= // Message Management // ============================================================================= // SaveMessages batch saves messages for a chat using a single database call // This is the primary write method - messages are buffered during execution // and batch-written at the end of a request func (store *Xun) SaveMessages(chatID string, messages []*types.Message) error { if chatID == "" { return fmt.Errorf("chat_id is required") } if len(messages) == 0 { return nil // Nothing to save } // Prepare batch insert data now := time.Now() rows := make([]map[string]interface{}, 0, len(messages)) for _, msg := range messages { if msg == nil { continue } // Generate message_id if not provided messageID := msg.MessageID if messageID == "" { messageID = uuid.New().String() } // Validate required fields if msg.Role == "" { return fmt.Errorf("message role is required") } if msg.Type == "" { return fmt.Errorf("message type is required") } if msg.Props == nil { return fmt.Errorf("message props is required") } // Serialize JSON fields propsJSON, err := jsoniter.MarshalToString(msg.Props) if err != nil { return fmt.Errorf("failed to marshal props: %w", err) } // Build row with all fields (including nullable ones for consistent batch insert) row := map[string]interface{}{ "message_id": messageID, "chat_id": chatID, "role": msg.Role, "type": msg.Type, "props": propsJSON, "sequence": msg.Sequence, "request_id": nil, "block_id": nil, "thread_id": nil, "assistant_id": nil, "connector": nil, "mode": nil, "metadata": nil, "created_at": now, "updated_at": now, } // Set nullable fields if they have values if msg.RequestID != "" { row["request_id"] = msg.RequestID } if msg.BlockID != "" { row["block_id"] = msg.BlockID } if msg.ThreadID != "" { row["thread_id"] = msg.ThreadID } if msg.AssistantID != "" { row["assistant_id"] = msg.AssistantID } if msg.Connector != "" { row["connector"] = msg.Connector } if msg.Mode != "" { row["mode"] = msg.Mode } if msg.Metadata != nil { metadataJSON, err := jsoniter.MarshalToString(msg.Metadata) if err != nil { return fmt.Errorf("failed to marshal metadata: %w", err) } row["metadata"] = metadataJSON } rows = append(rows, row) } if len(rows) == 0 { return nil } // Single batch insert - one database call for all messages return store.newQueryMessage().Insert(rows) } // GetMessages retrieves messages for a chat with filtering func (store *Xun) GetMessages(chatID string, filter types.MessageFilter) ([]*types.Message, error) { if chatID == "" { return nil, fmt.Errorf("chat_id is required") } qb := store.newQueryMessage(). Where("chat_id", chatID). WhereNull("deleted_at") // Apply filters if filter.RequestID != "" { qb.Where("request_id", filter.RequestID) } if filter.Role != "" { qb.Where("role", filter.Role) } if filter.BlockID != "" { qb.Where("block_id", filter.BlockID) } if filter.ThreadID != "" { qb.Where("thread_id", filter.ThreadID) } if filter.Type != "" { qb.Where("type", filter.Type) } // When Limit is specified WITHOUT Offset, we want the N most-recent // messages. Strategy: query DESC to get the latest rows, then reverse // the slice so the caller receives them in chronological (ASC) order. // When Offset is also present, the caller is doing forward pagination, // so we keep ASC order and apply Limit+Offset normally. needReverse := false if filter.Limit > 0 && filter.Offset <= 0 { qb.Limit(filter.Limit) qb.OrderBy("id", "desc") needReverse = true } else if filter.Limit > 0 && filter.Offset > 0 { qb.Limit(filter.Limit).Offset(filter.Offset) qb.OrderBy("id", "asc") } else { if filter.Offset > 0 { qb.Limit(1000000).Offset(filter.Offset) } qb.OrderBy("id", "asc") } rows, err := qb.Get() if err != nil { return nil, err } messages := make([]*types.Message, 0, len(rows)) for _, row := range rows { data := row.ToMap() if data == nil || data["message_id"] == nil { continue } msg, err := store.rowToMessage(data) if err != nil { continue } messages = append(messages, msg) } if needReverse { for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 { messages[i], messages[j] = messages[j], messages[i] } } return messages, nil } // UpdateMessage updates a single message func (store *Xun) UpdateMessage(messageID string, updates map[string]interface{}) error { if messageID == "" { return fmt.Errorf("message_id is required") } if len(updates) == 0 { return fmt.Errorf("no fields to update") } // Check if message exists exists, err := store.newQueryMessage(). Where("message_id", messageID). WhereNull("deleted_at"). Exists() if err != nil { return err } if !exists { return fmt.Errorf("message %s not found", messageID) } // Prepare update data data := make(map[string]interface{}) for key, value := range updates { // Skip system fields if key == "message_id" || key == "chat_id" || key == "created_at" { continue } // Handle JSON fields if key == "props" || key == "metadata" { if value != nil { jsonStr, err := jsoniter.MarshalToString(value) if err != nil { return fmt.Errorf("failed to marshal %s: %w", key, err) } data[key] = jsonStr } else { data[key] = nil } continue } data[key] = value } // Always update updated_at data["updated_at"] = time.Now() if len(data) == 0 { return fmt.Errorf("no valid fields to update") } _, err = store.newQueryMessage(). Where("message_id", messageID). Update(data) return err } // DeleteMessages soft deletes specific messages from a chat func (store *Xun) DeleteMessages(chatID string, messageIDs []string) error { if chatID == "" { return fmt.Errorf("chat_id is required") } if len(messageIDs) == 0 { return nil // Nothing to delete } // Soft delete all specified messages in one query _, err := store.newQueryMessage(). Where("chat_id", chatID). WhereIn("message_id", messageIDs). WhereNull("deleted_at"). Update(map[string]interface{}{ "deleted_at": time.Now(), "updated_at": time.Now(), }) return err } // GetMessageByID retrieves a single message by ID func (store *Xun) GetMessageByID(messageID string) (*types.Message, error) { if messageID == "" { return nil, fmt.Errorf("message_id is required") } row, err := store.newQueryMessage(). Where("message_id", messageID). WhereNull("deleted_at"). First() if err != nil { return nil, err } if row == nil { return nil, fmt.Errorf("message %s not found", messageID) } data := row.ToMap() if len(data) == 0 || data["message_id"] == nil { return nil, fmt.Errorf("message %s not found", messageID) } return store.rowToMessage(data) } // GetMessageCount returns the count of messages for a chat func (store *Xun) GetMessageCount(chatID string) (int64, error) { if chatID == "" { return 0, fmt.Errorf("chat_id is required") } return store.newQueryMessage(). Where("chat_id", chatID). WhereNull("deleted_at"). Count() } // GetLastSequence returns the last sequence number for a chat func (store *Xun) GetLastSequence(chatID string) (int, error) { if chatID == "" { return 0, fmt.Errorf("chat_id is required") } row, err := store.newQueryMessage(). Where("chat_id", chatID). WhereNull("deleted_at"). OrderBy("sequence", "desc"). First() if err != nil { return 0, err } if row == nil { return 0, nil } data := row.ToMap() return getInt(data, "sequence"), nil } // ============================================================================= // Helper Functions // ============================================================================= // rowToMessage converts a database row to a Message struct func (store *Xun) rowToMessage(data map[string]interface{}) (*types.Message, error) { msg := &types.Message{ MessageID: getString(data, "message_id"), ChatID: getString(data, "chat_id"), RequestID: getString(data, "request_id"), Role: getString(data, "role"), Type: getString(data, "type"), BlockID: getString(data, "block_id"), ThreadID: getString(data, "thread_id"), AssistantID: getString(data, "assistant_id"), Connector: getString(data, "connector"), Mode: getString(data, "mode"), Sequence: getInt(data, "sequence"), } // Handle timestamps if createdAt := getTime(data, "created_at"); createdAt != nil { msg.CreatedAt = *createdAt } if updatedAt := getTime(data, "updated_at"); updatedAt != nil { msg.UpdatedAt = *updatedAt } // Handle props (required) if props := data["props"]; props != nil { if propsStr, ok := props.(string); ok && propsStr != "" { var propsMap map[string]interface{} if err := jsoniter.UnmarshalFromString(propsStr, &propsMap); err == nil { msg.Props = propsMap } } else if propsMap, ok := props.(map[string]interface{}); ok { msg.Props = propsMap } } // Handle metadata (optional) if metadata := data["metadata"]; metadata != nil { if metaStr, ok := metadata.(string); ok && metaStr != "" { var metaMap map[string]interface{} if err := jsoniter.UnmarshalFromString(metaStr, &metaMap); err == nil { msg.Metadata = metaMap } } else if metaMap, ok := metadata.(map[string]interface{}); ok { msg.Metadata = metaMap } } return msg, nil } ================================================ FILE: agent/store/xun/message_test.go ================================================ package xun_test import ( "fmt" "testing" "time" "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/agent/store/xun" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // TestSaveMessages tests batch saving messages func TestSaveMessages(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } // Create a chat first chat := &types.Chat{ AssistantID: "test_assistant", Title: "Message Test Chat", } err = store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) t.Run("SaveSingleMessage", func(t *testing.T) { messages := []*types.Message{ { Role: "user", Type: "text", Props: map[string]interface{}{"content": "Hello, world!"}, Sequence: 1, }, } err := store.SaveMessages(chat.ChatID, messages) if err != nil { t.Fatalf("Failed to save message: %v", err) } // Verify retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) < 1 { t.Fatal("Expected at least 1 message") } // Find the message we just saved var found *types.Message for _, msg := range retrieved { if msg.Sequence == 1 && msg.Type == "text" { found = msg break } } if found == nil { t.Fatal("Could not find saved message") } if found.Role != "user" { t.Errorf("Expected role 'user', got '%s'", found.Role) } if found.Props["content"] != "Hello, world!" { t.Errorf("Expected content 'Hello, world!', got '%v'", found.Props["content"]) } }) t.Run("SaveBatchMessages", func(t *testing.T) { // Create a new chat for this test batchChat := &types.Chat{ AssistantID: "test_assistant", Title: "Batch Message Test", } err := store.CreateChat(batchChat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(batchChat.ChatID) // Save multiple messages in one batch messages := []*types.Message{ { Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "What's the weather?"}, Sequence: 1, RequestID: "req_001", AssistantID: "weather_assistant", }, { Role: "assistant", Type: "loading", Props: map[string]interface{}{"message": "Checking weather..."}, Sequence: 2, RequestID: "req_001", BlockID: "B1", AssistantID: "weather_assistant", }, { Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "The weather is sunny, 25°C."}, Sequence: 3, RequestID: "req_001", BlockID: "B1", AssistantID: "weather_assistant", }, } err = store.SaveMessages(batchChat.ChatID, messages) if err != nil { t.Fatalf("Failed to save batch messages: %v", err) } // Verify all messages saved retrieved, err := store.GetMessages(batchChat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 3 { t.Errorf("Expected 3 messages, got %d", len(retrieved)) } // Verify order (should be by sequence) if len(retrieved) >= 3 { if retrieved[0].Sequence != 1 { t.Errorf("Expected first message sequence 1, got %d", retrieved[0].Sequence) } if retrieved[2].Sequence != 3 { t.Errorf("Expected last message sequence 3, got %d", retrieved[2].Sequence) } } t.Logf("Saved %d messages in single batch call", len(messages)) }) t.Run("SaveMessageWithAllFields", func(t *testing.T) { fullChat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(fullChat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(fullChat.ChatID) messages := []*types.Message{ { Role: "assistant", Type: "tool_call", Props: map[string]interface{}{"id": "call_123", "name": "get_weather", "arguments": `{"location":"SF"}`}, Sequence: 1, RequestID: "req_full", BlockID: "B1", ThreadID: "T1", AssistantID: "weather_assistant", Metadata: map[string]interface{}{"tool_call_id": "call_123", "is_tool_result": false}, }, } err = store.SaveMessages(fullChat.ChatID, messages) if err != nil { t.Fatalf("Failed to save message: %v", err) } retrieved, err := store.GetMessages(fullChat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 1 { t.Fatalf("Expected 1 message, got %d", len(retrieved)) } msg := retrieved[0] if msg.RequestID != "req_full" { t.Errorf("Expected request_id 'req_full', got '%s'", msg.RequestID) } if msg.BlockID != "B1" { t.Errorf("Expected block_id 'B1', got '%s'", msg.BlockID) } if msg.ThreadID != "T1" { t.Errorf("Expected thread_id 'T1', got '%s'", msg.ThreadID) } if msg.AssistantID != "weather_assistant" { t.Errorf("Expected assistant_id 'weather_assistant', got '%s'", msg.AssistantID) } if msg.Metadata == nil { t.Error("Expected metadata to be set") } else if msg.Metadata["tool_call_id"] != "call_123" { t.Errorf("Expected metadata tool_call_id 'call_123', got '%v'", msg.Metadata["tool_call_id"]) } }) t.Run("SaveMessageWithConnector", func(t *testing.T) { connChat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(connChat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(connChat.ChatID) // Save messages with different connectors messages := []*types.Message{ { Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Hello"}, Sequence: 1, Connector: "openai", AssistantID: "test_assistant", }, { Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "Hi there!"}, Sequence: 2, Connector: "openai", AssistantID: "test_assistant", }, { Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Switch to Claude"}, Sequence: 3, Connector: "anthropic", AssistantID: "test_assistant", }, { Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "Now using Claude!"}, Sequence: 4, Connector: "anthropic", AssistantID: "test_assistant", }, } err = store.SaveMessages(connChat.ChatID, messages) if err != nil { t.Fatalf("Failed to save messages: %v", err) } // Retrieve and verify connectors retrieved, err := store.GetMessages(connChat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 4 { t.Fatalf("Expected 4 messages, got %d", len(retrieved)) } // Verify each message has correct connector for _, msg := range retrieved { if msg.Sequence <= 2 && msg.Connector != "openai" { t.Errorf("Expected connector 'openai' for sequence %d, got '%s'", msg.Sequence, msg.Connector) } if msg.Sequence > 2 && msg.Connector != "anthropic" { t.Errorf("Expected connector 'anthropic' for sequence %d, got '%s'", msg.Sequence, msg.Connector) } } t.Logf("Successfully saved and retrieved messages with different connectors") }) t.Run("SaveMessageWithMode", func(t *testing.T) { modeChat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(modeChat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(modeChat.ChatID) // Save messages with different modes messages := []*types.Message{ { Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Hello in chat mode"}, Sequence: 1, Mode: "chat", Connector: "deepseek.v3", AssistantID: "test_assistant", }, { Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "Hi there in chat mode!"}, Sequence: 2, Mode: "chat", Connector: "deepseek.v3", AssistantID: "test_assistant", }, { Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Now run a task"}, Sequence: 3, Mode: "task", Connector: "deepseek.v3", AssistantID: "test_assistant", }, { Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "Running task!"}, Sequence: 4, Mode: "task", Connector: "deepseek.v3", AssistantID: "test_assistant", }, } err = store.SaveMessages(modeChat.ChatID, messages) if err != nil { t.Fatalf("Failed to save messages: %v", err) } // Retrieve and verify modes retrieved, err := store.GetMessages(modeChat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 4 { t.Fatalf("Expected 4 messages, got %d", len(retrieved)) } // Verify each message has correct mode for _, msg := range retrieved { if msg.Sequence <= 2 && msg.Mode != "chat" { t.Errorf("Expected mode 'chat' for sequence %d, got '%s'", msg.Sequence, msg.Mode) } if msg.Sequence > 2 && msg.Mode != "task" { t.Errorf("Expected mode 'task' for sequence %d, got '%s'", msg.Sequence, msg.Mode) } } t.Logf("Successfully saved and retrieved messages with different modes") }) t.Run("SaveMessageWithEmptyConnector", func(t *testing.T) { emptyConnChat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(emptyConnChat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(emptyConnChat.ChatID) // Save message without connector messages := []*types.Message{ { Role: "user", Type: "text", Props: map[string]interface{}{"content": "No connector"}, Sequence: 1, // Connector is empty }, } err = store.SaveMessages(emptyConnChat.ChatID, messages) if err != nil { t.Fatalf("Failed to save message: %v", err) } retrieved, err := store.GetMessages(emptyConnChat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 1 { t.Fatalf("Expected 1 message, got %d", len(retrieved)) } // Empty connector should be stored as empty string if retrieved[0].Connector != "" { t.Errorf("Expected empty connector, got '%s'", retrieved[0].Connector) } }) t.Run("SaveEmptyMessages", func(t *testing.T) { err := store.SaveMessages(chat.ChatID, []*types.Message{}) if err != nil { t.Errorf("Expected no error for empty messages, got: %v", err) } }) t.Run("SaveMessagesWithoutChatID", func(t *testing.T) { messages := []*types.Message{{Role: "user", Type: "text", Props: map[string]interface{}{"content": "test"}}} err := store.SaveMessages("", messages) if err == nil { t.Error("Expected error when saving without chat_id") } }) t.Run("SaveMessageWithoutRole", func(t *testing.T) { messages := []*types.Message{{Type: "text", Props: map[string]interface{}{"content": "test"}, Sequence: 1}} err := store.SaveMessages(chat.ChatID, messages) if err == nil { t.Error("Expected error when saving message without role") } }) t.Run("SaveMessageWithoutType", func(t *testing.T) { messages := []*types.Message{{Role: "user", Props: map[string]interface{}{"content": "test"}, Sequence: 1}} err := store.SaveMessages(chat.ChatID, messages) if err == nil { t.Error("Expected error when saving message without type") } }) t.Run("SaveMessageWithoutProps", func(t *testing.T) { messages := []*types.Message{{Role: "user", Type: "text", Sequence: 1}} err := store.SaveMessages(chat.ChatID, messages) if err == nil { t.Error("Expected error when saving message without props") } }) } // TestGetMessages tests retrieving messages with filters func TestGetMessages(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } // Create chat and messages chat := &types.Chat{ AssistantID: "test_assistant", } err = store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) // Save test messages messages := []*types.Message{ {Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Hello"}, Sequence: 1, RequestID: "req_001"}, {Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "Hi there!"}, Sequence: 2, RequestID: "req_001", BlockID: "B1"}, {Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Weather?"}, Sequence: 3, RequestID: "req_002"}, {Role: "assistant", Type: "loading", Props: map[string]interface{}{"message": "Checking..."}, Sequence: 4, RequestID: "req_002", BlockID: "B2"}, {Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "Sunny!"}, Sequence: 5, RequestID: "req_002", BlockID: "B2", ThreadID: "T1"}, } err = store.SaveMessages(chat.ChatID, messages) if err != nil { t.Fatalf("Failed to save messages: %v", err) } t.Run("GetAllMessages", func(t *testing.T) { retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 5 { t.Errorf("Expected 5 messages, got %d", len(retrieved)) } // Verify order by sequence for i := 1; i < len(retrieved); i++ { if retrieved[i].Sequence < retrieved[i-1].Sequence { t.Error("Messages not ordered by sequence") } } }) t.Run("FilterByRole", func(t *testing.T) { retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{Role: "user"}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 2 { t.Errorf("Expected 2 user messages, got %d", len(retrieved)) } for _, msg := range retrieved { if msg.Role != "user" { t.Errorf("Expected role 'user', got '%s'", msg.Role) } } }) t.Run("FilterByRequestID", func(t *testing.T) { retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{RequestID: "req_002"}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 3 { t.Errorf("Expected 3 messages for req_002, got %d", len(retrieved)) } }) t.Run("FilterByBlockID", func(t *testing.T) { retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{BlockID: "B2"}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 2 { t.Errorf("Expected 2 messages in block B2, got %d", len(retrieved)) } }) t.Run("FilterByThreadID", func(t *testing.T) { retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{ThreadID: "T1"}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 1 { t.Errorf("Expected 1 message in thread T1, got %d", len(retrieved)) } }) t.Run("FilterByType", func(t *testing.T) { retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{Type: "loading"}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 1 { t.Errorf("Expected 1 loading message, got %d", len(retrieved)) } }) t.Run("FilterWithLimit", func(t *testing.T) { retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{Limit: 2}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 2 { t.Errorf("Expected 2 messages with limit, got %d", len(retrieved)) } }) t.Run("FilterWithOffset", func(t *testing.T) { retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{Offset: 3}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 2 { t.Errorf("Expected 2 messages with offset 3, got %d", len(retrieved)) } }) t.Run("FilterWithLimitAndOffset", func(t *testing.T) { retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{Limit: 2, Offset: 1}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 2 { t.Errorf("Expected 2 messages, got %d", len(retrieved)) } // Should be sequence 2 and 3 if len(retrieved) >= 2 { if retrieved[0].Sequence != 2 { t.Errorf("Expected first message sequence 2, got %d", retrieved[0].Sequence) } } }) t.Run("GetMessagesWithEmptyChatID", func(t *testing.T) { _, err := store.GetMessages("", types.MessageFilter{}) if err == nil { t.Error("Expected error when getting messages without chat_id") } }) t.Run("GetMessagesFromNonExistentChat", func(t *testing.T) { retrieved, err := store.GetMessages("nonexistent_chat", types.MessageFilter{}) if err != nil { t.Fatalf("Unexpected error: %v", err) } if len(retrieved) != 0 { t.Errorf("Expected 0 messages from non-existent chat, got %d", len(retrieved)) } }) t.Run("OrderByCreatedAtThenSequence", func(t *testing.T) { // This test verifies that messages are ordered by created_at first, then by sequence // This is important when there are multiple request_ids with overlapping sequence numbers orderChat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(orderChat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(orderChat.ChatID) // Simulate two separate requests with overlapping sequence numbers // SaveMessages uses time.Now() for created_at, so we need to call it twice // with a small delay to ensure different timestamps // Request 1: sequences 1, 2 req1Messages := []*types.Message{ { Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Request 1 - Message 1"}, Sequence: 1, RequestID: "order_req_001", }, { Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "Request 1 - Response 1"}, Sequence: 2, RequestID: "order_req_001", }, } err = store.SaveMessages(orderChat.ChatID, req1Messages) if err != nil { t.Fatalf("Failed to save request 1 messages: %v", err) } // Delay to ensure different created_at timestamps // Database timestamp precision may only be to second level time.Sleep(1100 * time.Millisecond) // Request 2: sequences 1, 2 (same as request 1, but later created_at) req2Messages := []*types.Message{ { Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "Request 2 - Message 1"}, Sequence: 1, // Same sequence as req1, but later created_at RequestID: "order_req_002", }, { Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "Request 2 - Response 1"}, Sequence: 2, // Same sequence as req1, but later created_at RequestID: "order_req_002", }, } err = store.SaveMessages(orderChat.ChatID, req2Messages) if err != nil { t.Fatalf("Failed to save request 2 messages: %v", err) } // Retrieve messages retrieved, err := store.GetMessages(orderChat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 4 { t.Fatalf("Expected 4 messages, got %d", len(retrieved)) } // Verify order: should be chronological by created_at, then by sequence // Messages from req_001 should come before req_002 expectedOrder := []struct { requestID string sequence int content string }{ {"order_req_001", 1, "Request 1 - Message 1"}, {"order_req_001", 2, "Request 1 - Response 1"}, {"order_req_002", 1, "Request 2 - Message 1"}, {"order_req_002", 2, "Request 2 - Response 1"}, } for i, expected := range expectedOrder { msg := retrieved[i] if msg.RequestID != expected.requestID { t.Errorf("Message %d: expected RequestID '%s', got '%s'", i, expected.requestID, msg.RequestID) } if msg.Sequence != expected.sequence { t.Errorf("Message %d: expected Sequence %d, got %d", i, expected.sequence, msg.Sequence) } content, _ := msg.Props["content"].(string) if content != expected.content { t.Errorf("Message %d: expected content '%s', got '%s'", i, expected.content, content) } } // Additional verification: ensure created_at is non-decreasing for i := 1; i < len(retrieved); i++ { if retrieved[i].CreatedAt.Before(retrieved[i-1].CreatedAt) { t.Errorf("Message %d created_at (%v) is before message %d created_at (%v)", i, retrieved[i].CreatedAt, i-1, retrieved[i-1].CreatedAt) } // If same created_at, sequence should be increasing if retrieved[i].CreatedAt.Equal(retrieved[i-1].CreatedAt) { if retrieved[i].Sequence < retrieved[i-1].Sequence { t.Errorf("Messages with same created_at: message %d sequence (%d) < message %d sequence (%d)", i, retrieved[i].Sequence, i-1, retrieved[i-1].Sequence) } } } t.Logf("Successfully verified message ordering: created_at first, then sequence") }) } // TestUpdateMessage tests updating messages func TestUpdateMessage(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } // Create chat and message chat := &types.Chat{ AssistantID: "test_assistant", } err = store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) messages := []*types.Message{ { MessageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()), Role: "assistant", Type: "loading", Props: map[string]interface{}{"message": "Loading..."}, Sequence: 1, }, } err = store.SaveMessages(chat.ChatID, messages) if err != nil { t.Fatalf("Failed to save message: %v", err) } messageID := messages[0].MessageID t.Run("UpdateProps", func(t *testing.T) { err := store.UpdateMessage(messageID, map[string]interface{}{ "props": map[string]interface{}{"content": "Updated content"}, }) if err != nil { t.Fatalf("Failed to update message: %v", err) } retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } var found *types.Message for _, msg := range retrieved { if msg.MessageID == messageID { found = msg break } } if found == nil { t.Fatal("Could not find updated message") } if found.Props["content"] != "Updated content" { t.Errorf("Expected props content 'Updated content', got '%v'", found.Props["content"]) } }) t.Run("UpdateType", func(t *testing.T) { err := store.UpdateMessage(messageID, map[string]interface{}{ "type": "text", }) if err != nil { t.Fatalf("Failed to update message: %v", err) } retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } var found *types.Message for _, msg := range retrieved { if msg.MessageID == messageID { found = msg break } } if found == nil { t.Fatal("Could not find updated message") } if found.Type != "text" { t.Errorf("Expected type 'text', got '%s'", found.Type) } }) t.Run("UpdateMetadata", func(t *testing.T) { err := store.UpdateMessage(messageID, map[string]interface{}{ "metadata": map[string]interface{}{"updated": true}, }) if err != nil { t.Fatalf("Failed to update metadata: %v", err) } retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } var found *types.Message for _, msg := range retrieved { if msg.MessageID == messageID { found = msg break } } if found == nil { t.Fatal("Could not find updated message") } if found.Metadata == nil || found.Metadata["updated"] != true { t.Errorf("Expected metadata updated=true, got %v", found.Metadata) } }) t.Run("UpdateNonExistentMessage", func(t *testing.T) { err := store.UpdateMessage("nonexistent_msg", map[string]interface{}{ "type": "text", }) if err == nil { t.Error("Expected error when updating non-existent message") } }) t.Run("UpdateWithEmptyID", func(t *testing.T) { err := store.UpdateMessage("", map[string]interface{}{ "type": "text", }) if err == nil { t.Error("Expected error when updating with empty ID") } }) t.Run("UpdateWithEmptyFields", func(t *testing.T) { err := store.UpdateMessage(messageID, map[string]interface{}{}) if err == nil { t.Error("Expected error when updating with empty fields") } }) } // TestDeleteMessages tests deleting messages func TestDeleteMessages(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("DeleteSingleMessage", func(t *testing.T) { chat := &types.Chat{AssistantID: "test_assistant"} err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) msgID := fmt.Sprintf("msg_del_%d", time.Now().UnixNano()) messages := []*types.Message{ {MessageID: msgID, Role: "user", Type: "text", Props: map[string]interface{}{"content": "test"}, Sequence: 1}, } err = store.SaveMessages(chat.ChatID, messages) if err != nil { t.Fatalf("Failed to save message: %v", err) } err = store.DeleteMessages(chat.ChatID, []string{msgID}) if err != nil { t.Fatalf("Failed to delete message: %v", err) } // Verify deleted retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } for _, msg := range retrieved { if msg.MessageID == msgID { t.Error("Message should have been deleted") } } }) t.Run("DeleteMultipleMessages", func(t *testing.T) { chat := &types.Chat{AssistantID: "test_assistant"} err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) msgID1 := fmt.Sprintf("msg_del1_%d", time.Now().UnixNano()) msgID2 := fmt.Sprintf("msg_del2_%d", time.Now().UnixNano()) msgID3 := fmt.Sprintf("msg_del3_%d", time.Now().UnixNano()) messages := []*types.Message{ {MessageID: msgID1, Role: "user", Type: "text", Props: map[string]interface{}{"content": "1"}, Sequence: 1}, {MessageID: msgID2, Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "2"}, Sequence: 2}, {MessageID: msgID3, Role: "user", Type: "text", Props: map[string]interface{}{"content": "3"}, Sequence: 3}, } err = store.SaveMessages(chat.ChatID, messages) if err != nil { t.Fatalf("Failed to save messages: %v", err) } // Delete first two err = store.DeleteMessages(chat.ChatID, []string{msgID1, msgID2}) if err != nil { t.Fatalf("Failed to delete messages: %v", err) } // Verify retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 1 { t.Errorf("Expected 1 remaining message, got %d", len(retrieved)) } if len(retrieved) > 0 && retrieved[0].MessageID != msgID3 { t.Errorf("Expected remaining message to be %s, got %s", msgID3, retrieved[0].MessageID) } }) t.Run("DeleteEmptyList", func(t *testing.T) { chat := &types.Chat{AssistantID: "test_assistant"} err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) err = store.DeleteMessages(chat.ChatID, []string{}) if err != nil { t.Errorf("Expected no error for empty delete list, got: %v", err) } }) t.Run("DeleteWithEmptyChatID", func(t *testing.T) { err := store.DeleteMessages("", []string{"msg_123"}) if err == nil { t.Error("Expected error when deleting with empty chat_id") } }) } // TestMessageCompleteWorkflow tests a complete message workflow func TestMessageCompleteWorkflow(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("CompleteWorkflow", func(t *testing.T) { // 1. Create chat chat := &types.Chat{ AssistantID: "workflow_assistant", Title: "Message Workflow Test", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) // 2. Save batch messages (simulating a request) requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) messages := []*types.Message{ { Role: "user", Type: "user_input", Props: map[string]interface{}{"content": "What's the weather in SF?"}, Sequence: 1, RequestID: requestID, AssistantID: "workflow_assistant", }, { Role: "assistant", Type: "loading", Props: map[string]interface{}{"message": "Checking weather..."}, Sequence: 2, RequestID: requestID, BlockID: "B1", AssistantID: "workflow_assistant", }, { Role: "assistant", Type: "tool_call", Props: map[string]interface{}{"id": "call_weather", "name": "get_weather", "arguments": `{"location":"SF"}`}, Sequence: 3, RequestID: requestID, BlockID: "B1", AssistantID: "workflow_assistant", }, { Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "The weather in San Francisco is 18°C and sunny."}, Sequence: 4, RequestID: requestID, BlockID: "B1", AssistantID: "workflow_assistant", Metadata: map[string]interface{}{"tool_call_id": "call_weather"}, }, } err = store.SaveMessages(chat.ChatID, messages) if err != nil { t.Fatalf("Failed to save messages: %v", err) } t.Logf("Saved %d messages in single batch", len(messages)) // 3. Get all messages retrieved, err := store.GetMessages(chat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(retrieved) != 4 { t.Errorf("Expected 4 messages, got %d", len(retrieved)) } // 4. Filter by request byRequest, err := store.GetMessages(chat.ChatID, types.MessageFilter{RequestID: requestID}) if err != nil { t.Fatalf("Failed to filter by request: %v", err) } if len(byRequest) != 4 { t.Errorf("Expected 4 messages for request, got %d", len(byRequest)) } // 5. Filter by block byBlock, err := store.GetMessages(chat.ChatID, types.MessageFilter{BlockID: "B1"}) if err != nil { t.Fatalf("Failed to filter by block: %v", err) } if len(byBlock) != 3 { t.Errorf("Expected 3 messages in block B1, got %d", len(byBlock)) } // 6. Update loading message to text (simulating stream completion) var loadingMsgID string for _, msg := range retrieved { if msg.Type == "loading" { loadingMsgID = msg.MessageID break } } if loadingMsgID != "" { err = store.UpdateMessage(loadingMsgID, map[string]interface{}{ "type": "text", "props": map[string]interface{}{"content": "Weather check complete."}, }) if err != nil { t.Fatalf("Failed to update message: %v", err) } } // 7. Delete a message if len(retrieved) > 0 { err = store.DeleteMessages(chat.ChatID, []string{retrieved[0].MessageID}) if err != nil { t.Fatalf("Failed to delete message: %v", err) } } // 8. Verify final state final, err := store.GetMessages(chat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get final messages: %v", err) } if len(final) != 3 { t.Errorf("Expected 3 messages after delete, got %d", len(final)) } t.Log("Complete message workflow passed!") }) } // TestConcurrentMessages tests concurrent message storage func TestConcurrentMessages(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("ConcurrentThreadMessages", func(t *testing.T) { chat := &types.Chat{AssistantID: "test_assistant"} err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) // Simulate concurrent operations with different threads messages := []*types.Message{ {Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "Weather result"}, Sequence: 1, BlockID: "B1", ThreadID: "T1"}, {Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "News result"}, Sequence: 2, BlockID: "B1", ThreadID: "T2"}, {Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "Stock result"}, Sequence: 3, BlockID: "B1", ThreadID: "T3"}, {Role: "assistant", Type: "text", Props: map[string]interface{}{"content": "Summary"}, Sequence: 4, BlockID: "B2"}, } err = store.SaveMessages(chat.ChatID, messages) if err != nil { t.Fatalf("Failed to save concurrent messages: %v", err) } // Verify all saved all, err := store.GetMessages(chat.ChatID, types.MessageFilter{}) if err != nil { t.Fatalf("Failed to get messages: %v", err) } if len(all) != 4 { t.Errorf("Expected 4 messages, got %d", len(all)) } // Filter by thread t1Messages, err := store.GetMessages(chat.ChatID, types.MessageFilter{ThreadID: "T1"}) if err != nil { t.Fatalf("Failed to filter by thread: %v", err) } if len(t1Messages) != 1 { t.Errorf("Expected 1 message in thread T1, got %d", len(t1Messages)) } // Filter by block b1Messages, err := store.GetMessages(chat.ChatID, types.MessageFilter{BlockID: "B1"}) if err != nil { t.Fatalf("Failed to filter by block: %v", err) } if len(b1Messages) != 3 { t.Errorf("Expected 3 messages in block B1, got %d", len(b1Messages)) } t.Log("Concurrent thread messages test passed!") }) } ================================================ FILE: agent/store/xun/resume.go ================================================ package xun import ( "fmt" "time" "github.com/google/uuid" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/yao/agent/store/types" ) // ============================================================================= // Resume Management (only called on failure/interrupt) // ============================================================================= // SaveResume batch saves resume records using a single database call // Only called when request is interrupted or failed func (store *Xun) SaveResume(records []*types.Resume) error { if len(records) == 0 { return nil // Nothing to save } // Prepare batch insert data now := time.Now() rows := make([]map[string]interface{}, 0, len(records)) for _, record := range records { if record == nil { continue } // Generate resume_id if not provided resumeID := record.ResumeID if resumeID == "" { resumeID = uuid.New().String() } // Validate required fields if record.ChatID == "" { return fmt.Errorf("chat_id is required") } if record.RequestID == "" { return fmt.Errorf("request_id is required") } if record.AssistantID == "" { return fmt.Errorf("assistant_id is required") } if record.StackID == "" { return fmt.Errorf("stack_id is required") } if record.Type == "" { return fmt.Errorf("type is required") } if record.Status == "" { return fmt.Errorf("status is required") } // Build row with all fields (including nullable ones for consistent batch insert) row := map[string]interface{}{ "resume_id": resumeID, "chat_id": record.ChatID, "request_id": record.RequestID, "assistant_id": record.AssistantID, "stack_id": record.StackID, "stack_parent_id": nil, "stack_depth": record.StackDepth, "type": record.Type, "status": record.Status, "input": nil, "output": nil, "space_snapshot": nil, "error": nil, "sequence": record.Sequence, "metadata": nil, "created_at": now, "updated_at": now, } // Set nullable fields if they have values if record.StackParentID != "" { row["stack_parent_id"] = record.StackParentID } if record.Input != nil { inputJSON, err := jsoniter.MarshalToString(record.Input) if err != nil { return fmt.Errorf("failed to marshal input: %w", err) } row["input"] = inputJSON } if record.Output != nil { outputJSON, err := jsoniter.MarshalToString(record.Output) if err != nil { return fmt.Errorf("failed to marshal output: %w", err) } row["output"] = outputJSON } if record.SpaceSnapshot != nil { snapshotJSON, err := jsoniter.MarshalToString(record.SpaceSnapshot) if err != nil { return fmt.Errorf("failed to marshal space_snapshot: %w", err) } row["space_snapshot"] = snapshotJSON } if record.Error != "" { row["error"] = record.Error } if record.Metadata != nil { metadataJSON, err := jsoniter.MarshalToString(record.Metadata) if err != nil { return fmt.Errorf("failed to marshal metadata: %w", err) } row["metadata"] = metadataJSON } rows = append(rows, row) } if len(rows) == 0 { return nil } // Single batch insert - one database call for all records return store.newQueryResume().Insert(rows) } // GetResume retrieves all resume records for a chat func (store *Xun) GetResume(chatID string) ([]*types.Resume, error) { if chatID == "" { return nil, fmt.Errorf("chat_id is required") } rows, err := store.newQueryResume(). Where("chat_id", chatID). WhereNull("deleted_at"). OrderBy("sequence", "asc"). Get() if err != nil { return nil, err } records := make([]*types.Resume, 0, len(rows)) for _, row := range rows { data := row.ToMap() if data == nil || data["resume_id"] == nil { continue } record, err := store.rowToResume(data) if err != nil { continue } records = append(records, record) } return records, nil } // GetLastResume retrieves the last (most recent) resume record for a chat func (store *Xun) GetLastResume(chatID string) (*types.Resume, error) { if chatID == "" { return nil, fmt.Errorf("chat_id is required") } row, err := store.newQueryResume(). Where("chat_id", chatID). WhereNull("deleted_at"). OrderBy("sequence", "desc"). First() if err != nil { return nil, err } if row == nil { return nil, nil // No resume records found } data := row.ToMap() if len(data) == 0 || data["resume_id"] == nil { return nil, nil } return store.rowToResume(data) } // GetResumeByStackID retrieves resume records for a specific stack func (store *Xun) GetResumeByStackID(stackID string) ([]*types.Resume, error) { if stackID == "" { return nil, fmt.Errorf("stack_id is required") } rows, err := store.newQueryResume(). Where("stack_id", stackID). WhereNull("deleted_at"). OrderBy("sequence", "asc"). Get() if err != nil { return nil, err } records := make([]*types.Resume, 0, len(rows)) for _, row := range rows { data := row.ToMap() if data == nil || data["resume_id"] == nil { continue } record, err := store.rowToResume(data) if err != nil { continue } records = append(records, record) } return records, nil } // GetStackPath returns the stack path from root to the given stack // Returns: [root_stack_id, ..., current_stack_id] func (store *Xun) GetStackPath(stackID string) ([]string, error) { if stackID == "" { return nil, fmt.Errorf("stack_id is required") } path := []string{stackID} currentStackID := stackID // Walk up the stack tree by following stack_parent_id for { row, err := store.newQueryResume(). Where("stack_id", currentStackID). WhereNull("deleted_at"). First() if err != nil { return nil, err } if row == nil { break } data := row.ToMap() parentID := getString(data, "stack_parent_id") if parentID == "" { break // Reached root } // Prepend parent to path path = append([]string{parentID}, path...) currentStackID = parentID } return path, nil } // DeleteResume soft deletes all resume records for a chat // Called after successful resume to clean up func (store *Xun) DeleteResume(chatID string) error { if chatID == "" { return fmt.Errorf("chat_id is required") } _, err := store.newQueryResume(). Where("chat_id", chatID). WhereNull("deleted_at"). Update(map[string]interface{}{ "deleted_at": time.Now(), "updated_at": time.Now(), }) return err } // GetResumeByRequestID retrieves resume records for a specific request func (store *Xun) GetResumeByRequestID(requestID string) ([]*types.Resume, error) { if requestID == "" { return nil, fmt.Errorf("request_id is required") } rows, err := store.newQueryResume(). Where("request_id", requestID). WhereNull("deleted_at"). OrderBy("sequence", "asc"). Get() if err != nil { return nil, err } records := make([]*types.Resume, 0, len(rows)) for _, row := range rows { data := row.ToMap() if data == nil || data["resume_id"] == nil { continue } record, err := store.rowToResume(data) if err != nil { continue } records = append(records, record) } return records, nil } // ============================================================================= // Helper Functions // ============================================================================= // rowToResume converts a database row to a Resume struct func (store *Xun) rowToResume(data map[string]interface{}) (*types.Resume, error) { record := &types.Resume{ ResumeID: getString(data, "resume_id"), ChatID: getString(data, "chat_id"), RequestID: getString(data, "request_id"), AssistantID: getString(data, "assistant_id"), StackID: getString(data, "stack_id"), StackParentID: getString(data, "stack_parent_id"), StackDepth: getInt(data, "stack_depth"), Type: getString(data, "type"), Status: getString(data, "status"), Error: getString(data, "error"), Sequence: getInt(data, "sequence"), } // Handle timestamps if createdAt := getTime(data, "created_at"); createdAt != nil { record.CreatedAt = *createdAt } if updatedAt := getTime(data, "updated_at"); updatedAt != nil { record.UpdatedAt = *updatedAt } // Handle JSON fields if input := data["input"]; input != nil { if inputStr, ok := input.(string); ok && inputStr != "" { var inputMap map[string]interface{} if err := jsoniter.UnmarshalFromString(inputStr, &inputMap); err == nil { record.Input = inputMap } } else if inputMap, ok := input.(map[string]interface{}); ok { record.Input = inputMap } } if output := data["output"]; output != nil { if outputStr, ok := output.(string); ok && outputStr != "" { var outputMap map[string]interface{} if err := jsoniter.UnmarshalFromString(outputStr, &outputMap); err == nil { record.Output = outputMap } } else if outputMap, ok := output.(map[string]interface{}); ok { record.Output = outputMap } } if snapshot := data["space_snapshot"]; snapshot != nil { if snapshotStr, ok := snapshot.(string); ok && snapshotStr != "" { var snapshotMap map[string]interface{} if err := jsoniter.UnmarshalFromString(snapshotStr, &snapshotMap); err == nil { record.SpaceSnapshot = snapshotMap } } else if snapshotMap, ok := snapshot.(map[string]interface{}); ok { record.SpaceSnapshot = snapshotMap } } if metadata := data["metadata"]; metadata != nil { if metaStr, ok := metadata.(string); ok && metaStr != "" { var metaMap map[string]interface{} if err := jsoniter.UnmarshalFromString(metaStr, &metaMap); err == nil { record.Metadata = metaMap } } else if metaMap, ok := metadata.(map[string]interface{}); ok { record.Metadata = metaMap } } return record, nil } ================================================ FILE: agent/store/xun/resume_test.go ================================================ package xun_test import ( "fmt" "testing" "time" "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/agent/store/xun" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // TestSaveResume tests batch saving resume records func TestSaveResume(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } // Create a chat first chat := &types.Chat{ AssistantID: "test_assistant", Title: "Resume Test Chat", } err = store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) t.Run("SaveSingleRecord", func(t *testing.T) { requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) records := []*types.Resume{ { ChatID: chat.ChatID, RequestID: requestID, AssistantID: "test_assistant", StackID: "stack_001", StackDepth: 0, Type: types.ResumeTypeLLM, Status: types.ResumeStatusInterrupted, Sequence: 1, }, } err := store.SaveResume(records) if err != nil { t.Fatalf("Failed to save resume record: %v", err) } // Verify retrieved, err := store.GetResume(chat.ChatID) if err != nil { t.Fatalf("Failed to get resume records: %v", err) } found := false for _, r := range retrieved { if r.RequestID == requestID { found = true if r.Type != types.ResumeTypeLLM { t.Errorf("Expected type '%s', got '%s'", types.ResumeTypeLLM, r.Type) } if r.Status != types.ResumeStatusInterrupted { t.Errorf("Expected status '%s', got '%s'", types.ResumeStatusInterrupted, r.Status) } break } } if !found { t.Error("Could not find saved resume record") } // Clean up store.DeleteResume(chat.ChatID) }) t.Run("SaveBatchRecords", func(t *testing.T) { // Create a new chat for this test batchChat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(batchChat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(batchChat.ChatID) requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) records := []*types.Resume{ { ChatID: batchChat.ChatID, RequestID: requestID, AssistantID: "test_assistant", StackID: "stack_001", StackDepth: 0, Type: types.ResumeTypeInput, Status: types.ResumeStatusInterrupted, Sequence: 1, }, { ChatID: batchChat.ChatID, RequestID: requestID, AssistantID: "test_assistant", StackID: "stack_001", StackDepth: 0, Type: types.ResumeTypeHookCreate, Status: types.ResumeStatusInterrupted, Sequence: 2, }, { ChatID: batchChat.ChatID, RequestID: requestID, AssistantID: "test_assistant", StackID: "stack_001", StackDepth: 0, Type: types.ResumeTypeLLM, Status: types.ResumeStatusFailed, Sequence: 3, Error: "Connection timeout", }, } err = store.SaveResume(records) if err != nil { t.Fatalf("Failed to save batch resume records: %v", err) } // Verify all records saved retrieved, err := store.GetResume(batchChat.ChatID) if err != nil { t.Fatalf("Failed to get resume records: %v", err) } if len(retrieved) != 3 { t.Errorf("Expected 3 records, got %d", len(retrieved)) } // Verify order (should be by sequence) if len(retrieved) >= 3 { if retrieved[0].Sequence != 1 { t.Errorf("Expected first record sequence 1, got %d", retrieved[0].Sequence) } if retrieved[2].Sequence != 3 { t.Errorf("Expected last record sequence 3, got %d", retrieved[2].Sequence) } if retrieved[2].Error != "Connection timeout" { t.Errorf("Expected error 'Connection timeout', got '%s'", retrieved[2].Error) } } t.Logf("Saved %d resume records in single batch call", len(records)) }) t.Run("SaveRecordWithAllFields", func(t *testing.T) { fullChat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(fullChat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(fullChat.ChatID) requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) records := []*types.Resume{ { ChatID: fullChat.ChatID, RequestID: requestID, AssistantID: "test_assistant", StackID: "stack_001", StackParentID: "stack_000", StackDepth: 1, Type: types.ResumeTypeDelegate, Status: types.ResumeStatusInterrupted, Input: map[string]interface{}{"agent_id": "sub_agent", "messages": []interface{}{}}, Output: map[string]interface{}{"partial": true}, SpaceSnapshot: map[string]interface{}{"key1": "value1", "key2": 123}, Error: "User cancelled", Sequence: 1, Metadata: map[string]interface{}{"retry_count": 0}, }, } err = store.SaveResume(records) if err != nil { t.Fatalf("Failed to save record: %v", err) } retrieved, err := store.GetResume(fullChat.ChatID) if err != nil { t.Fatalf("Failed to get records: %v", err) } if len(retrieved) != 1 { t.Fatalf("Expected 1 record, got %d", len(retrieved)) } r := retrieved[0] if r.StackParentID != "stack_000" { t.Errorf("Expected stack_parent_id 'stack_000', got '%s'", r.StackParentID) } if r.StackDepth != 1 { t.Errorf("Expected stack_depth 1, got %d", r.StackDepth) } if r.Input == nil { t.Error("Expected input to be set") } if r.Output == nil { t.Error("Expected output to be set") } if r.SpaceSnapshot == nil { t.Error("Expected space_snapshot to be set") } else if r.SpaceSnapshot["key1"] != "value1" { t.Errorf("Expected space_snapshot key1='value1', got '%v'", r.SpaceSnapshot["key1"]) } if r.Metadata == nil { t.Error("Expected metadata to be set") } }) t.Run("SaveEmptyRecords", func(t *testing.T) { err := store.SaveResume([]*types.Resume{}) if err != nil { t.Errorf("Expected no error for empty records, got: %v", err) } }) t.Run("SaveRecordWithoutChatID", func(t *testing.T) { records := []*types.Resume{{RequestID: "req", AssistantID: "ast", StackID: "stk", Type: "llm", Status: "failed", Sequence: 1}} err := store.SaveResume(records) if err == nil { t.Error("Expected error when saving without chat_id") } }) t.Run("SaveRecordWithoutRequestID", func(t *testing.T) { records := []*types.Resume{{ChatID: chat.ChatID, AssistantID: "ast", StackID: "stk", Type: "llm", Status: "failed", Sequence: 1}} err := store.SaveResume(records) if err == nil { t.Error("Expected error when saving without request_id") } }) t.Run("SaveRecordWithoutAssistantID", func(t *testing.T) { records := []*types.Resume{{ChatID: chat.ChatID, RequestID: "req", StackID: "stk", Type: "llm", Status: "failed", Sequence: 1}} err := store.SaveResume(records) if err == nil { t.Error("Expected error when saving without assistant_id") } }) t.Run("SaveRecordWithoutStackID", func(t *testing.T) { records := []*types.Resume{{ChatID: chat.ChatID, RequestID: "req", AssistantID: "ast", Type: "llm", Status: "failed", Sequence: 1}} err := store.SaveResume(records) if err == nil { t.Error("Expected error when saving without stack_id") } }) t.Run("SaveRecordWithoutType", func(t *testing.T) { records := []*types.Resume{{ChatID: chat.ChatID, RequestID: "req", AssistantID: "ast", StackID: "stk", Status: "failed", Sequence: 1}} err := store.SaveResume(records) if err == nil { t.Error("Expected error when saving without type") } }) t.Run("SaveRecordWithoutStatus", func(t *testing.T) { records := []*types.Resume{{ChatID: chat.ChatID, RequestID: "req", AssistantID: "ast", StackID: "stk", Type: "llm", Sequence: 1}} err := store.SaveResume(records) if err == nil { t.Error("Expected error when saving without status") } }) } // TestGetResume tests retrieving resume records func TestGetResume(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } // Create chat and resume records chat := &types.Chat{ AssistantID: "test_assistant", } err = store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) records := []*types.Resume{ {ChatID: chat.ChatID, RequestID: requestID, AssistantID: "ast1", StackID: "stk1", Type: types.ResumeTypeInput, Status: types.ResumeStatusInterrupted, Sequence: 1}, {ChatID: chat.ChatID, RequestID: requestID, AssistantID: "ast1", StackID: "stk1", Type: types.ResumeTypeHookCreate, Status: types.ResumeStatusInterrupted, Sequence: 2}, {ChatID: chat.ChatID, RequestID: requestID, AssistantID: "ast1", StackID: "stk1", Type: types.ResumeTypeLLM, Status: types.ResumeStatusFailed, Sequence: 3}, } err = store.SaveResume(records) if err != nil { t.Fatalf("Failed to save records: %v", err) } defer store.DeleteResume(chat.ChatID) t.Run("GetAllRecords", func(t *testing.T) { retrieved, err := store.GetResume(chat.ChatID) if err != nil { t.Fatalf("Failed to get records: %v", err) } if len(retrieved) != 3 { t.Errorf("Expected 3 records, got %d", len(retrieved)) } // Verify order by sequence for i := 1; i < len(retrieved); i++ { if retrieved[i].Sequence < retrieved[i-1].Sequence { t.Error("Records not ordered by sequence") } } }) t.Run("GetRecordsWithEmptyChatID", func(t *testing.T) { _, err := store.GetResume("") if err == nil { t.Error("Expected error when getting records without chat_id") } }) t.Run("GetRecordsFromNonExistentChat", func(t *testing.T) { retrieved, err := store.GetResume("nonexistent_chat") if err != nil { t.Fatalf("Unexpected error: %v", err) } if len(retrieved) != 0 { t.Errorf("Expected 0 records from non-existent chat, got %d", len(retrieved)) } }) } // TestGetLastResume tests retrieving the last resume record func TestGetLastResume(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } chat := &types.Chat{ AssistantID: "test_assistant", } err = store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) t.Run("GetLastRecordFromMultiple", func(t *testing.T) { requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) records := []*types.Resume{ {ChatID: chat.ChatID, RequestID: requestID, AssistantID: "ast", StackID: "stk", Type: types.ResumeTypeInput, Status: types.ResumeStatusInterrupted, Sequence: 1}, {ChatID: chat.ChatID, RequestID: requestID, AssistantID: "ast", StackID: "stk", Type: types.ResumeTypeHookCreate, Status: types.ResumeStatusInterrupted, Sequence: 2}, {ChatID: chat.ChatID, RequestID: requestID, AssistantID: "ast", StackID: "stk", Type: types.ResumeTypeLLM, Status: types.ResumeStatusFailed, Sequence: 3, Error: "Last error"}, } err := store.SaveResume(records) if err != nil { t.Fatalf("Failed to save records: %v", err) } defer store.DeleteResume(chat.ChatID) last, err := store.GetLastResume(chat.ChatID) if err != nil { t.Fatalf("Failed to get last record: %v", err) } if last == nil { t.Fatal("Expected last record, got nil") } if last.Sequence != 3 { t.Errorf("Expected sequence 3, got %d", last.Sequence) } if last.Type != types.ResumeTypeLLM { t.Errorf("Expected type '%s', got '%s'", types.ResumeTypeLLM, last.Type) } if last.Error != "Last error" { t.Errorf("Expected error 'Last error', got '%s'", last.Error) } }) t.Run("GetLastRecordFromEmpty", func(t *testing.T) { emptyChat := &types.Chat{AssistantID: "test_assistant"} err := store.CreateChat(emptyChat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(emptyChat.ChatID) last, err := store.GetLastResume(emptyChat.ChatID) if err != nil { t.Fatalf("Unexpected error: %v", err) } if last != nil { t.Error("Expected nil for empty chat, got record") } }) t.Run("GetLastRecordWithEmptyChatID", func(t *testing.T) { _, err := store.GetLastResume("") if err == nil { t.Error("Expected error when getting last record without chat_id") } }) } // TestGetResumeByStackID tests retrieving records by stack ID func TestGetResumeByStackID(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } chat := &types.Chat{ AssistantID: "test_assistant", } err = store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) records := []*types.Resume{ {ChatID: chat.ChatID, RequestID: requestID, AssistantID: "ast1", StackID: "stack_A", Type: types.ResumeTypeInput, Status: types.ResumeStatusInterrupted, Sequence: 1}, {ChatID: chat.ChatID, RequestID: requestID, AssistantID: "ast1", StackID: "stack_A", Type: types.ResumeTypeLLM, Status: types.ResumeStatusInterrupted, Sequence: 2}, {ChatID: chat.ChatID, RequestID: requestID, AssistantID: "ast2", StackID: "stack_B", StackParentID: "stack_A", StackDepth: 1, Type: types.ResumeTypeDelegate, Status: types.ResumeStatusFailed, Sequence: 3}, } err = store.SaveResume(records) if err != nil { t.Fatalf("Failed to save records: %v", err) } defer store.DeleteResume(chat.ChatID) t.Run("GetRecordsByStackA", func(t *testing.T) { retrieved, err := store.GetResumeByStackID("stack_A") if err != nil { t.Fatalf("Failed to get records: %v", err) } if len(retrieved) != 2 { t.Errorf("Expected 2 records for stack_A, got %d", len(retrieved)) } }) t.Run("GetRecordsByStackB", func(t *testing.T) { retrieved, err := store.GetResumeByStackID("stack_B") if err != nil { t.Fatalf("Failed to get records: %v", err) } if len(retrieved) != 1 { t.Errorf("Expected 1 record for stack_B, got %d", len(retrieved)) } if len(retrieved) > 0 { if retrieved[0].StackParentID != "stack_A" { t.Errorf("Expected stack_parent_id 'stack_A', got '%s'", retrieved[0].StackParentID) } if retrieved[0].StackDepth != 1 { t.Errorf("Expected stack_depth 1, got %d", retrieved[0].StackDepth) } } }) t.Run("GetRecordsByNonExistentStack", func(t *testing.T) { retrieved, err := store.GetResumeByStackID("nonexistent_stack") if err != nil { t.Fatalf("Unexpected error: %v", err) } if len(retrieved) != 0 { t.Errorf("Expected 0 records, got %d", len(retrieved)) } }) t.Run("GetRecordsByEmptyStackID", func(t *testing.T) { _, err := store.GetResumeByStackID("") if err == nil { t.Error("Expected error when getting records without stack_id") } }) } // TestGetStackPath tests retrieving the stack path func TestGetStackPath(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } chat := &types.Chat{ AssistantID: "test_assistant", } err = store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) // Create a nested stack structure: root -> child -> grandchild requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) records := []*types.Resume{ {ChatID: chat.ChatID, RequestID: requestID, AssistantID: "ast1", StackID: "root_stack", Type: types.ResumeTypeInput, Status: types.ResumeStatusInterrupted, Sequence: 1}, {ChatID: chat.ChatID, RequestID: requestID, AssistantID: "ast2", StackID: "child_stack", StackParentID: "root_stack", StackDepth: 1, Type: types.ResumeTypeDelegate, Status: types.ResumeStatusInterrupted, Sequence: 2}, {ChatID: chat.ChatID, RequestID: requestID, AssistantID: "ast3", StackID: "grandchild_stack", StackParentID: "child_stack", StackDepth: 2, Type: types.ResumeTypeLLM, Status: types.ResumeStatusFailed, Sequence: 3}, } err = store.SaveResume(records) if err != nil { t.Fatalf("Failed to save records: %v", err) } defer store.DeleteResume(chat.ChatID) t.Run("GetPathFromGrandchild", func(t *testing.T) { path, err := store.GetStackPath("grandchild_stack") if err != nil { t.Fatalf("Failed to get stack path: %v", err) } if len(path) != 3 { t.Errorf("Expected path length 3, got %d", len(path)) } if len(path) >= 3 { if path[0] != "root_stack" { t.Errorf("Expected first element 'root_stack', got '%s'", path[0]) } if path[1] != "child_stack" { t.Errorf("Expected second element 'child_stack', got '%s'", path[1]) } if path[2] != "grandchild_stack" { t.Errorf("Expected third element 'grandchild_stack', got '%s'", path[2]) } } t.Logf("Stack path: %v", path) }) t.Run("GetPathFromChild", func(t *testing.T) { path, err := store.GetStackPath("child_stack") if err != nil { t.Fatalf("Failed to get stack path: %v", err) } if len(path) != 2 { t.Errorf("Expected path length 2, got %d", len(path)) } if len(path) >= 2 { if path[0] != "root_stack" { t.Errorf("Expected first element 'root_stack', got '%s'", path[0]) } if path[1] != "child_stack" { t.Errorf("Expected second element 'child_stack', got '%s'", path[1]) } } }) t.Run("GetPathFromRoot", func(t *testing.T) { path, err := store.GetStackPath("root_stack") if err != nil { t.Fatalf("Failed to get stack path: %v", err) } if len(path) != 1 { t.Errorf("Expected path length 1, got %d", len(path)) } if len(path) >= 1 && path[0] != "root_stack" { t.Errorf("Expected 'root_stack', got '%s'", path[0]) } }) t.Run("GetPathWithEmptyStackID", func(t *testing.T) { _, err := store.GetStackPath("") if err == nil { t.Error("Expected error when getting path without stack_id") } }) } // TestDeleteResume tests deleting resume records func TestDeleteResume(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("DeleteExistingRecords", func(t *testing.T) { chat := &types.Chat{AssistantID: "test_assistant"} err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) records := []*types.Resume{ {ChatID: chat.ChatID, RequestID: requestID, AssistantID: "ast", StackID: "stk", Type: types.ResumeTypeLLM, Status: types.ResumeStatusFailed, Sequence: 1}, {ChatID: chat.ChatID, RequestID: requestID, AssistantID: "ast", StackID: "stk", Type: types.ResumeTypeTool, Status: types.ResumeStatusFailed, Sequence: 2}, } err = store.SaveResume(records) if err != nil { t.Fatalf("Failed to save records: %v", err) } // Delete err = store.DeleteResume(chat.ChatID) if err != nil { t.Fatalf("Failed to delete records: %v", err) } // Verify deleted retrieved, err := store.GetResume(chat.ChatID) if err != nil { t.Fatalf("Failed to get records: %v", err) } if len(retrieved) != 0 { t.Errorf("Expected 0 records after delete, got %d", len(retrieved)) } }) t.Run("DeleteFromEmptyChat", func(t *testing.T) { chat := &types.Chat{AssistantID: "test_assistant"} err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) // Delete from chat with no records - should not error err = store.DeleteResume(chat.ChatID) if err != nil { t.Errorf("Expected no error when deleting from empty chat, got: %v", err) } }) t.Run("DeleteWithEmptyChatID", func(t *testing.T) { err := store.DeleteResume("") if err == nil { t.Error("Expected error when deleting with empty chat_id") } }) } // TestResumeCompleteWorkflow tests a complete resume/retry workflow func TestResumeCompleteWorkflow(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("CompleteA2AWorkflow", func(t *testing.T) { // Create chat chat := &types.Chat{ AssistantID: "main_assistant", Title: "A2A Workflow Test", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) // Simulate A2A call that gets interrupted // Main assistant -> Sub assistant (interrupted during LLM call) requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) records := []*types.Resume{ // Main assistant steps { ChatID: chat.ChatID, RequestID: requestID, AssistantID: "main_assistant", StackID: "main_stack", StackDepth: 0, Type: types.ResumeTypeInput, Status: types.ResumeStatusInterrupted, Input: map[string]interface{}{"messages": []interface{}{map[string]interface{}{"role": "user", "content": "Analyze this"}}}, Sequence: 1, }, { ChatID: chat.ChatID, RequestID: requestID, AssistantID: "main_assistant", StackID: "main_stack", StackDepth: 0, Type: types.ResumeTypeDelegate, Status: types.ResumeStatusInterrupted, SpaceSnapshot: map[string]interface{}{"task": "analyze", "data_id": "123"}, Sequence: 2, }, // Sub assistant steps { ChatID: chat.ChatID, RequestID: requestID, AssistantID: "sub_assistant", StackID: "sub_stack", StackParentID: "main_stack", StackDepth: 1, Type: types.ResumeTypeInput, Status: types.ResumeStatusInterrupted, Sequence: 3, }, { ChatID: chat.ChatID, RequestID: requestID, AssistantID: "sub_assistant", StackID: "sub_stack", StackParentID: "main_stack", StackDepth: 1, Type: types.ResumeTypeLLM, Status: types.ResumeStatusInterrupted, Input: map[string]interface{}{"messages": []interface{}{}}, Output: map[string]interface{}{"partial_content": "The analysis shows..."}, SpaceSnapshot: map[string]interface{}{"task": "analyze", "data_id": "123"}, Sequence: 4, }, } err = store.SaveResume(records) if err != nil { t.Fatalf("Failed to save resume records: %v", err) } t.Logf("Saved %d resume records for A2A workflow", len(records)) // 1. Get last resume record (should be the interrupted LLM call) last, err := store.GetLastResume(chat.ChatID) if err != nil { t.Fatalf("Failed to get last resume: %v", err) } if last == nil { t.Fatal("Expected last resume record") } if last.Type != types.ResumeTypeLLM { t.Errorf("Expected type '%s', got '%s'", types.ResumeTypeLLM, last.Type) } if last.StackDepth != 1 { t.Errorf("Expected stack_depth 1, got %d", last.StackDepth) } // 2. Get stack path to understand the call hierarchy path, err := store.GetStackPath(last.StackID) if err != nil { t.Fatalf("Failed to get stack path: %v", err) } if len(path) != 2 { t.Errorf("Expected path length 2, got %d", len(path)) } t.Logf("Stack path: %v", path) // 3. Get all records for the sub stack subRecords, err := store.GetResumeByStackID("sub_stack") if err != nil { t.Fatalf("Failed to get sub stack records: %v", err) } if len(subRecords) != 2 { t.Errorf("Expected 2 records for sub_stack, got %d", len(subRecords)) } // 4. Verify space snapshot is preserved if last.SpaceSnapshot == nil { t.Error("Expected space_snapshot to be set") } else { if last.SpaceSnapshot["task"] != "analyze" { t.Errorf("Expected task='analyze', got '%v'", last.SpaceSnapshot["task"]) } } // 5. Clean up after successful resume err = store.DeleteResume(chat.ChatID) if err != nil { t.Fatalf("Failed to delete resume records: %v", err) } // 6. Verify cleanup remaining, err := store.GetResume(chat.ChatID) if err != nil { t.Fatalf("Failed to get remaining records: %v", err) } if len(remaining) != 0 { t.Errorf("Expected 0 records after cleanup, got %d", len(remaining)) } t.Log("Complete A2A workflow test passed!") }) } ================================================ FILE: agent/store/xun/search.go ================================================ package xun import ( "fmt" "time" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/dbal/query" "github.com/yaoapp/yao/agent/store/types" ) // ============================================================================= // Search Management // ============================================================================= // SaveSearch saves a search record for a request func (store *Xun) SaveSearch(search *types.Search) error { if search == nil { return fmt.Errorf("search is nil") } if search.RequestID == "" { return fmt.Errorf("request_id is required") } if search.ChatID == "" { return fmt.Errorf("chat_id is required") } if search.Source == "" { return fmt.Errorf("source is required") } now := time.Now() // Build row data row := map[string]interface{}{ "request_id": search.RequestID, "chat_id": search.ChatID, "query": search.Query, "source": search.Source, "duration": search.Duration, "created_at": now, "updated_at": now, } // Handle JSON fields if search.Config != nil { configJSON, err := jsoniter.MarshalToString(search.Config) if err != nil { return fmt.Errorf("failed to marshal config: %w", err) } row["config"] = configJSON } if len(search.Keywords) > 0 { keywordsJSON, err := jsoniter.MarshalToString(search.Keywords) if err != nil { return fmt.Errorf("failed to marshal keywords: %w", err) } row["keywords"] = keywordsJSON } if len(search.Entities) > 0 { entitiesJSON, err := jsoniter.MarshalToString(search.Entities) if err != nil { return fmt.Errorf("failed to marshal entities: %w", err) } row["entities"] = entitiesJSON } if len(search.Relations) > 0 { relationsJSON, err := jsoniter.MarshalToString(search.Relations) if err != nil { return fmt.Errorf("failed to marshal relations: %w", err) } row["relations"] = relationsJSON } if search.DSL != nil { dslJSON, err := jsoniter.MarshalToString(search.DSL) if err != nil { return fmt.Errorf("failed to marshal dsl: %w", err) } row["dsl"] = dslJSON } if len(search.References) > 0 { refsJSON, err := jsoniter.MarshalToString(search.References) if err != nil { return fmt.Errorf("failed to marshal references: %w", err) } row["references"] = refsJSON } if len(search.Graph) > 0 { graphJSON, err := jsoniter.MarshalToString(search.Graph) if err != nil { return fmt.Errorf("failed to marshal graph: %w", err) } row["graph"] = graphJSON } if search.XML != "" { row["xml"] = search.XML } if search.Prompt != "" { row["prompt"] = search.Prompt } if search.Error != "" { row["error"] = search.Error } return store.newQuerySearch().Insert(row) } // GetSearches retrieves all search records for a request func (store *Xun) GetSearches(requestID string) ([]*types.Search, error) { if requestID == "" { return nil, fmt.Errorf("request_id is required") } rows, err := store.newQuerySearch(). Where("request_id", requestID). WhereNull("deleted_at"). OrderBy("created_at", "asc"). Get() if err != nil { return nil, err } searches := make([]*types.Search, 0, len(rows)) for _, row := range rows { data := row.ToMap() if data == nil { continue } search, err := store.rowToSearch(data) if err != nil { continue } searches = append(searches, search) } return searches, nil } // GetReference retrieves a single reference by request ID and index func (store *Xun) GetReference(requestID string, index int) (*types.Reference, error) { if requestID == "" { return nil, fmt.Errorf("request_id is required") } if index < 1 { return nil, fmt.Errorf("index must be >= 1") } // Get all searches for this request searches, err := store.GetSearches(requestID) if err != nil { return nil, err } // Find the reference with matching index for _, search := range searches { for _, ref := range search.References { if ref.Index == index { return &ref, nil } } } return nil, fmt.Errorf("reference not found: request_id=%s, index=%d", requestID, index) } // DeleteSearches deletes all search records for a chat (soft delete) func (store *Xun) DeleteSearches(chatID string) error { if chatID == "" { return fmt.Errorf("chat_id is required") } _, err := store.newQuerySearch(). Where("chat_id", chatID). WhereNull("deleted_at"). Update(map[string]interface{}{ "deleted_at": time.Now(), "updated_at": time.Now(), }) return err } // ============================================================================= // Query Builder // ============================================================================= // newQuerySearch creates a new query builder for the search table func (store *Xun) newQuerySearch() query.Query { qb := store.query.New() qb.Table(store.getSearchTable()) return qb } // getSearchTable returns the search table name func (store *Xun) getSearchTable() string { m := model.Select("__yao.agent.search") if m != nil && m.MetaData.Table.Name != "" { return m.MetaData.Table.Name } return "agent_search" } // ============================================================================= // Helper Functions // ============================================================================= // rowToSearch converts a database row to a Search struct func (store *Xun) rowToSearch(data map[string]interface{}) (*types.Search, error) { search := &types.Search{ ID: getInt64(data, "id"), RequestID: getString(data, "request_id"), ChatID: getString(data, "chat_id"), Query: getString(data, "query"), Source: getString(data, "source"), XML: getString(data, "xml"), Prompt: getString(data, "prompt"), Duration: getInt64(data, "duration"), Error: getString(data, "error"), } // Handle timestamps if createdAt := getTime(data, "created_at"); createdAt != nil { search.CreatedAt = *createdAt } // Parse JSON fields if config := data["config"]; config != nil { if configStr, ok := config.(string); ok && configStr != "" { var configMap map[string]any if err := jsoniter.UnmarshalFromString(configStr, &configMap); err == nil { search.Config = configMap } } } if keywords := data["keywords"]; keywords != nil { if keywordsStr, ok := keywords.(string); ok && keywordsStr != "" { var keywordsList []string if err := jsoniter.UnmarshalFromString(keywordsStr, &keywordsList); err == nil { search.Keywords = keywordsList } } } if entities := data["entities"]; entities != nil { if entitiesStr, ok := entities.(string); ok && entitiesStr != "" { var entitiesList []types.Entity if err := jsoniter.UnmarshalFromString(entitiesStr, &entitiesList); err == nil { search.Entities = entitiesList } } } if relations := data["relations"]; relations != nil { if relationsStr, ok := relations.(string); ok && relationsStr != "" { var relationsList []types.Relation if err := jsoniter.UnmarshalFromString(relationsStr, &relationsList); err == nil { search.Relations = relationsList } } } if dsl := data["dsl"]; dsl != nil { if dslStr, ok := dsl.(string); ok && dslStr != "" { var dslMap map[string]any if err := jsoniter.UnmarshalFromString(dslStr, &dslMap); err == nil { search.DSL = dslMap } } } if refs := data["references"]; refs != nil { if refsStr, ok := refs.(string); ok && refsStr != "" { var refsList []types.Reference if err := jsoniter.UnmarshalFromString(refsStr, &refsList); err == nil { search.References = refsList } } } if graph := data["graph"]; graph != nil { if graphStr, ok := graph.(string); ok && graphStr != "" { var graphList []types.GraphNode if err := jsoniter.UnmarshalFromString(graphStr, &graphList); err == nil { search.Graph = graphList } } } return search, nil } ================================================ FILE: agent/store/xun/search_test.go ================================================ package xun_test import ( "fmt" "testing" "time" "github.com/yaoapp/yao/agent/store/types" "github.com/yaoapp/yao/agent/store/xun" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // TestSaveSearch tests saving search records func TestSaveSearch(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } // Create a chat first chat := &types.Chat{ AssistantID: "test_assistant", Title: "Search Test Chat", } err = store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) t.Run("SaveBasicSearch", func(t *testing.T) { requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) search := &types.Search{ RequestID: requestID, ChatID: chat.ChatID, Query: "What is the weather today?", Source: "web", Duration: 150, } err := store.SaveSearch(search) if err != nil { t.Fatalf("Failed to save search: %v", err) } // Verify searches, err := store.GetSearches(requestID) if err != nil { t.Fatalf("Failed to get searches: %v", err) } if len(searches) != 1 { t.Fatalf("Expected 1 search, got %d", len(searches)) } if searches[0].Query != "What is the weather today?" { t.Errorf("Expected query 'What is the weather today?', got '%s'", searches[0].Query) } if searches[0].Source != "web" { t.Errorf("Expected source 'web', got '%s'", searches[0].Source) } if searches[0].Duration != 150 { t.Errorf("Expected duration 150, got %d", searches[0].Duration) } }) t.Run("SaveSearchWithKeywords", func(t *testing.T) { requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) search := &types.Search{ RequestID: requestID, ChatID: chat.ChatID, Query: "Latest news about AI", Keywords: []string{"AI", "news", "latest"}, Source: "web", Duration: 200, } err := store.SaveSearch(search) if err != nil { t.Fatalf("Failed to save search: %v", err) } searches, err := store.GetSearches(requestID) if err != nil { t.Fatalf("Failed to get searches: %v", err) } if len(searches) != 1 { t.Fatalf("Expected 1 search, got %d", len(searches)) } if len(searches[0].Keywords) != 3 { t.Errorf("Expected 3 keywords, got %d", len(searches[0].Keywords)) } if searches[0].Keywords[0] != "AI" { t.Errorf("Expected first keyword 'AI', got '%s'", searches[0].Keywords[0]) } }) t.Run("SaveSearchWithReferences", func(t *testing.T) { requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) search := &types.Search{ RequestID: requestID, ChatID: chat.ChatID, Query: "How to learn Go programming?", Source: "web", References: []types.Reference{ { Index: 1, Type: "web", Title: "Go Programming Tutorial", URL: "https://go.dev/tour/", Snippet: "An interactive introduction to Go", }, { Index: 2, Type: "web", Title: "Effective Go", URL: "https://go.dev/doc/effective_go", Snippet: "Tips for writing clear, idiomatic Go code", }, }, XML: "...", Prompt: "Please cite sources using [1], [2]...", Duration: 300, } err := store.SaveSearch(search) if err != nil { t.Fatalf("Failed to save search: %v", err) } searches, err := store.GetSearches(requestID) if err != nil { t.Fatalf("Failed to get searches: %v", err) } if len(searches) != 1 { t.Fatalf("Expected 1 search, got %d", len(searches)) } if len(searches[0].References) != 2 { t.Errorf("Expected 2 references, got %d", len(searches[0].References)) } if searches[0].References[0].Title != "Go Programming Tutorial" { t.Errorf("Expected first reference title 'Go Programming Tutorial', got '%s'", searches[0].References[0].Title) } if searches[0].XML != "..." { t.Errorf("Expected XML '...', got '%s'", searches[0].XML) } if searches[0].Prompt != "Please cite sources using [1], [2]..." { t.Errorf("Expected prompt to be set") } }) t.Run("SaveSearchWithConfig", func(t *testing.T) { requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) search := &types.Search{ RequestID: requestID, ChatID: chat.ChatID, Query: "Config test", Source: "auto", Config: map[string]any{ "uses": map[string]any{ "search": "builtin", "web": "builtin", "keyword": "builtin", }, "web": map[string]any{ "provider": "tavily", "max_results": 5, }, }, Duration: 100, } err := store.SaveSearch(search) if err != nil { t.Fatalf("Failed to save search: %v", err) } searches, err := store.GetSearches(requestID) if err != nil { t.Fatalf("Failed to get searches: %v", err) } if len(searches) != 1 { t.Fatalf("Expected 1 search, got %d", len(searches)) } if searches[0].Config == nil { t.Fatal("Expected config to be set") } uses, ok := searches[0].Config["uses"].(map[string]any) if !ok { t.Fatal("Expected uses in config") } if uses["search"] != "builtin" { t.Errorf("Expected uses.search='builtin', got '%v'", uses["search"]) } }) t.Run("SaveSearchWithEntitiesAndRelations", func(t *testing.T) { requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) search := &types.Search{ RequestID: requestID, ChatID: chat.ChatID, Query: "Who is the CEO of Apple?", Source: "kb", Entities: []types.Entity{ {Name: "Apple", Type: "Organization"}, {Name: "Tim Cook", Type: "Person"}, }, Relations: []types.Relation{ {Subject: "Tim Cook", Predicate: "CEO_of", Object: "Apple"}, }, Graph: []types.GraphNode{ {ID: "node1", Type: "Organization", Label: "Apple", Score: 0.95}, {ID: "node2", Type: "Person", Label: "Tim Cook", Score: 0.92}, }, Duration: 250, } err := store.SaveSearch(search) if err != nil { t.Fatalf("Failed to save search: %v", err) } searches, err := store.GetSearches(requestID) if err != nil { t.Fatalf("Failed to get searches: %v", err) } if len(searches) != 1 { t.Fatalf("Expected 1 search, got %d", len(searches)) } if len(searches[0].Entities) != 2 { t.Errorf("Expected 2 entities, got %d", len(searches[0].Entities)) } if searches[0].Entities[0].Name != "Apple" { t.Errorf("Expected first entity 'Apple', got '%s'", searches[0].Entities[0].Name) } if len(searches[0].Relations) != 1 { t.Errorf("Expected 1 relation, got %d", len(searches[0].Relations)) } if searches[0].Relations[0].Predicate != "CEO_of" { t.Errorf("Expected predicate 'CEO_of', got '%s'", searches[0].Relations[0].Predicate) } if len(searches[0].Graph) != 2 { t.Errorf("Expected 2 graph nodes, got %d", len(searches[0].Graph)) } }) t.Run("SaveSearchWithDSL", func(t *testing.T) { requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) search := &types.Search{ RequestID: requestID, ChatID: chat.ChatID, Query: "Find orders over $1000", Source: "db", DSL: map[string]any{ "wheres": []map[string]any{ {"column": "amount", "op": ">", "value": 1000}, }, "orders": []map[string]any{ {"column": "created_at", "option": "desc"}, }, }, Duration: 50, } err := store.SaveSearch(search) if err != nil { t.Fatalf("Failed to save search: %v", err) } searches, err := store.GetSearches(requestID) if err != nil { t.Fatalf("Failed to get searches: %v", err) } if len(searches) != 1 { t.Fatalf("Expected 1 search, got %d", len(searches)) } if searches[0].DSL == nil { t.Fatal("Expected DSL to be set") } wheres, ok := searches[0].DSL["wheres"].([]any) if !ok || len(wheres) == 0 { t.Error("Expected wheres in DSL") } }) t.Run("SaveSearchWithError", func(t *testing.T) { requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) search := &types.Search{ RequestID: requestID, ChatID: chat.ChatID, Query: "Failed search", Source: "web", Error: "API rate limit exceeded", Duration: 10, } err := store.SaveSearch(search) if err != nil { t.Fatalf("Failed to save search: %v", err) } searches, err := store.GetSearches(requestID) if err != nil { t.Fatalf("Failed to get searches: %v", err) } if len(searches) != 1 { t.Fatalf("Expected 1 search, got %d", len(searches)) } if searches[0].Error != "API rate limit exceeded" { t.Errorf("Expected error 'API rate limit exceeded', got '%s'", searches[0].Error) } }) t.Run("SaveSearchWithoutRequestID", func(t *testing.T) { search := &types.Search{ ChatID: chat.ChatID, Query: "Test", Source: "web", } err := store.SaveSearch(search) if err == nil { t.Error("Expected error when saving without request_id") } }) t.Run("SaveSearchWithoutChatID", func(t *testing.T) { search := &types.Search{ RequestID: "req_test", Query: "Test", Source: "web", } err := store.SaveSearch(search) if err == nil { t.Error("Expected error when saving without chat_id") } }) t.Run("SaveSearchWithoutSource", func(t *testing.T) { search := &types.Search{ RequestID: "req_test", ChatID: chat.ChatID, Query: "Test", } err := store.SaveSearch(search) if err == nil { t.Error("Expected error when saving without source") } }) t.Run("SaveNilSearch", func(t *testing.T) { err := store.SaveSearch(nil) if err == nil { t.Error("Expected error when saving nil search") } }) } // TestGetSearches tests retrieving search records func TestGetSearches(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } // Create a chat chat := &types.Chat{ AssistantID: "test_assistant", } err = store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) t.Run("GetMultipleSearches", func(t *testing.T) { requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) // Save multiple searches for the same request for i := 1; i <= 3; i++ { search := &types.Search{ RequestID: requestID, ChatID: chat.ChatID, Query: fmt.Sprintf("Query %d", i), Source: "web", Duration: int64(i * 100), } err := store.SaveSearch(search) if err != nil { t.Fatalf("Failed to save search %d: %v", i, err) } time.Sleep(10 * time.Millisecond) // Ensure different created_at } searches, err := store.GetSearches(requestID) if err != nil { t.Fatalf("Failed to get searches: %v", err) } if len(searches) != 3 { t.Errorf("Expected 3 searches, got %d", len(searches)) } // Verify order (by created_at asc) for i := 0; i < len(searches)-1; i++ { if searches[i].CreatedAt.After(searches[i+1].CreatedAt) { t.Error("Searches not ordered by created_at asc") } } }) t.Run("GetSearchesForNonExistentRequest", func(t *testing.T) { searches, err := store.GetSearches("nonexistent_request") if err != nil { t.Fatalf("Unexpected error: %v", err) } if len(searches) != 0 { t.Errorf("Expected 0 searches, got %d", len(searches)) } }) t.Run("GetSearchesWithEmptyRequestID", func(t *testing.T) { _, err := store.GetSearches("") if err == nil { t.Error("Expected error when getting searches without request_id") } }) } // TestGetReference tests retrieving a single reference func TestGetReference(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } // Create a chat chat := &types.Chat{ AssistantID: "test_assistant", } err = store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) // Save search with references search := &types.Search{ RequestID: requestID, ChatID: chat.ChatID, Query: "Test query", Source: "web", References: []types.Reference{ {Index: 1, Type: "web", Title: "Reference 1", URL: "https://example.com/1"}, {Index: 2, Type: "web", Title: "Reference 2", URL: "https://example.com/2"}, {Index: 3, Type: "kb", Title: "Reference 3", Content: "KB content"}, }, Duration: 100, } err = store.SaveSearch(search) if err != nil { t.Fatalf("Failed to save search: %v", err) } t.Run("GetExistingReference", func(t *testing.T) { ref, err := store.GetReference(requestID, 1) if err != nil { t.Fatalf("Failed to get reference: %v", err) } if ref.Title != "Reference 1" { t.Errorf("Expected title 'Reference 1', got '%s'", ref.Title) } if ref.URL != "https://example.com/1" { t.Errorf("Expected URL 'https://example.com/1', got '%s'", ref.URL) } }) t.Run("GetReferenceByIndex", func(t *testing.T) { ref, err := store.GetReference(requestID, 3) if err != nil { t.Fatalf("Failed to get reference: %v", err) } if ref.Type != "kb" { t.Errorf("Expected type 'kb', got '%s'", ref.Type) } if ref.Content != "KB content" { t.Errorf("Expected content 'KB content', got '%s'", ref.Content) } }) t.Run("GetNonExistentReference", func(t *testing.T) { _, err := store.GetReference(requestID, 999) if err == nil { t.Error("Expected error when getting non-existent reference") } }) t.Run("GetReferenceWithInvalidIndex", func(t *testing.T) { _, err := store.GetReference(requestID, 0) if err == nil { t.Error("Expected error when getting reference with index 0") } _, err = store.GetReference(requestID, -1) if err == nil { t.Error("Expected error when getting reference with negative index") } }) t.Run("GetReferenceWithEmptyRequestID", func(t *testing.T) { _, err := store.GetReference("", 1) if err == nil { t.Error("Expected error when getting reference without request_id") } }) } // TestDeleteSearches tests deleting search records func TestDeleteSearches(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("DeleteSearchesForChat", func(t *testing.T) { // Create a chat chat := &types.Chat{ AssistantID: "test_assistant", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) // Save multiple searches for i := 1; i <= 3; i++ { requestID := fmt.Sprintf("req_%d_%d", time.Now().UnixNano(), i) search := &types.Search{ RequestID: requestID, ChatID: chat.ChatID, Query: fmt.Sprintf("Query %d", i), Source: "web", Duration: 100, } err := store.SaveSearch(search) if err != nil { t.Fatalf("Failed to save search: %v", err) } } // Delete all searches for the chat err = store.DeleteSearches(chat.ChatID) if err != nil { t.Fatalf("Failed to delete searches: %v", err) } // Note: GetSearches filters by request_id, not chat_id // We can't easily verify deletion without a GetSearchesByChatID method // But the soft delete should have been applied }) t.Run("DeleteSearchesWithEmptyChatID", func(t *testing.T) { err := store.DeleteSearches("") if err == nil { t.Error("Expected error when deleting searches without chat_id") } }) } // TestSearchCompleteWorkflow tests a complete search workflow func TestSearchCompleteWorkflow(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() store, err := xun.NewXun(types.Setting{ Connector: "default", }) if err != nil { t.Fatalf("Failed to create store: %v", err) } t.Run("CompleteWorkflow", func(t *testing.T) { // 1. Create chat chat := &types.Chat{ AssistantID: "workflow_assistant", Title: "Search Workflow Test", } err := store.CreateChat(chat) if err != nil { t.Fatalf("Failed to create chat: %v", err) } defer store.DeleteChat(chat.ChatID) requestID := fmt.Sprintf("req_%d", time.Now().UnixNano()) // 2. Save search with full data search := &types.Search{ RequestID: requestID, ChatID: chat.ChatID, Query: "What are the best practices for Go programming?", Config: map[string]any{ "uses": map[string]any{"search": "builtin", "web": "builtin"}, "web": map[string]any{"provider": "tavily", "max_results": 5}, }, Keywords: []string{"Go", "programming", "best practices"}, Source: "auto", References: []types.Reference{ {Index: 1, Type: "web", Title: "Effective Go", URL: "https://go.dev/doc/effective_go"}, {Index: 2, Type: "web", Title: "Go Proverbs", URL: "https://go-proverbs.github.io/"}, {Index: 3, Type: "kb", Title: "Internal Go Guide", Content: "Our team's Go coding standards..."}, }, XML: "...", Prompt: "When citing, use [1], [2], [3] format.", Duration: 350, } err = store.SaveSearch(search) if err != nil { t.Fatalf("Failed to save search: %v", err) } // 3. Retrieve searches searches, err := store.GetSearches(requestID) if err != nil { t.Fatalf("Failed to get searches: %v", err) } if len(searches) != 1 { t.Fatalf("Expected 1 search, got %d", len(searches)) } // 4. Verify all fields s := searches[0] if s.Query != "What are the best practices for Go programming?" { t.Errorf("Query mismatch") } if len(s.Keywords) != 3 { t.Errorf("Expected 3 keywords, got %d", len(s.Keywords)) } if len(s.References) != 3 { t.Errorf("Expected 3 references, got %d", len(s.References)) } if s.Config == nil { t.Error("Config should not be nil") } // 5. Get specific reference ref, err := store.GetReference(requestID, 2) if err != nil { t.Fatalf("Failed to get reference: %v", err) } if ref.Title != "Go Proverbs" { t.Errorf("Expected 'Go Proverbs', got '%s'", ref.Title) } // 6. Delete searches err = store.DeleteSearches(chat.ChatID) if err != nil { t.Fatalf("Failed to delete searches: %v", err) } // 7. Verify deletion (soft delete, so GetSearches should return empty) deletedSearches, err := store.GetSearches(requestID) if err != nil { t.Fatalf("Failed to get searches after delete: %v", err) } if len(deletedSearches) != 0 { t.Errorf("Expected 0 searches after delete, got %d", len(deletedSearches)) } t.Log("Complete search workflow passed!") }) } ================================================ FILE: agent/store/xun/utils.go ================================================ package xun import ( "fmt" "reflect" "time" jsoniter "github.com/json-iterator/go" ) // isNil checks whether a value is truly nil, handling the Go typed-nil-in-interface pitfall. // A nil map, slice, or pointer stored in an interface{} is not == nil in Go; // this helper uses reflect to detect that case. func isNil(v interface{}) bool { if v == nil { return true } rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Ptr, reflect.Map, reflect.Slice, reflect.Interface, reflect.Chan, reflect.Func: return rv.IsNil() } return false } // marshalJSONFields serialises each value in fields to a JSON string and writes // it into data. Truly-nil values (including typed nils) are skipped so the // database column keeps its SQL NULL / default. func marshalJSONFields(data map[string]interface{}, fields map[string]interface{}) error { for field, value := range fields { if isNil(value) { continue } jsonStr, err := jsoniter.MarshalToString(value) if err != nil { return fmt.Errorf("failed to marshal %s: %w", field, err) } data[field] = jsonStr } return nil } // Helper functions for type conversion func getString(data map[string]interface{}, key string) string { if v, ok := data[key].(string); ok { return v } return "" } func getBool(data map[string]interface{}, key string) bool { switch v := data[key].(type) { case bool: return v case int64: return v != 0 case int: return v != 0 case float64: return v != 0 } return false } func getInt(data map[string]interface{}, key string) int { switch v := data[key].(type) { case int: return v case int64: return int(v) case float64: return int(v) } return 0 } func getInt64(data map[string]interface{}, key string) int64 { switch v := data[key].(type) { case int64: return v case int: return int64(v) case float64: return int64(v) case string: // Handle string representation of numbers (common with MySQL BIGINT) var result int64 if _, err := fmt.Sscanf(v, "%d", &result); err == nil { return result } case time.Time: // Handle time.Time from database return v.UnixNano() } return 0 } // toMySQLTime converts UnixNano timestamp to MySQL BIGINT format func toMySQLTime(unixNano int64) int64 { if unixNano == 0 { return 0 } return unixNano } // fromMySQLTime converts MySQL BIGINT timestamp to UnixNano func fromMySQLTime(mysqlTime int64) int64 { return mysqlTime } ================================================ FILE: agent/store/xun/utils_test.go ================================================ package xun import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type testStruct struct{ Name string } func TestIsNil(t *testing.T) { // Untyped nil t.Run("UntypedNil", func(t *testing.T) { assert.True(t, isNil(nil)) }) // Typed nil pointer t.Run("TypedNilPointer", func(t *testing.T) { var p *testStruct assert.True(t, isNil(p)) }) // Typed nil map t.Run("TypedNilMap", func(t *testing.T) { var m map[string]string assert.True(t, isNil(m)) }) // Typed nil slice t.Run("TypedNilSlice", func(t *testing.T) { var s []string assert.True(t, isNil(s)) }) // Non-nil pointer t.Run("NonNilPointer", func(t *testing.T) { p := &testStruct{Name: "test"} assert.False(t, isNil(p)) }) // Non-nil map (empty) t.Run("NonNilEmptyMap", func(t *testing.T) { m := map[string]string{} assert.False(t, isNil(m)) }) // Non-nil map with values t.Run("NonNilMap", func(t *testing.T) { m := map[string]string{"a": "1"} assert.False(t, isNil(m)) }) // Non-nil slice (empty) t.Run("NonNilEmptySlice", func(t *testing.T) { s := []string{} assert.False(t, isNil(s)) }) // Non-nil slice with values t.Run("NonNilSlice", func(t *testing.T) { s := []string{"a"} assert.False(t, isNil(s)) }) // Scalar types (never nil) t.Run("String", func(t *testing.T) { assert.False(t, isNil("hello")) }) t.Run("EmptyString", func(t *testing.T) { assert.False(t, isNil("")) }) t.Run("Int", func(t *testing.T) { assert.False(t, isNil(42)) }) t.Run("Bool", func(t *testing.T) { assert.False(t, isNil(false)) }) } func TestMarshalJSONFields(t *testing.T) { t.Run("SkipUntypedNil", func(t *testing.T) { data := make(map[string]interface{}) err := marshalJSONFields(data, map[string]interface{}{ "field1": nil, }) require.NoError(t, err) _, exists := data["field1"] assert.False(t, exists, "untyped nil should be skipped") }) t.Run("SkipTypedNilMap", func(t *testing.T) { data := make(map[string]interface{}) var m map[string]string err := marshalJSONFields(data, map[string]interface{}{ "deps": m, }) require.NoError(t, err) _, exists := data["deps"] assert.False(t, exists, "typed nil map should be skipped") }) t.Run("SkipTypedNilSlice", func(t *testing.T) { data := make(map[string]interface{}) var s []string err := marshalJSONFields(data, map[string]interface{}{ "tags": s, }) require.NoError(t, err) _, exists := data["tags"] assert.False(t, exists, "typed nil slice should be skipped") }) t.Run("SkipTypedNilPointer", func(t *testing.T) { data := make(map[string]interface{}) var p *testStruct err := marshalJSONFields(data, map[string]interface{}{ "kb": p, }) require.NoError(t, err) _, exists := data["kb"] assert.False(t, exists, "typed nil pointer should be skipped") }) t.Run("MarshalNonNilMap", func(t *testing.T) { data := make(map[string]interface{}) err := marshalJSONFields(data, map[string]interface{}{ "deps": map[string]string{"echo": "^1.0.0"}, }) require.NoError(t, err) assert.Equal(t, `{"echo":"^1.0.0"}`, data["deps"]) }) t.Run("MarshalEmptyMap", func(t *testing.T) { data := make(map[string]interface{}) err := marshalJSONFields(data, map[string]interface{}{ "deps": map[string]string{}, }) require.NoError(t, err) assert.Equal(t, `{}`, data["deps"]) }) t.Run("MarshalSlice", func(t *testing.T) { data := make(map[string]interface{}) err := marshalJSONFields(data, map[string]interface{}{ "tags": []string{"ai", "bot"}, }) require.NoError(t, err) assert.Equal(t, `["ai","bot"]`, data["tags"]) }) t.Run("MarshalPointer", func(t *testing.T) { data := make(map[string]interface{}) err := marshalJSONFields(data, map[string]interface{}{ "kb": &testStruct{Name: "test"}, }) require.NoError(t, err) assert.Equal(t, `{"Name":"test"}`, data["kb"]) }) t.Run("MixedNilAndNonNil", func(t *testing.T) { data := make(map[string]interface{}) var nilMap map[string]string var nilSlice []string var nilPtr *testStruct err := marshalJSONFields(data, map[string]interface{}{ "nil_map": nilMap, "nil_slice": nilSlice, "nil_ptr": nilPtr, "nil_raw": nil, "good_map": map[string]string{"k": "v"}, "good_list": []string{"a"}, }) require.NoError(t, err) assert.Len(t, data, 2, "only non-nil fields should be written") assert.Equal(t, `{"k":"v"}`, data["good_map"]) assert.Equal(t, `["a"]`, data["good_list"]) _, exists := data["nil_map"] assert.False(t, exists) _, exists = data["nil_slice"] assert.False(t, exists) _, exists = data["nil_ptr"] assert.False(t, exists) _, exists = data["nil_raw"] assert.False(t, exists) }) } ================================================ FILE: agent/store/xun/xun.go ================================================ package xun import ( "fmt" "time" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/model" "github.com/yaoapp/xun/capsule" "github.com/yaoapp/xun/dbal/query" "github.com/yaoapp/xun/dbal/schema" "github.com/yaoapp/yao/agent/store/types" ) // Package store provides functionality for managing chat conversations and assistants. // Xun implements the Store interface using a database backend. // It provides functionality for: // - Managing chat sessions and their messages // - Organizing chats with pagination and date-based grouping // - Handling chat metadata like titles and creation dates // - Managing AI assistants with their configurations and metadata // - Managing resume records for recovery from interruptions type Xun struct { query query.Query schema schema.Schema setting types.Setting } // Public interface methods: // // NewXun creates a new store instance with the given settings // // Chat Management: // CreateChat creates a new chat session // GetChat retrieves a single chat by ID // UpdateChat updates chat fields // DeleteChat deletes a chat and its associated messages // ListChats retrieves a paginated list of chats with optional grouping // // Message Management: // SaveMessages batch saves messages for a chat // GetMessages retrieves messages for a chat with filtering // UpdateMessage updates a single message // DeleteMessages deletes specific messages from a chat // // Resume Management: // SaveResume batch saves resume records (only on failure/interrupt) // GetResume retrieves all resume records for a chat // GetLastResume retrieves the last resume record for a chat // GetResumeByStackID retrieves resume records for a specific stack // GetStackPath returns the stack path from root to the given stack // DeleteResume deletes all resume records for a chat // // Assistant Management: // SaveAssistant creates or updates an assistant // UpdateAssistant updates assistant fields // DeleteAssistant deletes an assistant by assistant_id // GetAssistants retrieves a paginated list of assistants with filtering // GetAssistant retrieves a single assistant by assistant_id // DeleteAssistants deletes assistants based on filter conditions // GetAssistantTags retrieves all unique tags from assistants // NewXun create a new xun store func NewXun(setting types.Setting) (types.Store, error) { store := &Xun{setting: setting} if setting.Connector == "default" || setting.Connector == "" { store.query = capsule.Global.Query() store.schema = capsule.Global.Schema() } else { conn, err := connector.Select(setting.Connector) if err != nil { return nil, fmt.Errorf("select store connector %s error: %s", setting.Connector, err.Error()) } store.query, err = conn.Query() if err != nil { return nil, fmt.Errorf("query store connector %s error: %s", setting.Connector, err.Error()) } store.schema, err = conn.Schema() if err != nil { return nil, err } } return store, nil } // ============================================================================= // Query Builders // ============================================================================= // newQueryChat creates a new query builder for the chat table func (store *Xun) newQueryChat() query.Query { qb := store.query.New() qb.Table(store.getChatTable()) return qb } // newQueryMessage creates a new query builder for the message table func (store *Xun) newQueryMessage() query.Query { qb := store.query.New() qb.Table(store.getMessageTable()) return qb } // newQueryResume creates a new query builder for the resume table func (store *Xun) newQueryResume() query.Query { qb := store.query.New() qb.Table(store.getResumeTable()) return qb } // newQueryAssistant creates a new query builder for the assistant table func (store *Xun) newQueryAssistant() query.Query { qb := store.query.New() qb.Table(store.getAssistantTable()) return qb } // ============================================================================= // Table Name Getters // ============================================================================= // getChatTable returns the chat table name func (store *Xun) getChatTable() string { m := model.Select("__yao.agent.chat") if m != nil && m.MetaData.Table.Name != "" { return m.MetaData.Table.Name } return "agent_chat" } // getMessageTable returns the message table name func (store *Xun) getMessageTable() string { m := model.Select("__yao.agent.message") if m != nil && m.MetaData.Table.Name != "" { return m.MetaData.Table.Name } return "agent_message" } // getResumeTable returns the resume table name func (store *Xun) getResumeTable() string { m := model.Select("__yao.agent.resume") if m != nil && m.MetaData.Table.Name != "" { return m.MetaData.Table.Name } return "agent_resume" } // getAssistantTable returns the assistant table name func (store *Xun) getAssistantTable() string { m := model.Select("__yao.agent.assistant") if m != nil && m.MetaData.Table.Name != "" { return m.MetaData.Table.Name } return "agent_assistant" } // ============================================================================= // Utility Functions // ============================================================================= // parseJSONFields parses JSON string fields into their corresponding Go types func (store *Xun) parseJSONFields(data map[string]interface{}, fields []string) { for _, field := range fields { if val := data[field]; val != nil { var jsonStr string switch v := val.(type) { case string: jsonStr = v case []byte: jsonStr = string(v) default: continue } if jsonStr != "" { var parsed interface{} if err := jsoniter.UnmarshalFromString(jsonStr, &parsed); err == nil { data[field] = parsed } } } } } // GenerateAssistantID generates a random-looking 6-digit ID func (store *Xun) GenerateAssistantID() (string, error) { maxAttempts := 10 // Maximum number of attempts to generate a unique ID for i := 0; i < maxAttempts; i++ { // Generate a random number using timestamp and some bit operations timestamp := time.Now().UnixNano() random := (timestamp ^ (timestamp >> 12)) % 1000000 hash := fmt.Sprintf("%06d", random) // Check if this ID already exists exists, err := store.query.New(). Table(store.getAssistantTable()). Where("assistant_id", hash). Exists() if err != nil { return "", err } if !exists { return hash, nil } // If ID exists, wait a bit and try again time.Sleep(time.Millisecond) } return "", fmt.Errorf("failed to generate unique ID after %d attempts", maxAttempts) } ================================================ FILE: agent/test/DESIGN.md ================================================ # Agent Test Package Design ## Overview Agent Test Package provides a framework for testing AI agents with structured test cases. It supports batch testing, report generation, stability analysis, and CI integration. Additionally, it supports **Script Testing** for testing Agent handler scripts (hooks, tools, etc.) with a Go-like testing interface. ### Quick Start ```bash # Quick test with a single message (auto-detect agent from current directory) cd assistants/keyword yao agent test -i "hello world" # Or specify agent explicitly yao agent test -i "hello world" -n keyword.agent # Run tests from JSONL file (auto-detect agent from path) yao agent test -i assistants/keyword/tests/inputs.jsonl # Run with stability analysis (5 runs per test case) yao agent test -i assistants/keyword/tests/inputs.jsonl --runs 5 # Generate HTML report yao agent test -i assistants/keyword/tests/inputs.jsonl -r report.html -o report.html # Run script tests (test agent handler scripts) yao agent test -i scripts.expense.setup -v # Run script tests with specific user/team context yao agent test -i scripts.expense.tools -u admin -t ops-team -v ``` ## Usage ```bash # Basic usage - auto-detect agent, output to same directory as input # Output: tests/output-20241217100000.jsonl yao agent test -i tests/inputs.jsonl # Override connector yao agent test -i tests/inputs.jsonl -c openai.gpt4 # Specify agent explicitly yao agent test -i tests/inputs.jsonl -n my.agent # Specify test environment (user and team) yao agent test -i tests/inputs.jsonl -u test-user -t test-team # Run multiple times for stability analysis yao agent test -i tests/inputs.jsonl --runs 5 # Custom timeout per test case (default: 5m) yao agent test -i tests/inputs.jsonl --timeout 10m # Run tests in parallel (4 concurrent test cases) yao agent test -i tests/inputs.jsonl --parallel 4 # Combine parallel and timeout for faster execution yao agent test -i tests/inputs.jsonl --parallel 8 --timeout 2m # Custom output file path yao agent test -i tests/inputs.jsonl -o /path/to/results.jsonl # Use custom reporter agent for personalized report (HTML) yao agent test -i tests/inputs.jsonl -r report.html -o report.html # Use custom reporter agent for personalized report (Markdown) yao agent test -i tests/inputs.jsonl -r report.markdown -o report.md # Full example with all options yao agent test -i tests/inputs.jsonl \ -n keyword.agent \ -c deepseek.v3 \ -u test-user \ -t test-team \ --runs 3 \ --timeout 10m \ --parallel 4 \ -r report.html \ -o report.html ``` ### Input Modes The `-i` flag supports three input modes: **1. JSONL File Mode** - Load test cases from a file: ```bash yao agent test -i tests/inputs.jsonl ``` **2. Direct Message Mode** - Test with a single message: ```bash # Auto-detect agent from current working directory cd assistants/keyword yao agent test -i "Extract keywords from this text" # Or specify agent explicitly yao agent test -i "Extract keywords from this text" -n keyword.agent yao agent test -i "你好世界" -n keyword.agent -c deepseek.v3 ``` When using direct message mode: - Agent is resolved from current working directory (looks for `package.yao` upward) - If not found, use `-n` flag to specify agent explicitly - Output is printed to stdout (or saved to `-o` if specified) - Useful for quick testing and debugging **3. Script Test Mode** - Test agent handler scripts: ```bash # Run all tests in a script module yao agent test -i scripts.expense.setup -v # Run with specific user/team context yao agent test -i scripts.expense.tools -u admin -t ops-team # Run with timeout yao agent test -i scripts.expense.setup --timeout 30s # Run specific tests by pattern (like go test -run) yao agent test -i scripts.expense.setup -run TestSystemReady # Run tests matching a regex pattern yao agent test -i scripts.expense.setup -run "TestSystem.*" ``` When using script test mode: - Input starts with `scripts.` prefix to indicate script testing - Maps to the script file (e.g., `scripts.expense.setup` → `expense/src/setup_test.ts`) - Automatically discovers and runs all `Test*` functions in the script - Uses Go-like testing interface with assertions - See [Script Testing](#script-testing) section for details ### Default Output Path When `-o` is not specified and using JSONL file mode, the output file is automatically generated in the same directory as the input file: ``` {input_directory}/output-{timestamp}.jsonl ``` Example: - Input: `/app/assistants/keyword/tests/inputs.jsonl` - Output: `/app/assistants/keyword/tests/output-20241217100000.jsonl` The timestamp format is `YYYYMMDDHHMMSS` (e.g., `20241217100000` for 2024-12-17 10:00:00). When using direct message mode without `-o`, output is printed to stdout. ## Command Line Options | Flag | Long Flag | Description | Default | Example | | ---- | ------------- | ------------------------------------- | -------------------------- | --------------------------------------- | | `-i` | `--input` | Input: JSONL file path or message | - | `-i tests/inputs.jsonl` or `-i "hello"` | | `-o` | `--output` | Path to output file (format by ext) | `output-{timestamp}.jsonl` | `-o report.html` | | `-n` | `--name` | Explicit agent ID | auto-detect | `-n keyword.agent` | | `-c` | `--connector` | Override connector | agent default | `-c openai.gpt4` | | `-u` | `--user` | Test user ID (global override) | "test-user" | `-u admin` | | `-t` | `--team` | Test team ID (global override) | "test-team" | `-t ops-team` | | | `--ctx` | Path to context JSON file | - | `--ctx tests/context.json` | | `-r` | `--reporter` | Custom reporter agent ID | - (use built-in) | `-r report.beautiful` | | | `--runs` | Number of runs for stability analysis | 1 | `--runs 5` | | | `--run` | Regex pattern to filter tests | - | `--run "TestSystem.*"` | | | `--timeout` | Default timeout per test case | 5m | `--timeout 10m` | | | `--parallel` | Number of parallel test cases | 1 | `--parallel 4` | | `-v` | `--verbose` | Verbose output | false | `-v` | | | `--fail-fast` | Stop on first failure | false | `--fail-fast` | **Notes**: - Without `-o` flag, output is saved to `{input_dir}/output-{timestamp}.jsonl` - Output format is determined by `-o` file extension: `.jsonl`, `.json`, `.md`, `.html` - Use `-r` to specify a custom reporter agent for personalized report generation ## Agent Resolution The agent is resolved in the following order: 1. **Explicit specification** (`-n` flag): Use the specified agent ID 2. **Path-based detection**: Traverse up from `tests/inputs.jsonl` to find `package.yao` ### Path-based Detection Example ``` /app/assistants/workers/system/keyword/ ├── package.yao <- Agent definition ├── prompts.yml ├── src/ │ └── index.ts └── tests/ └── inputs.jsonl <- Test input file ``` Given input path `/app/assistants/workers/system/keyword/tests/inputs.jsonl`: 1. Check `/app/assistants/workers/system/keyword/tests/package.yao` - not found 2. Check `/app/assistants/workers/system/keyword/package.yao` - **found!** 3. Load agent from `/app/assistants/workers/system/keyword/` ## Test Environment Agent calls require a `Context` with user and tenant information. The test framework creates a test context with configurable environment: ```go // TestEnvironment configures the test execution context type TestEnvironment struct { UserID string // User ID for authorized info (-u flag) TeamID string // Team ID for authorized info (-t flag) Locale string // Locale (default: "en-us") ClientType string // Client type (default: "test") ClientIP string // Client IP (default: "127.0.0.1") Referer string // Request referer (default: "test") Accept string // Accept format (default: "standard") } ``` Example context creation (similar to `agent_next_test.go`): ```go func newTestContext(env *TestEnvironment, chatID, assistantID string) *context.Context { authorized := &types.AuthorizedInfo{ Subject: env.UserID, UserID: env.UserID, TeamID: env.TeamID, } ctx := context.New(stdContext.Background(), authorized, chatID) ctx.AssistantID = assistantID ctx.Locale = env.Locale ctx.Client = context.Client{ Type: env.ClientType, IP: env.ClientIP, } ctx.Referer = env.Referer ctx.Accept = env.Accept return ctx } ``` ## Stability Analysis (Multiple Runs) When `--runs N` is specified (N > 1), the framework runs each test case N times and collects stability metrics: ### Stability Metrics | Metric | Description | | ------------------ | ------------------------------------------ | | `pass_rate` | Percentage of runs that passed (0-100%) | | `consistency` | How consistent the outputs are across runs | | `avg_duration_ms` | Average execution time | | `min_duration_ms` | Minimum execution time | | `max_duration_ms` | Maximum execution time | | `std_deviation_ms` | Standard deviation of execution time | ### Stability Report Structure ```json { "summary": { "total_cases": 42, "total_runs": 126, "runs_per_case": 3, "overall_pass_rate": 95.2, "stable_cases": 38, "unstable_cases": 4, "duration_ms": 45678 }, "results": [ { "id": "T001", "runs": 3, "passed": 3, "failed": 0, "pass_rate": 100.0, "consistency": 1.0, "stable": true, "avg_duration_ms": 234, "min_duration_ms": 210, "max_duration_ms": 256, "std_deviation_ms": 18.5, "run_details": [ {"run": 1, "status": "passed", "duration_ms": 234, "output": {...}}, {"run": 2, "status": "passed", "duration_ms": 210, "output": {...}}, {"run": 3, "status": "passed", "duration_ms": 256, "output": {...}} ] }, { "id": "T002", "runs": 3, "passed": 2, "failed": 1, "pass_rate": 66.7, "consistency": 0.67, "stable": false, "run_details": [...] } ] } ``` ### Stability Classification | Pass Rate | Classification | | --------- | --------------- | | 100% | Stable | | 80-99% | Mostly Stable | | 50-79% | Unstable | | < 50% | Highly Unstable | ## Custom Reporter Agent By default, the framework outputs JSONL format. You can specify a reporter agent (`-r` flag) for personalized report generation: ### Reporter Agent Interface The reporter agent receives the test results and generates a custom report: ```json // Input to reporter agent { "report": { "summary": {...}, "results": [...], "metadata": {...} }, "format": "html", // or "markdown", "text" "options": { "verbose": true, "include_outputs": true } } ``` ### Built-in Reporter Agents | Agent ID | Description | | ----------------- | -------------------------------------- | | `report.json` | JSON format (default, no agent needed) | | `report.markdown` | Markdown format with tables | | `report.html` | Interactive HTML report | | `report.summary` | Brief text summary | ### Custom Reporter Example Create a custom reporter agent at `assistants/reporters/my-reporter/`: ```yaml # prompts.yml - role: system content: | You are a test report generator. Generate a beautiful report from test results. Output format: HTML with embedded CSS Requirements: - Show summary statistics prominently - Use color coding (green=pass, red=fail) - Include charts for stability metrics - Make it printable ``` ## Script Testing Script Testing allows you to test Agent handler scripts (hooks, tools, setup functions, etc.) using a Go-like testing interface. This is useful for unit testing individual functions in your agent's TypeScript/JavaScript code. ### Quick Start ```bash # Run script tests yao agent test -i scripts.expense.setup -v # With user/team context yao agent test -i scripts.expense.setup -u admin -t ops-team -v # With timeout yao agent test -i scripts.expense.setup --timeout 30s -v ``` ### Script Resolution The `scripts.` prefix indicates script test mode. The script is resolved as follows: | Input | Script Path | Test File | | ----------------------- | ---------------------- | --------------------------- | | `scripts.expense.setup` | `expense/src/setup.ts` | `expense/src/setup_test.ts` | | `scripts.expense.tools` | `expense/src/tools.ts` | `expense/src/tools_test.ts` | | `scripts.keyword.index` | `keyword/src/index.ts` | `keyword/src/index_test.ts` | The test file naming convention is `{module}_test.ts` (similar to Go's `_test.go` convention). ### Test Function Signature Test functions must follow this signature: ```typescript function TestFunctionName(t: testing.T, ctx: agent.Context) { // Test logic here } ``` **Requirements:** - Function name must start with `Test` (case-sensitive) - First parameter `t` is the testing object with assertions - Second parameter `ctx` is the agent context (same as used in hooks/tools) - Functions not starting with `Test` are ignored (can be used as helpers) ### Example Test File ```typescript // setup_test.ts // @ts-nocheck // Test the SystemReady function function TestSystemReady(t: testing.T, ctx: agent.Context) { const { assert } = t; // Call the function being tested const result = SystemReady(ctx); // Assert the result assert.True(result, "SystemReady should return true"); } // Test error case function TestSystemReadyWithInvalidContext(t: testing.T, ctx: agent.Context) { const { assert } = t; // Modify context to simulate error condition ctx.User = null; const result = SystemReady(ctx); assert.False(result, "SystemReady should return false when user is null"); } // Helper function (not a test - doesn't start with "Test") function createMockData() { return { id: 1, name: "test" }; } // Test with helper function TestSetupWithMockData(t: testing.T, ctx: agent.Context) { const { assert } = t; const mockData = createMockData(); const result = Setup(ctx, mockData); assert.NotNil(result, "Setup should return a result"); assert.Equal(result.id, 1, "Result ID should match"); } ``` ### Testing Object (`t`) The `t` parameter provides the testing interface: ```typescript interface testing.T { // Assertions object assert: testing.Assert; // Test metadata name: string; // Current test function name failed: boolean; // Whether the test has failed // Logging (output appears in test report) log(...args: any[]): void; // Log info message error(...args: any[]): void; // Log error message // Control flow skip(reason?: string): void; // Skip this test fail(reason?: string): void; // Mark test as failed fatal(reason?: string): void; // Mark as failed and stop execution } ``` ### Assertions (`t.assert`) The `assert` object provides assertion methods: | Method | Description | | ----------------------------------- | ---------------------------------- | | `True(value, message?)` | Assert value is true | | `False(value, message?)` | Assert value is false | | `Equal(actual, expected, message?)` | Assert deep equality | | `NotEqual(actual, expected, msg?)` | Assert not equal | | `Nil(value, message?)` | Assert value is null/undefined | | `NotNil(value, message?)` | Assert value is not null/undefined | | `Contains(str, substr, message?)` | Assert string contains substring | | `NotContains(str, substr, msg?)` | Assert string does not contain | | `Len(value, length, message?)` | Assert array/string length | | `Greater(a, b, message?)` | Assert a > b | | `GreaterOrEqual(a, b, message?)` | Assert a >= b | | `Less(a, b, message?)` | Assert a < b | | `LessOrEqual(a, b, message?)` | Assert a <= b | | `Error(err, message?)` | Assert err is an error | | `NoError(err, message?)` | Assert err is null/undefined | | `Panic(fn, message?)` | Assert function throws | | `NoPanic(fn, message?)` | Assert function does not throw | | `Match(value, pattern, message?)` | Assert value matches regex | | `NotMatch(value, pattern, msg?)` | Assert value does not match regex | | `JSONPath(obj, path, expected, m?)` | Assert JSON path value | | `Type(value, typeName, message?)` | Assert value type | ### Agent Context (`ctx`) The `ctx` parameter is the same `agent.Context` used in agent hooks and tools: ```typescript interface agent.Context { // User information (from -u flag or default) User: { ID: string; Name?: string; }; // Team information (from -t flag or default) Team: { ID: string; Name?: string; }; // Locale (default: "en-us") Locale: string; // Client information Client: { Type: string; // "test" IP: string; // "127.0.0.1" }; // Metadata (can be set via test case) Metadata: Record; // Chat/Session ID ChatID: string; // Assistant ID (resolved from script path) AssistantID: string; } ``` ### Script Test Output Script test results are reported in the same format as agent tests: ``` ═══════════════════════════════════════════════════════════════════════════════ Script Test: scripts.expense.setup ═══════════════════════════════════════════════════════════════════════════════ Script: expense/src/setup_test.ts Tests: 3 functions User: test-user Team: test-team ─────────────────────────────────────────────────────────────────────────────── Running Tests ─────────────────────────────────────────────────────────────────────────────── ► [TestSystemReady] ... ✓ PASSED (12ms) ► [TestSystemReadyWithInvalidContext] ... ✓ PASSED (8ms) ► [TestSetupWithMockData] ... ✗ FAILED (15ms) └─ assertion failed: Result ID should match expected: 1 actual: 2 ═══════════════════════════════════════════════════════════════════════════════ Summary: 2 passed, 1 failed, 0 skipped (35ms) ═══════════════════════════════════════════════════════════════════════════════ ``` ### Script Test Options Script tests support the following command line options: | Flag | Description | Default | Example | | ------------- | -------------------------------- | ----------- | -------------------- | | `-u` | User ID for context | "test-user" | `-u admin` | | `-t` | Team ID for context | "test-team" | `-t ops-team` | | `--ctx` | Path to context JSON file | - | `--ctx context.json` | | `-v` | Verbose output | false | `-v` | | `--run` | Regex to filter tests | - | `--run "TestSystem"` | | `--timeout` | Timeout per test function | 30s | `--timeout 1m` | | `--fail-fast` | Stop on first failure | false | `--fail-fast` | | `-o` | Output file for report | stdout | `-o report.json` | | `-r` | Reporter agent for custom report | - | `-r report.html` | The `--run` flag accepts a Go-style regex pattern to filter which tests to run: ```bash # Run only TestSystemReady yao agent test -i scripts.expense.setup --run TestSystemReady # Run all tests starting with "TestSystem" yao agent test -i scripts.expense.setup --run "TestSystem.*" # Run tests containing "Error" yao agent test -i scripts.expense.setup --run ".*Error.*" ``` ### Custom Context Configuration The `--ctx` flag allows you to provide a JSON file with custom context configuration, giving full control over authorization data, metadata, and client information: ```bash # Use custom context file yao agent test -i scripts.expense.setup --ctx tests/context.json -v ``` **Context JSON Format:** ```json { "authorized": { "sub": "user-12345", "client_id": "my-app", "scope": "read write", "session_id": "sess-abc123", "user_id": "admin", "team_id": "team-001", "tenant_id": "acme-corp", "remember_me": false, "constraints": { "owner_only": false, "creator_only": false, "editor_only": false, "team_only": true, "extra": { "department": "engineering", "region": "us-west" } } }, "metadata": { "request_id": "req-123", "trace_id": "trace-456", "custom_field": "custom_value" }, "client": { "type": "web", "user_agent": "Mozilla/5.0", "ip": "192.168.1.100" }, "locale": "zh-cn", "referer": "https://example.com/dashboard" } ``` **Field Descriptions:** | Field | Description | | -------------------------- | --------------------------------------------------- | | `authorized.sub` | Subject identifier (JWT sub claim) | | `authorized.client_id` | OAuth client ID | | `authorized.scope` | Access scope | | `authorized.session_id` | Session identifier | | `authorized.user_id` | User identifier (overrides -u flag) | | `authorized.team_id` | Team identifier (overrides -t flag) | | `authorized.tenant_id` | Tenant identifier | | `authorized.remember_me` | Remember me flag | | `authorized.constraints` | Data access constraints (set by ACL enforcement) | | `constraints.owner_only` | Only access owner's data | | `constraints.creator_only` | Only access creator's data | | `constraints.editor_only` | Only access editor's data | | `constraints.team_only` | Only access team's data (filter by team_id) | | `constraints.extra` | User-defined constraints (department, region, etc.) | | `metadata` | Custom metadata passed to context | | `client.type` | Client type (web, mobile, test, etc.) | | `client.user_agent` | Client user agent string | | `client.ip` | Client IP address | | `locale` | Locale setting (e.g., "en-us", "zh-cn") | | `referer` | Request referer URL | **Priority:** When both `-u`/`-t` flags and `--ctx` file are provided, the context file values take precedence. ### Script Test Report Format When using `-o` to save results: ```json { "type": "script_test", "script": "scripts.expense.setup", "script_path": "expense/src/setup_test.ts", "summary": { "total": 3, "passed": 2, "failed": 1, "skipped": 0, "duration_ms": 35 }, "environment": { "user_id": "test-user", "team_id": "test-team", "locale": "en-us" }, "results": [ { "name": "TestSystemReady", "status": "passed", "duration_ms": 12, "logs": [] }, { "name": "TestSystemReadyWithInvalidContext", "status": "passed", "duration_ms": 8, "logs": [] }, { "name": "TestSetupWithMockData", "status": "failed", "duration_ms": 15, "error": "assertion failed: Result ID should match", "assertion": { "type": "Equal", "expected": 1, "actual": 2, "message": "Result ID should match" }, "logs": [] } ], "metadata": { "started_at": "2024-12-17T10:00:00Z", "completed_at": "2024-12-17T10:00:00Z", "version": "0.10.5" } } ``` ### Best Practices 1. **Naming Convention**: Use descriptive test names that explain what's being tested - Good: `TestSystemReadyWithValidUser`, `TestSetupReturnsErrorOnMissingConfig` - Bad: `Test1`, `TestIt` 2. **One Assertion Per Concept**: Each test should verify one behavior ```typescript // Good: Focused tests function TestSetupCreatesDatabase(t, ctx) { ... } function TestSetupInitializesCache(t, ctx) { ... } // Bad: Testing too many things function TestSetup(t, ctx) { // tests database, cache, config, etc. } ``` 3. **Use Helper Functions**: Extract common setup logic ```typescript function setupTestContext(ctx) { ctx.Metadata.testMode = true; return ctx; } function TestFeatureA(t, ctx) { ctx = setupTestContext(ctx); // ... } ``` 4. **Test Error Cases**: Don't just test happy paths ```typescript function TestSetupWithMissingConfig(t, ctx) { const { assert } = t; ctx.Metadata.config = null; const result = Setup(ctx); assert.Error(result.error, "Should return error for missing config"); } ``` 5. **Clean Up**: If your test modifies global state, clean up after ```typescript function TestWithGlobalState(t, ctx) { const originalValue = GlobalConfig.value; try { GlobalConfig.value = "test"; // ... test logic } finally { GlobalConfig.value = originalValue; } } ``` ## Input Format (JSONL) Each line in the input file is a JSON object with the following structure: ```jsonl {"id": "T001", "input": "Simple text input"} {"id": "T002", "input": {"role": "user", "content": "Message with role"}} {"id": "T003", "input": {"role": "user", "content": [{"type": "text", "text": "ContentPart array"}]}} {"id": "T004", "input": [{"role": "user", "content": "First message"}, {"role": "assistant", "content": "Response"}, {"role": "user", "content": "Follow-up"}]} {"id": "T005", "input": "Text input", "expected": {"keywords": ["keyword1", "keyword2"]}} {"id": "T006", "input": "Test with specific user", "user": "admin", "team": "ops-team"} ``` ### Input Types | Type | Description | Example | | ----------- | ------------------------ | ----------------------------------------------------- | | `string` | Simple text input | `"Hello world"` | | `Message` | Single message with role | `{"role": "user", "content": "..."}` | | `[]Message` | Conversation history | `[{"role": "user", ...}, {"role": "assistant", ...}]` | ### Fields | Field | Type | Required | Description | | ---------- | ------------------------------ | -------- | ---------------------------------------------------- | | `id` | string | Yes | Unique test case identifier (e.g., "T001") | | `input` | string \| Message \| []Message | Yes | Test input | | `expected` | any | No | Expected output for exact match validation | | `assert` | Assertion \| []Assertion | No | Custom assertion rules (see Assertions section) | | `user` | string | No | User ID for this test case (overridden by `-u` flag) | | `team` | string | No | Team ID for this test case (overridden by `-t` flag) | | `metadata` | map | No | Additional metadata for the test case | | `skip` | bool | No | Skip this test case | | `timeout` | string | No | Override timeout (e.g., "30s", "1m") | ### Assertions The `assert` field allows flexible validation of agent output. If `assert` is defined, it takes precedence over `expected`. #### Assertion Types | Type | Description | Example | | -------------- | ----------------------------------------------- | ---------------------------------------------------------------- | | `equals` | Exact match (default if only `expected` is set) | `{"type": "equals", "value": {"need_search": false}}` | | `contains` | Output contains the expected string/value | `{"type": "contains", "value": "keyword"}` | | `not_contains` | Output does not contain the string/value | `{"type": "not_contains", "value": "error"}` | | `json_path` | Extract value using JSON path and compare | `{"type": "json_path", "path": "$.need_search", "value": false}` | | `regex` | Match output against regex pattern | `{"type": "regex", "value": "\\d{3}-\\d{4}"}` | | `type` | Check output type (string, object, array, etc.) | `{"type": "type", "value": "object"}` | | `script` | Run a custom assertion script | `{"type": "script", "script": "scripts.test.Assert"}` | #### Assertion Structure ```typescript interface Assertion { type: string; // Assertion type (required) value?: any; // Expected value or pattern path?: string; // JSON path for json_path assertions script?: string; // Script name for script assertions message?: string; // Custom failure message negate?: boolean; // Invert the assertion result } ``` #### Examples **Simple contains check:** ```jsonl { "id": "T001", "input": "Hello", "assert": { "type": "contains", "value": "need_search" } } ``` **JSON path validation (for agents returning JSON):** ```jsonl { "id": "T002", "input": "What's the weather?", "assert": { "type": "json_path", "path": "$.need_search", "value": true } } ``` **Multiple assertions (all must pass):** ```jsonl { "id": "T003", "input": "Calculate 2+2", "assert": [ { "type": "json_path", "path": "$.need_search", "value": false }, { "type": "json_path", "path": "$.confidence", "value": 0.99 }, { "type": "not_contains", "value": "error" } ] } ``` **Custom script assertion:** ```jsonl { "id": "T004", "input": "Complex test", "assert": { "type": "script", "script": "scripts.test.ValidateOutput" } } ``` The script receives `(output, input, expected)` and should return: ```typescript // Simple boolean return true; // or false // Or detailed result return { pass: true, message: "Validation passed: output contains expected keywords", }; ``` **Negated assertion:** ```jsonl { "id": "T005", "input": "Hello", "assert": { "type": "contains", "value": "error", "negate": true } } ``` #### JSON Path Notes - Supports simple dot-notation paths: `$.field.subfield` or `field.subfield` - Automatically extracts JSON from markdown code blocks (e.g., ` ```json ... ``` `) - Works with both string output and structured objects ### Environment Override Priority The test environment (user/team) is determined by the following priority (highest first): 1. **Command line flags** (`-u`, `-t`): Global override for all test cases 2. **Test case fields** (`user`, `team`): Per-test case configuration 3. **Default values**: "test-user", "test-team" Example: ```bash # All tests run as "admin" user in "prod-team", regardless of test case settings yao agent test -i tests/inputs.jsonl -u admin -t prod-team -o report.json ``` ```jsonl # T001 uses default user/team {"id": "T001", "input": "Hello"} # T002 uses specific user/team (unless overridden by -u/-t flags) {"id": "T002", "input": "Admin action", "user": "admin", "team": "admin-team"} # T003 uses specific user only, team uses default {"id": "T003", "input": "User specific test", "user": "special-user"} ``` ## Output Format ### Default: JSONL (without `-r` flag) By default (without `-r` flag), the output is JSONL format - one JSON object per line, suitable for streaming and CI integration: ```jsonl {"type": "start", "timestamp": "2024-12-17T10:00:00Z", "agent_id": "keyword", "total_cases": 42} {"type": "result", "id": "T001", "status": "passed", "duration_ms": 234, "output": {"keywords": ["AI", "ML"]}} {"type": "result", "id": "T002", "status": "passed", "duration_ms": 189, "output": {"keywords": ["cloud"]}} {"type": "result", "id": "T003", "status": "failed", "duration_ms": 0, "error": "timeout after 30s"} {"type": "summary", "total": 42, "passed": 40, "failed": 2, "duration_ms": 12345} ``` This format is: - **Streamable**: Results are output as they complete - **Parseable**: Each line is valid JSON, easy to process with `jq` or scripts - **CI-friendly**: Exit code indicates pass/fail status ### Custom Report (with `-r` flag) ```json { "summary": { "total": 42, "passed": 40, "failed": 2, "skipped": 0, "duration_ms": 12345, "agent_id": "keyword", "connector": "deepseek.v3", "runs_per_case": 1, "overall_pass_rate": 95.2 }, "environment": { "user_id": "test-user", "team_id": "test-team", "locale": "en-us" }, "results": [ { "id": "T001", "status": "passed", "input": "...", "output": { "keywords": ["AI", "machine learning"] }, "expected": null, "duration_ms": 234, "error": null } ], "metadata": { "started_at": "2024-12-17T10:00:00Z", "completed_at": "2024-12-17T10:00:12Z", "version": "0.10.5" } } ``` ### HTML Report Beautiful, interactive HTML report with: - Summary statistics (pass/fail/skip counts, duration) - Stability charts (when runs > 1) - Filterable test results table - Expandable input/output details - Error highlighting - Export options ### Markdown Report ```markdown # Agent Test Report ## Summary | Metric | Value | | --------- | ----------- | | Agent | keyword | | Connector | deepseek.v3 | | Total | 42 | | Passed | 40 | | Failed | 2 | | Pass Rate | 95.2% | | Duration | 12.3s | ## Environment | Setting | Value | | ------- | --------- | | User | test-user | | Team | test-team | | Locale | en-us | ## Results ### ✅ T001 - Passed (234ms) ... ``` ## Architecture ``` agent/test/ ├── DESIGN.md # This file ├── types.go # Core types and interfaces ├── interfaces.go # Runner and Reporter interfaces ├── runner.go # Test runner implementation ├── loader.go # Test case loader ├── resolver.go # Agent resolver ├── context.go # Test context creation ├── assert.go # Assertion implementation ├── input.go # Input parsing ├── output.go # Output formatting ├── script.go # Script test runner (NEW) ├── script_types.go # Script test types (NEW) ├── script_assert.go # Script assertion bindings (NEW) └── reporter/ ├── json.go # JSON reporter ├── html.go # HTML reporter ├── markdown.go # Markdown reporter └── agent.go # Agent-based custom reporter ``` ## Core Components ### 1. TestCase Represents a single test case loaded from JSONL. ### 2. TestResult Represents the result of running a single test case. ### 3. TestReport Represents the complete test report with summary and results. ### 4. Runner Executes test cases against an agent: - Loads test cases from JSONL - Resolves agent from path or explicit ID - Creates test context with environment - Executes each test case (optionally multiple runs) - Collects results and stability metrics ### 5. ScriptRunner (NEW) Executes script tests for agent handler scripts: - Resolves script path from `scripts.` prefix - Discovers `Test*` functions in the script - Creates test context with environment - Executes each test function with testing object and context - Collects results and generates report ### 6. ScriptTestCase (NEW) Represents a single script test function: ```go type ScriptTestCase struct { Name string // Function name (e.g., "TestSystemReady") Function string // Full function reference } ``` ### 7. ScriptTestResult (NEW) Represents the result of running a script test function: ```go type ScriptTestResult struct { Name string `json:"name"` Status Status `json:"status"` DurationMs int64 `json:"duration_ms"` Error string `json:"error,omitempty"` Assertion *AssertionInfo `json:"assertion,omitempty"` Logs []string `json:"logs,omitempty"` } type AssertionInfo struct { Type string `json:"type"` Expected interface{} `json:"expected,omitempty"` Actual interface{} `json:"actual,omitempty"` Message string `json:"message,omitempty"` } ``` ### 8. ScriptTestReport (NEW) Represents the complete script test report: ```go type ScriptTestReport struct { Type string `json:"type"` // "script_test" Script string `json:"script"` ScriptPath string `json:"script_path"` Summary *ScriptTestSummary `json:"summary"` Environment *Environment `json:"environment"` Results []*ScriptTestResult `json:"results"` Metadata *ReportMetadata `json:"metadata"` } type ScriptTestSummary struct { Total int `json:"total"` Passed int `json:"passed"` Failed int `json:"failed"` Skipped int `json:"skipped"` DurationMs int64 `json:"duration_ms"` } ``` ### 9. Reporter Generates reports in various formats. The format is determined by the `-o` file extension: | Extension | Format | Description | | --------- | -------- | -------------------------- | | `.jsonl` | JSONL | Streaming, line-by-line | | `.json` | JSON | Full structured report | | `.md` | Markdown | Human-readable with tables | | `.html` | HTML | Interactive web report | #### Custom Reporter Agent (`-r` flag) When `-r ` is specified, the framework calls the specified agent to generate the report: 1. Test execution completes, `TestReport` is generated 2. Framework calls the reporter agent with input: ```json { "report": { /* TestReport object */ }, "format": "html", "options": { "verbose": true } } ``` 3. Agent processes the report and returns formatted content 4. Framework writes the returned content to the output file Example usage: ```bash # Use custom reporter agent to generate a beautiful HTML report yao agent test -i tests/inputs.jsonl -r report.beautiful -o report.html # Use custom reporter agent to generate Slack-formatted summary yao agent test -i tests/inputs.jsonl -r report.slack -o summary.txt ``` This allows for fully customizable report generation using AI agents ## Configuration ### Test Options ```go type Options struct { // Input/Output Input string // Input source: file path, message, or scripts.xxx InputMode InputMode // Auto-detected: file, message, or script OutputFile string // Path to output report // Agent Selection AgentID string // Explicit agent ID (optional) Connector string // Override connector (optional) // Test Environment UserID string // Test user ID (-u flag) TeamID string // Test team ID (-t flag) Locale string // Locale (default: "en-us") // Execution Timeout time.Duration // Default timeout per test Parallel int // Number of parallel tests (default: 1) Runs int // Number of runs per test case (default: 1) // Reporting ReporterID string // Reporter agent ID for custom report // Behavior Verbose bool // Verbose output FailFast bool // Stop on first failure } // InputMode represents the input mode for test cases type InputMode string const ( InputModeFile InputMode = "file" // JSONL file input InputModeMessage InputMode = "message" // Direct message input InputModeScript InputMode = "script" // Script test mode (NEW) ) ``` ### Input Mode Detection The input mode is automatically detected based on the input value: | Input Pattern | Mode | Description | | ----------------- | --------- | -------------------------- | | `scripts.xxx.yyy` | `script` | Script test mode | | `*.jsonl` | `file` | JSONL file mode | | `path/to/file` | `file` | File path (if file exists) | | `"any text"` | `message` | Direct message mode | ```go func DetectInputMode(input string) InputMode { // Check for script test prefix if strings.HasPrefix(input, "scripts.") { return InputModeScript } // Check if it's a file path if strings.HasSuffix(input, ".jsonl") || fileExists(input) { return InputModeFile } // Default to message mode return InputModeMessage } ``` ## Script Testing Implementation ### Script Resolution ```go // ResolveScript resolves the script path from scripts.xxx.yyy format func ResolveScript(input string) (*ScriptInfo, error) { // Remove "scripts." prefix path := strings.TrimPrefix(input, "scripts.") // Split into parts: "expense.setup" -> ["expense", "setup"] parts := strings.Split(path, ".") if len(parts) < 2 { return nil, fmt.Errorf("invalid script path: %s", input) } // Build paths // assistantDir: expense // moduleName: setup // scriptPath: expense/src/setup.ts // testPath: expense/src/setup_test.ts assistantDir := parts[0] moduleName := parts[1] return &ScriptInfo{ ID: input, Assistant: assistantDir, Module: moduleName, ScriptPath: filepath.Join(assistantDir, "src", moduleName+".ts"), TestPath: filepath.Join(assistantDir, "src", moduleName+"_test.ts"), }, nil } ``` ### Test Function Discovery Test functions are discovered by scanning the script for functions starting with `Test`: ```go // DiscoverTests finds all Test* functions in the script func DiscoverTests(scriptPath string) ([]*ScriptTestCase, error) { // Use the JavaScript runtime to list exported functions // Filter for functions starting with "Test" // Return list of test cases } ``` ### Testing Object Binding The `testing.T` object is provided to test functions via JavaScript runtime binding: ```go // TestingT represents the testing object passed to test functions type TestingT struct { name string failed bool skipped bool logs []string assert *AssertObject } // AssertObject provides assertion methods type AssertObject struct { t *TestingT } func (a *AssertObject) True(value bool, message ...string) { if !value { a.t.fail(formatMessage("expected true, got false", message)) } } func (a *AssertObject) Equal(actual, expected interface{}, message ...string) { if !reflect.DeepEqual(actual, expected) { a.t.fail(formatMessage( fmt.Sprintf("expected %v, got %v", expected, actual), message, )) } } // ... other assertion methods ``` ### Script Execution Flow ``` 1. Parse input: "scripts.expense.setup" 2. Resolve script info: - TestPath: expense/src/setup_test.ts - ScriptPath: expense/src/setup.ts 3. Discover test functions: [TestSystemReady, TestSetupWithMockData, ...] 4. For each test function: a. Create testing.T object b. Create agent.Context with environment c. Execute: TestFunction(t, ctx) d. Collect result (passed/failed/skipped) 5. Generate report ``` ### Integration with Existing Runner ```go func (r *Executor) Run() (*Report, error) { switch r.opts.InputMode { case InputModeScript: return r.RunScriptTests() case InputModeMessage: return r.RunDirect() default: return r.RunTests() } } func (r *Executor) RunScriptTests() (*Report, error) { // 1. Resolve script scriptInfo, err := ResolveScript(r.opts.Input) if err != nil { return nil, err } // 2. Discover tests tests, err := DiscoverTests(scriptInfo.TestPath) if err != nil { return nil, err } // 3. Run each test results := make([]*ScriptTestResult, 0, len(tests)) for _, tc := range tests { result := r.runScriptTest(tc, scriptInfo) results = append(results, result) if r.opts.FailFast && result.Status == StatusFailed { break } } // 4. Generate report return r.buildScriptReport(scriptInfo, results), nil } ``` ## Exit Codes | Code | Description | | ---- | ------------------- | | 0 | All tests passed | | 1 | Some tests failed | | 2 | Configuration error | | 3 | Runtime error | ## CI Integration ### GitHub Actions Example ```yaml - name: Run Agent Tests run: | yao agent test -i assistants/keyword/tests/inputs.jsonl \ -u ci-user -t ci-team \ --runs 3 \ -o report.json - name: Check Stability run: | # Fail if any test has pass rate below 80% jq -e '.results | all(.pass_rate >= 80)' report.json - name: Upload Test Report uses: actions/upload-artifact@v3 with: name: agent-test-report path: report.json ``` ### Exit Code Handling The command exits with code 1 if any tests fail, making it easy to integrate with CI pipelines. ## Future Enhancements 1. **Snapshot Testing**: Compare outputs against saved snapshots 2. **Fuzzing**: Generate random inputs for robustness testing 3. **Coverage**: Track which agent code paths are exercised 4. **Benchmarking**: Performance metrics and regression detection 5. **Diff Reports**: Compare results between runs 6. **Flaky Test Detection**: Automatic identification of unstable tests 7. **Test Prioritization**: Run most important/failing tests first 8. **Script Test Enhancements**: - Parallel script test execution - Setup/Teardown hooks (`TestMain`, `BeforeEach`, `AfterEach`) - Mocking utilities for external dependencies - Code coverage for TypeScript/JavaScript scripts ================================================ FILE: agent/test/DESIGN_V2.md ================================================ # Agent Test Framework V2 Design ## Overview This document describes the design for Agent Test Framework V2, which extends the existing testing capabilities with: - **Message history support** - Test agents with conversation context via `input` array (already implemented) - **Agent-driven testing** - Use agents to generate test cases and validate responses - **Dynamic testing** - Simulator-driven testing with checkpoint validation ## Quick Reference: Format Rules | Context | Format | Example | | --------------------- | ------------------------ | ------------------------------------------------------- | | `-i` flag (CLI) | Prefix required | `agents:workers.test.gen`, `scripts:tests.gen` | | JSONL assertion `use` | Prefix required | `"use": "agents:workers.test.validator"` | | JSONL `simulator.use` | No prefix (agent only) | `"use": "workers.test.user-simulator"` | | `--simulator` flag | No prefix (agent only) | `--simulator workers.test.user-simulator` | | `t.assert.Agent()` | No prefix (method-bound) | `t.assert.Agent(resp, "workers.test.validator", {...})` | | JSONL `before/after` | No prefix (in src/) | `"before": "env_test.Before"` | | `--before/--after` | No prefix (in src/) | `--before env_test.BeforeAll` | ## Design Goals 1. **Simple** - Single-turn with optional message history, no complex multi-turn state 2. **Stateless** - Each test is independent, no session management needed 3. **Parallel** - Tests can run in parallel since they don't share state 4. **Flexible** - Support both static (messages) and dynamic (simulator) testing 5. **Agent-driven** - Input generation, simulation, and validation can all be agent-powered ## Architecture Overview ``` ┌─────────────────────────────────────────────────────────────────────────┐ │ yao agent test │ ├─────────────────────────────────────────────────────────────────────────┤ │ │ │ INPUT SOURCES (-i flag) │ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ │ JSONL File │ │ Message │ │ Generator │ │ │ │ ./test.jsonl│ │ "Hello..." │ │ agents:xxx │ │ │ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │ │ │ │ │ │ │ └────────────────┴────────────────┘ │ │ │ │ │ ▼ │ │ ┌───────────────────────────────────────────────────────────────────┐ │ │ │ Test Case Parser │ │ │ │ │ │ │ │ Standard Mode: {input: "..." | [...], assertions} │ │ │ │ Dynamic Mode: {simulator: {...}, checkpoints: [...]} │ │ │ └───────────────────────────────────────────────────────────────────┘ │ │ │ │ │ ┌───────────────┴───────────────┐ │ │ ▼ ▼ │ │ ┌───────────────────┐ ┌───────────────────────┐ │ │ │ STANDARD MODE │ │ DYNAMIC MODE │ │ │ │ │ │ │ │ │ │ 1. Build messages │ │ LOOP: │ │ │ │ 2. Call Agent │ │ 1. Simulator→input │ │ │ │ 3. Run assertions │ │ 2. Call Agent │ │ │ │ │ │ 3. Check checkpoints │ │ │ │ → PASS/FAIL │ │ 4. Until done │ │ │ └───────────────────┘ └───────────────────────┘ │ │ │ │ │ ▼ │ │ ┌───────────────────────────────────────────────────────────────────┐ │ │ │ Reporter │ │ │ │ - Console output │ │ │ │ - JSON file output │ │ │ └───────────────────────────────────────────────────────────────────┘ │ │ │ └─────────────────────────────────────────────────────────────────────────┘ ``` ## Test Modes ### Standard Mode (Default) Single call to agent with optional message history. **No multi-turn state management needed.** | Field | Type | Description | | ------------ | ------------------------------ | --------------------------------------------- | | `input` | string \| Message \| Message[] | Text, single message, or conversation history | | `assertions` | array | Assertions to validate response | | `options` | object | `context.Options` passed to agent | ### Dynamic Mode Simulator-driven testing with checkpoint validation. | Field | Type | Description | | ------------- | ------ | -------------------------------- | | `simulator` | object | Simulator agent configuration | | `checkpoints` | array | Functional checkpoints to verify | | `max_turns` | int | Maximum turns before timeout | | `timeout` | string | Maximum time (e.g., "5m") | ## Test Case Format ### Simple Input (Existing) ```jsonl { "id": "T001", "input": "Hello", "assertions": [ { "type": "contains", "value": "Hi" } ] } ``` ### With Message History (Existing) The `input` field already supports message arrays for conversation context: ```jsonl { "id": "T002", "name": "Expense submission - final confirmation", "input": [ { "role": "user", "content": "I want to submit an expense" }, { "role": "assistant", "content": "What type of expense would you like to submit?" }, { "role": "user", "content": "Business travel to Beijing, $3500" }, { "role": "assistant", "content": "I'll create an expense for business travel, $3500. Please confirm." }, { "role": "user", "content": "Yes, confirm" } ], "assertions": [ { "type": "contains", "value": "submitted" }, { "type": "tool_called", "name": "create_expense" } ] } ``` **Key insight**: Instead of executing 3 turns sequentially, we pass the full conversation history. The agent sees the context and responds to the last message. This is: - **Simpler** - No turn-by-turn execution, no session state - **Faster** - Single API call instead of multiple - **Parallelizable** - Each test is independent - **Debuggable** - Clear input/output for each test ### Testing Different Points in a Conversation To test agent behavior at different conversation stages, create separate test cases: ```jsonl // Test 1: First turn - agent should ask for expense type { "id": "expense-turn1", "input": [{"role": "user", "content": "I want to submit an expense"}], "assertions": [{"type": "contains", "value": "type"}] } // Test 2: Second turn - agent should create expense { "id": "expense-turn2", "input": [ {"role": "user", "content": "I want to submit an expense"}, {"role": "assistant", "content": "What type of expense would you like to submit?"}, {"role": "user", "content": "Business travel, $3500"} ], "assertions": [{"type": "tool_called", "name": "create_expense"}] } // Test 3: Final turn - agent should confirm submission { "id": "expense-turn3", "input": [ {"role": "user", "content": "I want to submit an expense"}, {"role": "assistant", "content": "What type of expense?"}, {"role": "user", "content": "Business travel, $3500"}, {"role": "assistant", "content": "Confirm $3500 expense?"}, {"role": "user", "content": "Yes"} ], "assertions": [{"type": "contains", "value": "submitted"}] } ``` ### With Attachments ```jsonl { "id": "T003", "input": [ { "role": "user", "content": [ { "type": "text", "text": "What's in this receipt?" }, { "type": "image", "source": "file://./fixtures/receipt.jpg" } ] } ], "assertions": [ { "type": "contains", "value": "amount" } ] } ``` ### Dynamic Mode (Simulator + Checkpoints) For coverage testing where conversation flow is unpredictable: ```jsonl { "id": "T004", "name": "Expense Submission Coverage", "simulator": { "use": "workers.test.user-simulator", "options": { "metadata": { "persona": "New employee unfamiliar with expense process", "goal": "Submit a $3500 travel expense" } } }, "checkpoints": [ { "id": "ask_type", "description": "Agent asks for expense type", "assertion": { "type": "contains", "value": "type" } }, { "id": "call_create", "description": "Agent calls create_expense", "after": [ "ask_type" ], "assertion": { "type": "tool_called", "name": "create_expense" } }, { "id": "confirm", "description": "Agent confirms submission", "after": [ "call_create" ], "assertion": { "type": "contains", "value": "submitted" } } ], "max_turns": 10, "timeout": "2m" } ``` ## Field Descriptions ### Standard Mode Fields | Field | Type | Required | Description | | ------------ | ------------------------------ | -------- | ------------------------------------------------- | | `id` | string | Yes | Unique test identifier | | `name` | string | No | Human-readable test name | | `input` | string \| Message \| Message[] | Yes | Input: text, single message, or message array | | `assertions` | array | No | Assertions to validate response (alias: `assert`) | | `options` | object | No | `context.Options` passed to agent | | `before` | string | No | Before script (e.g., `env_test.Before`) | | `after` | string | No | After script (e.g., `env_test.After`) | **Note**: The `input` field supports three formats: - `string`: Simple text (converted to `[{role: "user", content: "..."}]`) - `object`: Single message `{role: "...", content: "..."}` - `array`: Message history `[{role: "user", ...}, {role: "assistant", ...}, ...]` ### Dynamic Mode Fields | Field | Type | Required | Description | | --------------------------- | ------ | -------- | ------------------------------------------ | | `id` | string | Yes | Unique test identifier | | `name` | string | No | Human-readable test name | | `simulator` | object | Yes | User simulator configuration | | `simulator.use` | string | Yes | Simulator agent ID (no prefix) | | `simulator.options` | object | No | `context.Options` passed to simulator | | `checkpoints` | array | Yes | Functionality checkpoints to verify | | `checkpoints[].id` | string | Yes | Unique checkpoint identifier | | `checkpoints[].description` | string | No | Human-readable description | | `checkpoints[].assertion` | object | Yes | Assertion to verify | | `checkpoints[].after` | array | No | Checkpoint IDs that must occur first | | `max_turns` | int | No | Maximum turns before timeout (default: 20) | | `timeout` | string | No | Maximum time (default: "5m") | | `options` | object | No | `context.Options` passed to target agent | | `before` | string | No | Before script function | | `after` | string | No | After script function | ## Before and After Scripts JSONL test cases can reference `*_test.ts` scripts for environment preparation: ### Script Location Scripts are located in the agent's `src/` directory (as `*_test.ts` files): ``` assistants/expense/ ├── package.yao ├── prompts.yml ├── src/ │ ├── index.ts # Main agent script │ └── env_test.ts # Before/after functions └── tests/ ├── inputs.jsonl # Test cases └── fixtures/ └── receipt.jpg ``` ### Script Interface ```typescript // src/env_test.ts // Before function - called before test case runs // Returns context data that will be passed to After export function Before(ctx: Context, testCase: TestCase): BeforeResult { // Prepare database const userId = Process("models.user.Create", { name: "Test User", email: "test@example.com", }); // Prepare knowledge base Process("knowledge.expense.Index", { documents: [{ title: "Policy", content: "Max expense $5000" }], }); return { data: { userId, testId: testCase.id }, }; } // After function - called after test case completes (pass or fail) export function After( ctx: Context, testCase: TestCase, result: TestResult, beforeData: any ) { // Clean up database if (beforeData?.userId) { Process("models.user.Delete", beforeData.userId); } // Clean up knowledge base Process("knowledge.expense.Clear"); } // Global before - called once before all test cases export function BeforeAll(ctx: Context, testCases: TestCase[]): BeforeResult { // One-time initialization Process("models.migrate"); return { data: { initialized: true } }; } // Global after - called once after all test cases export function AfterAll(ctx: Context, results: TestResult[], beforeData: any) { // Final cleanup Process("models.cleanup"); } ``` ### Test Case with Before/After ```jsonl { "id": "T001", "name": "Submit expense with user context", "before": "env_test.Before", "after": "env_test.After", "input": "Submit a $500 travel expense", "assertions": [ { "type": "tool_called", "name": "create_expense" } ] } ``` ### Global Before/After via CLI ```bash # Run with global before/after yao agent test -i ./tests/inputs.jsonl \ --before env_test.BeforeAll \ --after env_test.AfterAll ``` ### Execution Order ``` ┌─────────────────────────────────────────────────────────────────┐ │ Test Execution with Before/After │ ├─────────────────────────────────────────────────────────────────┤ │ │ │ 1. BeforeAll() - Global initialization (once) │ │ ↓ │ │ FOR EACH test case: │ │ 2. Before() - Per-test initialization │ │ ↓ │ │ 3. Run test (call agent, check assertions) │ │ ↓ │ │ 4. After() - Per-test cleanup (always runs) │ │ ↓ │ │ 5. AfterAll() - Global cleanup (once) │ │ │ └─────────────────────────────────────────────────────────────────┘ ``` **Note**: Script tests (`*_test.ts`) don't need before/after fields since they can call functions directly within the test. ## Execution Flow ### Standard Mode ``` ┌─────────────────────────────────────────────────────────────────┐ │ Standard Mode Execution │ ├─────────────────────────────────────────────────────────────────┤ │ │ │ 1. Parse test case │ │ ├─ `input` is array? → Use as messages │ │ └─ `input` is string? → Convert to [{role: "user", content}] │ │ ↓ │ │ 2. Call Agent.Stream(ctx, messages, options) │ │ ↓ │ │ 3. Run assertions against response │ │ ├─ All PASS → Test PASSED ✅ │ │ └─ Any FAIL → Test FAILED ❌ │ │ │ └─────────────────────────────────────────────────────────────────┘ ``` ### Dynamic Mode ``` ┌─────────────────────────────────────────────────────────────────┐ │ Dynamic Mode Execution │ ├─────────────────────────────────────────────────────────────────┤ │ │ │ Initialize: │ │ - pending_checkpoints = all checkpoints │ │ - messages = [] │ │ - turn_count = 0 │ │ ↓ │ │ LOOP: │ │ 1. Call Simulator → get user input │ │ 2. Append user message to messages │ │ 3. Call Agent.Stream(ctx, messages, options) │ │ 4. Append assistant response to messages │ │ 5. Check response against pending_checkpoints │ │ └─ If matched (and `after` satisfied) → move to reached │ │ 6. Check termination: │ │ ├─ All checkpoints reached → PASSED ✅ │ │ ├─ Simulator signals goal_achieved → FAILED ❌ │ │ ├─ turn_count >= max_turns → FAILED ❌ │ │ └─ timeout exceeded → FAILED ❌ │ │ │ └─────────────────────────────────────────────────────────────────┘ ``` ## Assertion Types ### Static Assertions | Type | Description | Example | | ------------- | ---------------------- | ---------------------------------------------------------- | | `contains` | Response contains text | `{"type": "contains", "value": "success"}` | | `equals` | Exact match | `{"type": "equals", "value": "OK"}` | | `regex` | Regex pattern match | `{"type": "regex", "pattern": "order-\\d+"}` | | `json_path` | JSONPath value check | `{"type": "json_path", "path": "$.status", "value": "ok"}` | | `tool_called` | Tool was invoked | `{"type": "tool_called", "name": "create_expense"}` | | `type` | Value type check | `{"type": "type", "path": "$.count", "value": "number"}` | ### Agent-Driven Assertions For semantic or fuzzy validation: ```jsonl { "type": "agent", "use": "agents:workers.test.validator", "options": { "metadata": { "criteria": "Response should be helpful and answer the user's question", "tone": "professional and friendly" } } } ``` ### Script Assertions For custom validation logic: ```jsonl { "type": "script", "use": "scripts:tests.validate-expense", "options": { "metadata": { "min_amount": 100, "max_amount": 10000 } } } ``` ## Script Testing with Agent Assertions Script tests can use Agent-driven assertions via `t.assert.Agent()`: ```typescript export function TestExpenseResponse(t: TestingT, ctx: Context) { const messages = [ { role: "user", content: "I want to submit an expense" }, { role: "assistant", content: "What type of expense?" }, { role: "user", content: "Travel, $500" }, ]; const response = Process("agents.expense.Stream", ctx, messages); // Static assertion t.assert.Contains(response.content, "confirm"); // Agent-driven assertion t.assert.Agent(response.content, "workers.test.validator", { metadata: { criteria: "Response should ask for confirmation before creating expense", conversation: messages, }, }); } ``` ## Standard Agent Interface All agent-driven features use `context.Options`: ```go type Options struct { Skip *Skip `json:"skip,omitempty"` Connector string `json:"connector,omitempty"` Search any `json:"search,omitempty"` Mode string `json:"mode,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } ``` ### Generator Agent Called when `-i agents:xxx` is used: ```go options := &context.Options{ Metadata: map[string]any{ "test_mode": "generator", "target_agent": "assistants.expense", "count": 10, "focus": "edge-cases", }, } ``` ### Simulator Agent Called in dynamic mode to generate user input: ```go options := &context.Options{ Metadata: map[string]any{ "test_mode": "simulator", "persona": "New employee", "goal": "Submit expense", "turn_number": 3, }, } ``` ### Validator Agent Called for agent-driven assertions: ```go options := &context.Options{ Metadata: map[string]any{ "test_mode": "validator", "criteria": "Response should be helpful", }, } ``` ## Command Line Interface ### Flags Reference | Flag | Long | Description | | ---- | ------------- | ------------------------------------------------------------ | | `-i` | `--input` | Input source: file path, message, or `agents:`/`scripts:` ID | | `-n` | `--name` | Target agent ID (the agent being tested) | | `-o` | `--output` | Output file path for results | | `-c` | `--connector` | Override connector for the target agent | | `-u` | `--user` | Test user ID (default: test-user) | | `-t` | `--team` | Test team ID (default: test-team) | | `-v` | `--verbose` | Verbose output | | | `--ctx` | Path to context JSON file for custom authorization | | | `--simulator` | Default simulator agent ID for dynamic mode | | | `--before` | Global before script (e.g., `env_test.BeforeAll`) | | | `--after` | Global after script (e.g., `env_test.AfterAll`) | | | `--timeout` | Timeout per test case (default: 2m) | | | `--parallel` | Number of parallel test cases | | | `--runs` | Number of runs for stability analysis | | | `--run` | Regex pattern to filter which tests to run | | | `--fail-fast` | Stop on first failure | | | `--dry-run` | Generate/parse tests without running | ### Examples ```bash # Simple test yao agent test -i "Hello, how are you?" -n assistants.chat # From JSONL file yao agent test -i ./tests/expense.jsonl # Agent-generated tests yao agent test -i "agents:workers.test.generator?count=10" -n assistants.expense # With simulator for dynamic mode yao agent test -i ./tests/dynamic.jsonl --simulator workers.test.user-simulator # Parallel execution yao agent test -i ./tests/expense.jsonl --parallel 5 # Verbose output yao agent test -i ./tests/expense.jsonl -v ``` ## Output Format ### Console Output (Standard Mode) Standard mode shows each test case as a single line with input preview: ``` ═══════════════════════════════════════════════════════════════ Agent Test ═══════════════════════════════════════════════════════════════ ℹ Agent: workers.system.keyword ℹ Connector: deepseek.v3 ℹ Input: ./tests/inputs.jsonl (42 test cases) ℹ Timeout: 5m0s ─────────────────────────────────────────────────────────────── Running Tests ─────────────────────────────────────────────────────────────── ► [T001] 人工智能和机器学习正在改变我们�... PASSED (2.7s) ► [T002] The rapid development of cloud computing has re... PASSED (3.0s) ► [T003] 区块链技术是一种分布式账本技术�... PASSED (2.7s) ... ─────────────────────────────────────────────────────────────── Summary ─────────────────────────────────────────────────────────────── Agent: workers.system.keyword Connector: deepseek.v3 Total: 42 Passed: 42 Failed: 0 Pass Rate: 100.0% Duration: 1.8m Output: ./tests/output-20251225185335.jsonl ═══════════════════════════════════════════════════════════════ ✨ ALL TESTS PASSED ✨ ═══════════════════════════════════════════════════════════════ ``` ### Console Output (Dynamic Mode) Dynamic mode shows each test case as a tree with turns and checkpoints: ``` ═══════════════════════════════════════════════════════════════ Agent Test (Dynamic Mode) ═══════════════════════════════════════════════════════════════ ℹ Agent: assistants.expense ℹ Connector: openai.gpt4 ℹ Input: ./tests/dynamic.jsonl (2 test cases) ℹ Simulator: workers.test.user-simulator ─────────────────────────────────────────────────────────────── Running Tests ─────────────────────────────────────────────────────────────── ► [T001] Expense Submission Coverage ├─ Turn 1: "Help me file an expense" → "What type of expense?" │ └─ ✓ checkpoint: ask_type ├─ Turn 2: "Client dinner, $250" → "I'll create... Please confirm." │ └─ ✓ checkpoint: call_create (tool: create_expense) └─ Turn 3: "Yes, confirm" → "Expense submitted! Reference: EXP-001" └─ ✓ checkpoint: confirm PASSED (6.8s) - 3 turns, 3/3 checkpoints ► [T002] Expense with Attachment ├─ Turn 1: "Submit receipt" + [receipt.jpg] → "What type?" │ └─ ✓ checkpoint: ask_type ├─ Turn 2: "Business lunch" → "Amount from receipt: $85.50. Confirm?" │ └─ ✓ checkpoint: extract_amount └─ Turn 3: "Yes" → "Submitted! Reference: EXP-002" └─ ✓ checkpoint: confirm PASSED (8.2s) - 3 turns, 3/3 checkpoints ─────────────────────────────────────────────────────────────── Summary ─────────────────────────────────────────────────────────────── Agent: assistants.expense Connector: openai.gpt4 Simulator: workers.test.user-simulator Total: 2 Passed: 2 Failed: 0 Pass Rate: 100.0% Duration: 15.0s Output: ./tests/output-20251225190000.jsonl ═══════════════════════════════════════════════════════════════ ✨ ALL TESTS PASSED ✨ ═══════════════════════════════════════════════════════════════ ``` ### Console Output (Parallel Mode) When `--parallel N` is enabled, tests run concurrently. Output is buffered and displayed as complete test trees: ``` ═══════════════════════════════════════════════════════════════ Agent Test (Parallel: 5) ═══════════════════════════════════════════════════════════════ ℹ Agent: assistants.expense ℹ Input: ./tests/dynamic.jsonl (10 test cases) ℹ Parallel: 5 concurrent ─────────────────────────────────────────────────────────────── Running Tests (5 parallel) ─────────────────────────────────────────────────────────────── ► [T003] Quick approval flow ├─ Turn 1: "Approve expense EXP-001" → "Approved!" └─ ✓ checkpoint: approved PASSED (1.2s) - 1 turn, 1/1 checkpoints ► [T001] Expense Submission Coverage ├─ Turn 1: "Help me file an expense" → "What type?" │ └─ ✓ checkpoint: ask_type ├─ Turn 2: "Client dinner, $250" → "Confirm?" │ └─ ✓ checkpoint: call_create └─ Turn 3: "Yes" → "Submitted!" └─ ✓ checkpoint: confirm PASSED (6.8s) - 3 turns, 3/3 checkpoints ► [T002] Expense with Attachment ├─ Turn 1: "Submit receipt" + [receipt.jpg] → "What type?" ... PASSED (8.2s) - 3 turns, 3/3 checkpoints [Progress: 3/10 completed, 5 running...] ► [T004] Rejection flow ... PASSED (4.5s) - 2 turns, 2/2 checkpoints ─────────────────────────────────────────────────────────────── Summary ─────────────────────────────────────────────────────────────── Total: 10 Passed: 10 Failed: 0 Pass Rate: 100.0% Duration: 25.3s (effective: 2.5s/test with 5 parallel) ═══════════════════════════════════════════════════════════════ ✨ ALL TESTS PASSED ✨ ═══════════════════════════════════════════════════════════════ ``` **Note**: In parallel mode, test results appear in completion order (not input order). Each test's output is buffered and displayed as a complete tree to maintain readability. ### JSON Output (Standard Mode) Output file is a JSON object with `summary`, `environment`, `results`, and `metadata`: ```json { "summary": { "total": 3, "passed": 3, "failed": 0, "skipped": 0, "errors": 0, "timeouts": 0, "duration_ms": 5100, "agent_id": "assistants.expense", "agent_path": "/path/to/expense" }, "environment": { "user_id": "test-user", "team_id": "test-team", "locale": "en-us" }, "results": [ { "id": "expense-turn1", "status": "passed", "input": [{ "role": "user", "content": "I want to submit an expense" }], "output": "What type of expense would you like to submit?", "duration_ms": 1200 }, { "id": "expense-turn2", "status": "passed", "input": [ { "role": "user", "content": "I want to submit an expense" }, { "role": "assistant", "content": "What type?" }, { "role": "user", "content": "Business travel, $3500" } ], "output": "Confirm $3500 expense?", "duration_ms": 2100 } ], "metadata": { "started_at": "2025-12-25T10:00:00Z", "completed_at": "2025-12-25T10:00:05Z", "input_file": "./tests/expense.jsonl" } } ``` ### JSON Output (Dynamic Mode) Dynamic mode adds `turns` and `checkpoints` to each result: ```json { "summary": { "total": 1, "passed": 1, "failed": 0, "duration_ms": 6800, "agent_id": "assistants.expense" }, "results": [ { "id": "expense-dynamic", "name": "Expense Coverage Test", "status": "passed", "turns": [ { "turn": 1, "input": "Help me file an expense", "output": "What type?" }, { "turn": 2, "input": "Client dinner, $250", "output": "Confirm?" }, { "turn": 3, "input": "Yes", "output": "Submitted!" } ], "checkpoints": [ { "id": "ask_type", "reached_at_turn": 1, "passed": true }, { "id": "call_create", "reached_at_turn": 2, "passed": true }, { "id": "confirm", "reached_at_turn": 3, "passed": true } ], "total_turns": 3, "duration_ms": 6800 } ], "metadata": { "started_at": "2025-12-25T10:00:00Z", "completed_at": "2025-12-25T10:00:07Z" } } ``` ## User Simulator Agent ### Interface ```typescript interface SimulatorInput { persona: string; goal: string; conversation: Message[]; turn_number: number; max_turns: number; } interface SimulatorOutput { input: string; goal_achieved: boolean; reasoning?: string; } ``` ### Example Prompt ``` You are simulating a user with the following characteristics: Persona: {{persona}} Goal: {{goal}} Current conversation: {{conversation}} Generate the next user message to continue toward the goal. If the goal has been achieved, set goal_achieved to true. Respond in JSON format: { "input": "your response as the user", "goal_achieved": true/false, "reasoning": "brief explanation" } ``` ## Backward Compatibility Existing single-turn tests work unchanged: ```jsonl // Simple string input {"id": "T001", "input": "Hello", "assertions": [...]} // Equivalent to array format {"id": "T001", "input": [{"role": "user", "content": "Hello"}], "assertions": [...]} ``` ## Error Handling ### Standard Mode Errors | Error Type | Behavior | Output | | ---------------- | ----------- | ---------------------------- | | Agent timeout | Test FAILED | `error: "timeout after 30s"` | | Agent error | Test FAILED | `error: "agent error: ..."` | | Assertion failed | Test FAILED | `assertion_errors: [...]` | ### Dynamic Mode Errors | Error Type | Behavior | Output | | --------------------------- | ----------- | ----------------------------------- | | All checkpoints reached | Test PASSED | `status: "passed"` | | Checkpoints missing | Test FAILED | `error: "missing checkpoints: ..."` | | Max turns exceeded | Test FAILED | `error: "max turns (20) exceeded"` | | Timeout exceeded | Test FAILED | `error: "timeout after 5m"` | | Simulator error | Test FAILED | `error: "simulator error: ..."` | | Checkpoint assertion failed | Test FAILED | `error: "checkpoint X failed"` | ## Current Implementation Status | Feature | Status | Notes | | ----------------------- | ------- | -------------------------------------------------- | | Simple text input | ✅ Done | `input: "Hello"` | | Message history | ✅ Done | `input: [{role, content}, ...]` | | File attachments | ✅ Done | `file://` protocol in content parts | | Static assertions | ✅ Done | contains, equals, regex, json_path, etc. | | Before/After hooks | ✅ Done | `before/after` in JSONL, `--before/--after` in CLI | | Agent-driven assertions | ✅ Done | `type: "agent"` + `t.assert.Agent()` JSAPI | | Agent-driven input | ✅ Done | `-i agents:xxx` for test generation | | Dry-run mode | ✅ Done | `--dry-run` to preview generated tests | | Dynamic mode | ✅ Done | Simulator + Checkpoints | | Console output | ✅ Done | Dynamic mode tree output, checkpoint display | ## Open Questions 1. **Message Generation**: Should we provide a helper to generate message history from a script? 2. **Snapshot Testing**: Should we support "golden file" comparison for responses? 3. **Retry Logic**: If a test fails, should we support automatic retry? ================================================ FILE: agent/test/README.md ================================================ # Agent Test Framework A comprehensive testing framework for Yao AI agents with support for standard testing, dynamic (simulator-driven) testing, agent-driven assertions, and CI integration. ## Quick Start ### Standard Tests ```bash # Test with direct message (auto-detect agent from current directory) cd assistants/keyword yao agent test -i "Extract keywords from: AI and machine learning" # Test with direct message (specify agent explicitly) yao agent test -i "Hello world" -n workers.system.keyword # Test with JSONL file (auto-detect agent from path) yao agent test -i assistants/keyword/tests/inputs.jsonl # Generate HTML report yao agent test -i tests/inputs.jsonl -o report.html # Stability analysis (run each test 5 times) yao agent test -i tests/inputs.jsonl --runs 5 ``` ### Agent-Driven Input ```bash # Generate test cases using an agent yao agent test -i "agents:tests.generator-agent?count=10" -n assistants.expense # Preview generated tests without running (dry-run) yao agent test -i "agents:tests.generator-agent?count=5" -n assistants.expense --dry-run ``` ### Dynamic Mode (Simulator) ```bash # Run dynamic tests with simulator yao agent test -i tests/dynamic.jsonl --simulator tests.simulator-agent # See detailed turn-by-turn output yao agent test -i tests/dynamic.jsonl -v ``` ### Script Tests ```bash # Test agent handler scripts (hooks, tools, setup functions) yao agent test -i scripts.expense.setup -v # Run specific tests with regex filter yao agent test -i scripts.expense.setup --run "TestSystemReady" -v # Run with custom context (authorization, metadata) yao agent test -i scripts.expense.setup --ctx tests/context.json -v ``` ## Input Modes The `-i` flag supports multiple input modes: ### 1. JSONL File Mode Load test cases from a file: ```bash yao agent test -i tests/inputs.jsonl ``` Agent is auto-detected by traversing up from the input file to find `package.yao`. ### 2. Direct Message Mode Test with a single message: ```bash # Auto-detect agent from current working directory cd assistants/keyword yao agent test -i "Extract keywords from this text" # Or specify agent explicitly yao agent test -i "Hello" -n workers.system.keyword ``` ### 3. Agent-Driven Input Mode Generate test cases using a generator agent: ```bash # Basic usage (-n specifies the target agent to test) yao agent test -i "agents:tests.generator-agent" -n assistants.expense # With parameters yao agent test -i "agents:tests.generator-agent?count=10&focus=edge-cases" -n assistants.expense # Dry-run to preview generated tests yao agent test -i "agents:tests.generator-agent?count=5" -n assistants.expense --dry-run ``` **Note**: The `-n` flag is **required** for agent-driven input mode to specify which agent to test. The generator agent creates test cases for the target agent. ### 4. Script Test Mode Test agent handler scripts: ```bash yao agent test -i scripts.expense.setup -v ``` Script test input format: `scripts..` (e.g., `scripts.expense.setup` → `assistants/expense/src/setup_test.ts`). ### 5. Script-Generated Input Mode Generate test cases using a script: ```bash yao agent test -i "scripts:tests.gen.Generate" -n assistants.expense ``` **Note**: `scripts.xxx` (with dot) runs script tests, while `scripts:xxx` (with colon) generates test cases from a script. ## Test Modes ### Standard Mode Single call to agent with optional message history. Each test is independent and stateless. ```jsonl { "id": "T001", "input": "Hello", "assert": { "type": "contains", "value": "Hi" } } ``` ### Dynamic Mode Simulator-driven testing with checkpoint validation. A simulator agent generates user messages while checkpoints verify agent behavior. ```jsonl { "id": "T001", "input": "I want to order coffee", "simulator": { "use": "tests.simulator-agent", "options": { "metadata": { "persona": "Customer", "goal": "Order a latte" } } }, "checkpoints": [ { "id": "greeting", "assert": { "type": "regex", "value": "(?i)hello" } }, { "id": "ask_size", "after": [ "greeting" ], "assert": { "type": "regex", "value": "(?i)size" } } ], "max_turns": 10 } ``` ## Command Line Options | Flag | Description | Default | | ------------- | -------------------------------------------------------- | -------------------------- | | `-i` | Input: JSONL file, message, `agents:xxx`, or `scripts:x` | (required) | | `-o` | Output file path | `output-{timestamp}.jsonl` | | `-n` | Agent ID (optional, auto-detected) | auto-detect | | `-a` | Application directory | auto-detect | | `-e` | Environment file | - | | `-c` | Override connector | agent default | | `-u` | Test user ID | `test-user` | | `-t` | Test team ID | `test-team` | | `-r` | Reporter agent ID for custom report | built-in | | `-v` | Verbose output | false | | `--ctx` | Path to context JSON file for custom authorization | - | | `--simulator` | Default simulator agent ID for dynamic mode | - | | `--before` | Global BeforeAll hook (e.g., `env_test.BeforeAll`) | - | | `--after` | Global AfterAll hook (e.g., `env_test.AfterAll`) | - | | `--runs` | Runs per test (stability analysis) | 1 | | `--run` | Regex pattern to filter which tests to run | - | | `--timeout` | Timeout per test | 2m | | `--parallel` | Parallel test cases | 1 | | `--fail-fast` | Stop on first failure | false | | `--dry-run` | Generate test cases without running them | false | ## Custom Context File Create a JSON file for custom authorization: ```json { "chat_id": "test-chat-001", "authorized": { "user_id": "test-user-123", "team_id": "test-team-456", "constraints": { "owner_only": true, "extra": { "department": "engineering" } } }, "metadata": { "mode": "test" } } ``` Use with `--ctx`: ```bash yao agent test -i scripts.expense.setup --ctx tests/context.json -v ``` ## Input Format (JSONL) Each line is a JSON object. Below are examples organized by scenario. ### Scenario 1: Simple Text Input Basic test with string input: ```jsonl {"id": "greeting-basic", "input": "Hello, how are you?"} {"id": "greeting-chinese", "input": "你好,请问有什么可以帮助你的?"} ``` ### Scenario 2: With Assertions Validate response content: ```jsonl {"id": "keyword-extract", "input": "Extract keywords from: AI and machine learning", "assert": {"type": "contains", "value": "AI"}} {"id": "json-response", "input": "What's the weather?", "assert": {"type": "json_path", "path": "need_search", "value": true}} {"id": "no-error", "input": "Help me", "assert": {"type": "not_contains", "value": "error"}} ``` ### Scenario 3: Multiple Assertions All assertions must pass: ```jsonl { "id": "expense-submit", "input": "Submit $500 travel expense", "assert": [ { "type": "contains", "value": "expense" }, { "type": "not_contains", "value": "error" }, { "type": "regex", "value": "(?i)(submitted|created|confirmed)" } ] } ``` ### Scenario 4: Conversation History Test with multi-turn context: ```jsonl { "id": "expense-confirm", "input": [ { "role": "user", "content": "Submit an expense" }, { "role": "assistant", "content": "What type of expense?" }, { "role": "user", "content": "Travel, $500" }, { "role": "assistant", "content": "Please confirm: $500 travel expense" }, { "role": "user", "content": "Yes, confirm" } ], "assert": { "type": "regex", "value": "(?i)(submitted|created)" } } ``` ### Scenario 5: With File Attachments Test with images or documents: ```jsonl { "id": "receipt-analyze", "input": { "role": "user", "content": [ { "type": "text", "text": "Analyze this receipt" }, { "type": "image", "source": "file://fixtures/receipt.jpg" } ] }, "assert": { "type": "contains", "value": "amount" } } ``` ### Scenario 6: Agent-Driven Assertion Use LLM to validate response semantics: ```jsonl { "id": "helpful-response", "input": "How do I reset my password?", "assert": { "type": "agent", "use": "agents:tests.validator-agent", "value": "Response should provide clear step-by-step instructions" } } ``` ### Scenario 7: With Options Override connector or skip features: ```jsonl {"id": "fast-model", "input": "Quick question", "options": {"connector": "deepseek.v3", "skip": {"history": true, "trace": true}}} {"id": "scenario-test", "input": "Query users", "options": {"metadata": {"scenario": "filter"}}, "assert": {"type": "json_path", "path": "from", "value": "users"}} ``` ### Scenario 8: With Before/After Hooks Setup and teardown for each test: ```jsonl { "id": "with-user-data", "input": "Show my expenses", "before": "env_test.Before", "after": "env_test.After", "assert": { "type": "contains", "value": "expense" } } ``` ### Scenario 9: Skip Test Temporarily disable a test: ```jsonl { "id": "wip-feature", "input": "New feature test", "skip": true } ``` ### Scenario 10: Dynamic Mode (Simulator) Multi-turn testing with user simulator: ```jsonl { "id": "coffee-order", "input": "I want to order coffee", "simulator": { "use": "tests.simulator-agent", "options": { "metadata": { "persona": "Regular customer", "goal": "Order a medium latte" } } }, "checkpoints": [ { "id": "greeting", "assert": { "type": "regex", "value": "(?i)(hello|hi|help)" } }, { "id": "ask-size", "after": [ "greeting" ], "assert": { "type": "regex", "value": "(?i)size" } }, { "id": "confirm", "after": [ "ask-size" ], "assert": { "type": "regex", "value": "(?i)confirm" } } ], "max_turns": 10 } ``` ### Scenario 11: Dynamic Mode with Optional Checkpoint Some checkpoints are optional: ```jsonl { "id": "expense-flow", "input": "Submit expense", "simulator": { "use": "tests.simulator-agent", "options": { "metadata": { "persona": "New employee", "goal": "Submit $500 travel expense" } } }, "checkpoints": [ { "id": "ask-type", "assert": { "type": "regex", "value": "(?i)type" } }, { "id": "suggest-category", "required": false, "assert": { "type": "contains", "value": "category" } }, { "id": "confirm", "after": [ "ask-type" ], "assert": { "type": "regex", "value": "(?i)confirm" } } ], "max_turns": 15 } ``` ### Standard Mode Fields | Field | Type | Required | Description | | ---------- | ------------------------------ | -------- | ------------------------------------- | | `id` | string | Yes | Test case ID | | `input` | string \| Message \| []Message | Yes | Test input | | `assert` | Assertion \| []Assertion | No | Assertion rules | | `expected` | any | No | Expected output (exact match) | | `user` | string | No | Override user ID for this test | | `team` | string | No | Override team ID for this test | | `metadata` | map | No | Additional metadata for hooks | | `options` | Options | No | Context options | | `timeout` | string | No | Override timeout (e.g., "30s") | | `skip` | bool | No | Skip this test | | `before` | string | No | Before hook (e.g., `env_test.Before`) | | `after` | string | No | After hook (e.g., `env_test.After`) | ### Dynamic Mode Fields | Field | Type | Required | Description | | ----------------------------- | ------ | -------- | -------------------------------------- | | `id` | string | Yes | Test case ID | | `input` | string | Yes | Initial user message | | `simulator` | object | Yes | Simulator configuration | | `simulator.use` | string | Yes | Simulator agent ID (no prefix) | | `simulator.options` | object | No | Simulator options | | `simulator.options.metadata` | map | No | Metadata (persona, goal, etc.) | | `simulator.options.connector` | string | No | Override simulator connector | | `checkpoints` | array | Yes | Checkpoints to verify | | `checkpoints[].id` | string | Yes | Checkpoint identifier | | `checkpoints[].description` | string | No | Human-readable description | | `checkpoints[].assert` | object | Yes | Assertion to validate | | `checkpoints[].after` | array | No | Checkpoint IDs that must occur first | | `checkpoints[].required` | bool | No | Is checkpoint required (default: true) | | `max_turns` | int | No | Maximum turns (default: 20) | | `timeout` | string | No | Override timeout (e.g., "2m") | ### Options The `options` field allows per-test-case configuration: | Field | Type | Description | | ------------------------ | ------ | ------------------------------------------ | | `connector` | string | Override connector (e.g., `"deepseek.v3"`) | | `mode` | string | Agent mode (default: `"chat"`) | | `search` | bool | Enable/disable search mode | | `disable_global_prompts` | bool | Temporarily disable global prompts | | `metadata` | map | Custom data passed to hooks | | `skip` | object | Skip configuration (see below) | ### Options.skip | Field | Type | Description | | --------- | ---- | ----------------------- | | `history` | bool | Skip history loading | | `trace` | bool | Skip trace logging | | `output` | bool | Skip output to client | | `keyword` | bool | Skip keyword extraction | | `search` | bool | Skip auto search | ### Input Types | Type | Description | Example | | ----------- | -------------------- | ----------------------------------------------------- | | `string` | Simple text | `"Hello world"` | | `Message` | Single message | `{"role": "user", "content": "..."}` | | `[]Message` | Conversation history | `[{"role": "user", ...}, {"role": "assistant", ...}]` | ## Assertions Use `assert` for flexible validation. If `assert` is defined, it takes precedence over `expected`. ### Static Assertions | Type | Description | Example | | -------------- | ----------------------------- | --------------------------------------------------------- | | `equals` | Exact match | `{"type": "equals", "value": {"key": "val"}}` | | `contains` | Output contains value | `{"type": "contains", "value": "keyword"}` | | `not_contains` | Output does not contain value | `{"type": "not_contains", "value": "error"}` | | `json_path` | Extract JSON path and compare | `{"type": "json_path", "path": "$.field", "value": true}` | | `regex` | Match regex pattern | `{"type": "regex", "value": "\\d+"}` | | `type` | Check output type | `{"type": "type", "value": "object"}` | | `tool_called` | Check if a tool was called | `{"type": "tool_called", "value": "setup"}` | | `tool_result` | Check tool execution result | `{"type": "tool_result", "value": {"tool": "setup", "result": {"success": true}}}` | ### Assertion Fields | Field | Type | Description | | --------- | ------ | -------------------------------------------------------- | | `type` | string | Assertion type (required) | | `value` | any | Expected value or pattern | | `path` | string | JSON path for `json_path` type | | `script` | string | Script name for `script` type | | `use` | string | Agent/script ID for `agent` type (with `agents:` prefix) | | `options` | object | Options for agent assertions | | `message` | string | Custom failure message | | `negate` | bool | Invert the assertion result | ### Agent-Driven Assertions For semantic or fuzzy validation using an LLM: ```jsonl { "id": "T001", "input": "Hello", "assert": { "type": "agent", "use": "agents:tests.validator-agent", "value": "Response should be friendly and helpful" } } ``` The validator agent receives the output and criteria, then returns `{"passed": true/false, "reason": "..."}`. **How it works:** 1. The framework builds a validation request with the agent's response (including tool result messages) 2. The validator agent evaluates the response against the criteria 3. The validator returns a JSON response with `passed` and `reason` **Output in test report (for checkpoints):** ```json { "agent_validation": { "passed": true, "reason": "Response explicitly confirms setup completion", "criteria": "Response should be friendly and helpful", "input": "Hello! How can I help you today?", "response": { "passed": true, "reason": "Response explicitly confirms setup completion" } } } ``` - `input`: The content sent to the validator (agent response + tool result messages) - `response`: The raw JSON response from the validator agent - `criteria`: The validation criteria from the test case ### Tool Assertions For validating that specific tools were called and their results: #### tool_called Check if a specific tool was called: ```jsonl { "id": "T001", "input": "Set up my expense system", "assert": { "type": "tool_called", "value": "setup" } } ``` **Value formats:** - **String**: Tool name (supports suffix matching, e.g., `"setup"` matches `"agents_expense_tools__setup"`) - **Array**: Any of the specified tools must be called - **Object**: Match tool name and optionally arguments ```jsonl // Match any of these tools {"type": "tool_called", "value": ["setup", "init"]} // Match tool with specific arguments {"type": "tool_called", "value": {"name": "setup", "arguments": {"action": "init"}}} ``` #### tool_result Check the result of a tool execution: ```jsonl { "id": "T001", "input": "Set up my expense system", "assert": { "type": "tool_result", "value": { "tool": "setup", "result": { "success": true } } } } ``` **Result matching:** - If `result` is omitted, only checks that the tool executed without error - Supports partial matching (only specified fields are checked) - Supports regex patterns with `regex:` prefix for string values ```jsonl // Just check tool executed without error {"type": "tool_result", "value": {"tool": "setup"}} // Check specific result fields {"type": "tool_result", "value": {"tool": "setup", "result": {"success": true}}} // Use regex for message matching {"type": "tool_result", "value": {"tool": "setup", "result": {"message": "regex:(?i)setup.*complete"}}} ``` ### Script Assertions For custom validation logic: ```jsonl { "id": "T001", "input": "Test", "assert": { "type": "script", "script": "scripts.test.Validate" } } ``` ### Multiple Assertions All assertions must pass: ```jsonl { "id": "T001", "input": "Hello", "assert": [ { "type": "contains", "value": "Hi" }, { "type": "not_contains", "value": "error" }, { "type": "json_path", "path": "status", "value": "ok" } ] } ``` ## File Attachments Test inputs support file attachments using the `file://` protocol: ```jsonl { "id": "T001", "input": { "role": "user", "content": [ { "type": "text", "text": "Analyze this image" }, { "type": "image", "source": "file://fixtures/receipt.jpg" } ] } } ``` Supported types: images (jpg, png, gif, webp), audio (wav, mp3), documents (pdf, doc, txt). ## Before/After Hooks Hooks allow you to run setup and teardown code before and after tests. Hook scripts must be placed in the agent's `src/` directory with `_test.ts` suffix. ### Hook Types | Hook | Scope | When Called | Use Case | | ----------- | -------- | --------------------- | ------------------------------- | | `Before` | Per-test | Before each test case | Create test data, setup context | | `After` | Per-test | After each test case | Cleanup test data, log results | | `BeforeAll` | Global | Once before all tests | Database migration, init | | `AfterAll` | Global | Once after all tests | Global cleanup, report | ### Execution Order ``` BeforeAll (global) ├─ Before (test 1) │ └─ Test 1 execution │ └─ After (test 1) ├─ Before (test 2) │ └─ Test 2 execution │ └─ After (test 2) └─ ... AfterAll (global) ``` ### Per-Test Hooks Defined in JSONL, scripts located in agent's `src/` directory: ```jsonl { "id": "T001", "input": "Test", "before": "env_test.Before", "after": "env_test.After" } ``` ### Global Hooks Via CLI flags: ```bash yao agent test -i tests/inputs.jsonl --before env_test.BeforeAll --after env_test.AfterAll ``` ### Hook Function Signatures ```typescript // assistants/expense/src/env_test.ts /** * Before - Called before each test case * @param ctx - Agent context with user/team info * @param testCase - The test case about to run * @returns any - Data passed to After hook (optional) */ export function Before(ctx: Context, testCase: TestCase): any { const userId = Process("models.user.Create", { name: "Test User" }); return { userId }; // This data is passed to After } /** * After - Called after each test case (pass or fail) * @param ctx - Agent context * @param testCase - The test case that ran * @param result - Test result with status, output, duration * @param beforeData - Data returned from Before hook */ export function After( ctx: Context, testCase: TestCase, result: TestResult, beforeData: any ) { if (beforeData?.userId) { Process("models.user.Delete", beforeData.userId); } if (result.status === "failed") { console.log(`Test ${testCase.id} failed: ${result.error}`); } } /** * BeforeAll - Called once before all tests * @param ctx - Agent context * @param testCases - Array of all test cases * @returns any - Data passed to AfterAll hook (optional) */ export function BeforeAll(ctx: Context, testCases: TestCase[]): any { Process("models.migrate"); return { initialized: true, count: testCases.length }; } /** * AfterAll - Called once after all tests complete * @param ctx - Agent context * @param results - Array of all test results * @param beforeData - Data returned from BeforeAll hook */ export function AfterAll(ctx: Context, results: TestResult[], beforeData: any) { const passed = results.filter((r) => r.status === "passed").length; console.log(`Tests completed: ${passed}/${results.length} passed`); Process("models.cleanup"); } ``` ### Hook Parameters **Context** - Agent execution context: ```typescript interface Context { locale: string; // Locale (e.g., "en-us") authorized: { user_id: string; // Test user ID team_id: string; // Test team ID constraints?: object; // Access constraints }; metadata: object; // Custom metadata from test case } ``` **TestCase** - Test case definition: ```typescript interface TestCase { id: string; // Test case ID input: any; // Test input (string, Message, or Message[]) assert?: object; // Assertion rules expected?: any; // Expected output user?: string; // Override user ID team?: string; // Override team ID metadata?: object; // Custom metadata options?: object; // Context options timeout?: string; // Timeout (e.g., "30s") skip?: boolean; // Skip flag before?: string; // Before hook reference after?: string; // After hook reference } ``` **TestResult** - Test execution result: ```typescript interface TestResult { id: string; // Test case ID status: string; // "passed" | "failed" | "error" | "skipped" | "timeout" input: any; // Actual input sent output: any; // Agent response expected?: any; // Expected output (if defined) error?: string; // Error message (if failed) duration_ms: number; // Execution time in milliseconds assertions?: object[]; // Assertion results } ``` ### Common Use Cases **Database Setup/Teardown**: ```typescript export function Before(ctx: Context, testCase: TestCase): any { // Create test records const user = Process("models.user.Create", { name: "Test", email: "test@example.com", }); const expense = Process("models.expense.Create", { user_id: user.id, amount: 100, }); return { user, expense }; } export function After( ctx: Context, testCase: TestCase, result: TestResult, data: any ) { // Clean up in reverse order if (data?.expense) Process("models.expense.Delete", data.expense.id); if (data?.user) Process("models.user.Delete", data.user.id); } ``` **Conditional Setup Based on Metadata**: ```typescript export function Before(ctx: Context, testCase: TestCase): any { const scenario = testCase.metadata?.scenario || "default"; if (scenario === "empty_db") { Process("models.expense.DeleteAll"); } else if (scenario === "with_data") { Process("scripts.tests.seed.LoadTestData"); } return { scenario }; } ``` **Logging and Debugging**: ```typescript export function After( ctx: Context, testCase: TestCase, result: TestResult, data: any ) { if (result.status === "failed") { console.log("=== Test Failed ==="); console.log("Test ID:", testCase.id); console.log("Input:", JSON.stringify(testCase.input)); console.log("Output:", JSON.stringify(result.output)); console.log("Error:", result.error); } } ``` ## Script Testing Test agent handler scripts with the `t.assert` API: ```typescript // assistants/expense/src/setup_test.ts import { SystemReady } from "./setup"; export function TestSystemReady(t: TestingT, ctx: Context) { const result = SystemReady(ctx); t.assert.True(result.success, "Should succeed"); t.assert.Equal(result.status, "ready", "Status should be ready"); t.assert.NotNil(result.data, "Data should not be nil"); } export function TestWithAgentAssertion(t: TestingT, ctx: Context) { const response = Process("agents.expense.Stream", ctx, messages); // Static assertion t.assert.Contains(response.content, "confirm"); // Agent-driven assertion t.assert.Agent(response.content, "tests.validator-agent", { criteria: "Response should ask for confirmation", }); } ``` ### Available Assertions | Method | Description | | -------------------------------- | ------------------------------ | | `t.assert.True(value, msg)` | Assert value is true | | `t.assert.False(value, msg)` | Assert value is false | | `t.assert.Equal(a, b, msg)` | Assert a equals b | | `t.assert.NotEqual(a, b, msg)` | Assert a not equals b | | `t.assert.Nil(value, msg)` | Assert value is null/undefined | | `t.assert.NotNil(value, msg)` | Assert value is not nil | | `t.assert.Contains(s, sub, msg)` | Assert string contains substr | | `t.assert.Len(arr, n, msg)` | Assert array/string length | | `t.assert.Agent(resp, id, opts)` | Agent-driven assertion | ## Dynamic Mode For testing complex conversation flows where the path is unpredictable: ```jsonl { "id": "coffee-order", "input": "I want to order coffee", "simulator": { "use": "tests.simulator-agent", "options": { "metadata": { "persona": "Customer ordering a latte", "goal": "Complete the coffee order" } } }, "checkpoints": [ { "id": "greeting", "description": "Agent greets customer", "assert": { "type": "regex", "value": "(?i)(hello|hi|help)" } }, { "id": "ask_size", "description": "Agent asks for size", "after": [ "greeting" ], "assert": { "type": "regex", "value": "(?i)size" } }, { "id": "confirm", "description": "Agent confirms order", "after": [ "ask_size" ], "assert": { "type": "regex", "value": "(?i)confirm" } } ], "max_turns": 10 } ``` ### Console Output (Dynamic Mode) ``` ► [coffee-order] (dynamic, 3 checkpoints) ℹ Dynamic test: coffee-order (max 10 turns) ℹ Turn 1: User: I want to order coffee ℹ Turn 1: Agent: Hello! What can I get for you? ℹ ✓ checkpoint: greeting ℹ Turn 2: User: A medium latte please ℹ Turn 2: Agent: What size would you like? ℹ ✓ checkpoint: ask_size ℹ Turn 3: User: Medium ℹ Turn 3: Agent: Let me confirm: Medium latte. Correct? ℹ ✓ checkpoint: confirm └─ PASSED (3 turns, 3 checkpoints, 8.5s) ``` ### Dynamic Mode Output Structure Each turn in the output includes: ```typescript interface TurnResult { turn: number; // Turn number (1-based) input: string; // User message output: any; // Agent response summary (for display) response: { // Full agent response (for detailed analysis) content: string; // LLM text content tool_calls: [ { // Tool calls made tool: string; // Tool name arguments: any; // Call arguments result: any; // Execution result } ]; next: any; // Next hook data }; checkpoints_reached: string[]; // Checkpoint IDs reached duration_ms: number; // Execution time } ``` ### Checkpoint Result Structure Each checkpoint in the output includes: ```typescript interface CheckpointResult { id: string; // Checkpoint identifier reached: boolean; // Whether checkpoint was reached reached_at_turn?: number; // Turn number when reached (if reached) required: boolean; // Whether checkpoint is required passed: boolean; // Whether assertion passed message?: string; // Assertion result message agent_validation?: { // Agent assertion details (for type: "agent") passed: boolean; // Validator's determination reason: string; // Explanation from validator criteria: string; // Validation criteria checked input: any; // Content sent to validator response: { // Raw validator response passed: boolean; reason: string; }; }; } ``` **Note**: For agent-based assertions (`type: "agent"`), the `agent_validation` field provides full transparency into the validation process. The `input` field contains the combined output (agent text response + tool result messages) that was validated. ## Output Formats Determined by `-o` file extension: | Extension | Format | Description | | --------- | -------- | ---------------------- | | `.jsonl` | JSONL | Streaming (default) | | `.json` | JSON | Complete structured | | `.md` | Markdown | Human-readable | | `.html` | HTML | Interactive web report | ## Stability Analysis Run each test multiple times to measure consistency: ```bash yao agent test -i tests/inputs.jsonl --runs 5 -o stability.json ``` | Pass Rate | Classification | | --------- | --------------- | | 100% | Stable | | 80-99% | Mostly Stable | | 50-79% | Unstable | | < 50% | Highly Unstable | ## CI Integration ```bash # Exit code: 0 = all passed, 1 = failures yao agent test -i tests/inputs.jsonl --fail-fast # Run with parallel execution yao agent test -i tests/inputs.jsonl --parallel 4 ``` ### GitHub Actions Example ```yaml - name: Run Agent Tests run: | yao agent test -i assistants/expense/tests/inputs.jsonl \ -u ci-user -t ci-team \ --runs 3 \ -o report.json - name: Run Dynamic Tests run: | yao agent test -i assistants/expense/tests/dynamic.jsonl \ --simulator tests.simulator-agent \ -v - name: Run Script Tests run: | yao agent test -i scripts.expense.setup -v ``` ## Format Rules Reference | Context | Format | Example | | ---------------------- | ------------------------ | ----------------------------------------- | | `-i agents:xxx` (CLI) | Colon prefix | `agents:tests.generator` | | `-i scripts:xxx` (CLI) | Colon prefix | `scripts:tests.gen.Generate` | | `-i scripts.xxx` (CLI) | Dot prefix (test mode) | `scripts.expense.setup` | | JSONL assertion `use` | Prefix required | `"use": "agents:tests.validator"` | | JSONL `simulator.use` | No prefix (agent only) | `"use": "tests.simulator-agent"` | | `--simulator` flag | No prefix (agent only) | `--simulator tests.simulator-agent` | | `t.assert.Agent()` | No prefix (method-bound) | `t.assert.Agent(resp, "tests.validator")` | | JSONL `before/after` | No prefix (in src/) | `"before": "env_test.Before"` | | `--before/--after` | No prefix (in src/) | `--before env_test.BeforeAll` | **Script input modes**: - `scripts.xxx` (dot) - Run script tests (`*_test.ts` functions) - `scripts:xxx` (colon) - Generate test cases from a script ## Built-in Test Agents The framework provides three specialized agents for testing: ### Generator Agent (`tests.generator-agent`) Generates test cases based on target agent description. **package.yao**: ```json { "name": "Test Case Generator", "connector": "gpt-4o", "description": "Generates test cases for agent testing", "options": { "temperature": 0.7 }, "automated": true } ``` **prompts.yml**: ```yaml - role: system content: | You are a test case generator. Generate test cases based on the target agent. ## Input Format - `target_agent`: Agent info (id, description, tools) - `count`: Number of test cases (default: 5) - `focus`: Focus area (e.g., "edge-cases", "happy-path") ## Output Format JSON array of test cases: [ { "id": "test-id", "input": "User message", "assert": [{"type": "contains", "value": "expected"}] } ] ``` **Usage**: ```bash yao agent test -i "agents:tests.generator-agent?count=10" -n assistants.expense ``` ### Validator Agent (`tests.validator-agent`) Validates agent responses for agent-driven assertions. **package.yao**: ```json { "name": "Response Validator", "connector": "gpt-4o", "description": "Validates responses against criteria", "options": { "temperature": 0 }, "automated": true } ``` **prompts.yml**: ```yaml - role: system content: | You are a response validator. Evaluate whether the response meets the criteria. ## Input Format - `output`: The response to validate - `criteria`: The validation rules - `input`: Original input (optional) ## Output Format JSON object (no markdown): {"passed": true/false, "reason": "explanation"} ## Examples Input: {"output": "Paris is the capital", "criteria": "factually accurate"} Output: {"passed": true, "reason": "Statement is correct"} ``` **Usage in JSONL**: ```jsonl { "id": "T001", "input": "Hello", "assert": { "type": "agent", "use": "agents:tests.validator-agent", "value": "Response should be friendly" } } ``` **Usage in script tests**: ```typescript t.assert.Agent(response, "tests.validator-agent", { criteria: "Response should be helpful", }); ``` ### Simulator Agent (`tests.simulator-agent`) Simulates user behavior for dynamic mode testing. **package.yao**: ```json { "name": "User Simulator", "connector": "gpt-4o", "description": "Simulates user behavior for dynamic testing", "options": { "temperature": 0.7 }, "automated": true } ``` **prompts.yml**: ```yaml - role: system content: | You are a user simulator. Generate realistic user messages based on persona and goal. ## Input Format - `persona`: User description (e.g., "New employee") - `goal`: What user wants to achieve - `conversation`: Previous messages - `turn_number`: Current turn - `max_turns`: Maximum turns ## Output Format JSON object: { "message": "User response", "goal_achieved": false, "reasoning": "Strategy explanation" } ## Guidelines 1. Stay in character 2. Work toward the goal 3. Be realistic (include natural variations) 4. Set goal_achieved: true when done ``` **Usage in JSONL**: ```jsonl { "id": "dynamic-test", "input": "I need help", "simulator": { "use": "tests.simulator-agent", "options": { "metadata": { "persona": "New employee", "goal": "Submit expense report" } } }, "checkpoints": [ { "id": "greeting", "assert": { "type": "regex", "value": "(?i)hello" } } ], "max_turns": 10 } ``` **Usage via CLI**: ```bash yao agent test -i tests/dynamic.jsonl --simulator tests.simulator-agent ``` ## Extract Command Extract test results from output JSONL file to individual Markdown or JSON files for human review: ```bash # Extract to Markdown files (default) yao agent extract output-20260127104118.jsonl # Specify output directory yao agent extract output.jsonl -o ./reports/ # Extract to JSON format yao agent extract output.jsonl --format json ``` ### Extract Command Options | Flag | Description | Default | | ---------- | ---------------------------------------- | ---------- | | `-o` | Output directory | same as input | | `--format` | Output format: `markdown`, `json` | `markdown` | ### Output Format (Markdown) Each test result is extracted to a separate file: ```markdown # T001-销售分析师-月末周五 **Status**: ✅ PASSED **Duration**: 16743ms ## Input (Full input content in markdown code block) ## Output (Agent's response content) ``` This is useful for: - Human review of agent outputs - Comparing results across test runs - Documentation and reporting ## Exit Codes | Code | Description | | ---- | --------------------------------------------------- | | 0 | All tests passed | | 1 | Tests failed, configuration error, or runtime error | ================================================ FILE: agent/test/assert.go ================================================ package test import ( "encoding/json" "fmt" "regexp" "strconv" "strings" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/process" goutext "github.com/yaoapp/gou/text" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" ) // Asserter handles test assertions type Asserter struct { // response holds the current response for tool-related assertions response *context.Response } // NewAsserter creates a new asserter func NewAsserter() *Asserter { return &Asserter{} } // WithResponse sets the response for tool-related assertions func (a *Asserter) WithResponse(response *context.Response) *Asserter { a.response = response return a } // Validate validates the output against the test case's assertions // Returns (passed, error message) func (a *Asserter) Validate(tc *Case, output interface{}) (bool, string) { // If assert is defined, use assertion rules if tc.Assert != nil { return a.validateAssertions(tc, output) } // If expected is defined, use simple comparison if tc.Expected != nil { if validateOutput(output, tc.Expected) { return true, "" } return false, "output does not match expected" } // No assertions defined - pass if we got output without error return true, "" } // ValidateWithDetails validates the output and returns detailed results // This is useful for agent assertions where we want to capture the validator's response func (a *Asserter) ValidateWithDetails(tc *Case, output interface{}) *AssertionResult { if tc.Assert == nil { return &AssertionResult{Passed: true} } assertions := a.parseAssertions(tc.Assert) if len(assertions) == 0 { return &AssertionResult{Passed: true} } // For single assertion, return its full result if len(assertions) == 1 { return a.evaluateAssertion(assertions[0], output, tc.Input) } // For multiple assertions, combine results var failures []string for _, assertion := range assertions { result := a.evaluateAssertion(assertion, output, tc.Input) if !result.Passed { msg := result.Message if assertion.Message != "" { msg = assertion.Message } failures = append(failures, msg) } } if len(failures) > 0 { return &AssertionResult{ Passed: false, Message: strings.Join(failures, "; "), } } return &AssertionResult{Passed: true} } // validateAssertions validates output against assertion rules func (a *Asserter) validateAssertions(tc *Case, output interface{}) (bool, string) { assertions := a.parseAssertions(tc.Assert) if len(assertions) == 0 { return true, "" } var failures []string for _, assertion := range assertions { result := a.evaluateAssertion(assertion, output, tc.Input) if !result.Passed { msg := result.Message if assertion.Message != "" { msg = assertion.Message } failures = append(failures, msg) } } if len(failures) > 0 { return false, strings.Join(failures, "; ") } return true, "" } // parseAssertions parses the assert field into a list of assertions func (a *Asserter) parseAssertions(assert interface{}) []*Assertion { if assert == nil { return nil } var assertions []*Assertion switch v := assert.(type) { case map[string]interface{}: // Single assertion object assertion := a.mapToAssertion(v) if assertion != nil { assertions = append(assertions, assertion) } case []interface{}: // Array of assertions for _, item := range v { if m, ok := item.(map[string]interface{}); ok { assertion := a.mapToAssertion(m) if assertion != nil { assertions = append(assertions, assertion) } } } case string: // Shorthand: just a type name (e.g., "contains") assertions = append(assertions, &Assertion{Type: v}) } return assertions } // mapToAssertion converts a map to an Assertion func (a *Asserter) mapToAssertion(m map[string]interface{}) *Assertion { assertion := &Assertion{} if t, ok := m["type"].(string); ok { assertion.Type = t } if v, ok := m["value"]; ok { assertion.Value = v } if p, ok := m["path"].(string); ok { assertion.Path = p } if s, ok := m["script"].(string); ok { assertion.Script = s } if u, ok := m["use"].(string); ok { assertion.Use = u } if msg, ok := m["message"].(string); ok { assertion.Message = msg } if n, ok := m["negate"].(bool); ok { assertion.Negate = n } // Parse options for agent assertions if opts, ok := m["options"].(map[string]interface{}); ok { assertion.Options = &AssertionOptions{} if c, ok := opts["connector"].(string); ok { assertion.Options.Connector = c } if meta, ok := opts["metadata"].(map[string]interface{}); ok { assertion.Options.Metadata = meta } } return assertion } // evaluateAssertion evaluates a single assertion func (a *Asserter) evaluateAssertion(assertion *Assertion, output, input interface{}) *AssertionResult { result := &AssertionResult{ Assertion: assertion, Expected: assertion.Value, } switch assertion.Type { case "equals", "": result = a.assertEquals(assertion, output) case "contains": result = a.assertContains(assertion, output) case "not_contains": result = a.assertNotContains(assertion, output) case "json_path": result = a.assertJSONPath(assertion, output) case "regex": result = a.assertRegex(assertion, output) case "type": result = a.assertType(assertion, output) case "script": result = a.assertScript(assertion, output, input) case "agent": result = a.assertAgent(assertion, output, input) case "tool_called": result = a.assertToolCalled(assertion) case "tool_result": result = a.assertToolResult(assertion) default: result.Passed = false result.Message = fmt.Sprintf("unknown assertion type: %s", assertion.Type) } // Apply negate if assertion.Negate { result.Passed = !result.Passed if result.Passed { result.Message = "negated assertion passed" } else { result.Message = "negated: " + result.Message } } return result } // assertEquals checks for exact equality func (a *Asserter) assertEquals(assertion *Assertion, output interface{}) *AssertionResult { result := &AssertionResult{ Assertion: assertion, Actual: output, Expected: assertion.Value, } if validateOutput(output, assertion.Value) { result.Passed = true result.Message = "values are equal" } else { result.Passed = false result.Message = fmt.Sprintf("expected %v, got %v", assertion.Value, output) } return result } // assertContains checks if output contains the expected value func (a *Asserter) assertContains(assertion *Assertion, output interface{}) *AssertionResult { result := &AssertionResult{ Assertion: assertion, Actual: output, Expected: assertion.Value, } outputStr := a.toString(output) expectedStr := a.toString(assertion.Value) if strings.Contains(outputStr, expectedStr) { result.Passed = true result.Message = fmt.Sprintf("output contains '%s'", expectedStr) } else { result.Passed = false result.Message = fmt.Sprintf("output does not contain '%s'", expectedStr) } return result } // assertNotContains checks if output does not contain the expected value func (a *Asserter) assertNotContains(assertion *Assertion, output interface{}) *AssertionResult { result := a.assertContains(assertion, output) result.Passed = !result.Passed if result.Passed { result.Message = fmt.Sprintf("output does not contain '%s'", a.toString(assertion.Value)) } else { result.Message = fmt.Sprintf("output should not contain '%s'", a.toString(assertion.Value)) } return result } // assertJSONPath extracts a value using JSON path and compares func (a *Asserter) assertJSONPath(assertion *Assertion, output interface{}) *AssertionResult { result := &AssertionResult{ Assertion: assertion, Expected: assertion.Value, } // Convert output to JSON if needed var jsonData interface{} switch v := output.(type) { case string: // Use gou/text to extract JSON (handles markdown, auto-repair, etc.) extracted := goutext.ExtractJSON(v) if extracted != nil { jsonData = extracted } else { result.Passed = false result.Message = fmt.Sprintf("output is not valid JSON: %s", v) return result } case map[string]interface{}, []interface{}: jsonData = v default: result.Passed = false result.Message = fmt.Sprintf("output is not a JSON object or array, got: %T = %v", output, truncateOutput(output, 200)) return result } // Extract value using simple path (e.g., "$.need_search" or "need_search") path := strings.TrimPrefix(assertion.Path, "$.") actual := a.extractPath(jsonData, path) result.Actual = actual // Compare expected value with actual value // First try direct comparison (handles array-to-array comparison) if validateOutput(actual, assertion.Value) { result.Passed = true result.Message = fmt.Sprintf("path '%s' equals expected value", assertion.Path) return result } // If expected is an array and direct comparison failed, check if actual matches ANY element (IN semantics) // This is for cases like: expected: ["a", "b"], actual: "a" (actual is one of expected) if expectedArr, ok := assertion.Value.([]interface{}); ok { // Only apply IN semantics if actual is NOT an array (otherwise it was already compared above) if _, actualIsArr := actual.([]interface{}); !actualIsArr { for _, expectedItem := range expectedArr { if validateOutput(actual, expectedItem) { result.Passed = true result.Message = fmt.Sprintf("path '%s' equals one of expected values", assertion.Path) return result } } } result.Passed = false result.Message = fmt.Sprintf("path '%s': expected %v, got %v", assertion.Path, assertion.Value, actual) } else { // Direct comparison already failed above result.Passed = false result.Message = fmt.Sprintf("path '%s': expected %v, got %v", assertion.Path, assertion.Value, actual) } return result } // truncateOutput truncates output for error messages func truncateOutput(output interface{}, maxLen int) string { var s string switch v := output.(type) { case string: s = v case nil: return "" default: bytes, err := jsoniter.Marshal(v) if err != nil { s = fmt.Sprintf("%v", v) } else { s = string(bytes) } } if len(s) > maxLen { return s[:maxLen] + "..." } return s } // extractPath extracts a value from JSON using dot-notation path with array index support // Supports: "field", "field.nested", "field[0]", "field[0].nested", "field.nested[0].value" func (a *Asserter) extractPath(data interface{}, path string) interface{} { current := data // Parse path into segments, handling both dots and array indices // e.g., "wheres[0].like" -> ["wheres", "[0]", "like"] segments := parsePathSegments(path) for _, segment := range segments { if segment == "" { continue } // Check if this is an array index like "[0]" if strings.HasPrefix(segment, "[") && strings.HasSuffix(segment, "]") { indexStr := segment[1 : len(segment)-1] index, err := strconv.Atoi(indexStr) if err != nil { return nil } arr, ok := current.([]interface{}) if !ok { return nil } if index < 0 || index >= len(arr) { return nil } current = arr[index] } else { // Regular field access switch v := current.(type) { case map[string]interface{}: current = v[segment] default: return nil } } } return current } // parsePathSegments splits a path like "wheres[0].like" into ["wheres", "[0]", "like"] func parsePathSegments(path string) []string { var segments []string var current strings.Builder for i := 0; i < len(path); i++ { ch := path[i] switch ch { case '.': if current.Len() > 0 { segments = append(segments, current.String()) current.Reset() } case '[': if current.Len() > 0 { segments = append(segments, current.String()) current.Reset() } // Find the closing bracket j := i + 1 for j < len(path) && path[j] != ']' { j++ } if j < len(path) { segments = append(segments, path[i:j+1]) // Include "[" and "]" i = j } default: current.WriteByte(ch) } } if current.Len() > 0 { segments = append(segments, current.String()) } return segments } // assertRegex checks if output matches a regex pattern func (a *Asserter) assertRegex(assertion *Assertion, output interface{}) *AssertionResult { result := &AssertionResult{ Assertion: assertion, Actual: output, Expected: assertion.Value, } pattern, ok := assertion.Value.(string) if !ok { result.Passed = false result.Message = "regex pattern must be a string" return result } re, err := regexp.Compile(pattern) if err != nil { result.Passed = false result.Message = fmt.Sprintf("invalid regex pattern: %s", err.Error()) return result } outputStr := a.toString(output) if re.MatchString(outputStr) { result.Passed = true result.Message = fmt.Sprintf("output matches pattern '%s'", pattern) } else { result.Passed = false result.Message = fmt.Sprintf("output does not match pattern '%s'", pattern) } return result } // assertType checks the type of the output func (a *Asserter) assertType(assertion *Assertion, output interface{}) *AssertionResult { result := &AssertionResult{ Assertion: assertion, Actual: output, Expected: assertion.Value, } expectedType, ok := assertion.Value.(string) if !ok { result.Passed = false result.Message = "type assertion value must be a string" return result } actualType := a.getType(output) result.Actual = actualType if actualType == expectedType { result.Passed = true result.Message = fmt.Sprintf("output is of type '%s'", expectedType) } else { result.Passed = false result.Message = fmt.Sprintf("expected type '%s', got '%s'", expectedType, actualType) } return result } // getType returns the type name of a value func (a *Asserter) getType(v interface{}) string { if v == nil { return "null" } switch v.(type) { case string: return "string" case float64, float32, int, int64, int32: return "number" case bool: return "boolean" case []interface{}: return "array" case map[string]interface{}: return "object" default: return fmt.Sprintf("%T", v) } } // assertAgent uses an agent to validate the output func (a *Asserter) assertAgent(assertion *Assertion, output, input interface{}) *AssertionResult { result := &AssertionResult{ Assertion: assertion, Actual: output, } // Parse use field: "agents:tests.validator-agent" if !strings.HasPrefix(assertion.Use, "agents:") { result.Passed = false result.Message = "agent assertion requires 'use' field with 'agents:' prefix" return result } agentID := strings.TrimPrefix(assertion.Use, "agents:") // Get assistant ast, err := assistant.Get(agentID) if err != nil { result.Passed = false result.Message = fmt.Sprintf("failed to get validator agent: %s", err.Error()) return result } // Build validation request validationInput := map[string]interface{}{ "output": output, "input": input, } // Add criteria from Value field if assertion.Value != nil { validationInput["criteria"] = assertion.Value } // Add metadata from options if assertion.Options != nil && assertion.Options.Metadata != nil { for k, v := range assertion.Options.Metadata { validationInput[k] = v } } // Build context options - skip history and trace for validator opts := &context.Options{ Skip: &context.Skip{ History: true, Trace: true, Output: true, }, Metadata: map[string]interface{}{ "test_mode": "validator", }, } if assertion.Options != nil && assertion.Options.Connector != "" { opts.Connector = assertion.Options.Connector } // Create context and call agent env := NewEnvironment("", "") ctx := NewTestContext("validator", agentID, env) defer ctx.Release() // Convert validation input to JSON string for the message inputJSON, err := json.Marshal(validationInput) if err != nil { result.Passed = false result.Message = fmt.Sprintf("failed to marshal validation input: %s", err.Error()) return result } messages := []context.Message{{ Role: context.RoleUser, Content: string(inputJSON), }} response, err := ast.Stream(ctx, messages, opts) if err != nil { result.Passed = false result.Message = fmt.Sprintf("validator agent error: %s", err.Error()) return result } // Parse response return a.parseValidatorResponse(response, result) } // parseValidatorResponse parses the validator agent's response func (a *Asserter) parseValidatorResponse(response *context.Response, result *AssertionResult) *AssertionResult { output := extractValidatorOutput(response) // Expected format: { "passed": bool, "reason": string, "score": float, "suggestions": [] } if outputMap, ok := output.(map[string]interface{}); ok { if passed, ok := outputMap["passed"].(bool); ok { result.Passed = passed } else { result.Passed = false result.Message = "validator response missing 'passed' field" return result } if reason, ok := outputMap["reason"].(string); ok { result.Message = reason } // Store score and suggestions in expected field for reference result.Expected = outputMap } else { result.Passed = false result.Message = "validator agent returned invalid response format" } return result } // extractValidatorOutput extracts the output from a validator response func extractValidatorOutput(response *context.Response) interface{} { if response == nil || response.Completion == nil { return nil } // Get content from completion content := response.Completion.Content if content == nil { return nil } // Try to get text content var text string switch v := content.(type) { case string: text = v default: // Try to marshal and use as-is data, err := json.Marshal(content) if err != nil { return nil } text = string(data) } if text == "" { return nil } // Use gou/text to extract JSON (handles markdown code blocks, auto-repair, etc.) result := goutext.ExtractJSON(text) if result != nil { return result } // Return raw text if extraction fails return text } // assertScript runs a custom assertion script func (a *Asserter) assertScript(assertion *Assertion, output, input interface{}) *AssertionResult { result := &AssertionResult{ Assertion: assertion, Actual: output, } if assertion.Script == "" { result.Passed = false result.Message = "script assertion requires a script name" return result } // Build script arguments args := []interface{}{ output, input, assertion.Value, } // Run the script as a process p, err := process.Of(assertion.Script, args...) if err != nil { result.Passed = false result.Message = fmt.Sprintf("failed to create process: %s", err.Error()) return result } res, err := p.Exec() if err != nil { result.Passed = false result.Message = fmt.Sprintf("script execution failed: %s", err.Error()) return result } // Parse script result // Expected format: { "pass": bool, "message": string } switch v := res.(type) { case bool: result.Passed = v if v { result.Message = "script assertion passed" } else { result.Message = "script assertion failed" } case map[string]interface{}: if pass, ok := v["pass"].(bool); ok { result.Passed = pass } if msg, ok := v["message"].(string); ok { result.Message = msg } default: result.Passed = false result.Message = fmt.Sprintf("script returned unexpected type: %T", res) } return result } // assertToolCalled checks if a specific tool was called // value can be: // - string: exact tool name to match // - []string: any of the tool names // - map with "name" and optional "arguments" for more specific matching func (a *Asserter) assertToolCalled(assertion *Assertion) *AssertionResult { result := &AssertionResult{ Assertion: assertion, Expected: assertion.Value, } if a.response == nil { result.Passed = false result.Message = "no response available for tool_called assertion" return result } if len(a.response.Tools) == 0 { result.Passed = false result.Message = "no tools were called" return result } // Get tool names that were called calledTools := make([]string, 0, len(a.response.Tools)) for _, tool := range a.response.Tools { calledTools = append(calledTools, tool.Tool) } result.Actual = calledTools switch v := assertion.Value.(type) { case string: // Simple case: check if tool name matches (supports prefix matching) for _, tool := range a.response.Tools { if matchToolName(tool.Tool, v) { result.Passed = true result.Message = fmt.Sprintf("tool '%s' was called", v) return result } } result.Passed = false result.Message = fmt.Sprintf("tool '%s' was not called, called: %v", v, calledTools) case []interface{}: // Check if any of the specified tools were called for _, expected := range v { if expectedStr, ok := expected.(string); ok { for _, tool := range a.response.Tools { if matchToolName(tool.Tool, expectedStr) { result.Passed = true result.Message = fmt.Sprintf("tool '%s' was called", expectedStr) return result } } } } result.Passed = false result.Message = fmt.Sprintf("none of the expected tools were called, called: %v", calledTools) case map[string]interface{}: // Advanced case: match name and optionally arguments expectedName, _ := v["name"].(string) expectedArgs := v["arguments"] for _, tool := range a.response.Tools { if matchToolName(tool.Tool, expectedName) { // If arguments specified, check them too if expectedArgs != nil { if matchArguments(tool.Arguments, expectedArgs) { result.Passed = true result.Message = fmt.Sprintf("tool '%s' was called with matching arguments", expectedName) return result } } else { result.Passed = true result.Message = fmt.Sprintf("tool '%s' was called", expectedName) return result } } } result.Passed = false if expectedArgs != nil { result.Message = fmt.Sprintf("tool '%s' was not called with expected arguments", expectedName) } else { result.Message = fmt.Sprintf("tool '%s' was not called, called: %v", expectedName, calledTools) } default: result.Passed = false result.Message = fmt.Sprintf("invalid tool_called value type: %T", assertion.Value) } return result } // assertToolResult checks the result of a tool call // value should be a map with "tool" (name) and "result" (expected result pattern) func (a *Asserter) assertToolResult(assertion *Assertion) *AssertionResult { result := &AssertionResult{ Assertion: assertion, Expected: assertion.Value, } if a.response == nil { result.Passed = false result.Message = "no response available for tool_result assertion" return result } if len(a.response.Tools) == 0 { result.Passed = false result.Message = "no tools were called" return result } spec, ok := assertion.Value.(map[string]interface{}) if !ok { result.Passed = false result.Message = "tool_result assertion requires a map with 'tool' and 'result' fields" return result } toolName, _ := spec["tool"].(string) expectedResult := spec["result"] if toolName == "" { result.Passed = false result.Message = "tool_result assertion requires 'tool' field" return result } // Find the tool call for _, tool := range a.response.Tools { if matchToolName(tool.Tool, toolName) { result.Actual = tool.Result // Check if there was an error if tool.Error != "" { result.Passed = false result.Message = fmt.Sprintf("tool '%s' returned error: %s", toolName, tool.Error) return result } // If no expected result specified, just check success (no error) if expectedResult == nil { result.Passed = true result.Message = fmt.Sprintf("tool '%s' executed successfully", toolName) return result } // Match result if matchResult(tool.Result, expectedResult) { result.Passed = true result.Message = fmt.Sprintf("tool '%s' result matches expected", toolName) return result } result.Passed = false result.Message = fmt.Sprintf("tool '%s' result does not match expected", toolName) return result } } result.Passed = false result.Message = fmt.Sprintf("tool '%s' was not called", toolName) return result } // matchToolName checks if a tool name matches the expected pattern // Supports exact match and suffix match (e.g., "setup" matches "agents_expense_tools__setup") func matchToolName(actual, expected string) bool { if actual == expected { return true } // Support suffix matching (tool name without namespace prefix) if strings.HasSuffix(actual, "__"+expected) || strings.HasSuffix(actual, "."+expected) { return true } // Support contains matching for partial names if strings.Contains(actual, expected) { return true } return false } // matchArguments checks if tool arguments match expected pattern func matchArguments(actual, expected interface{}) bool { expectedMap, ok := expected.(map[string]interface{}) if !ok { return false } actualMap, ok := actual.(map[string]interface{}) if !ok { // Try parsing as JSON string if actualStr, ok := actual.(string); ok { var parsed map[string]interface{} if err := jsoniter.UnmarshalFromString(actualStr, &parsed); err == nil { actualMap = parsed } else { return false } } else { return false } } // Check that all expected keys exist and match for key, expectedVal := range expectedMap { actualVal, exists := actualMap[key] if !exists { return false } if !validateOutput(actualVal, expectedVal) { return false } } return true } // matchResult checks if tool result matches expected pattern func matchResult(actual, expected interface{}) bool { switch exp := expected.(type) { case map[string]interface{}: actualMap, ok := actual.(map[string]interface{}) if !ok { return false } // Check that all expected keys exist and match for key, expectedVal := range exp { actualVal, exists := actualMap[key] if !exists { return false } if !matchResult(actualVal, expectedVal) { return false } } return true case string: // Support regex pattern matching for strings if strings.HasPrefix(exp, "regex:") { pattern := strings.TrimPrefix(exp, "regex:") re, err := regexp.Compile(pattern) if err != nil { return false } actualStr := fmt.Sprintf("%v", actual) return re.MatchString(actualStr) } return fmt.Sprintf("%v", actual) == exp case bool: actualBool, ok := actual.(bool) return ok && actualBool == exp default: return validateOutput(actual, expected) } } // toString converts a value to string for comparison func (a *Asserter) toString(v interface{}) string { if v == nil { return "" } switch val := v.(type) { case string: return val case []byte: return string(val) default: b, err := json.Marshal(v) if err != nil { return fmt.Sprintf("%v", v) } return string(b) } } ================================================ FILE: agent/test/assert_agent_test.go ================================================ package test_test import ( "testing" "github.com/stretchr/testify/assert" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/yao/agent" agenttest "github.com/yaoapp/yao/agent/test" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" "rogchap.com/v8go" ) func TestAsserter_AgentAssertion(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agent (includes assistants) err := agent.Load(config.Conf) if err != nil { t.Fatalf("Failed to load agent: %v", err) } asserter := agenttest.NewAsserter() tests := []struct { name string tc *agenttest.Case output interface{} expected bool skipMsg string }{ { name: "agent assertion - pass", tc: &agenttest.Case{ Assert: map[string]interface{}{ "type": "agent", "use": "agents:tests.validator-agent", "value": "Response should be a greeting", }, }, output: "Hello! How can I help you today?", expected: true, }, { name: "agent assertion - fail", tc: &agenttest.Case{ Assert: map[string]interface{}{ "type": "agent", "use": "agents:tests.validator-agent", "value": "Response should provide a detailed technical answer", }, }, output: "I don't know.", expected: false, }, { name: "agent assertion - missing prefix", tc: &agenttest.Case{ Assert: map[string]interface{}{ "type": "agent", "use": "tests.validator-agent", // Missing agents: prefix "value": "Should pass", }, }, output: "Hello", expected: false, // Should fail due to missing prefix }, { name: "agent assertion - with metadata", tc: &agenttest.Case{ Assert: map[string]interface{}{ "type": "agent", "use": "agents:tests.validator-agent", "value": "Response is helpful", "options": map[string]interface{}{ "metadata": map[string]interface{}{ "context": "customer support", }, }, }, }, output: "I'd be happy to help you with your order. Let me look that up for you.", expected: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.skipMsg != "" { t.Skip(tt.skipMsg) } passed, errMsg := asserter.Validate(tt.tc, tt.output) if passed != tt.expected { t.Errorf("Expected passed=%v, got passed=%v, error: %s", tt.expected, passed, errMsg) } }) } } func TestAsserter_AgentAssertion_InvalidAgent(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agent (includes assistants) err := agent.Load(config.Conf) if err != nil { t.Fatalf("Failed to load agent: %v", err) } asserter := agenttest.NewAsserter() tc := &agenttest.Case{ Assert: map[string]interface{}{ "type": "agent", "use": "agents:nonexistent.agent", "value": "Should fail", }, } passed, errMsg := asserter.Validate(tc, "Hello") assert.False(t, passed, "Should fail for nonexistent agent") assert.Contains(t, errMsg, "failed to get validator agent", "Error should mention agent loading failure") } func TestAsserter_MapToAssertion_WithUseAndOptions(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agent (includes assistants) err := agent.Load(config.Conf) if err != nil { t.Fatalf("Failed to load agent: %v", err) } asserter := agenttest.NewAsserter() // Test that mapToAssertion correctly parses use and options fields tc := &agenttest.Case{ Assert: map[string]interface{}{ "type": "agent", "use": "agents:tests.validator-agent", "value": "criteria here", "options": map[string]interface{}{ "connector": "gpt-4o", "metadata": map[string]interface{}{ "key": "value", }, }, }, } // Validate triggers parseAssertions internally // We just verify it doesn't panic and processes correctly _, _ = asserter.Validate(tc, "test output") // If we get here without panic, the parsing worked } // TestTestingT_AssertAgent tests the JSAPI t.assert.Agent() method func TestTestingT_AssertAgent(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agent (includes assistants) err := agent.Load(config.Conf) if err != nil { t.Fatalf("Failed to load agent: %v", err) } tests := []struct { name string script string shouldFail bool }{ { name: "JSAPI agent assertion - pass", script: ` function test(t) { var response = "Hello! How can I help you today?"; t.assert.Agent(response, "tests.validator-agent", { criteria: "Response should be a friendly greeting" }); } test(__test_t); `, shouldFail: false, }, { name: "JSAPI agent assertion - JSON response", script: ` function test(t) { var response = { status: "success", data: { user: "john", email: "john@example.com" }, message: "User created successfully" }; t.assert.Agent(response, "tests.validator-agent", { criteria: "Response should be a successful API response with user data" }); } test(__test_t); `, shouldFail: false, }, { name: "JSAPI agent assertion - with metadata", script: ` function test(t) { var response = "I'd be happy to help you with your order."; t.assert.Agent(response, "tests.validator-agent", { criteria: "Response is helpful and professional", metadata: { context: "customer support" } }); } test(__test_t); `, shouldFail: false, }, { name: "JSAPI agent assertion - fail case", script: ` function test(t) { var response = "I don't know."; t.assert.Agent(response, "tests.validator-agent", { criteria: "Response should provide a detailed technical explanation" }); } test(__test_t); `, shouldFail: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create TestingT testingT := agenttest.NewTestingT(tt.name) // Create V8 isolate and context iso := v8go.NewIsolate() defer iso.Dispose() v8ctx := v8go.NewContext(iso) defer v8ctx.Close() // Create testing object testObj, err := agenttest.NewTestingTObject(v8ctx, testingT) if err != nil { t.Fatalf("Failed to create testing object: %v", err) } // Set testing object as global global := v8ctx.Global() global.Set("__test_t", testObj) // Run the test script _, err = v8ctx.RunScript(tt.script, "test.js") // Check results if tt.shouldFail { assert.True(t, testingT.Failed(), "Test should have failed") } else { if err != nil { t.Errorf("Script execution error: %v", err) } assert.False(t, testingT.Failed(), "Test should have passed, errors: %v", testingT.Errors()) } }) } } // Ensure v8 is used (for script loading) var _ = v8.Scripts ================================================ FILE: agent/test/assert_test.go ================================================ package test import ( "testing" "github.com/yaoapp/yao/agent/context" ) func TestAsserter_JSONPath_ArrayEquality(t *testing.T) { asserter := NewAsserter() tests := []struct { name string tc *Case output interface{} expected bool errMsg string }{ { name: "array equals array - same content", tc: &Case{ Assert: map[string]interface{}{ "type": "json_path", "path": "search_types", "value": []interface{}{"db"}, }, }, output: map[string]interface{}{ "search_types": []interface{}{"db"}, }, expected: true, }, { name: "array equals array - different content", tc: &Case{ Assert: map[string]interface{}{ "type": "json_path", "path": "search_types", "value": []interface{}{"db"}, }, }, output: map[string]interface{}{ "search_types": []interface{}{"web"}, }, expected: false, }, { name: "array equals array - multiple elements", tc: &Case{ Assert: map[string]interface{}{ "type": "json_path", "path": "search_types", "value": []interface{}{"web", "db"}, }, }, output: map[string]interface{}{ "search_types": []interface{}{"web", "db"}, }, expected: true, }, { name: "array equals array - different order", tc: &Case{ Assert: map[string]interface{}{ "type": "json_path", "path": "search_types", "value": []interface{}{"db", "web"}, }, }, output: map[string]interface{}{ "search_types": []interface{}{"web", "db"}, }, expected: false, // Order matters for array equality }, { name: "scalar in array - match", tc: &Case{ Assert: map[string]interface{}{ "type": "json_path", "path": "status", "value": []interface{}{"active", "pending"}, }, }, output: map[string]interface{}{ "status": "active", }, expected: true, // "active" is one of ["active", "pending"] }, { name: "scalar in array - no match", tc: &Case{ Assert: map[string]interface{}{ "type": "json_path", "path": "status", "value": []interface{}{"active", "pending"}, }, }, output: map[string]interface{}{ "status": "inactive", }, expected: false, }, { name: "simple value comparison", tc: &Case{ Assert: map[string]interface{}{ "type": "json_path", "path": "need_search", "value": true, }, }, output: map[string]interface{}{ "need_search": true, }, expected: true, }, { name: "nested path", tc: &Case{ Assert: map[string]interface{}{ "type": "json_path", "path": "result.count", "value": float64(5), }, }, output: map[string]interface{}{ "result": map[string]interface{}{ "count": float64(5), }, }, expected: true, }, { name: "array index access", tc: &Case{ Assert: map[string]interface{}{ "type": "json_path", "path": "items[0].name", "value": "first", }, }, output: map[string]interface{}{ "items": []interface{}{ map[string]interface{}{"name": "first"}, map[string]interface{}{"name": "second"}, }, }, expected: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { passed, errMsg := asserter.Validate(tt.tc, tt.output) if passed != tt.expected { t.Errorf("Expected passed=%v, got passed=%v, error: %s", tt.expected, passed, errMsg) } }) } } func TestAsserter_Contains(t *testing.T) { asserter := NewAsserter() tests := []struct { name string tc *Case output interface{} expected bool }{ { name: "string contains substring", tc: &Case{ Assert: map[string]interface{}{ "type": "contains", "value": "hello", }, }, output: "hello world", expected: true, }, { name: "string does not contain", tc: &Case{ Assert: map[string]interface{}{ "type": "contains", "value": "goodbye", }, }, output: "hello world", expected: false, }, { name: "JSON contains field", tc: &Case{ Assert: map[string]interface{}{ "type": "contains", "value": "success", }, }, output: map[string]interface{}{ "status": "success", }, expected: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { passed, _ := asserter.Validate(tt.tc, tt.output) if passed != tt.expected { t.Errorf("Expected passed=%v, got passed=%v", tt.expected, passed) } }) } } func TestAsserter_MultipleAssertions(t *testing.T) { asserter := NewAsserter() tc := &Case{ Assert: []interface{}{ map[string]interface{}{ "type": "json_path", "path": "need_search", "value": true, }, map[string]interface{}{ "type": "json_path", "path": "search_types", "value": []interface{}{"web"}, }, }, } output := map[string]interface{}{ "need_search": true, "search_types": []interface{}{"web"}, } passed, errMsg := asserter.Validate(tc, output) if !passed { t.Errorf("Expected all assertions to pass, got error: %s", errMsg) } } func TestAsserter_Negate(t *testing.T) { asserter := NewAsserter() tc := &Case{ Assert: map[string]interface{}{ "type": "contains", "value": "error", "negate": true, }, } output := "success message" passed, _ := asserter.Validate(tc, output) if !passed { t.Error("Expected negated assertion to pass") } } func TestAsserter_Regex(t *testing.T) { asserter := NewAsserter() tests := []struct { name string tc *Case output interface{} expected bool }{ { name: "regex matches", tc: &Case{ Assert: map[string]interface{}{ "type": "regex", "value": `\d{3}-\d{4}`, }, }, output: "Phone: 123-4567", expected: true, }, { name: "regex does not match", tc: &Case{ Assert: map[string]interface{}{ "type": "regex", "value": `\d{3}-\d{4}`, }, }, output: "No phone number here", expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { passed, _ := asserter.Validate(tt.tc, tt.output) if passed != tt.expected { t.Errorf("Expected passed=%v, got passed=%v", tt.expected, passed) } }) } } func TestAsserter_ToolCalled(t *testing.T) { tests := []struct { name string tc *Case response *context.Response expected bool errMsg string }{ { name: "tool called - exact match", tc: &Case{ Assert: map[string]interface{}{ "type": "tool_called", "value": "agents_expense_tools__setup", }, }, response: &context.Response{ Tools: []context.ToolCallResponse{ {Tool: "agents_expense_tools__setup", Result: map[string]interface{}{"success": true}}, }, }, expected: true, }, { name: "tool called - suffix match", tc: &Case{ Assert: map[string]interface{}{ "type": "tool_called", "value": "setup", }, }, response: &context.Response{ Tools: []context.ToolCallResponse{ {Tool: "agents_expense_tools__setup", Result: map[string]interface{}{"success": true}}, }, }, expected: true, }, { name: "tool not called", tc: &Case{ Assert: map[string]interface{}{ "type": "tool_called", "value": "setup", }, }, response: &context.Response{ Tools: []context.ToolCallResponse{}, }, expected: false, }, { name: "tool called - wrong tool", tc: &Case{ Assert: map[string]interface{}{ "type": "tool_called", "value": "setup", }, }, response: &context.Response{ Tools: []context.ToolCallResponse{ {Tool: "agents_expense_tools__submit", Result: map[string]interface{}{"success": true}}, }, }, expected: false, }, { name: "tool called - any of multiple", tc: &Case{ Assert: map[string]interface{}{ "type": "tool_called", "value": []interface{}{"setup", "init"}, }, }, response: &context.Response{ Tools: []context.ToolCallResponse{ {Tool: "agents_expense_tools__init", Result: map[string]interface{}{"success": true}}, }, }, expected: true, }, { name: "tool called - with arguments (map)", tc: &Case{ Assert: map[string]interface{}{ "type": "tool_called", "value": map[string]interface{}{ "name": "setup", "arguments": map[string]interface{}{ "action": "init", }, }, }, }, response: &context.Response{ Tools: []context.ToolCallResponse{ { Tool: "agents_expense_tools__setup", Arguments: map[string]interface{}{"action": "init", "config": map[string]interface{}{}}, Result: map[string]interface{}{"success": true}, }, }, }, expected: true, }, { name: "tool called - with arguments (JSON string)", tc: &Case{ Assert: map[string]interface{}{ "type": "tool_called", "value": map[string]interface{}{ "name": "setup", "arguments": map[string]interface{}{ "action": "init", }, }, }, }, response: &context.Response{ Tools: []context.ToolCallResponse{ { Tool: "agents_expense_tools__setup", Arguments: `{"action":"init","config":{"default_currency":"USD"}}`, Result: map[string]interface{}{"success": true}, }, }, }, expected: true, }, { name: "tool called - wrong arguments", tc: &Case{ Assert: map[string]interface{}{ "type": "tool_called", "value": map[string]interface{}{ "name": "setup", "arguments": map[string]interface{}{ "action": "update", }, }, }, }, response: &context.Response{ Tools: []context.ToolCallResponse{ { Tool: "agents_expense_tools__setup", Arguments: map[string]interface{}{"action": "init"}, Result: map[string]interface{}{"success": true}, }, }, }, expected: false, }, { name: "no response", tc: &Case{ Assert: map[string]interface{}{ "type": "tool_called", "value": "setup", }, }, response: nil, expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { asserter := NewAsserter().WithResponse(tt.response) passed, errMsg := asserter.Validate(tt.tc, nil) if passed != tt.expected { t.Errorf("Expected passed=%v, got passed=%v, error: %s", tt.expected, passed, errMsg) } }) } } func TestAsserter_ToolResult(t *testing.T) { tests := []struct { name string tc *Case response *context.Response expected bool }{ { name: "tool result - success check", tc: &Case{ Assert: map[string]interface{}{ "type": "tool_result", "value": map[string]interface{}{ "tool": "setup", "result": map[string]interface{}{ "success": true, }, }, }, }, response: &context.Response{ Tools: []context.ToolCallResponse{ { Tool: "agents_expense_tools__setup", Result: map[string]interface{}{"success": true, "message": "Setup complete"}, }, }, }, expected: true, }, { name: "tool result - message check with regex", tc: &Case{ Assert: map[string]interface{}{ "type": "tool_result", "value": map[string]interface{}{ "tool": "setup", "result": map[string]interface{}{ "message": "regex:(?i)setup.*complete", }, }, }, }, response: &context.Response{ Tools: []context.ToolCallResponse{ { Tool: "agents_expense_tools__setup", Result: map[string]interface{}{"success": true, "message": "Setup complete!"}, }, }, }, expected: true, }, { name: "tool result - no expected result (just check no error)", tc: &Case{ Assert: map[string]interface{}{ "type": "tool_result", "value": map[string]interface{}{ "tool": "setup", }, }, }, response: &context.Response{ Tools: []context.ToolCallResponse{ { Tool: "agents_expense_tools__setup", Result: map[string]interface{}{"success": true}, }, }, }, expected: true, }, { name: "tool result - tool has error", tc: &Case{ Assert: map[string]interface{}{ "type": "tool_result", "value": map[string]interface{}{ "tool": "setup", }, }, }, response: &context.Response{ Tools: []context.ToolCallResponse{ { Tool: "agents_expense_tools__setup", Error: "permission denied", }, }, }, expected: false, }, { name: "tool result - result mismatch", tc: &Case{ Assert: map[string]interface{}{ "type": "tool_result", "value": map[string]interface{}{ "tool": "setup", "result": map[string]interface{}{ "success": true, }, }, }, }, response: &context.Response{ Tools: []context.ToolCallResponse{ { Tool: "agents_expense_tools__setup", Result: map[string]interface{}{"success": false}, }, }, }, expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { asserter := NewAsserter().WithResponse(tt.response) passed, errMsg := asserter.Validate(tt.tc, nil) if passed != tt.expected { t.Errorf("Expected passed=%v, got passed=%v, error: %s", tt.expected, passed, errMsg) } }) } } func TestAsserter_MultipleToolAssertions(t *testing.T) { // This tests the exact scenario from setup-006: tool_called + tool_result response := &context.Response{ Tools: []context.ToolCallResponse{ { Tool: "agents_expense_tools__setup", Arguments: `{"action":"init","config":{"default_currency":"USD","categories":[{"id":"meals","name":"Meals","daily_limit":100}]}}`, Result: map[string]interface{}{ "success": true, "action": "init", "config": map[string]interface{}{ "default_currency": "USD", }, "message": "Setup complete!", }, }, }, } asserter := NewAsserter().WithResponse(response) tc := &Case{ Assert: []interface{}{ map[string]interface{}{ "type": "tool_called", "value": map[string]interface{}{ "name": "setup", "arguments": map[string]interface{}{ "action": "init", }, }, }, map[string]interface{}{ "type": "tool_result", "value": map[string]interface{}{ "tool": "setup", "result": map[string]interface{}{ "success": true, }, }, }, }, } result := asserter.ValidateWithDetails(tc, nil) if !result.Passed { t.Errorf("Expected multiple tool assertions to pass, got: %s", result.Message) } } func TestAsserter_SharedAsserterWithResponse(t *testing.T) { // Test that a shared asserter correctly uses WithResponse asserter := NewAsserter() // First call without response - should fail tc := &Case{ Assert: map[string]interface{}{ "type": "tool_called", "value": "setup", }, } result := asserter.ValidateWithDetails(tc, nil) if result.Passed { t.Error("Expected tool_called to fail without response") } if result.Message != "no response available for tool_called assertion" { t.Errorf("Unexpected message: %s", result.Message) } // Now set response response := &context.Response{ Tools: []context.ToolCallResponse{ { Tool: "agents_expense_tools__setup", Result: map[string]interface{}{"success": true}, }, }, } asserter.WithResponse(response) // Should pass now result = asserter.ValidateWithDetails(tc, nil) if !result.Passed { t.Errorf("Expected tool_called to pass with response, got: %s", result.Message) } } func TestAsserter_Setup006Scenario(t *testing.T) { // Exact reproduction of setup-006 scenario // Turn 2: tool was called with action: init, result has success: true response := &context.Response{ Tools: []context.ToolCallResponse{ { Tool: "agents_expense_tools__setup", Arguments: `{"action":"init","config":{"default_currency":"USD","categories":[{"id":"meals","name":"Meals","daily_limit":100},{"id":"travel","name":"Travel","daily_limit":500}]}}`, Result: map[string]interface{}{ "config": map[string]interface{}{ "categories": []interface{}{ map[string]interface{}{"daily_limit": float64(100), "id": "meals", "name": "Meals"}, map[string]interface{}{"daily_limit": float64(500), "id": "travel", "name": "Travel"}, }, "default_currency": "USD", }, "message": "Setup complete! The expense system has been initialized successfully with the configured settings. You can now start submitting expenses.", "success": true, "action": "init", }, }, }, } // This is the exact assert from setup-006's quick_complete checkpoint assertDef := []interface{}{ map[string]interface{}{ "type": "tool_called", "value": map[string]interface{}{ "name": "setup", "arguments": map[string]interface{}{ "action": "init", }, }, }, map[string]interface{}{ "type": "tool_result", "value": map[string]interface{}{ "tool": "setup", "result": map[string]interface{}{ "success": true, }, }, }, } asserter := NewAsserter().WithResponse(response) tc := &Case{Assert: assertDef} result := asserter.ValidateWithDetails(tc, nil) if !result.Passed { t.Errorf("Expected setup-006 scenario to pass, got: %s", result.Message) } // Also test individual assertions t.Run("tool_called only", func(t *testing.T) { tc2 := &Case{Assert: assertDef[0]} result2 := asserter.ValidateWithDetails(tc2, nil) if !result2.Passed { t.Errorf("Expected tool_called to pass, got: %s", result2.Message) } }) t.Run("tool_result only", func(t *testing.T) { tc3 := &Case{Assert: assertDef[1]} result3 := asserter.ValidateWithDetails(tc3, nil) if !result3.Passed { t.Errorf("Expected tool_result to pass, got: %s", result3.Message) } }) } func TestAsserter_Setup003Scenario(t *testing.T) { // Exact reproduction of setup-003 scenario // Turn 3: tool was called with action: update response := &context.Response{ Tools: []context.ToolCallResponse{ { Tool: "agents_expense_tools__setup", Arguments: `{"action":"update","config":{"categories":[{"daily_limit":500,"id":"meals","name":"Business Meals"}]}}`, Result: map[string]interface{}{ "action": "update", "config": map[string]interface{}{ "categories": []interface{}{ map[string]interface{}{"daily_limit": float64(500), "id": "meals", "name": "Business Meals"}, }, }, "message": "Configuration updated successfully! Your changes have been saved.", "success": true, }, }, }, } // This is the exact assert from setup-003's update_complete checkpoint assertDef := []interface{}{ map[string]interface{}{ "type": "tool_called", "value": map[string]interface{}{ "name": "setup", "arguments": map[string]interface{}{ "action": "update", }, }, }, map[string]interface{}{ "type": "tool_result", "value": map[string]interface{}{ "tool": "setup", "result": map[string]interface{}{ "success": true, }, }, }, } asserter := NewAsserter().WithResponse(response) tc := &Case{Assert: assertDef} result := asserter.ValidateWithDetails(tc, nil) if !result.Passed { t.Errorf("Expected setup-003 scenario to pass, got: %s", result.Message) } // Test individual assertions t.Run("tool_called with action:update", func(t *testing.T) { tc2 := &Case{Assert: assertDef[0]} result2 := asserter.ValidateWithDetails(tc2, nil) if !result2.Passed { t.Errorf("Expected tool_called to pass, got: %s", result2.Message) } }) } func TestMatchToolName(t *testing.T) { tests := []struct { actual string expected string match bool }{ {"agents_expense_tools__setup", "agents_expense_tools__setup", true}, {"agents_expense_tools__setup", "setup", true}, {"agents.expense.tools.setup", "setup", true}, {"agents_expense_tools__setup", "init", false}, {"setup", "setup", true}, } for _, tt := range tests { t.Run(tt.actual+"_"+tt.expected, func(t *testing.T) { if matchToolName(tt.actual, tt.expected) != tt.match { t.Errorf("matchToolName(%q, %q) = %v, want %v", tt.actual, tt.expected, !tt.match, tt.match) } }) } } ================================================ FILE: agent/test/context.go ================================================ package test import ( stdContext "context" "github.com/yaoapp/yao/agent/context" "github.com/yaoapp/yao/agent/output/message" "github.com/yaoapp/yao/openapi/oauth/types" ) // NewTestContext creates a new context for testing // This is similar to newAgentNextTestContext in agent_next_test.go // but configurable via Environment func NewTestContext(chatID, assistantID string, env *Environment) *context.Context { // Build authorized info from environment authorized := buildAuthorizedInfo(env) // Create context with standard initialization ctx := context.New(stdContext.Background(), authorized, chatID) ctx.ID = chatID ctx.AssistantID = assistantID ctx.Locale = env.Locale ctx.Client = context.Client{ Type: env.ClientType, UserAgent: "yao-agent-test/1.0", IP: env.ClientIP, } ctx.Referer = env.Referer ctx.Accept = context.AcceptStandard ctx.IDGenerator = message.NewIDGenerator() ctx.Metadata = make(map[string]interface{}) // Apply metadata from context config if available if env.ContextConfig != nil && env.ContextConfig.Metadata != nil { for k, v := range env.ContextConfig.Metadata { ctx.Metadata[k] = v } } // Initialize interrupt controller ctx.Interrupt = context.NewInterruptController() // Close the default logger created by context.New() and use noop logger // to suppress LLM debug output during tests if ctx.Logger != nil { ctx.Logger.Close() } ctx.Logger = context.Noop() return ctx } // buildAuthorizedInfo builds AuthorizedInfo from Environment func buildAuthorizedInfo(env *Environment) *types.AuthorizedInfo { authorized := &types.AuthorizedInfo{ Subject: env.UserID, UserID: env.UserID, TeamID: env.TeamID, TenantID: env.TeamID, } // Apply custom authorized config if available if env.ContextConfig != nil && env.ContextConfig.Authorized != nil { authCfg := env.ContextConfig.Authorized if authCfg.Sub != "" { authorized.Subject = authCfg.Sub } if authCfg.ClientID != "" { authorized.ClientID = authCfg.ClientID } if authCfg.Scope != "" { authorized.Scope = authCfg.Scope } if authCfg.SessionID != "" { authorized.SessionID = authCfg.SessionID } if authCfg.UserID != "" { authorized.UserID = authCfg.UserID } if authCfg.TeamID != "" { authorized.TeamID = authCfg.TeamID } if authCfg.TenantID != "" { authorized.TenantID = authCfg.TenantID } authorized.RememberMe = authCfg.RememberMe // Apply constraints if authCfg.Constraints != nil { authorized.Constraints = types.DataConstraints{ OwnerOnly: authCfg.Constraints.OwnerOnly, CreatorOnly: authCfg.Constraints.CreatorOnly, EditorOnly: authCfg.Constraints.EditorOnly, TeamOnly: authCfg.Constraints.TeamOnly, Extra: authCfg.Constraints.Extra, } } } return authorized } // NewTestContextFromOptions creates a test context from test options and test case func NewTestContextFromOptions(chatID, assistantID string, opts *Options, tc *Case) *context.Context { // Get environment from test case (with options override) env := tc.GetEnvironment(opts) return NewTestContext(chatID, assistantID, env) } // GenerateChatID generates a unique chat ID for testing func GenerateChatID(testID string, runNumber int) string { if runNumber > 1 { return "test-" + testID + "-run" + string(rune('0'+runNumber)) } return "test-" + testID } ================================================ FILE: agent/test/dynamic_integration_test.go ================================================ package test_test import ( "os" "path/filepath" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent" agenttest "github.com/yaoapp/yao/agent/test" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // TestDynamicRunner_CoffeeOrder tests a complete dynamic mode flow: // Simulator acts as a customer ordering coffee, agent handles the order func TestDynamicRunner_CoffeeOrder(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agents err := agent.Load(config.Conf) require.NoError(t, err, "Failed to load agents") // Create a temporary JSONL file with a dynamic test case tmpDir := t.TempDir() inputFile := filepath.Join(tmpDir, "dynamic-inputs.jsonl") // Dynamic test case: customer ordering coffee (JSONL must be single line) testCase := `{"id": "coffee-order-flow", "name": "Complete Coffee Order", "input": "Hi, I would like to order a coffee please", "simulator": {"use": "tests.simulator-agent", "options": {"metadata": {"persona": "A customer who wants to order a medium latte with oat milk", "goal": "Successfully complete a coffee order"}}}, "checkpoints": [{"id": "greeting", "description": "Agent greets and asks for order", "assert": {"type": "regex", "value": "(?i)(order|like|help)"}}, {"id": "ask_size", "description": "Agent asks for size", "after": ["greeting"], "assert": {"type": "regex", "value": "(?i)size"}}, {"id": "confirm_order", "description": "Agent confirms the order", "after": ["ask_size"], "assert": {"type": "regex", "value": "(?i)confirm"}}], "max_turns": 8}` err = os.WriteFile(inputFile, []byte(testCase), 0644) require.NoError(t, err, "Failed to write test file") // Run dynamic test opts := &agenttest.Options{ Input: inputFile, AgentID: "tests.dynamic-test-agent", Verbose: true, InputMode: agenttest.InputModeFile, } opts = agenttest.MergeOptions(opts, agenttest.DefaultOptions()) runner := agenttest.NewRunner(opts) report, err := runner.Run() require.NoError(t, err, "Runner should not return error") require.NotNil(t, report, "Report should not be nil") require.NotNil(t, report.Summary, "Summary should not be nil") // Log results t.Logf("Total: %d, Passed: %d, Failed: %d", report.Summary.Total, report.Summary.Passed, report.Summary.Failed) // Check results if len(report.Results) > 0 { result := report.Results[0] t.Logf("Test [%s] Status: %s", result.ID, result.Status) // Check metadata for dynamic mode info if result.Metadata != nil { if mode, ok := result.Metadata["mode"].(string); ok { assert.Equal(t, "dynamic", mode, "Should be dynamic mode") } if turns, ok := result.Metadata["total_turns"].(int); ok { t.Logf("Total turns: %d", turns) } } if result.Error != "" { t.Logf("Error: %s", result.Error) } } } // TestDynamicRunner_WithInitialInput tests dynamic mode with initial user input func TestDynamicRunner_WithInitialInput(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agents err := agent.Load(config.Conf) require.NoError(t, err, "Failed to load agents") // Create a test case with initial input tmpDir := t.TempDir() inputFile := filepath.Join(tmpDir, "dynamic-inputs.jsonl") // Start with user's first message (JSONL must be single line) testCase := `{"id": "coffee-with-initial", "name": "Coffee Order with Initial Message", "input": "Hi, I want to order a coffee", "simulator": {"use": "tests.simulator-agent", "options": {"metadata": {"persona": "Customer ordering a large cappuccino", "goal": "Complete the coffee order"}}}, "checkpoints": [{"id": "acknowledge", "description": "Agent acknowledges the order request", "assert": {"type": "regex", "value": "(?i)(coffee|order|help)"}}, {"id": "ask_details", "description": "Agent asks for more details", "after": ["acknowledge"], "assert": {"type": "regex", "value": "(?i)(size|type|what)"}}], "max_turns": 5}` err = os.WriteFile(inputFile, []byte(testCase), 0644) require.NoError(t, err, "Failed to write test file") // Run test opts := &agenttest.Options{ Input: inputFile, AgentID: "tests.dynamic-test-agent", Verbose: true, InputMode: agenttest.InputModeFile, } opts = agenttest.MergeOptions(opts, agenttest.DefaultOptions()) runner := agenttest.NewRunner(opts) report, err := runner.Run() require.NoError(t, err, "Runner should not return error") require.NotNil(t, report, "Report should not be nil") t.Logf("Total: %d, Passed: %d, Failed: %d", report.Summary.Total, report.Summary.Passed, report.Summary.Failed) } // TestDynamicRunner_OptionalCheckpoint tests optional checkpoint behavior func TestDynamicRunner_OptionalCheckpoint(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agents err := agent.Load(config.Conf) require.NoError(t, err, "Failed to load agents") tmpDir := t.TempDir() inputFile := filepath.Join(tmpDir, "dynamic-inputs.jsonl") // Test with one required and one optional checkpoint (JSONL must be single line) testCase := `{"id": "optional-checkpoint-test", "name": "Test with Optional Checkpoint", "input": "Hello", "simulator": {"use": "tests.simulator-agent", "options": {"metadata": {"persona": "Simple customer", "goal": "Get a greeting response"}}}, "checkpoints": [{"id": "greeting_response", "description": "Agent responds with greeting", "assert": {"type": "regex", "value": "(?i)(hello|hi|help)"}}, {"id": "special_offer", "description": "Agent mentions special offer (optional)", "required": false, "assert": {"type": "contains", "value": "special offer"}}], "max_turns": 3}` err = os.WriteFile(inputFile, []byte(testCase), 0644) require.NoError(t, err, "Failed to write test file") // Run test opts := &agenttest.Options{ Input: inputFile, AgentID: "tests.dynamic-test-agent", Verbose: true, InputMode: agenttest.InputModeFile, } opts = agenttest.MergeOptions(opts, agenttest.DefaultOptions()) runner := agenttest.NewRunner(opts) report, err := runner.Run() require.NoError(t, err, "Runner should not return error") require.NotNil(t, report, "Report should not be nil") // Test should pass even if optional checkpoint is not reached t.Logf("Total: %d, Passed: %d, Failed: %d", report.Summary.Total, report.Summary.Passed, report.Summary.Failed) // If the required checkpoint is reached, the test should pass if len(report.Results) > 0 && report.Results[0].Metadata != nil { if checkpoints, ok := report.Results[0].Metadata["checkpoints"].(map[string]*agenttest.CheckpointResult); ok { for id, cp := range checkpoints { t.Logf("Checkpoint [%s]: reached=%v, required=%v", id, cp.Reached, cp.Required) } } } } // TestDynamicRunner_MaxTurnsExceeded tests behavior when max turns is exceeded func TestDynamicRunner_MaxTurnsExceeded(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agents err := agent.Load(config.Conf) require.NoError(t, err, "Failed to load agents") tmpDir := t.TempDir() inputFile := filepath.Join(tmpDir, "dynamic-inputs.jsonl") // Test case with impossible checkpoint and low max_turns (JSONL must be single line) testCase := `{"id": "max-turns-test", "name": "Test Max Turns Exceeded", "input": "Hello", "simulator": {"use": "tests.simulator-agent", "options": {"metadata": {"persona": "Persistent customer", "goal": "Keep talking"}}}, "checkpoints": [{"id": "impossible", "description": "This checkpoint will never be reached", "assert": {"type": "contains", "value": "IMPOSSIBLE_STRING_NEVER_APPEARS_12345"}}], "max_turns": 2}` err = os.WriteFile(inputFile, []byte(testCase), 0644) require.NoError(t, err, "Failed to write test file") // Run test opts := &agenttest.Options{ Input: inputFile, AgentID: "tests.dynamic-test-agent", Verbose: true, InputMode: agenttest.InputModeFile, } opts = agenttest.MergeOptions(opts, agenttest.DefaultOptions()) runner := agenttest.NewRunner(opts) report, err := runner.Run() require.NoError(t, err, "Runner should not return error") require.NotNil(t, report, "Report should not be nil") // Test should fail due to max turns exceeded or goal achieved without checkpoints assert.Equal(t, 1, report.Summary.Failed, "Test should fail") if len(report.Results) > 0 { result := report.Results[0] assert.Equal(t, agenttest.StatusFailed, result.Status, "Status should be failed") // Either max turns exceeded or simulator signaled goal achieved without checkpoints validError := strings.Contains(result.Error, "max turns") || strings.Contains(result.Error, "not all required checkpoints reached") assert.True(t, validError, "Error should mention max turns or checkpoints not reached, got: %s", result.Error) t.Logf("Error (expected): %s", result.Error) } } // TestDynamicRunner_CheckpointOrdering tests that checkpoint ordering is enforced func TestDynamicRunner_CheckpointOrderingEnforced(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agents err := agent.Load(config.Conf) require.NoError(t, err, "Failed to load agents") tmpDir := t.TempDir() inputFile := filepath.Join(tmpDir, "dynamic-inputs.jsonl") // Test case with ordered checkpoints (JSONL must be single line) testCase := `{"id": "ordered-checkpoints", "name": "Test Checkpoint Ordering", "input": "I want to order coffee", "simulator": {"use": "tests.simulator-agent", "options": {"metadata": {"persona": "Customer ordering step by step", "goal": "Complete coffee order following the flow"}}}, "checkpoints": [{"id": "step1_greeting", "description": "Agent greets", "assert": {"type": "regex", "value": "(?i)(hello|hi|help|order)"}}, {"id": "step2_size", "description": "Agent asks about size", "after": ["step1_greeting"], "assert": {"type": "regex", "value": "(?i)size"}}, {"id": "step3_confirm", "description": "Agent confirms", "after": ["step2_size"], "assert": {"type": "regex", "value": "(?i)confirm"}}], "max_turns": 10}` err = os.WriteFile(inputFile, []byte(testCase), 0644) require.NoError(t, err, "Failed to write test file") // Run test opts := &agenttest.Options{ Input: inputFile, AgentID: "tests.dynamic-test-agent", Verbose: true, InputMode: agenttest.InputModeFile, } opts = agenttest.MergeOptions(opts, agenttest.DefaultOptions()) runner := agenttest.NewRunner(opts) report, err := runner.Run() require.NoError(t, err, "Runner should not return error") require.NotNil(t, report, "Report should not be nil") t.Logf("Total: %d, Passed: %d, Failed: %d", report.Summary.Total, report.Summary.Passed, report.Summary.Failed) // Log checkpoint order if len(report.Results) > 0 && report.Results[0].Metadata != nil { if checkpoints, ok := report.Results[0].Metadata["checkpoints"].(map[string]*agenttest.CheckpointResult); ok { for id, cp := range checkpoints { t.Logf("Checkpoint [%s]: reached=%v, at_turn=%d", id, cp.Reached, cp.ReachedAtTurn) } } } } ================================================ FILE: agent/test/dynamic_runner.go ================================================ package test import ( "fmt" "time" jsoniter "github.com/json-iterator/go" goutext "github.com/yaoapp/gou/text" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" ) // DynamicRunner handles dynamic (simulator-driven) test execution type DynamicRunner struct { opts *Options output *OutputWriter asserter *Asserter } // NewDynamicRunner creates a new dynamic runner func NewDynamicRunner(opts *Options) *DynamicRunner { return &DynamicRunner{ opts: opts, output: NewOutputWriter(opts.Verbose), asserter: NewAsserter(), } } // RunDynamic executes a dynamic test case func (r *DynamicRunner) RunDynamic(ast *assistant.Assistant, tc *Case, agentID string) *DynamicResult { startTime := time.Now() result := &DynamicResult{ ID: tc.ID, Turns: make([]*TurnResult, 0), Checkpoints: make(map[string]*CheckpointResult), } // Initialize checkpoints for _, cp := range tc.Checkpoints { result.Checkpoints[cp.ID] = &CheckpointResult{ ID: cp.ID, Reached: false, Required: cp.IsRequired(), } } // Get simulator agent simAST, err := assistant.Get(tc.Simulator.Use) if err != nil { result.Status = StatusError result.Error = fmt.Sprintf("failed to get simulator agent: %s", err.Error()) result.DurationMs = time.Since(startTime).Milliseconds() return result } // Get configuration maxTurns := tc.GetMaxTurns() timeout := tc.GetTimeout(r.opts.Timeout) // Build simulator metadata simMetadata := make(map[string]interface{}) if tc.Simulator.Options != nil && tc.Simulator.Options.Metadata != nil { for k, v := range tc.Simulator.Options.Metadata { simMetadata[k] = v } } // Conversation history messages := make([]context.Message, 0) // Get initial input if provided initialMessages, err := tc.GetMessages() if err == nil && len(initialMessages) > 0 { messages = append(messages, initialMessages...) } // Output dynamic test start if r.opts.Verbose { r.output.Verbose("Dynamic test: %s (max %d turns)", tc.ID, maxTurns) } // Use consistent chatID across all turns to preserve session state (ctx.memory.chat) // Priority: context config > generated ID chatID := fmt.Sprintf("dynamic-%s", tc.ID) if r.opts.ContextData != nil && r.opts.ContextData.ChatID != "" { chatID = r.opts.ContextData.ChatID } // Conversation loop for turn := 1; turn <= maxTurns; turn++ { turnStart := time.Now() turnResult := &TurnResult{Turn: turn} // Check timeout if time.Since(startTime) > timeout { result.Status = StatusTimeout result.Error = fmt.Sprintf("timeout after %s", timeout) result.DurationMs = time.Since(startTime).Milliseconds() result.TotalTurns = turn - 1 return result } // For turns after the first, get input from simulator if turn > 1 || len(messages) == 0 { simInput := r.buildSimulatorInput(tc, messages, result, turn, maxTurns, simMetadata) simOutput, err := r.callSimulator(simAST, tc, simInput) if err != nil { turnResult.Error = fmt.Sprintf("simulator error: %s", err.Error()) result.Turns = append(result.Turns, turnResult) result.Status = StatusError result.Error = turnResult.Error result.DurationMs = time.Since(startTime).Milliseconds() result.TotalTurns = turn return result } // Check if goal achieved if simOutput.GoalAchieved { if r.opts.Verbose { r.output.Verbose("Turn %d: Simulator signaled goal achieved", turn) } // Check if all required checkpoints reached if r.allRequiredCheckpointsReached(result) { result.Status = StatusPassed } else { result.Status = StatusFailed result.Error = "simulator signaled goal achieved but not all required checkpoints reached" } result.DurationMs = time.Since(startTime).Milliseconds() result.TotalTurns = turn - 1 return result } // Add user message userMessage := context.Message{ Role: context.RoleUser, Content: simOutput.Message, } messages = append(messages, userMessage) turnResult.Input = simOutput.Message if r.opts.Verbose { r.output.Verbose("Turn %d: User: %s", turn, truncateOutput(simOutput.Message, 50)) } } else { // Use initial input for first turn if len(messages) > 0 { lastMsg := messages[len(messages)-1] turnResult.Input = lastMsg.Content if r.opts.Verbose { r.output.Verbose("Turn %d: User: %s", turn, truncateOutput(lastMsg.Content, 50)) } } } // Call target agent // Use consistent chatID across all turns to preserve session state (ctx.memory.chat) ctx := NewTestContextFromOptions( chatID, agentID, r.opts, tc, ) opts := buildContextOptions(tc, r.opts) response, err := ast.Stream(ctx, messages, opts) ctx.Release() if err != nil { turnResult.Error = err.Error() turnResult.DurationMs = time.Since(turnStart).Milliseconds() result.Turns = append(result.Turns, turnResult) result.Status = StatusError result.Error = fmt.Sprintf("agent error at turn %d: %s", turn, err.Error()) result.DurationMs = time.Since(startTime).Milliseconds() result.TotalTurns = turn return result } // Extract output (summary for display and conversation history) output := extractOutput(response) turnResult.Output = output turnResult.DurationMs = time.Since(turnStart).Milliseconds() // Store full response for reporting turnResult.Response = buildTurnResponse(response) if r.opts.Verbose { r.output.Verbose("Turn %d: Agent: %s", turn, truncateOutput(output, 50)) } // Add assistant response to messages, including tool calls if any messages = appendAssistantMessages(messages, response) // Check checkpoints against this response (including tool results) reachedIDs := r.checkCheckpoints(tc.Checkpoints, response, result) turnResult.CheckpointsReached = reachedIDs if r.opts.Verbose && len(reachedIDs) > 0 { for _, id := range reachedIDs { r.output.Verbose(" ✓ checkpoint: %s", id) } } result.Turns = append(result.Turns, turnResult) // Check if all required checkpoints reached if r.allRequiredCheckpointsReached(result) { result.Status = StatusPassed result.DurationMs = time.Since(startTime).Milliseconds() result.TotalTurns = turn return result } } // Max turns exceeded result.Status = StatusFailed result.Error = fmt.Sprintf("max turns (%d) exceeded without reaching all checkpoints", maxTurns) result.DurationMs = time.Since(startTime).Milliseconds() result.TotalTurns = maxTurns return result } // buildSimulatorInput builds the input for the simulator agent func (r *DynamicRunner) buildSimulatorInput( tc *Case, messages []context.Message, result *DynamicResult, turn, maxTurns int, metadata map[string]interface{}, ) *SimulatorInput { input := &SimulatorInput{ Conversation: messages, TurnNumber: turn, MaxTurns: maxTurns, } // Extract persona and goal from metadata if persona, ok := metadata["persona"].(string); ok { input.Persona = persona } if goal, ok := metadata["goal"].(string); ok { input.Goal = goal } // Build checkpoint lists input.CheckpointsReached = make([]string, 0) input.CheckpointsPending = make([]string, 0) for id, cp := range result.Checkpoints { if cp.Reached { input.CheckpointsReached = append(input.CheckpointsReached, id) } else { input.CheckpointsPending = append(input.CheckpointsPending, id) } } // Store extra metadata input.Extra = make(map[string]interface{}) for k, v := range metadata { if k != "persona" && k != "goal" { input.Extra[k] = v } } return input } // callSimulator calls the simulator agent and parses the response func (r *DynamicRunner) callSimulator(simAST *assistant.Assistant, tc *Case, input *SimulatorInput) (*SimulatorOutput, error) { // Create context env := NewEnvironment("", "") ctx := NewTestContext("simulator", tc.Simulator.Use, env) defer ctx.Release() // Build options - skip history and trace opts := &context.Options{ Skip: &context.Skip{ History: true, Trace: true, Output: true, }, Metadata: map[string]interface{}{ "test_mode": "simulator", }, } // Override connector if specified if tc.Simulator.Options != nil && tc.Simulator.Options.Connector != "" { opts.Connector = tc.Simulator.Options.Connector } // Build message inputJSON, err := jsoniter.Marshal(input) if err != nil { return nil, fmt.Errorf("failed to marshal simulator input: %w", err) } messages := []context.Message{{ Role: context.RoleUser, Content: string(inputJSON), }} // Call simulator response, err := simAST.Stream(ctx, messages, opts) if err != nil { return nil, fmt.Errorf("simulator agent error: %w", err) } // Parse response return r.parseSimulatorResponse(response) } // parseSimulatorResponse parses the simulator agent's response func (r *DynamicRunner) parseSimulatorResponse(response *context.Response) (*SimulatorOutput, error) { if response == nil || response.Completion == nil { return nil, fmt.Errorf("empty response from simulator") } // Extract content content := response.Completion.Content if content == nil { return nil, fmt.Errorf("no content in simulator response") } // Convert to string var text string switch v := content.(type) { case string: text = v default: data, err := jsoniter.Marshal(content) if err != nil { return nil, fmt.Errorf("failed to marshal content: %w", err) } text = string(data) } // Use goutext.ExtractJSON for fault-tolerant parsing parsed := goutext.ExtractJSON(text) if parsed == nil { // Try to use the text as the message directly return &SimulatorOutput{ Message: text, GoalAchieved: false, }, nil } // Parse as SimulatorOutput output := &SimulatorOutput{} if m, ok := parsed.(map[string]interface{}); ok { if msg, ok := m["message"].(string); ok { output.Message = msg } if achieved, ok := m["goal_achieved"].(bool); ok { output.GoalAchieved = achieved } if reasoning, ok := m["reasoning"].(string); ok { output.Reasoning = reasoning } } if output.Message == "" { return nil, fmt.Errorf("simulator returned empty message") } return output, nil } // checkCheckpoints validates checkpoints against current response // It checks both the completion content and tool results for comprehensive validation func (r *DynamicRunner) checkCheckpoints(checkpoints []*Checkpoint, response *context.Response, result *DynamicResult) []string { reachedIDs := make([]string, 0) // Build combined output for checkpoint validation // This includes both content and tool result messages combinedOutput := buildCombinedOutput(response) // Set response on asserter for tool-related assertions r.asserter.WithResponse(response) for _, cp := range checkpoints { cpResult := result.Checkpoints[cp.ID] if cpResult.Reached { continue // Already reached } // Check "after" constraint if len(cp.After) > 0 { allAfterReached := true for _, afterID := range cp.After { if afterResult, ok := result.Checkpoints[afterID]; ok { if !afterResult.Reached { allAfterReached = false break } } } if !allAfterReached { continue // Dependencies not met } } // Validate using asserter against combined output with full details tempCase := &Case{Assert: cp.Assert} assertResult := r.asserter.ValidateWithDetails(tempCase, combinedOutput) if assertResult.Passed { cpResult.Reached = true cpResult.Passed = true cpResult.ReachedAtTurn = len(result.Turns) + 1 cpResult.Message = assertResult.Message reachedIDs = append(reachedIDs, cp.ID) } else { // Store failure message for debugging (but don't mark as failed yet - it might pass in a later turn) if cpResult.Message == "" { cpResult.Message = assertResult.Message } } // Store agent validation details if this is an agent assertion if isAgentAssertion(cp.Assert) { // Extract criteria from assertion value var criteria string if assertMap, ok := cp.Assert.(map[string]interface{}); ok { if c, ok := assertMap["value"].(string); ok { criteria = c } } cpResult.AgentValidation = &AgentValidationResult{ Passed: assertResult.Passed, Criteria: criteria, Input: combinedOutput, // Content sent to validator for checking } // Extract reason and store full response from validator if assertResult.Expected != nil { if validatorResponse, ok := assertResult.Expected.(map[string]interface{}); ok { if reason, ok := validatorResponse["reason"].(string); ok { cpResult.AgentValidation.Reason = reason } // Store the full validator response cpResult.AgentValidation.Response = validatorResponse } } } } return reachedIDs } // isAgentAssertion checks if the assertion is an agent-based assertion func isAgentAssertion(assert interface{}) bool { if assertMap, ok := assert.(map[string]interface{}); ok { if assertType, ok := assertMap["type"].(string); ok { return assertType == "agent" } } return false } // truncateForReport truncates content for report output func truncateForReport(content interface{}, maxLen int) interface{} { if content == nil { return nil } str, ok := content.(string) if !ok { return content } if len(str) <= maxLen { return str } return str[:maxLen] + "... (truncated)" } // buildCombinedOutput builds a combined output string from response // that includes both completion content and tool result messages func buildCombinedOutput(response *context.Response) string { if response == nil { return "" } var parts []string // Add completion content if response.Completion != nil && response.Completion.Content != nil { if content := extractContentString(response.Completion.Content); content != "" { parts = append(parts, content) } } // Add tool result messages if len(response.Tools) > 0 { for _, tool := range response.Tools { if tool.Result != nil { // Try to extract message from result if resultMap, ok := tool.Result.(map[string]interface{}); ok { if msg, exists := resultMap["message"]; exists && msg != nil { if msgStr, ok := msg.(string); ok && msgStr != "" { parts = append(parts, msgStr) } } } } } } // Join all parts with newline for comprehensive matching return joinNonEmpty(parts, "\n") } // extractContentString extracts string content from various types func extractContentString(content interface{}) string { if content == nil { return "" } switch v := content.(type) { case string: return v case []interface{}: // Handle array content (e.g., multimodal content) var texts []string for _, item := range v { if m, ok := item.(map[string]interface{}); ok { if text, ok := m["text"].(string); ok { texts = append(texts, text) } } } return joinNonEmpty(texts, "\n") default: // Try to marshal to string if data, err := jsoniter.MarshalToString(content); err == nil { return data } return "" } } // joinNonEmpty joins non-empty strings with separator func joinNonEmpty(parts []string, sep string) string { var nonEmpty []string for _, p := range parts { if p != "" { nonEmpty = append(nonEmpty, p) } } if len(nonEmpty) == 0 { return "" } result := nonEmpty[0] for i := 1; i < len(nonEmpty); i++ { result += sep + nonEmpty[i] } return result } // allRequiredCheckpointsReached checks if all required checkpoints are reached func (r *DynamicRunner) allRequiredCheckpointsReached(result *DynamicResult) bool { for _, cp := range result.Checkpoints { if cp.Required && !cp.Reached { return false } } return true } // buildTurnResponse builds a TurnResponse from the agent response func buildTurnResponse(response *context.Response) *TurnResponse { if response == nil { return nil } tr := &TurnResponse{} // Extract completion content if response.Completion != nil { tr.Content = response.Completion.Content // Extract tool calls from completion if len(response.Completion.ToolCalls) > 0 { for _, tc := range response.Completion.ToolCalls { tr.ToolCalls = append(tr.ToolCalls, ToolCallInfo{ Tool: tc.Function.Name, Arguments: tc.Function.Arguments, }) } } } // Add tool results if len(response.Tools) > 0 { // If we already have tool calls from completion, match results if len(tr.ToolCalls) > 0 { for i, toolResult := range response.Tools { if i < len(tr.ToolCalls) { tr.ToolCalls[i].Result = toolResult.Result } } } else { // Create tool call entries from results for _, toolResult := range response.Tools { tr.ToolCalls = append(tr.ToolCalls, ToolCallInfo{ Tool: toolResult.Tool, Arguments: toolResult.Arguments, Result: toolResult.Result, }) } } } // Extract Next hook data if response.Next != nil && !isEmptyValue(response.Next) { tr.Next = response.Next } return tr } // appendAssistantMessages appends assistant messages to the conversation history // including tool calls and tool results if present func appendAssistantMessages(messages []context.Message, response *context.Response) []context.Message { if response == nil { return messages } // Check if there are tool calls in the completion hasToolCalls := response.Completion != nil && len(response.Completion.ToolCalls) > 0 if hasToolCalls { // Add assistant message with tool calls assistantMsg := context.Message{ Role: context.RoleAssistant, ToolCalls: response.Completion.ToolCalls, } // Include content if present if response.Completion.Content != nil && !isEmptyValue(response.Completion.Content) { assistantMsg.Content = response.Completion.Content } messages = append(messages, assistantMsg) // Add tool result messages for each tool call for i, tc := range response.Completion.ToolCalls { toolCallID := tc.ID var resultContent string // Get result from response.Tools if available if i < len(response.Tools) { resultJSON, err := jsoniter.MarshalToString(response.Tools[i].Result) if err == nil { resultContent = resultJSON } else { resultContent = fmt.Sprintf("%v", response.Tools[i].Result) } } else { resultContent = "{}" } messages = append(messages, context.Message{ Role: context.RoleTool, ToolCallID: &toolCallID, Content: resultContent, }) } } else { // No tool calls, just add content if present content := extractOutput(response) if content != nil && !isEmptyValue(content) { messages = append(messages, context.Message{ Role: context.RoleAssistant, Content: content, }) } } return messages } ================================================ FILE: agent/test/dynamic_runner_test.go ================================================ package test_test import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent" "github.com/yaoapp/yao/agent/test" "github.com/yaoapp/yao/config" testutils "github.com/yaoapp/yao/test" ) func TestCase_IsDynamicMode(t *testing.T) { tests := []struct { name string tc *test.Case expected bool }{ { name: "standard mode - no simulator", tc: &test.Case{ ID: "T001", Input: "Hello", }, expected: false, }, { name: "standard mode - simulator but no checkpoints", tc: &test.Case{ ID: "T002", Input: "Hello", Simulator: &test.Simulator{Use: "tests.simulator-agent"}, }, expected: false, }, { name: "standard mode - checkpoints but no simulator", tc: &test.Case{ ID: "T003", Input: "Hello", Checkpoints: []*test.Checkpoint{ {ID: "cp1", Assert: map[string]interface{}{"type": "contains", "value": "hi"}}, }, }, expected: false, }, { name: "dynamic mode - has both simulator and checkpoints", tc: &test.Case{ ID: "T004", Simulator: &test.Simulator{Use: "tests.simulator-agent"}, Checkpoints: []*test.Checkpoint{ {ID: "cp1", Assert: map[string]interface{}{"type": "contains", "value": "hi"}}, }, }, expected: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.tc.IsDynamicMode() assert.Equal(t, tt.expected, result) }) } } func TestCase_GetMaxTurns(t *testing.T) { tests := []struct { name string tc *test.Case expected int }{ { name: "default max turns", tc: &test.Case{ID: "T001"}, expected: 20, }, { name: "custom max turns", tc: &test.Case{ID: "T002", MaxTurns: 10}, expected: 10, }, { name: "zero max turns uses default", tc: &test.Case{ID: "T003", MaxTurns: 0}, expected: 20, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.tc.GetMaxTurns() assert.Equal(t, tt.expected, result) }) } } func TestCheckpoint_IsRequired(t *testing.T) { boolTrue := true boolFalse := false tests := []struct { name string cp *test.Checkpoint expected bool }{ { name: "default is required", cp: &test.Checkpoint{ID: "cp1"}, expected: true, }, { name: "explicitly required", cp: &test.Checkpoint{ID: "cp2", Required: &boolTrue}, expected: true, }, { name: "explicitly not required", cp: &test.Checkpoint{ID: "cp3", Required: &boolFalse}, expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.cp.IsRequired() assert.Equal(t, tt.expected, result) }) } } func TestDynamicResult_ToResult(t *testing.T) { dr := &test.DynamicResult{ ID: "T001", Status: test.StatusPassed, TotalTurns: 3, DurationMs: 5000, Turns: []*test.TurnResult{ {Turn: 1, Input: "Hello", Output: "Hi there!"}, {Turn: 2, Input: "How are you?", Output: "I'm doing well!"}, {Turn: 3, Input: "Goodbye", Output: "Bye!"}, }, Checkpoints: map[string]*test.CheckpointResult{ "greet": {ID: "greet", Reached: true, ReachedAtTurn: 1, Required: true}, "bye": {ID: "bye", Reached: true, ReachedAtTurn: 3, Required: true}, }, } result := dr.ToResult() assert.Equal(t, "T001", result.ID) assert.Equal(t, test.StatusPassed, result.Status) assert.Equal(t, int64(5000), result.DurationMs) assert.Equal(t, "Hello", result.Input) assert.Equal(t, "Bye!", result.Output) // Check metadata assert.NotNil(t, result.Metadata) assert.Equal(t, "dynamic", result.Metadata["mode"]) assert.Equal(t, 3, result.Metadata["total_turns"]) } func TestDynamicRunner_Integration(t *testing.T) { // Skip if running in short mode if testing.Short() { t.Skip("skipping integration test in short mode") } // Prepare test environment testutils.Prepare(t, config.Conf) defer testutils.Clean() // Load agents err := agent.Load(config.Conf) if err != nil { t.Skipf("Failed to load agents: %v", err) } // Create a dynamic test case tc := &test.Case{ ID: "dynamic-greeting", Simulator: &test.Simulator{ Use: "tests.simulator-agent", Options: &test.SimulatorOptions{ Metadata: map[string]interface{}{ "persona": "Friendly user", "goal": "Have a brief greeting exchange", }, }, }, Input: "Hello!", Checkpoints: []*test.Checkpoint{ { ID: "greeting", Description: "Agent should greet back", Assert: map[string]interface{}{ "type": "regex", "value": "(?i)(hello|hi|hey|greetings)", }, }, }, MaxTurns: 3, } // Verify it's dynamic mode assert.True(t, tc.IsDynamicMode()) // Create runner options opts := &test.Options{ Verbose: true, Timeout: 30 * time.Second, } // Create dynamic runner runner := test.NewDynamicRunner(opts) assert.NotNil(t, runner) // Note: Full integration test would require the simulator agent to be loaded // and would make actual LLM calls. For CI, we test the structure and logic. } func TestDynamicRunner_CheckpointOrdering(t *testing.T) { // Test that checkpoints with "after" constraints are properly ordered testutils.Prepare(t, config.Conf) defer testutils.Clean() // Load agents err := agent.Load(config.Conf) if err != nil { t.Skipf("Failed to load agents: %v", err) } // Create a test case with ordered checkpoints tc := &test.Case{ ID: "ordered-checkpoints", Simulator: &test.Simulator{ Use: "tests.simulator-agent", Options: &test.SimulatorOptions{ Metadata: map[string]interface{}{ "persona": "Customer", "goal": "Complete a purchase", }, }, }, Checkpoints: []*test.Checkpoint{ { ID: "ask_product", Description: "Agent asks about product", Assert: map[string]interface{}{ "type": "contains", "value": "product", }, }, { ID: "confirm_order", Description: "Agent confirms order", After: []string{"ask_product"}, Assert: map[string]interface{}{ "type": "contains", "value": "confirm", }, }, { ID: "complete", Description: "Order completed", After: []string{"confirm_order"}, Assert: map[string]interface{}{ "type": "contains", "value": "complete", }, }, }, MaxTurns: 10, } // Verify checkpoint structure assert.Len(t, tc.Checkpoints, 3) assert.Empty(t, tc.Checkpoints[0].After) assert.Equal(t, []string{"ask_product"}, tc.Checkpoints[1].After) assert.Equal(t, []string{"confirm_order"}, tc.Checkpoints[2].After) } func TestSimulatorInput_Structure(t *testing.T) { // Test SimulatorInput structure input := &test.SimulatorInput{ Persona: "Test user", Goal: "Complete task", TurnNumber: 3, MaxTurns: 10, CheckpointsReached: []string{"cp1", "cp2"}, CheckpointsPending: []string{"cp3"}, Extra: map[string]interface{}{ "style": "formal", }, } assert.Equal(t, "Test user", input.Persona) assert.Equal(t, "Complete task", input.Goal) assert.Equal(t, 3, input.TurnNumber) assert.Equal(t, 10, input.MaxTurns) assert.Len(t, input.CheckpointsReached, 2) assert.Len(t, input.CheckpointsPending, 1) assert.Equal(t, "formal", input.Extra["style"]) } func TestSimulatorOutput_Structure(t *testing.T) { // Test SimulatorOutput structure output := &test.SimulatorOutput{ Message: "I'd like to buy a product", GoalAchieved: false, Reasoning: "Continuing toward purchase goal", } assert.Equal(t, "I'd like to buy a product", output.Message) assert.False(t, output.GoalAchieved) assert.Equal(t, "Continuing toward purchase goal", output.Reasoning) } ================================================ FILE: agent/test/dynamic_types.go ================================================ package test import "github.com/yaoapp/yao/agent/context" // DynamicResult represents the result of a dynamic (simulator-driven) test type DynamicResult struct { // ID is the test case identifier ID string `json:"id"` // Status is the overall test status Status Status `json:"status"` // Turns contains results for each conversation turn Turns []*TurnResult `json:"turns"` // Checkpoints maps checkpoint ID to its result Checkpoints map[string]*CheckpointResult `json:"checkpoints"` // TotalTurns is the number of turns executed TotalTurns int `json:"total_turns"` // DurationMs is the total execution time in milliseconds DurationMs int64 `json:"duration_ms"` // Error contains error message if status is failed/error/timeout Error string `json:"error,omitempty"` } // TurnResult represents the result of a single conversation turn type TurnResult struct { // Turn is the turn number (1-based) Turn int `json:"turn"` // Input is the user message (from simulator or initial input) Input interface{} `json:"input"` // Output is the agent's response (summary for display and conversation history) Output interface{} `json:"output,omitempty"` // Response is the full agent response including completion and tool results Response *TurnResponse `json:"response,omitempty"` // CheckpointsReached lists checkpoint IDs reached in this turn CheckpointsReached []string `json:"checkpoints_reached,omitempty"` // DurationMs is the turn execution time in milliseconds DurationMs int64 `json:"duration_ms"` // Error contains error message if this turn failed Error string `json:"error,omitempty"` } // TurnResponse contains the full agent response for a turn type TurnResponse struct { // Content is the text content from LLM completion Content interface{} `json:"content,omitempty"` // ToolCalls contains the tool calls made by the agent ToolCalls []ToolCallInfo `json:"tool_calls,omitempty"` // Next is the data returned from Next hook Next interface{} `json:"next,omitempty"` } // ToolCallInfo contains information about a tool call type ToolCallInfo struct { // Tool is the tool name Tool string `json:"tool"` // Arguments are the tool call arguments Arguments interface{} `json:"arguments,omitempty"` // Result is the tool execution result Result interface{} `json:"result,omitempty"` } // CheckpointResult represents the result of a checkpoint validation type CheckpointResult struct { // ID is the checkpoint identifier ID string `json:"id"` // Reached indicates if the checkpoint was reached Reached bool `json:"reached"` // ReachedAtTurn is the turn number when checkpoint was reached (0 if not reached) ReachedAtTurn int `json:"reached_at_turn,omitempty"` // Required indicates if this checkpoint is required Required bool `json:"required"` // Passed indicates if the checkpoint assertion passed Passed bool `json:"passed"` // Message contains assertion result message Message string `json:"message,omitempty"` // AgentValidation contains the agent validator's response (for agent assertions) AgentValidation *AgentValidationResult `json:"agent_validation,omitempty"` } // AgentValidationResult contains the result from an agent-based assertion type AgentValidationResult struct { // Passed indicates if the agent validator determined the assertion passed Passed bool `json:"passed"` // Reason is the explanation from the agent validator Reason string `json:"reason,omitempty"` // Criteria is the validation criteria that was checked Criteria string `json:"criteria,omitempty"` // Input is the content that was sent to validator for checking Input interface{} `json:"input,omitempty"` // Response is the raw response from the validator agent Response interface{} `json:"response,omitempty"` } // SimulatorInput is the input sent to the simulator agent type SimulatorInput struct { // Persona describes the user being simulated Persona string `json:"persona,omitempty"` // Goal is what the user is trying to achieve Goal string `json:"goal,omitempty"` // Conversation is the message history Conversation []context.Message `json:"conversation"` // TurnNumber is the current turn (1-based) TurnNumber int `json:"turn_number"` // MaxTurns is the maximum allowed turns MaxTurns int `json:"max_turns"` // CheckpointsReached lists checkpoint IDs already reached CheckpointsReached []string `json:"checkpoints_reached,omitempty"` // CheckpointsPending lists checkpoint IDs still pending CheckpointsPending []string `json:"checkpoints_pending,omitempty"` // Extra metadata from simulator options Extra map[string]interface{} `json:"extra,omitempty"` } // SimulatorOutput is the expected output from the simulator agent type SimulatorOutput struct { // Message is the simulated user message Message string `json:"message"` // GoalAchieved indicates if the user's goal has been accomplished GoalAchieved bool `json:"goal_achieved"` // Reasoning explains the simulator's response strategy Reasoning string `json:"reasoning,omitempty"` } // ToResult converts DynamicResult to standard Result for reporting func (dr *DynamicResult) ToResult() *Result { result := &Result{ ID: dr.ID, Status: dr.Status, DurationMs: dr.DurationMs, Error: dr.Error, } // Store dynamic-specific data in metadata result.Metadata = map[string]interface{}{ "mode": "dynamic", "total_turns": dr.TotalTurns, "turns": dr.Turns, "checkpoints": dr.Checkpoints, } // Set input from first turn if len(dr.Turns) > 0 { result.Input = dr.Turns[0].Input } // Set output from last turn if len(dr.Turns) > 0 { result.Output = dr.Turns[len(dr.Turns)-1].Output } return result } // IsDynamicMode checks if a test case should run in dynamic mode func (tc *Case) IsDynamicMode() bool { return tc.Simulator != nil && len(tc.Checkpoints) > 0 } // GetMaxTurns returns the max turns for dynamic mode func (tc *Case) GetMaxTurns() int { if tc.MaxTurns > 0 { return tc.MaxTurns } return 20 // Default max turns } // IsRequired returns true if the checkpoint is required func (cp *Checkpoint) IsRequired() bool { if cp.Required == nil { return true // Default to required } return *cp.Required } ================================================ FILE: agent/test/extract.go ================================================ package test import ( "fmt" "os" "path/filepath" "strings" jsoniter "github.com/json-iterator/go" ) // ExtractOptions represents options for extracting test results type ExtractOptions struct { // InputFile is the path to the output JSONL file from test run InputFile string // OutputDir is the directory to write extracted files (default: same as input file) OutputDir string // Format is the output format: "markdown" (default), "json" Format string } // Extractor extracts test results to individual files for review type Extractor struct { opts *ExtractOptions } // NewExtractor creates a new extractor func NewExtractor(opts *ExtractOptions) *Extractor { if opts.Format == "" { opts.Format = "markdown" } if opts.OutputDir == "" { opts.OutputDir = filepath.Dir(opts.InputFile) } return &Extractor{opts: opts} } // Extract reads the test output file and extracts results to individual files func (e *Extractor) Extract() ([]string, error) { // Read the JSONL file data, err := os.ReadFile(e.opts.InputFile) if err != nil { return nil, fmt.Errorf("failed to read input file: %w", err) } // Parse the JSON (the output file is a single JSON object, not JSONL) var report Report if err := jsoniter.Unmarshal(data, &report); err != nil { return nil, fmt.Errorf("failed to parse test report: %w", err) } // Create output directory if it doesn't exist if err := os.MkdirAll(e.opts.OutputDir, 0755); err != nil { return nil, fmt.Errorf("failed to create output directory: %w", err) } var extractedFiles []string // Extract each result for _, result := range report.Results { var filename string var content string switch e.opts.Format { case "markdown": filename = filepath.Join(e.opts.OutputDir, result.ID+".md") content = e.formatMarkdown(result) case "json": filename = filepath.Join(e.opts.OutputDir, result.ID+".json") jsonBytes, err := jsoniter.MarshalIndent(result, "", " ") if err != nil { return extractedFiles, fmt.Errorf("failed to marshal result %s: %w", result.ID, err) } content = string(jsonBytes) default: return nil, fmt.Errorf("unsupported format: %s", e.opts.Format) } if err := os.WriteFile(filename, []byte(content), 0644); err != nil { return extractedFiles, fmt.Errorf("failed to write file %s: %w", filename, err) } extractedFiles = append(extractedFiles, filename) } return extractedFiles, nil } // formatMarkdown formats a single test result as Markdown func (e *Extractor) formatMarkdown(result *Result) string { var sb strings.Builder // Title sb.WriteString(fmt.Sprintf("# %s\n\n", result.ID)) // Status badge switch result.Status { case StatusPassed: sb.WriteString("**Status**: ✅ PASSED\n\n") case StatusFailed: sb.WriteString("**Status**: ❌ FAILED\n\n") case StatusError: sb.WriteString("**Status**: ⚠️ ERROR\n\n") case StatusTimeout: sb.WriteString("**Status**: ⏱️ TIMEOUT\n\n") case StatusSkipped: sb.WriteString("**Status**: ⏭️ SKIPPED\n\n") } // Duration sb.WriteString(fmt.Sprintf("**Duration**: %dms\n\n", result.DurationMs)) // Error (if any) if result.Error != "" { sb.WriteString("## Error\n\n") sb.WriteString("```\n") sb.WriteString(result.Error) sb.WriteString("\n```\n\n") } // Input sb.WriteString("## Input\n\n") sb.WriteString("```markdown\n") sb.WriteString(formatInputAsString(result.Input)) sb.WriteString("\n```\n\n") // Output sb.WriteString("## Output\n\n") output := formatOutputAsString(result.Output) // Remove markdown code block wrapper if present output = strings.TrimPrefix(output, "```markdown\n") output = strings.TrimSuffix(output, "\n```") output = strings.TrimSuffix(output, "```") sb.WriteString(output) sb.WriteString("\n") return sb.String() } // formatInputAsString converts input to string format func formatInputAsString(input interface{}) string { switch v := input.(type) { case string: return v case map[string]interface{}: // Single message format if content, ok := v["content"].(string); ok { return content } // Fallback to JSON jsonBytes, _ := jsoniter.MarshalIndent(v, "", " ") return string(jsonBytes) case []interface{}: // Array of messages - extract content from last user message for i := len(v) - 1; i >= 0; i-- { if msg, ok := v[i].(map[string]interface{}); ok { if role, ok := msg["role"].(string); ok && role == "user" { if content, ok := msg["content"].(string); ok { return content } } } } // Fallback to JSON jsonBytes, _ := jsoniter.MarshalIndent(v, "", " ") return string(jsonBytes) default: jsonBytes, _ := jsoniter.MarshalIndent(input, "", " ") return string(jsonBytes) } } // formatOutputAsString converts output to string format func formatOutputAsString(output interface{}) string { switch v := output.(type) { case string: return v case map[string]interface{}, []interface{}: jsonBytes, _ := jsoniter.MarshalIndent(v, "", " ") return string(jsonBytes) default: if output == nil { return "(no output)" } return fmt.Sprintf("%v", output) } } ================================================ FILE: agent/test/input.go ================================================ package test import ( "encoding/base64" "fmt" "mime" "os" "path/filepath" "strings" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/yao/agent/context" ) // FileProtocol is the protocol prefix for local file references const FileProtocol = "file://" // SupportedImageExtensions lists supported image file extensions var SupportedImageExtensions = map[string]string{ ".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".png": "image/png", ".gif": "image/gif", ".webp": "image/webp", ".bmp": "image/bmp", } // SupportedAudioExtensions lists supported audio file extensions var SupportedAudioExtensions = map[string]string{ ".wav": "wav", ".mp3": "mp3", ".flac": "flac", ".ogg": "ogg", ".m4a": "m4a", } // SupportedFileExtensions lists supported document file extensions var SupportedFileExtensions = map[string]string{ // Documents ".pdf": "application/pdf", ".doc": "application/msword", ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", ".xls": "application/vnd.ms-excel", ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", ".txt": "text/plain", ".csv": "text/csv", ".json": "application/json", ".xml": "application/xml", ".html": "text/html", ".htm": "text/html", ".md": "text/markdown", // Source code ".yao": "application/json", // Yao DSL files (JSON-based) ".ts": "text/typescript", // TypeScript ".tsx": "text/typescript", // TypeScript JSX ".js": "text/javascript", // JavaScript ".jsx": "text/javascript", // JavaScript JSX ".go": "text/x-go", // Go ".py": "text/x-python", // Python ".rs": "text/x-rust", // Rust ".java": "text/x-java", // Java ".c": "text/x-c", // C ".cpp": "text/x-c++", // C++ ".h": "text/x-c", // C header ".hpp": "text/x-c++", // C++ header ".rb": "text/x-ruby", // Ruby ".php": "text/x-php", // PHP ".sh": "text/x-shellscript", // Shell script ".bash": "text/x-shellscript", // Bash script ".zsh": "text/x-shellscript", // Zsh script ".sql": "text/x-sql", // SQL ".yaml": "text/yaml", // YAML ".yml": "text/yaml", // YAML ".toml": "text/x-toml", // TOML ".ini": "text/x-ini", // INI ".conf": "text/plain", // Config files ".css": "text/css", // CSS ".scss": "text/x-scss", // SCSS ".less": "text/x-less", // LESS ".vue": "text/x-vue", // Vue ".svelte": "text/x-svelte", // Svelte } // InputOptions configures how input is parsed type InputOptions struct { // BaseDir is the base directory for resolving relative file paths // If empty, the current working directory is used BaseDir string } // ParseInput converts various input formats to []context.Message // Supported formats: // - string: converted to single user message // - map (Message): single message with role and content // - []interface{} ([]Message): array of messages (conversation history) func ParseInput(input interface{}) ([]context.Message, error) { return ParseInputWithOptions(input, nil) } // ParseInputWithOptions converts various input formats to []context.Message with options // Supported formats: // - string: converted to single user message // - map (Message): single message with role and content // - []interface{} ([]Message): array of messages (conversation history) // // File references in content parts (type="image", "file", "audio") with "source" field // starting with "file://" will be loaded and converted to appropriate format: // - Images: converted to base64 data URL in image_url field // - Audio: converted to base64 in input_audio field // - Files: converted to base64 data URL in file field func ParseInputWithOptions(input interface{}, opts *InputOptions) ([]context.Message, error) { if input == nil { return nil, fmt.Errorf("input is nil") } if opts == nil { opts = &InputOptions{} } switch v := input.(type) { case string: // Simple string input -> single user message return []context.Message{ { Role: context.RoleUser, Content: v, }, }, nil case map[string]interface{}: // Single message object msg, err := parseMessageMap(v, opts) if err != nil { return nil, fmt.Errorf("failed to parse message: %w", err) } return []context.Message{*msg}, nil case []interface{}: // Array of messages (conversation history) messages := make([]context.Message, 0, len(v)) for i, item := range v { switch m := item.(type) { case map[string]interface{}: msg, err := parseMessageMap(m, opts) if err != nil { return nil, fmt.Errorf("failed to parse message at index %d: %w", i, err) } messages = append(messages, *msg) default: return nil, fmt.Errorf("invalid message type at index %d: expected object, got %T", i, item) } } return messages, nil default: return nil, fmt.Errorf("unsupported input type: %T", input) } } // parseMessageMap converts a map to context.Message func parseMessageMap(m map[string]interface{}, opts *InputOptions) (*context.Message, error) { msg := &context.Message{} // Parse role (required) if role, ok := m["role"].(string); ok { msg.Role = context.MessageRole(role) } else { // Default to user role if not specified msg.Role = context.RoleUser } // Parse content (required) if content, ok := m["content"]; ok { // Process content to handle file:// references processedContent, err := processContent(content, opts) if err != nil { return nil, fmt.Errorf("failed to process content: %w", err) } msg.Content = processedContent } else { return nil, fmt.Errorf("message missing 'content' field") } // Parse optional name if name, ok := m["name"].(string); ok { msg.Name = &name } // Parse optional tool_call_id (for tool messages) if toolCallID, ok := m["tool_call_id"].(string); ok { msg.ToolCallID = &toolCallID } // Parse optional tool_calls (for assistant messages) if toolCalls, ok := m["tool_calls"].([]interface{}); ok { msg.ToolCalls = make([]context.ToolCall, 0, len(toolCalls)) for _, tc := range toolCalls { if tcMap, ok := tc.(map[string]interface{}); ok { toolCall, err := parseToolCall(tcMap) if err != nil { return nil, fmt.Errorf("failed to parse tool_call: %w", err) } msg.ToolCalls = append(msg.ToolCalls, *toolCall) } } } // Parse optional refusal (for assistant messages) if refusal, ok := m["refusal"].(string); ok { msg.Refusal = &refusal } return msg, nil } // processContent processes content to handle file:// references // Returns the processed content with files loaded and converted func processContent(content interface{}, opts *InputOptions) (interface{}, error) { switch v := content.(type) { case string: // Simple string content, no processing needed return v, nil case []interface{}: // Array of content parts processedParts := make([]context.ContentPart, 0, len(v)) for i, part := range v { if partMap, ok := part.(map[string]interface{}); ok { processedPart, err := processContentPart(partMap, opts) if err != nil { return nil, fmt.Errorf("failed to process content part at index %d: %w", i, err) } processedParts = append(processedParts, *processedPart) } else { return nil, fmt.Errorf("invalid content part type at index %d: expected object, got %T", i, part) } } return processedParts, nil case map[string]interface{}: // Single content part processedPart, err := processContentPart(v, opts) if err != nil { return nil, fmt.Errorf("failed to process content part: %w", err) } return []context.ContentPart{*processedPart}, nil default: return content, nil } } // processContentPart processes a single content part map // Handles file:// references and converts them to appropriate format func processContentPart(partMap map[string]interface{}, opts *InputOptions) (*context.ContentPart, error) { partType, _ := partMap["type"].(string) switch partType { case "text": text, _ := partMap["text"].(string) return &context.ContentPart{ Type: context.ContentText, Text: text, }, nil case "image": return processImagePart(partMap, opts) case "image_url": // Already in correct format, just parse it return parseImageURLPart(partMap) case "audio", "input_audio": return processAudioPart(partMap, opts) case "file": return processFilePart(partMap, opts) case "data": return parseDataPart(partMap) default: // Unknown type, try to preserve as-is return parseGenericPart(partMap) } } // processImagePart processes an image content part // Supports: source="file://path" for local files func processImagePart(partMap map[string]interface{}, opts *InputOptions) (*context.ContentPart, error) { source, hasSource := partMap["source"].(string) // Check for file:// protocol if hasSource && strings.HasPrefix(source, FileProtocol) { filePath := strings.TrimPrefix(source, FileProtocol) return loadImageFile(filePath, opts) } // Check for url field (already a URL or base64) if url, ok := partMap["url"].(string); ok { detail := context.DetailAuto if d, ok := partMap["detail"].(string); ok { detail = context.ImageDetailLevel(d) } return &context.ContentPart{ Type: context.ContentImageURL, ImageURL: &context.ImageURL{ URL: url, Detail: detail, }, }, nil } return nil, fmt.Errorf("image part requires 'source' (file://...) or 'url' field") } // parseImageURLPart parses an image_url content part func parseImageURLPart(partMap map[string]interface{}) (*context.ContentPart, error) { imageURL, ok := partMap["image_url"].(map[string]interface{}) if !ok { return nil, fmt.Errorf("image_url part requires 'image_url' object") } url, _ := imageURL["url"].(string) detail := context.DetailAuto if d, ok := imageURL["detail"].(string); ok { detail = context.ImageDetailLevel(d) } return &context.ContentPart{ Type: context.ContentImageURL, ImageURL: &context.ImageURL{ URL: url, Detail: detail, }, }, nil } // processAudioPart processes an audio content part // Supports: source="file://path" for local files func processAudioPart(partMap map[string]interface{}, opts *InputOptions) (*context.ContentPart, error) { source, hasSource := partMap["source"].(string) // Check for file:// protocol if hasSource && strings.HasPrefix(source, FileProtocol) { filePath := strings.TrimPrefix(source, FileProtocol) return loadAudioFile(filePath, opts) } // Check for data field (already base64) if data, ok := partMap["data"].(string); ok { format, _ := partMap["format"].(string) return &context.ContentPart{ Type: context.ContentInputAudio, InputAudio: &context.InputAudio{ Data: data, Format: format, }, }, nil } // Check for input_audio field if inputAudio, ok := partMap["input_audio"].(map[string]interface{}); ok { data, _ := inputAudio["data"].(string) format, _ := inputAudio["format"].(string) return &context.ContentPart{ Type: context.ContentInputAudio, InputAudio: &context.InputAudio{ Data: data, Format: format, }, }, nil } return nil, fmt.Errorf("audio part requires 'source' (file://...) or 'data'/'input_audio' field") } // processFilePart processes a file content part // Supports: source="file://path" for local files func processFilePart(partMap map[string]interface{}, opts *InputOptions) (*context.ContentPart, error) { source, hasSource := partMap["source"].(string) // Check for file:// protocol if hasSource && strings.HasPrefix(source, FileProtocol) { filePath := strings.TrimPrefix(source, FileProtocol) name, _ := partMap["name"].(string) return loadFile(filePath, name, opts) } // Check for url field (already a URL) if url, ok := partMap["url"].(string); ok { filename, _ := partMap["filename"].(string) if filename == "" { filename, _ = partMap["name"].(string) } return &context.ContentPart{ Type: context.ContentFile, File: &context.FileAttachment{ URL: url, Filename: filename, }, }, nil } // Check for file field if file, ok := partMap["file"].(map[string]interface{}); ok { url, _ := file["url"].(string) filename, _ := file["filename"].(string) return &context.ContentPart{ Type: context.ContentFile, File: &context.FileAttachment{ URL: url, Filename: filename, }, }, nil } return nil, fmt.Errorf("file part requires 'source' (file://...), 'url', or 'file' field") } // parseDataPart parses a data content part func parseDataPart(partMap map[string]interface{}) (*context.ContentPart, error) { data, ok := partMap["data"].(map[string]interface{}) if !ok { return nil, fmt.Errorf("data part requires 'data' object") } // Convert to DataContent dataContent := &context.DataContent{} if sources, ok := data["sources"].([]interface{}); ok { dataContent.Sources = make([]context.DataSource, 0, len(sources)) for _, src := range sources { if srcMap, ok := src.(map[string]interface{}); ok { ds := context.DataSource{} if t, ok := srcMap["type"].(string); ok { ds.Type = context.DataSourceType(t) } if id, ok := srcMap["id"].(string); ok { ds.ID = id } if name, ok := srcMap["name"].(string); ok { ds.Name = name } if filters, ok := srcMap["filters"].(map[string]interface{}); ok { ds.Filters = filters } if metadata, ok := srcMap["metadata"].(map[string]interface{}); ok { ds.Metadata = metadata } dataContent.Sources = append(dataContent.Sources, ds) } } } return &context.ContentPart{ Type: context.ContentData, Data: dataContent, }, nil } // parseGenericPart tries to parse an unknown content part type func parseGenericPart(partMap map[string]interface{}) (*context.ContentPart, error) { partType, _ := partMap["type"].(string) // Try to create a basic ContentPart part := &context.ContentPart{ Type: context.ContentPartType(partType), } // Try to extract text if present if text, ok := partMap["text"].(string); ok { part.Text = text } return part, nil } // loadImageFile loads an image file and converts it to a ContentPart func loadImageFile(filePath string, opts *InputOptions) (*context.ContentPart, error) { absPath := resolveFilePath(filePath, opts) // Read file data, err := os.ReadFile(absPath) if err != nil { return nil, fmt.Errorf("failed to read image file %s: %w", filePath, err) } // Determine MIME type ext := strings.ToLower(filepath.Ext(absPath)) mimeType, ok := SupportedImageExtensions[ext] if !ok { // Try to detect from extension mimeType = mime.TypeByExtension(ext) if mimeType == "" { mimeType = "application/octet-stream" } } // Encode to base64 data URL b64Data := base64.StdEncoding.EncodeToString(data) dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, b64Data) return &context.ContentPart{ Type: context.ContentImageURL, ImageURL: &context.ImageURL{ URL: dataURL, Detail: context.DetailAuto, }, }, nil } // loadAudioFile loads an audio file and converts it to a ContentPart func loadAudioFile(filePath string, opts *InputOptions) (*context.ContentPart, error) { absPath := resolveFilePath(filePath, opts) // Read file data, err := os.ReadFile(absPath) if err != nil { return nil, fmt.Errorf("failed to read audio file %s: %w", filePath, err) } // Determine format from extension ext := strings.ToLower(filepath.Ext(absPath)) format, ok := SupportedAudioExtensions[ext] if !ok { format = strings.TrimPrefix(ext, ".") } // Encode to base64 b64Data := base64.StdEncoding.EncodeToString(data) return &context.ContentPart{ Type: context.ContentInputAudio, InputAudio: &context.InputAudio{ Data: b64Data, Format: format, }, }, nil } // loadFile loads a file and converts it to a ContentPart func loadFile(filePath string, name string, opts *InputOptions) (*context.ContentPart, error) { absPath := resolveFilePath(filePath, opts) // Read file data, err := os.ReadFile(absPath) if err != nil { return nil, fmt.Errorf("failed to read file %s: %w", filePath, err) } // Determine filename filename := name if filename == "" { filename = filepath.Base(absPath) } // Determine MIME type ext := strings.ToLower(filepath.Ext(absPath)) mimeType, ok := SupportedFileExtensions[ext] if !ok { mimeType = mime.TypeByExtension(ext) if mimeType == "" { mimeType = "application/octet-stream" } } // Encode to base64 data URL b64Data := base64.StdEncoding.EncodeToString(data) dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, b64Data) return &context.ContentPart{ Type: context.ContentFile, File: &context.FileAttachment{ URL: dataURL, Filename: filename, }, }, nil } // resolveFilePath resolves a file path relative to the base directory // If the path is absolute, it's returned as-is // If BaseDir is empty, the current working directory is used func resolveFilePath(filePath string, opts *InputOptions) string { // If path is absolute, return as-is if filepath.IsAbs(filePath) { return filePath } // If BaseDir is set, resolve relative to it if opts != nil && opts.BaseDir != "" { return filepath.Join(opts.BaseDir, filePath) } // Otherwise, resolve relative to current working directory return filePath } // parseToolCall converts a map to context.ToolCall func parseToolCall(m map[string]interface{}) (*context.ToolCall, error) { tc := &context.ToolCall{} if id, ok := m["id"].(string); ok { tc.ID = id } if typ, ok := m["type"].(string); ok { tc.Type = context.ToolCallType(typ) } else { tc.Type = context.ToolTypeFunction } if fn, ok := m["function"].(map[string]interface{}); ok { if name, ok := fn["name"].(string); ok { tc.Function.Name = name } if args, ok := fn["arguments"].(string); ok { tc.Function.Arguments = args } else if args, ok := fn["arguments"].(map[string]interface{}); ok { // Convert map to JSON string argsBytes, err := jsoniter.Marshal(args) if err != nil { return nil, fmt.Errorf("failed to marshal arguments: %w", err) } tc.Function.Arguments = string(argsBytes) } } return tc, nil } // ExtractTextContent extracts text content from various content formats // Used for display in reports func ExtractTextContent(content interface{}) string { if content == nil { return "" } switch v := content.(type) { case string: return v case []interface{}: // ContentPart array var texts []string for _, part := range v { if partMap, ok := part.(map[string]interface{}); ok { if partMap["type"] == "text" { if text, ok := partMap["text"].(string); ok { texts = append(texts, text) } } } } if len(texts) > 0 { result := texts[0] for i := 1; i < len(texts); i++ { result += "\n" + texts[i] } return result } return fmt.Sprintf("[%d content parts]", len(v)) case map[string]interface{}: // Single ContentPart or Message if v["type"] == "text" { if text, ok := v["text"].(string); ok { return text } } if content, ok := v["content"]; ok { return ExtractTextContent(content) } return fmt.Sprintf("%v", v) default: return fmt.Sprintf("%v", v) } } // SummarizeInput creates a short summary of the input for display func SummarizeInput(input interface{}, maxLen int) string { text := "" switch v := input.(type) { case string: text = v case map[string]interface{}: if content, ok := v["content"]; ok { text = ExtractTextContent(content) } case []interface{}: // Get the last user message for summary for i := len(v) - 1; i >= 0; i-- { if msg, ok := v[i].(map[string]interface{}); ok { if msg["role"] == "user" { if content, ok := msg["content"]; ok { text = ExtractTextContent(content) break } } } } if text == "" && len(v) > 0 { text = fmt.Sprintf("[%d messages]", len(v)) } default: text = fmt.Sprintf("%v", v) } if maxLen > 0 && len(text) > maxLen { return text[:maxLen-3] + "..." } return text } ================================================ FILE: agent/test/input_source.go ================================================ package test import ( "fmt" "net/url" "strconv" "strings" jsoniter "github.com/json-iterator/go" goutext "github.com/yaoapp/gou/text" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" ) // InputSourceType represents the type of input source type InputSourceType string const ( // InputSourceFile indicates input from a JSONL file InputSourceFile InputSourceType = "file" // InputSourceMessage indicates input from a direct message string InputSourceMessage InputSourceType = "message" // InputSourceScript indicates script test mode InputSourceScript InputSourceType = "script" // InputSourceAgent indicates input generated by an agent InputSourceAgent InputSourceType = "agent" ) // InputSource represents a parsed input source type InputSource struct { Type InputSourceType // file, message, script, agent Value string // path, message, script ref, or agent ID Params map[string]interface{} // query parameters (for agent source) } // ParseInputSource parses the -i flag value into an InputSource // Supported formats: // - "agents:workers.test.generator" - Agent-generated test cases // - "agents:workers.test.generator?count=10&focus=edge-cases" - With parameters // - "scripts.tests.gen" - Script-generated test cases // - "./tests/inputs.jsonl" - JSONL file // - "Hello, how are you?" - Direct message func ParseInputSource(input string) *InputSource { // Check for agents: prefix if strings.HasPrefix(input, "agents:") { return parseAgentSource(strings.TrimPrefix(input, "agents:")) } // Check for scripts: prefix (for generator scripts) if strings.HasPrefix(input, "scripts:") { return &InputSource{ Type: InputSourceScript, Value: strings.TrimPrefix(input, "scripts:"), } } // Check for script test mode (scripts.xxx format without prefix) if strings.HasPrefix(input, "scripts.") { return &InputSource{ Type: InputSourceScript, Value: input, } } // Check for file extension if strings.HasSuffix(input, ".jsonl") || strings.HasSuffix(input, ".json") { return &InputSource{ Type: InputSourceFile, Value: input, } } // Check if it looks like a file path if strings.Contains(input, "/") || strings.Contains(input, "\\") { return &InputSource{ Type: InputSourceFile, Value: input, } } // Default to message return &InputSource{ Type: InputSourceMessage, Value: input, } } // parseAgentSource parses an agent source string with optional query parameters // Format: "agent.id" or "agent.id?count=10&focus=edge-cases" func parseAgentSource(input string) *InputSource { source := &InputSource{ Type: InputSourceAgent, Params: make(map[string]interface{}), } // Check for query parameters if idx := strings.Index(input, "?"); idx >= 0 { source.Value = input[:idx] queryStr := input[idx+1:] // Parse query parameters values, err := url.ParseQuery(queryStr) if err == nil { for key, vals := range values { if len(vals) > 0 { // Try to parse as number if num, err := strconv.Atoi(vals[0]); err == nil { source.Params[key] = num } else if num, err := strconv.ParseFloat(vals[0], 64); err == nil { source.Params[key] = num } else if vals[0] == "true" { source.Params[key] = true } else if vals[0] == "false" { source.Params[key] = false } else { source.Params[key] = vals[0] } } } } } else { source.Value = input } return source } // GeneratorInput represents the input sent to a generator agent type GeneratorInput struct { TargetAgent *TargetAgentInfo `json:"target_agent"` Count int `json:"count,omitempty"` Focus string `json:"focus,omitempty"` Extra map[string]interface{} `json:"extra,omitempty"` } // TargetAgentInfo contains information about the agent being tested type TargetAgentInfo struct { ID string `json:"id"` Description string `json:"description,omitempty"` Tools []map[string]interface{} `json:"tools,omitempty"` } // GenerateTestCases generates test cases using a generator agent func GenerateTestCases(agentID string, targetInfo *TargetAgentInfo, params map[string]interface{}) ([]*Case, error) { // Get generator assistant ast, err := assistant.Get(agentID) if err != nil { return nil, fmt.Errorf("failed to get generator agent %s: %w", agentID, err) } // Build generation request genInput := &GeneratorInput{ TargetAgent: targetInfo, Count: 5, // Default count } // Apply parameters if params != nil { if count, ok := params["count"].(int); ok { genInput.Count = count } if focus, ok := params["focus"].(string); ok { genInput.Focus = focus } // Store extra parameters genInput.Extra = make(map[string]interface{}) for k, v := range params { if k != "count" && k != "focus" { genInput.Extra[k] = v } } } // Create context env := NewEnvironment("", "") ctx := NewTestContext("generator", agentID, env) defer ctx.Release() // Build options - skip history and trace for efficiency opts := &context.Options{ Skip: &context.Skip{ History: true, Trace: true, Output: true, }, Metadata: map[string]interface{}{ "test_mode": "generator", }, } // Build message inputJSON, err := jsoniter.Marshal(genInput) if err != nil { return nil, fmt.Errorf("failed to marshal generator input: %w", err) } messages := []context.Message{{ Role: context.RoleUser, Content: string(inputJSON), }} // Call generator agent response, err := ast.Stream(ctx, messages, opts) if err != nil { return nil, fmt.Errorf("generator agent error: %w", err) } // Extract and parse response return parseGeneratedCases(response) } // parseGeneratedCases parses the generator agent's response into test cases func parseGeneratedCases(response *context.Response) ([]*Case, error) { if response == nil || response.Completion == nil { return nil, fmt.Errorf("empty response from generator agent") } // Extract content content := response.Completion.Content if content == nil { return nil, fmt.Errorf("no content in generator response") } // Convert content to string var text string switch v := content.(type) { case string: text = v default: data, err := jsoniter.Marshal(content) if err != nil { return nil, fmt.Errorf("failed to marshal content: %w", err) } text = string(data) } // Use goutext.ExtractJSON for fault-tolerant parsing parsed := goutext.ExtractJSON(text) if parsed == nil { return nil, fmt.Errorf("failed to parse generator response as JSON: %s", truncateOutput(text, 200)) } // Convert to []*Case return convertToCases(parsed) } // convertToCases converts parsed JSON to test cases func convertToCases(parsed interface{}) ([]*Case, error) { // Handle array of cases arr, ok := parsed.([]interface{}) if !ok { // Maybe it's a single case wrapped in an object if obj, ok := parsed.(map[string]interface{}); ok { if cases, ok := obj["cases"].([]interface{}); ok { arr = cases } else if testCases, ok := obj["test_cases"].([]interface{}); ok { arr = testCases } else { // Single case arr = []interface{}{obj} } } else { return nil, fmt.Errorf("expected array of test cases, got %T", parsed) } } cases := make([]*Case, 0, len(arr)) for i, item := range arr { caseMap, ok := item.(map[string]interface{}) if !ok { return nil, fmt.Errorf("test case %d is not an object", i) } tc, err := mapToCase(caseMap) if err != nil { return nil, fmt.Errorf("failed to parse test case %d: %w", i, err) } cases = append(cases, tc) } return cases, nil } // mapToCase converts a map to a Case struct func mapToCase(m map[string]interface{}) (*Case, error) { tc := &Case{} // Required: id if id, ok := m["id"].(string); ok { tc.ID = id } else { return nil, fmt.Errorf("missing required field 'id'") } // Required: input if input, ok := m["input"]; ok { tc.Input = input } else { return nil, fmt.Errorf("missing required field 'input'") } // Optional: assertions/assert if assertions, ok := m["assertions"]; ok { tc.Assert = assertions } else if assert, ok := m["assert"]; ok { tc.Assert = assert } // Optional: options - convert map to CaseOptions if options, ok := m["options"].(map[string]interface{}); ok { tc.Options = mapToCaseOptions(options) } // Optional: before/after if before, ok := m["before"].(string); ok { tc.Before = before } if after, ok := m["after"].(string); ok { tc.After = after } // Optional: timeout if timeout, ok := m["timeout"].(string); ok { tc.Timeout = timeout } return tc, nil } // ToInputMode converts InputSourceType to InputMode for backward compatibility func (s *InputSource) ToInputMode() InputMode { switch s.Type { case InputSourceFile: return InputModeFile case InputSourceMessage: return InputModeMessage case InputSourceScript: return InputModeScript case InputSourceAgent: // Agent source generates cases, then runs in file mode return InputModeFile default: return InputModeMessage } } // mapToCaseOptions converts a map to CaseOptions func mapToCaseOptions(m map[string]interface{}) *CaseOptions { opts := &CaseOptions{} if connector, ok := m["connector"].(string); ok { opts.Connector = connector } if mode, ok := m["mode"].(string); ok { opts.Mode = mode } if disableGlobalPrompts, ok := m["disable_global_prompts"].(bool); ok { opts.DisableGlobalPrompts = disableGlobalPrompts } if search, ok := m["search"].(bool); ok { opts.Search = &search } if metadata, ok := m["metadata"].(map[string]interface{}); ok { opts.Metadata = metadata } if skip, ok := m["skip"].(map[string]interface{}); ok { opts.Skip = &CaseSkipOptions{} if history, ok := skip["history"].(bool); ok { opts.Skip.History = history } if trace, ok := skip["trace"].(bool); ok { opts.Skip.Trace = trace } if output, ok := skip["output"].(bool); ok { opts.Skip.Output = output } if keyword, ok := skip["keyword"].(bool); ok { opts.Skip.Keyword = keyword } if searchSkip, ok := skip["search"].(bool); ok { opts.Skip.Search = searchSkip } } return opts } ================================================ FILE: agent/test/input_source_test.go ================================================ package test_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/agent" agenttest "github.com/yaoapp/yao/agent/test" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func TestParseInputSource(t *testing.T) { tests := []struct { name string input string wantType agenttest.InputSourceType wantValue string wantParams map[string]interface{} }{ { name: "JSONL file", input: "./tests/inputs.jsonl", wantType: agenttest.InputSourceFile, wantValue: "./tests/inputs.jsonl", }, { name: "JSON file", input: "./tests/inputs.json", wantType: agenttest.InputSourceFile, wantValue: "./tests/inputs.json", }, { name: "direct message", input: "Hello, how are you?", wantType: agenttest.InputSourceMessage, wantValue: "Hello, how are you?", }, { name: "agent source simple", input: "agents:tests.generator-agent", wantType: agenttest.InputSourceAgent, wantValue: "tests.generator-agent", }, { name: "agent source with params", input: "agents:tests.generator-agent?count=10&focus=edge-cases", wantType: agenttest.InputSourceAgent, wantValue: "tests.generator-agent", wantParams: map[string]interface{}{ "count": 10, "focus": "edge-cases", }, }, { name: "agent source with boolean param", input: "agents:tests.generator-agent?verbose=true", wantType: agenttest.InputSourceAgent, wantValue: "tests.generator-agent", wantParams: map[string]interface{}{ "verbose": true, }, }, { name: "script source with prefix", input: "scripts:tests.gen.Generate", wantType: agenttest.InputSourceScript, wantValue: "tests.gen.Generate", }, { name: "script test mode", input: "scripts.tests.gen", wantType: agenttest.InputSourceScript, wantValue: "scripts.tests.gen", }, { name: "path with separator", input: "/path/to/inputs.jsonl", wantType: agenttest.InputSourceFile, wantValue: "/path/to/inputs.jsonl", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { source := agenttest.ParseInputSource(tt.input) assert.Equal(t, tt.wantType, source.Type, "Type mismatch") assert.Equal(t, tt.wantValue, source.Value, "Value mismatch") if tt.wantParams != nil { for k, v := range tt.wantParams { assert.Equal(t, v, source.Params[k], "Param %s mismatch", k) } } }) } } func TestInputSource_ToInputMode(t *testing.T) { tests := []struct { name string source *agenttest.InputSource wantMode agenttest.InputMode }{ { name: "file source", source: &agenttest.InputSource{Type: agenttest.InputSourceFile}, wantMode: agenttest.InputModeFile, }, { name: "message source", source: &agenttest.InputSource{Type: agenttest.InputSourceMessage}, wantMode: agenttest.InputModeMessage, }, { name: "script source", source: &agenttest.InputSource{Type: agenttest.InputSourceScript}, wantMode: agenttest.InputModeScript, }, { name: "agent source", source: &agenttest.InputSource{Type: agenttest.InputSourceAgent}, wantMode: agenttest.InputModeFile, // Agent generates cases, then runs in file mode }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mode := tt.source.ToInputMode() assert.Equal(t, tt.wantMode, mode) }) } } func TestGenerateTestCases(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agent (includes assistants) err := agent.Load(config.Conf) if err != nil { t.Fatalf("Failed to load agent: %v", err) } // Test generating test cases from the generator agent targetInfo := &agenttest.TargetAgentInfo{ ID: "tests.next", Description: "A simple test agent for greeting", } params := map[string]interface{}{ "count": 3, "focus": "happy-path", } cases, err := agenttest.GenerateTestCases("tests.generator-agent", targetInfo, params) if err != nil { t.Fatalf("Failed to generate test cases: %v", err) } // Verify we got some test cases assert.NotEmpty(t, cases, "Should generate at least one test case") // Verify each case has required fields for _, tc := range cases { assert.NotEmpty(t, tc.ID, "Test case should have ID") assert.NotNil(t, tc.Input, "Test case should have Input") } t.Logf("Generated %d test cases", len(cases)) for _, tc := range cases { t.Logf(" - %s", tc.ID) } } func TestMapToCaseOptions(t *testing.T) { // Test that options map is correctly converted source := agenttest.ParseInputSource("agents:test?count=5") assert.Equal(t, 5, source.Params["count"]) } ================================================ FILE: agent/test/input_test.go ================================================ package test import ( "encoding/base64" "os" "path/filepath" "strings" "testing" "github.com/yaoapp/yao/agent/context" ) func TestParseInput_String(t *testing.T) { input := "Hello world" messages, err := ParseInput(input) if err != nil { t.Fatalf("ParseInput failed: %v", err) } if len(messages) != 1 { t.Fatalf("Expected 1 message, got %d", len(messages)) } if messages[0].Role != context.RoleUser { t.Errorf("Expected role 'user', got '%s'", messages[0].Role) } content, ok := messages[0].Content.(string) if !ok { t.Fatalf("Expected string content, got %T", messages[0].Content) } if content != "Hello world" { t.Errorf("Expected content 'Hello world', got '%s'", content) } } func TestParseInput_MessageMap(t *testing.T) { input := map[string]interface{}{ "role": "user", "content": "Test message", } messages, err := ParseInput(input) if err != nil { t.Fatalf("ParseInput failed: %v", err) } if len(messages) != 1 { t.Fatalf("Expected 1 message, got %d", len(messages)) } if messages[0].Role != context.RoleUser { t.Errorf("Expected role 'user', got '%s'", messages[0].Role) } } func TestParseInput_MessageArray(t *testing.T) { input := []interface{}{ map[string]interface{}{"role": "user", "content": "Hello"}, map[string]interface{}{"role": "assistant", "content": "Hi there"}, map[string]interface{}{"role": "user", "content": "Follow-up"}, } messages, err := ParseInput(input) if err != nil { t.Fatalf("ParseInput failed: %v", err) } if len(messages) != 3 { t.Fatalf("Expected 3 messages, got %d", len(messages)) } if messages[0].Role != context.RoleUser { t.Errorf("Expected first message role 'user', got '%s'", messages[0].Role) } if messages[1].Role != context.RoleAssistant { t.Errorf("Expected second message role 'assistant', got '%s'", messages[1].Role) } } func TestParseInput_ContentParts(t *testing.T) { input := map[string]interface{}{ "role": "user", "content": []interface{}{ map[string]interface{}{"type": "text", "text": "Analyze this"}, map[string]interface{}{"type": "image_url", "image_url": map[string]interface{}{ "url": "https://example.com/image.jpg", "detail": "high", }}, }, } messages, err := ParseInput(input) if err != nil { t.Fatalf("ParseInput failed: %v", err) } if len(messages) != 1 { t.Fatalf("Expected 1 message, got %d", len(messages)) } parts, ok := messages[0].Content.([]context.ContentPart) if !ok { t.Fatalf("Expected []ContentPart, got %T", messages[0].Content) } if len(parts) != 2 { t.Fatalf("Expected 2 content parts, got %d", len(parts)) } if parts[0].Type != context.ContentText { t.Errorf("Expected first part type 'text', got '%s'", parts[0].Type) } if parts[0].Text != "Analyze this" { t.Errorf("Expected text 'Analyze this', got '%s'", parts[0].Text) } if parts[1].Type != context.ContentImageURL { t.Errorf("Expected second part type 'image_url', got '%s'", parts[1].Type) } if parts[1].ImageURL == nil { t.Fatal("Expected ImageURL to be set") } if parts[1].ImageURL.URL != "https://example.com/image.jpg" { t.Errorf("Expected URL 'https://example.com/image.jpg', got '%s'", parts[1].ImageURL.URL) } if parts[1].ImageURL.Detail != context.DetailHigh { t.Errorf("Expected detail 'high', got '%s'", parts[1].ImageURL.Detail) } } func TestParseInputWithOptions_FileProtocol_Image(t *testing.T) { // Create a temporary test image file tmpDir := t.TempDir() imgPath := filepath.Join(tmpDir, "test.png") // Create a minimal PNG file (1x1 pixel, red) pngData := []byte{ 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, // IHDR chunk 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, // IDAT chunk 0x54, 0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x00, 0x05, 0xFE, 0xD4, 0xEF, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, // IEND chunk 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82, } if err := os.WriteFile(imgPath, pngData, 0644); err != nil { t.Fatalf("Failed to create test image: %v", err) } input := map[string]interface{}{ "role": "user", "content": []interface{}{ map[string]interface{}{"type": "text", "text": "Analyze this image"}, map[string]interface{}{"type": "image", "source": "file://test.png"}, }, } opts := &InputOptions{BaseDir: tmpDir} messages, err := ParseInputWithOptions(input, opts) if err != nil { t.Fatalf("ParseInputWithOptions failed: %v", err) } if len(messages) != 1 { t.Fatalf("Expected 1 message, got %d", len(messages)) } parts, ok := messages[0].Content.([]context.ContentPart) if !ok { t.Fatalf("Expected []ContentPart, got %T", messages[0].Content) } if len(parts) != 2 { t.Fatalf("Expected 2 content parts, got %d", len(parts)) } // Check image part imgPart := parts[1] if imgPart.Type != context.ContentImageURL { t.Errorf("Expected type 'image_url', got '%s'", imgPart.Type) } if imgPart.ImageURL == nil { t.Fatal("Expected ImageURL to be set") } if !strings.HasPrefix(imgPart.ImageURL.URL, "data:image/png;base64,") { t.Errorf("Expected base64 data URL, got '%s'", imgPart.ImageURL.URL[:50]) } // Verify the base64 content b64Part := strings.TrimPrefix(imgPart.ImageURL.URL, "data:image/png;base64,") decoded, err := base64.StdEncoding.DecodeString(b64Part) if err != nil { t.Fatalf("Failed to decode base64: %v", err) } if len(decoded) != len(pngData) { t.Errorf("Decoded data length mismatch: expected %d, got %d", len(pngData), len(decoded)) } } func TestParseInputWithOptions_FileProtocol_Audio(t *testing.T) { // Create a temporary test audio file tmpDir := t.TempDir() audioPath := filepath.Join(tmpDir, "test.wav") // Create a minimal WAV file header wavData := []byte{ 0x52, 0x49, 0x46, 0x46, // "RIFF" 0x24, 0x00, 0x00, 0x00, // File size - 8 0x57, 0x41, 0x56, 0x45, // "WAVE" 0x66, 0x6D, 0x74, 0x20, // "fmt " 0x10, 0x00, 0x00, 0x00, // Subchunk1Size (16 for PCM) 0x01, 0x00, // AudioFormat (1 = PCM) 0x01, 0x00, // NumChannels (1 = mono) 0x44, 0xAC, 0x00, 0x00, // SampleRate (44100) 0x88, 0x58, 0x01, 0x00, // ByteRate 0x02, 0x00, // BlockAlign 0x10, 0x00, // BitsPerSample (16) 0x64, 0x61, 0x74, 0x61, // "data" 0x00, 0x00, 0x00, 0x00, // Subchunk2Size (0 = no data) } if err := os.WriteFile(audioPath, wavData, 0644); err != nil { t.Fatalf("Failed to create test audio: %v", err) } input := map[string]interface{}{ "role": "user", "content": []interface{}{ map[string]interface{}{"type": "text", "text": "Transcribe this"}, map[string]interface{}{"type": "audio", "source": "file://test.wav"}, }, } opts := &InputOptions{BaseDir: tmpDir} messages, err := ParseInputWithOptions(input, opts) if err != nil { t.Fatalf("ParseInputWithOptions failed: %v", err) } parts, ok := messages[0].Content.([]context.ContentPart) if !ok { t.Fatalf("Expected []ContentPart, got %T", messages[0].Content) } // Check audio part audioPart := parts[1] if audioPart.Type != context.ContentInputAudio { t.Errorf("Expected type 'input_audio', got '%s'", audioPart.Type) } if audioPart.InputAudio == nil { t.Fatal("Expected InputAudio to be set") } if audioPart.InputAudio.Format != "wav" { t.Errorf("Expected format 'wav', got '%s'", audioPart.InputAudio.Format) } if audioPart.InputAudio.Data == "" { t.Error("Expected base64 data to be set") } } func TestParseInputWithOptions_FileProtocol_File(t *testing.T) { // Create a temporary test file tmpDir := t.TempDir() pdfPath := filepath.Join(tmpDir, "document.pdf") // Create a minimal PDF file pdfData := []byte("%PDF-1.4\n1 0 obj\n<<>>\nendobj\ntrailer\n<<>>\n%%EOF") if err := os.WriteFile(pdfPath, pdfData, 0644); err != nil { t.Fatalf("Failed to create test PDF: %v", err) } input := map[string]interface{}{ "role": "user", "content": []interface{}{ map[string]interface{}{"type": "text", "text": "Analyze this document"}, map[string]interface{}{"type": "file", "source": "file://document.pdf", "name": "my_doc.pdf"}, }, } opts := &InputOptions{BaseDir: tmpDir} messages, err := ParseInputWithOptions(input, opts) if err != nil { t.Fatalf("ParseInputWithOptions failed: %v", err) } parts, ok := messages[0].Content.([]context.ContentPart) if !ok { t.Fatalf("Expected []ContentPart, got %T", messages[0].Content) } // Check file part filePart := parts[1] if filePart.Type != context.ContentFile { t.Errorf("Expected type 'file', got '%s'", filePart.Type) } if filePart.File == nil { t.Fatal("Expected File to be set") } if filePart.File.Filename != "my_doc.pdf" { t.Errorf("Expected filename 'my_doc.pdf', got '%s'", filePart.File.Filename) } if !strings.HasPrefix(filePart.File.URL, "data:application/pdf;base64,") { t.Errorf("Expected base64 data URL with PDF mime type, got '%s'", filePart.File.URL[:40]) } } func TestParseInputWithOptions_FileProtocol_AbsolutePath(t *testing.T) { // Create a temporary test image file tmpDir := t.TempDir() imgPath := filepath.Join(tmpDir, "absolute.png") // Create a minimal PNG file pngData := []byte{ 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, 0x54, 0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x00, 0x05, 0xFE, 0xD4, 0xEF, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82, } if err := os.WriteFile(imgPath, pngData, 0644); err != nil { t.Fatalf("Failed to create test image: %v", err) } // Use absolute path input := map[string]interface{}{ "role": "user", "content": []interface{}{ map[string]interface{}{"type": "image", "source": "file://" + imgPath}, }, } // BaseDir should be ignored for absolute paths opts := &InputOptions{BaseDir: "/some/other/dir"} messages, err := ParseInputWithOptions(input, opts) if err != nil { t.Fatalf("ParseInputWithOptions failed: %v", err) } parts, ok := messages[0].Content.([]context.ContentPart) if !ok { t.Fatalf("Expected []ContentPart, got %T", messages[0].Content) } if parts[0].Type != context.ContentImageURL { t.Errorf("Expected type 'image_url', got '%s'", parts[0].Type) } } func TestParseInputWithOptions_FileNotFound(t *testing.T) { input := map[string]interface{}{ "role": "user", "content": []interface{}{ map[string]interface{}{"type": "image", "source": "file://nonexistent.png"}, }, } opts := &InputOptions{BaseDir: t.TempDir()} _, err := ParseInputWithOptions(input, opts) if err == nil { t.Fatal("Expected error for non-existent file") } if !strings.Contains(err.Error(), "failed to read image file") { t.Errorf("Expected 'failed to read image file' error, got: %v", err) } } func TestResolveFilePath(t *testing.T) { tests := []struct { name string filePath string baseDir string expected string }{ { name: "relative path with base dir", filePath: "fixtures/image.png", baseDir: "/app/tests", expected: "/app/tests/fixtures/image.png", }, { name: "relative path without base dir", filePath: "fixtures/image.png", baseDir: "", expected: "fixtures/image.png", }, { name: "absolute path ignores base dir", filePath: "/absolute/path/image.png", baseDir: "/app/tests", expected: "/absolute/path/image.png", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { opts := &InputOptions{BaseDir: tt.baseDir} result := resolveFilePath(tt.filePath, opts) if result != tt.expected { t.Errorf("Expected '%s', got '%s'", tt.expected, result) } }) } } func TestExtractTextContent(t *testing.T) { tests := []struct { name string content interface{} expected string }{ { name: "string content", content: "Hello world", expected: "Hello world", }, { name: "content parts array", content: []interface{}{ map[string]interface{}{"type": "text", "text": "First"}, map[string]interface{}{"type": "image", "source": "file://test.png"}, map[string]interface{}{"type": "text", "text": "Second"}, }, expected: "First\nSecond", }, { name: "nil content", content: nil, expected: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := ExtractTextContent(tt.content) if result != tt.expected { t.Errorf("Expected '%s', got '%s'", tt.expected, result) } }) } } func TestSummarizeInput(t *testing.T) { tests := []struct { name string input interface{} maxLen int expected string }{ { name: "short string", input: "Hello", maxLen: 10, expected: "Hello", }, { name: "long string truncated", input: "This is a very long message that should be truncated", maxLen: 20, expected: "This is a very lo...", }, { name: "message array - last user message", input: []interface{}{ map[string]interface{}{"role": "user", "content": "First"}, map[string]interface{}{"role": "assistant", "content": "Response"}, map[string]interface{}{"role": "user", "content": "Last user message"}, }, maxLen: 50, expected: "Last user message", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := SummarizeInput(tt.input, tt.maxLen) if result != tt.expected { t.Errorf("Expected '%s', got '%s'", tt.expected, result) } }) } } ================================================ FILE: agent/test/interfaces.go ================================================ package test import ( "context" "io" ) // Runner is the interface for test execution type Runner interface { // Run executes all test cases and returns the report Run(ctx context.Context) (*Report, error) // RunCase executes a single test case RunCase(ctx context.Context, tc *Case) (*Result, error) // GetAgentInfo returns information about the agent being tested GetAgentInfo() *AgentInfo // SetProgressCallback sets a callback for progress updates SetProgressCallback(callback ProgressCallback) } // ProgressCallback is called during test execution to report progress // Parameters: // - current: current test index (1-based) // - total: total number of tests // - result: result of the current test (nil if not yet completed) type ProgressCallback func(current, total int, result *Result) // Reporter is the interface for generating test reports type Reporter interface { // Generate generates a report from the test results Generate(report *Report) error // Write writes the report to the given writer Write(report *Report, w io.Writer) error } // Loader is the interface for loading test cases type Loader interface { // Load loads test cases from the input source Load() ([]*Case, error) // LoadFile loads test cases from a JSONL file LoadFile(path string) ([]*Case, error) // LoadFromAgent generates test cases using a generator agent LoadFromAgent(agentID string, targetInfo *TargetAgentInfo, params map[string]interface{}) ([]*Case, error) // LoadFromScript generates test cases using a script LoadFromScript(scriptRef string, targetInfo *TargetAgentInfo) ([]*Case, error) } // Resolver is the interface for resolving agent information type Resolver interface { // Resolve resolves the agent from options // Priority: explicit AgentID > path-based detection Resolve(opts *Options) (*AgentInfo, error) // ResolveFromPath resolves the agent by traversing up from the input file path ResolveFromPath(inputPath string) (*AgentInfo, error) } // Validator is the interface for validating test outputs type Validator interface { // Validate compares actual output against expected output // Returns nil if validation passes, error otherwise Validate(actual, expected interface{}) error // ValidateJSON validates JSON outputs with flexible comparison ValidateJSON(actual, expected interface{}) error } // OutputAdapter adapts agent output to a comparable format type OutputAdapter interface { // Adapt transforms the raw agent output to a normalized format Adapt(output interface{}) (interface{}, error) } // RunnerFactory creates Runner instances type RunnerFactory interface { // Create creates a new Runner with the given options Create(opts *Options) (Runner, error) } // ReporterFactory creates Reporter instances type ReporterFactory interface { // Create creates a new Reporter for the given format Create(format OutputFormat) (Reporter, error) // CreateFromPath creates a Reporter based on output file extension CreateFromPath(outputPath string) (Reporter, error) } // Hook allows customization of test execution type Hook interface { // BeforeAll is called before any tests run BeforeAll(ctx context.Context, cases []*Case) error // BeforeEach is called before each test case BeforeEach(ctx context.Context, tc *Case) error // AfterEach is called after each test case AfterEach(ctx context.Context, tc *Case, result *Result) error // AfterAll is called after all tests complete AfterAll(ctx context.Context, report *Report) error } // DefaultHook provides a no-op implementation of Hook type DefaultHook struct{} // BeforeAll implements Hook func (h *DefaultHook) BeforeAll(ctx context.Context, cases []*Case) error { return nil } // BeforeEach implements Hook func (h *DefaultHook) BeforeEach(ctx context.Context, tc *Case) error { return nil } // AfterEach implements Hook func (h *DefaultHook) AfterEach(ctx context.Context, tc *Case, result *Result) error { return nil } // AfterAll implements Hook func (h *DefaultHook) AfterAll(ctx context.Context, report *Report) error { return nil } ================================================ FILE: agent/test/loader.go ================================================ package test import ( "bufio" "fmt" "os" "regexp" "strings" "time" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/process" ) // JSONLLoader loads test cases from JSONL files type JSONLLoader struct{} // NewLoader creates a new JSONL loader func NewLoader() Loader { return &JSONLLoader{} } // Load loads test cases from the default input source // This is a placeholder - actual implementation would use configured path func (l *JSONLLoader) Load() ([]*Case, error) { return nil, fmt.Errorf("Load() requires explicit path, use LoadFile() instead") } // LoadFile loads test cases from a JSONL file // If path is relative and YAO_ROOT is set, resolves relative to YAO_ROOT func (l *JSONLLoader) LoadFile(path string) ([]*Case, error) { // Resolve path relative to YAO_ROOT if it's a relative path resolvedPath := ResolvePathWithYaoRoot(path) file, err := os.Open(resolvedPath) if err != nil { return nil, fmt.Errorf("failed to open file %s: %w", path, err) } defer file.Close() var cases []*Case scanner := bufio.NewScanner(file) lineNum := 0 // Increase buffer size for long lines const maxCapacity = 1024 * 1024 // 1MB buf := make([]byte, maxCapacity) scanner.Buffer(buf, maxCapacity) for scanner.Scan() { lineNum++ line := strings.TrimSpace(scanner.Text()) // Skip empty lines if line == "" { continue } // Skip comments (lines starting with //) if strings.HasPrefix(line, "//") { continue } var tc Case if err := jsoniter.UnmarshalFromString(line, &tc); err != nil { return nil, fmt.Errorf("failed to parse line %d: %w", lineNum, err) } // Validate required fields if tc.ID == "" { return nil, fmt.Errorf("line %d: missing required field 'id'", lineNum) } if tc.Input == nil { return nil, fmt.Errorf("line %d (id=%s): missing required field 'input'", lineNum, tc.ID) } cases = append(cases, &tc) } if err := scanner.Err(); err != nil { return nil, fmt.Errorf("error reading file: %w", err) } if len(cases) == 0 { return nil, fmt.Errorf("no test cases found in %s", path) } return cases, nil } // ValidateTestCases validates a slice of test cases func ValidateTestCases(cases []*Case) error { ids := make(map[string]bool) for i, tc := range cases { // Check for duplicate IDs if ids[tc.ID] { return fmt.Errorf("duplicate test case ID: %s", tc.ID) } ids[tc.ID] = true // Validate input can be parsed if _, err := tc.GetMessages(); err != nil { return fmt.Errorf("test case %s (index %d): invalid input: %w", tc.ID, i, err) } // Validate timeout format if specified if tc.Timeout != "" { // GetTimeout returns a duration, parsing error would return default // We validate by checking if the string is parseable if _, err := time.ParseDuration(tc.Timeout); err != nil { return fmt.Errorf("test case %s: invalid timeout format: %s", tc.ID, tc.Timeout) } } } return nil } // FilterTestCases filters test cases based on criteria func FilterTestCases(cases []*Case, filter func(*Case) bool) []*Case { var result []*Case for _, tc := range cases { if filter(tc) { result = append(result, tc) } } return result } // FilterSkipped returns test cases that are not skipped func FilterSkipped(cases []*Case) []*Case { return FilterTestCases(cases, func(tc *Case) bool { return !tc.Skip }) } // FilterByIDs returns test cases matching the given IDs func FilterByIDs(cases []*Case, ids []string) []*Case { idSet := make(map[string]bool) for _, id := range ids { idSet[id] = true } return FilterTestCases(cases, func(tc *Case) bool { return idSet[tc.ID] }) } // FilterByPattern returns test cases whose ID matches the given regex pattern func FilterByPattern(cases []*Case, pattern *regexp.Regexp) []*Case { return FilterTestCases(cases, func(tc *Case) bool { return pattern.MatchString(tc.ID) }) } // LoadFromAgent generates test cases using a generator agent func (l *JSONLLoader) LoadFromAgent(agentID string, targetInfo *TargetAgentInfo, params map[string]interface{}) ([]*Case, error) { return GenerateTestCases(agentID, targetInfo, params) } // LoadFromScript generates test cases using a script // scriptRef format: "module.FunctionName" (e.g., "tests.gen.Generate") func (l *JSONLLoader) LoadFromScript(scriptRef string, targetInfo *TargetAgentInfo) ([]*Case, error) { // Parse script reference parts := strings.Split(scriptRef, ".") if len(parts) < 2 { return nil, fmt.Errorf("invalid script reference format: %s (expected 'module.Function')", scriptRef) } // Build process name: scripts.module.Function processName := "scripts." + scriptRef // Execute via process p, err := process.Of(processName, targetInfo) if err != nil { return nil, fmt.Errorf("failed to create process %s: %w", processName, err) } result, err := p.Exec() if err != nil { return nil, fmt.Errorf("script execution failed: %w", err) } // Parse result as test cases return convertToCases(result) } ================================================ FILE: agent/test/output.go ================================================ package test import ( "fmt" "strings" "time" "github.com/fatih/color" jsoniter "github.com/json-iterator/go" ) // OutputWriter handles colored console output for test execution type OutputWriter struct { verbose bool } // NewOutputWriter creates a new output writer func NewOutputWriter(verbose bool) *OutputWriter { return &OutputWriter{verbose: verbose} } // Header prints a header section func (w *OutputWriter) Header(title string) { fmt.Println() color.New(color.FgCyan, color.Bold).Println("═══════════════════════════════════════════════════════════════") color.New(color.FgCyan, color.Bold).Printf(" %s\n", title) color.New(color.FgCyan, color.Bold).Println("═══════════════════════════════════════════════════════════════") } // SubHeader prints a sub-header func (w *OutputWriter) SubHeader(title string) { fmt.Println() color.New(color.FgWhite, color.Bold).Println("───────────────────────────────────────────────────────────────") color.New(color.FgWhite, color.Bold).Printf(" %s\n", title) color.New(color.FgWhite, color.Bold).Println("───────────────────────────────────────────────────────────────") } // Info prints an info message func (w *OutputWriter) Info(format string, args ...interface{}) { color.New(color.FgBlue).Printf("ℹ ") fmt.Printf(format+"\n", args...) } // Success prints a success message func (w *OutputWriter) Success(format string, args ...interface{}) { color.New(color.FgGreen).Printf("✓ ") fmt.Printf(format+"\n", args...) } // Error prints an error message func (w *OutputWriter) Error(format string, args ...interface{}) { color.New(color.FgRed).Printf("✗ ") fmt.Printf(format+"\n", args...) } // Warning prints a warning message func (w *OutputWriter) Warning(format string, args ...interface{}) { color.New(color.FgYellow).Printf("⚠ ") fmt.Printf(format+"\n", args...) } // Skip prints a skip message func (w *OutputWriter) Skip(format string, args ...interface{}) { color.New(color.FgYellow).Printf("○ ") fmt.Printf(format+"\n", args...) } // Verbose prints a verbose message (only if verbose mode is enabled) func (w *OutputWriter) Verbose(format string, args ...interface{}) { if w.verbose { color.New(color.FgHiBlack).Printf(" │ ") fmt.Printf(format+"\n", args...) } } // TestStart prints test case start func (w *OutputWriter) TestStart(id string, input string, runNum int) { inputPreview := truncateString(input, 50) if runNum > 1 { color.New(color.FgWhite).Printf("► [%s] Run %d: ", id, runNum) } else { color.New(color.FgWhite).Printf("► [%s] ", id) } color.New(color.FgHiBlack).Printf("%s", inputPreview) fmt.Print(" ") } // TestResult prints test case result func (w *OutputWriter) TestResult(status Status, duration time.Duration) { switch status { case StatusPassed: color.New(color.FgGreen, color.Bold).Printf("PASSED") case StatusFailed: color.New(color.FgRed, color.Bold).Printf("FAILED") case StatusSkipped: color.New(color.FgYellow).Printf("SKIPPED") case StatusError: color.New(color.FgRed, color.Bold).Printf("ERROR") case StatusTimeout: color.New(color.FgRed).Printf("TIMEOUT") } color.New(color.FgHiBlack).Printf(" (%s)\n", formatDuration(duration)) } // TestError prints test error details func (w *OutputWriter) TestError(err string) { color.New(color.FgRed).Printf(" └─ %s\n", err) } // TestOutput prints test output (verbose mode) func (w *OutputWriter) TestOutput(output string) { if w.verbose && output != "" { outputPreview := truncateString(output, 100) color.New(color.FgHiBlack).Printf(" └─ Output: %s\n", outputPreview) } } // Progress prints progress information func (w *OutputWriter) Progress(current, total int) { percentage := float64(current) / float64(total) * 100 color.New(color.FgHiBlack).Printf("\r Progress: %d/%d (%.0f%%)", current, total, percentage) } // Summary prints the test summary func (w *OutputWriter) Summary(summary *Summary, duration time.Duration) { w.SubHeader("Summary") // Agent info color.New(color.FgWhite).Printf(" Agent: ") color.New(color.FgCyan).Printf("%s\n", summary.AgentID) if summary.Connector != "" { color.New(color.FgWhite).Printf(" Connector: ") color.New(color.FgCyan).Printf("%s\n", summary.Connector) } // Results color.New(color.FgWhite).Printf(" Total: ") fmt.Printf("%d\n", summary.Total) color.New(color.FgWhite).Printf(" Passed: ") if summary.Passed > 0 { color.New(color.FgGreen).Printf("%d\n", summary.Passed) } else { fmt.Printf("%d\n", summary.Passed) } color.New(color.FgWhite).Printf(" Failed: ") if summary.Failed > 0 { color.New(color.FgRed).Printf("%d\n", summary.Failed) } else { fmt.Printf("%d\n", summary.Failed) } if summary.Skipped > 0 { color.New(color.FgWhite).Printf(" Skipped: ") color.New(color.FgYellow).Printf("%d\n", summary.Skipped) } if summary.Errors > 0 { color.New(color.FgWhite).Printf(" Errors: ") color.New(color.FgRed).Printf("%d\n", summary.Errors) } if summary.Timeouts > 0 { color.New(color.FgWhite).Printf(" Timeouts: ") color.New(color.FgRed).Printf("%d\n", summary.Timeouts) } // Pass rate passRate := float64(0) if summary.Total > 0 { passRate = float64(summary.Passed) / float64(summary.Total) * 100 } color.New(color.FgWhite).Printf(" Pass Rate: ") if passRate == 100 { color.New(color.FgGreen, color.Bold).Printf("%.1f%%\n", passRate) } else if passRate >= 80 { color.New(color.FgYellow).Printf("%.1f%%\n", passRate) } else { color.New(color.FgRed).Printf("%.1f%%\n", passRate) } // Duration color.New(color.FgWhite).Printf(" Duration: ") fmt.Printf("%s\n", formatDuration(duration)) // Stability info (if runs > 1) if summary.RunsPerCase > 1 { fmt.Println() color.New(color.FgWhite, color.Bold).Println(" Stability Analysis:") color.New(color.FgWhite).Printf(" Runs/Case: %d\n", summary.RunsPerCase) color.New(color.FgWhite).Printf(" Total Runs: %d\n", summary.TotalRuns) color.New(color.FgWhite).Printf(" Stable Cases: ") if summary.StableCases == summary.Total { color.New(color.FgGreen).Printf("%d\n", summary.StableCases) } else { color.New(color.FgYellow).Printf("%d\n", summary.StableCases) } color.New(color.FgWhite).Printf(" Unstable: ") if summary.UnstableCases > 0 { color.New(color.FgRed).Printf("%d\n", summary.UnstableCases) } else { fmt.Printf("%d\n", summary.UnstableCases) } } } // OutputFile prints the output file path func (w *OutputWriter) OutputFile(path string) { fmt.Println() color.New(color.FgWhite).Printf(" Output: ") color.New(color.FgCyan).Printf("%s\n", path) } // FinalResult prints the final result banner func (w *OutputWriter) FinalResult(passed bool) { fmt.Println() if passed { color.New(color.FgGreen, color.Bold).Println("═══════════════════════════════════════════════════════════════") color.New(color.FgGreen, color.Bold).Println(" ✨ ALL TESTS PASSED ✨") color.New(color.FgGreen, color.Bold).Println("═══════════════════════════════════════════════════════════════") } else { color.New(color.FgRed, color.Bold).Println("═══════════════════════════════════════════════════════════════") color.New(color.FgRed, color.Bold).Println(" ❌ TESTS FAILED") color.New(color.FgRed, color.Bold).Println("═══════════════════════════════════════════════════════════════") } fmt.Println() } // DirectOutput prints the agent output directly (for development mode) func (w *OutputWriter) DirectOutput(output interface{}) { if output == nil { return } // Try to format as JSON if it's a complex type switch v := output.(type) { case string: fmt.Println(v) case map[string]interface{}, []interface{}: // Pretty print JSON jsonBytes, err := jsoniter.MarshalIndent(v, "", " ") if err != nil { fmt.Printf("%v\n", output) } else { fmt.Println(string(jsonBytes)) } default: // Try to marshal as JSON jsonBytes, err := jsoniter.MarshalIndent(output, "", " ") if err != nil { fmt.Printf("%v\n", output) } else { fmt.Println(string(jsonBytes)) } } } // ScriptTestSummary prints the script test summary func (w *OutputWriter) ScriptTestSummary(summary *ScriptTestSummary, duration time.Duration) { w.SubHeader("Summary") // Results color.New(color.FgWhite).Printf(" Total: ") fmt.Printf("%d\n", summary.Total) color.New(color.FgWhite).Printf(" Passed: ") if summary.Passed > 0 { color.New(color.FgGreen).Printf("%d\n", summary.Passed) } else { fmt.Printf("%d\n", summary.Passed) } color.New(color.FgWhite).Printf(" Failed: ") if summary.Failed > 0 { color.New(color.FgRed).Printf("%d\n", summary.Failed) } else { fmt.Printf("%d\n", summary.Failed) } if summary.Skipped > 0 { color.New(color.FgWhite).Printf(" Skipped: ") color.New(color.FgYellow).Printf("%d\n", summary.Skipped) } // Pass rate passRate := float64(0) if summary.Total > 0 { passRate = float64(summary.Passed) / float64(summary.Total) * 100 } color.New(color.FgWhite).Printf(" Pass Rate: ") if passRate == 100 { color.New(color.FgGreen, color.Bold).Printf("%.1f%%\n", passRate) } else if passRate >= 80 { color.New(color.FgYellow).Printf("%.1f%%\n", passRate) } else { color.New(color.FgRed).Printf("%.1f%%\n", passRate) } // Duration color.New(color.FgWhite).Printf(" Duration: ") fmt.Printf("%s\n", formatDuration(duration)) } // DynamicTestStart outputs the start of a dynamic test func (w *OutputWriter) DynamicTestStart(id string, checkpointCount int) { color.New(color.FgWhite).Printf("► [%s] ", id) color.New(color.FgCyan).Printf("(dynamic, %d checkpoints)\n", checkpointCount) } // DynamicTurn outputs a single turn in dynamic testing func (w *OutputWriter) DynamicTurn(turn int, inputSummary string, checkpointsReached, total int) { if w.verbose { color.New(color.FgHiBlack).Printf("│ ├─ Turn %d: %s ", turn, inputSummary) color.New(color.FgCyan).Printf("[%d/%d checkpoints]\n", checkpointsReached, total) } } // DynamicCheckpoint outputs a checkpoint being reached func (w *OutputWriter) DynamicCheckpoint(checkpointID string) { if w.verbose { color.New(color.FgGreen).Printf("│ │ └─ ✓ checkpoint: %s\n", checkpointID) } } // DynamicTestResult outputs the result of a dynamic test func (w *OutputWriter) DynamicTestResult(status Status, turns int, checkpoints int, duration time.Duration) { color.New(color.FgHiBlack).Printf(" └─ ") switch status { case StatusPassed: color.New(color.FgGreen).Printf("PASSED") case StatusFailed: color.New(color.FgRed).Printf("FAILED") case StatusError: color.New(color.FgRed).Printf("ERROR") case StatusTimeout: color.New(color.FgRed).Printf("TIMEOUT") } color.New(color.FgHiBlack).Printf(" (%d turns, %d checkpoints, %s)\n", turns, checkpoints, formatDuration(duration)) } // StabilityResult prints stability analysis result for a test case func (w *OutputWriter) StabilityResult(sr *StabilityResult) { color.New(color.FgWhite).Printf(" [%s] ", sr.ID) // Pass rate if sr.PassRate == 100 { color.New(color.FgGreen).Printf("%.0f%%", sr.PassRate) } else if sr.PassRate >= 80 { color.New(color.FgYellow).Printf("%.0f%%", sr.PassRate) } else { color.New(color.FgRed).Printf("%.0f%%", sr.PassRate) } // Classification color.New(color.FgHiBlack).Printf(" (%d/%d) ", sr.Passed, sr.Runs) switch sr.StabilityClass { case StabilityStable: color.New(color.FgGreen).Printf("Stable") case StabilityMostlyStable: color.New(color.FgYellow).Printf("Mostly Stable") case StabilityUnstable: color.New(color.FgRed).Printf("Unstable") case StabilityHighlyUnstable: color.New(color.FgRed, color.Bold).Printf("Highly Unstable") } // Timing color.New(color.FgHiBlack).Printf(" avg:%.0fms\n", sr.AvgDurationMs) } // Helper functions func truncateString(s string, maxLen int) string { // Remove newlines and extra spaces s = strings.ReplaceAll(s, "\n", " ") s = strings.ReplaceAll(s, "\r", "") s = strings.Join(strings.Fields(s), " ") if len(s) <= maxLen { return s } return s[:maxLen-3] + "..." } func formatDuration(d time.Duration) string { if d < time.Millisecond { return fmt.Sprintf("%dµs", d.Microseconds()) } if d < time.Second { return fmt.Sprintf("%dms", d.Milliseconds()) } if d < time.Minute { return fmt.Sprintf("%.1fs", d.Seconds()) } return fmt.Sprintf("%.1fm", d.Minutes()) } ================================================ FILE: agent/test/reporter.go ================================================ package test import ( "bufio" "fmt" "html/template" "io" "strings" "time" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/yao/agent/caller" "github.com/yaoapp/yao/agent/context" ) // JSONLReporter generates JSONL format reports (default) type JSONLReporter struct{} // NewJSONLReporter creates a new JSONL reporter func NewJSONLReporter() *JSONLReporter { return &JSONLReporter{} } // Generate generates a JSONL report (writes to stdout or file) func (r *JSONLReporter) Generate(report *Report) error { return nil // JSONL is written during test execution } // Write writes the report in JSONL format func (r *JSONLReporter) Write(report *Report, w io.Writer) error { writer := bufio.NewWriter(w) defer writer.Flush() // Start event startEvent := map[string]interface{}{ "type": "start", "timestamp": report.Metadata.StartedAt.Format(time.RFC3339), "agent_id": report.Summary.AgentID, "total_cases": report.Summary.Total, } if err := writeJSONLineToWriter(writer, startEvent); err != nil { return err } // Result events if report.Results != nil { for _, result := range report.Results { resultEvent := map[string]interface{}{ "type": "result", "id": result.ID, "status": result.Status, "duration_ms": result.DurationMs, } if result.Output != nil { resultEvent["output"] = result.Output } if result.Error != "" { resultEvent["error"] = result.Error } if err := writeJSONLineToWriter(writer, resultEvent); err != nil { return err } } } // Stability results if report.StabilityResults != nil { for _, sr := range report.StabilityResults { stabilityEvent := map[string]interface{}{ "type": "stability", "id": sr.ID, "runs": sr.Runs, "passed": sr.Passed, "failed": sr.Failed, "pass_rate": sr.PassRate, "stable": sr.Stable, "stability_class": sr.StabilityClass, "avg_duration_ms": sr.AvgDurationMs, } if err := writeJSONLineToWriter(writer, stabilityEvent); err != nil { return err } } } // Summary event summaryEvent := map[string]interface{}{ "type": "summary", "total": report.Summary.Total, "passed": report.Summary.Passed, "failed": report.Summary.Failed, "skipped": report.Summary.Skipped, "errors": report.Summary.Errors, "timeouts": report.Summary.Timeouts, "duration_ms": report.Summary.DurationMs, } if report.Summary.RunsPerCase > 1 { summaryEvent["runs_per_case"] = report.Summary.RunsPerCase summaryEvent["total_runs"] = report.Summary.TotalRuns summaryEvent["overall_pass_rate"] = report.Summary.OverallPassRate summaryEvent["stable_cases"] = report.Summary.StableCases summaryEvent["unstable_cases"] = report.Summary.UnstableCases } return writeJSONLineToWriter(writer, summaryEvent) } // writeJSONLineToWriter writes a JSON line to the writer func writeJSONLineToWriter(writer *bufio.Writer, data interface{}) error { line, err := jsoniter.Marshal(data) if err != nil { return err } _, err = writer.Write(line) if err != nil { return err } _, err = writer.WriteString("\n") return err } // JSONReporter generates full JSON format reports type JSONReporter struct{} // NewJSONReporter creates a new JSON reporter func NewJSONReporter() *JSONReporter { return &JSONReporter{} } // Generate generates a JSON report func (r *JSONReporter) Generate(report *Report) error { return nil } // Write writes the report in JSON format func (r *JSONReporter) Write(report *Report, w io.Writer) error { encoder := jsoniter.NewEncoder(w) encoder.SetIndent("", " ") return encoder.Encode(report) } // MarkdownReporter generates Markdown format reports type MarkdownReporter struct{} // NewMarkdownReporter creates a new Markdown reporter func NewMarkdownReporter() *MarkdownReporter { return &MarkdownReporter{} } // Generate generates a Markdown report func (r *MarkdownReporter) Generate(report *Report) error { return nil } // Write writes the report in Markdown format func (r *MarkdownReporter) Write(report *Report, w io.Writer) error { var sb strings.Builder // Header sb.WriteString("# Agent Test Report\n\n") // Summary sb.WriteString("## Summary\n\n") sb.WriteString("| Metric | Value |\n") sb.WriteString("| ------ | ----- |\n") sb.WriteString(fmt.Sprintf("| Agent | %s |\n", report.Summary.AgentID)) if report.Summary.Connector != "" { sb.WriteString(fmt.Sprintf("| Connector | %s |\n", report.Summary.Connector)) } sb.WriteString(fmt.Sprintf("| Total | %d |\n", report.Summary.Total)) sb.WriteString(fmt.Sprintf("| Passed | %d |\n", report.Summary.Passed)) sb.WriteString(fmt.Sprintf("| Failed | %d |\n", report.Summary.Failed)) if report.Summary.Skipped > 0 { sb.WriteString(fmt.Sprintf("| Skipped | %d |\n", report.Summary.Skipped)) } if report.Summary.Errors > 0 { sb.WriteString(fmt.Sprintf("| Errors | %d |\n", report.Summary.Errors)) } if report.Summary.Timeouts > 0 { sb.WriteString(fmt.Sprintf("| Timeouts | %d |\n", report.Summary.Timeouts)) } passRate := float64(0) if report.Summary.Total > 0 { passRate = float64(report.Summary.Passed) / float64(report.Summary.Total) * 100 } sb.WriteString(fmt.Sprintf("| Pass Rate | %.1f%% |\n", passRate)) sb.WriteString(fmt.Sprintf("| Duration | %dms |\n", report.Summary.DurationMs)) sb.WriteString("\n") // Environment if report.Environment != nil { sb.WriteString("## Environment\n\n") sb.WriteString("| Setting | Value |\n") sb.WriteString("| ------- | ----- |\n") sb.WriteString(fmt.Sprintf("| User | %s |\n", report.Environment.UserID)) sb.WriteString(fmt.Sprintf("| Team | %s |\n", report.Environment.TeamID)) sb.WriteString(fmt.Sprintf("| Locale | %s |\n", report.Environment.Locale)) sb.WriteString("\n") } // Results sb.WriteString("## Results\n\n") if report.Results != nil { for _, result := range report.Results { statusIcon := "✅" switch result.Status { case StatusFailed: statusIcon = "❌" case StatusError: statusIcon = "💥" case StatusTimeout: statusIcon = "⏱️" case StatusSkipped: statusIcon = "⏭️" } sb.WriteString(fmt.Sprintf("### %s %s - %s (%dms)\n\n", statusIcon, result.ID, result.Status, result.DurationMs)) if result.Error != "" { sb.WriteString(fmt.Sprintf("**Error:** %s\n\n", result.Error)) } } } // Stability results if report.StabilityResults != nil { sb.WriteString("## Stability Analysis\n\n") sb.WriteString("| ID | Pass Rate | Runs | Status | Avg Duration |\n") sb.WriteString("| -- | --------- | ---- | ------ | ------------ |\n") for _, sr := range report.StabilityResults { status := string(sr.StabilityClass) sb.WriteString(fmt.Sprintf("| %s | %.0f%% | %d/%d | %s | %.0fms |\n", sr.ID, sr.PassRate, sr.Passed, sr.Runs, status, sr.AvgDurationMs)) } sb.WriteString("\n") } // Metadata sb.WriteString("## Metadata\n\n") sb.WriteString(fmt.Sprintf("- **Started:** %s\n", report.Metadata.StartedAt.Format(time.RFC3339))) sb.WriteString(fmt.Sprintf("- **Completed:** %s\n", report.Metadata.CompletedAt.Format(time.RFC3339))) if report.Metadata.InputFile != "" { sb.WriteString(fmt.Sprintf("- **Input File:** %s\n", report.Metadata.InputFile)) } if report.Metadata.OutputFile != "" { sb.WriteString(fmt.Sprintf("- **Output File:** %s\n", report.Metadata.OutputFile)) } _, err := w.Write([]byte(sb.String())) return err } // HTMLReporter generates HTML format reports type HTMLReporter struct{} // NewHTMLReporter creates a new HTML reporter func NewHTMLReporter() *HTMLReporter { return &HTMLReporter{} } // Generate generates an HTML report func (r *HTMLReporter) Generate(report *Report) error { return nil } // Write writes the report in HTML format func (r *HTMLReporter) Write(report *Report, w io.Writer) error { tmpl, err := template.New("report").Parse(htmlTemplate) if err != nil { return fmt.Errorf("failed to parse HTML template: %w", err) } // Calculate pass rate passRate := float64(0) if report.Summary.Total > 0 { passRate = float64(report.Summary.Passed) / float64(report.Summary.Total) * 100 } data := map[string]interface{}{ "Report": report, "PassRate": passRate, } return tmpl.Execute(w, data) } // HTML template for reports const htmlTemplate = ` Agent Test Report - {{.Report.Summary.AgentID}}

Agent Test Report

{{.Report.Summary.AgentID}} {{if .Report.Summary.Connector}}• {{.Report.Summary.Connector}}{{end}}

{{.Report.Summary.Total}}
Total Tests
{{.Report.Summary.Passed}}
Passed
{{.Report.Summary.Failed}}
Failed
{{printf "%.1f" .PassRate}}%
Pass Rate
{{.Report.Summary.DurationMs}}ms
Duration

Test Results

{{range .Report.Results}} {{end}} {{range .Report.StabilityResults}} {{end}}
ID Status Duration Details
{{.ID}} {{.Status}} {{.DurationMs}}ms {{if .Error}}
{{.Error}}
{{end}}
{{.ID}} {{.StabilityClass}} {{printf "%.0f" .AvgDurationMs}}ms avg {{.Passed}}/{{.Runs}} passed ({{printf "%.0f" .PassRate}}%)

Metadata

` // AgentReporter uses a custom agent to generate reports type AgentReporter struct { agentID string format string verbose bool ctx *context.Context // Test context for agent call } // NewAgentReporter creates a new agent-based reporter func NewAgentReporter(agentID, format string, verbose bool) *AgentReporter { return &AgentReporter{ agentID: agentID, format: format, verbose: verbose, } } // SetContext sets the context for agent calls func (r *AgentReporter) SetContext(ctx *context.Context) { r.ctx = ctx } // Generate generates a report using the agent func (r *AgentReporter) Generate(report *Report) error { return nil } // Write writes the report using the agent func (r *AgentReporter) Write(report *Report, w io.Writer) error { // Check if AgentGetterFunc is initialized if caller.AgentGetterFunc == nil { return fmt.Errorf("AgentGetterFunc not initialized, cannot call reporter agent") } // Get the reporter agent agent, err := caller.AgentGetterFunc(r.agentID) if err != nil { return fmt.Errorf("failed to get reporter agent %s: %w", r.agentID, err) } // Build input for the reporter agent input := &ReporterInput{ Report: report, Format: r.format, Options: &ReporterOptions{ Verbose: r.verbose, IncludeOutputs: r.verbose, IncludeInputs: r.verbose, }, } // Convert input to JSON for the agent inputJSON, err := jsoniter.Marshal(input) if err != nil { return fmt.Errorf("failed to marshal reporter input: %w", err) } // Create message for the agent messages := []context.Message{ { Role: context.RoleUser, Content: string(inputJSON), }, } // Create context if not provided ctx := r.ctx if ctx == nil { // Create a minimal context for the reporter agent call ctx = NewTestContext("reporter", r.agentID, NewEnvironment("", "")) defer ctx.Release() } // Call the agent with skip options (no history, no output) options := &context.Options{ Skip: &context.Skip{ History: true, Output: true, }, } response, err := agent.Stream(ctx, messages, options) if err != nil { return fmt.Errorf("reporter agent call failed: %w", err) } // Extract content from response content, err := r.extractContent(response) if err != nil { return fmt.Errorf("failed to extract report content: %w", err) } // Write the content to output _, err = w.Write([]byte(content)) if err != nil { return fmt.Errorf("failed to write report: %w", err) } return nil } // extractContent extracts the report content from the agent's *context.Response // Now that agent.Stream() returns *context.Response directly, // we can access fields without type assertions. func (r *AgentReporter) extractContent(response *context.Response) (string, error) { if response == nil { return "", fmt.Errorf("agent returned nil response") } // Priority 1: Check Next field (custom hook data) if response.Next != nil { return r.contentToString(response.Next) } // Priority 2: Extract from completion content if response.Completion != nil && response.Completion.Content != nil { return r.contentToString(response.Completion.Content) } return "", fmt.Errorf("no content in response") } // contentToString converts various content types to string func (r *AgentReporter) contentToString(content interface{}) (string, error) { switch v := content.(type) { case string: return v, nil case []byte: return string(v), nil default: jsonBytes, err := jsoniter.Marshal(content) if err != nil { return fmt.Sprintf("%v", content), nil } return string(jsonBytes), nil } } // GetReporter returns a reporter based on output format func GetReporter(format OutputFormat) Reporter { switch format { case FormatJSON: return NewJSONReporter() case FormatHTML: return NewHTMLReporter() case FormatMarkdown: return NewMarkdownReporter() default: return NewJSONLReporter() } } // GetReporterFromPath returns a reporter based on file extension func GetReporterFromPath(outputPath string) Reporter { format := GetOutputFormat(outputPath) return GetReporter(format) } // GetReporterWithAgent returns an agent-based reporter if agentID is specified, // otherwise returns a built-in reporter based on output format func GetReporterWithAgent(agentID, outputPath string, verbose bool) Reporter { if agentID != "" { format := GetOutputFormat(outputPath) return NewAgentReporter(agentID, string(format), verbose) } return GetReporterFromPath(outputPath) } ================================================ FILE: agent/test/resolver.go ================================================ package test import ( "fmt" "os" "path/filepath" "strings" "time" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/application" ) // PathResolver resolves agent information from file paths type PathResolver struct{} // NewResolver creates a new path resolver func NewResolver() Resolver { return &PathResolver{} } // Resolve resolves the agent from options // Priority: explicit AgentID > path-based detection (from input file or cwd) func (r *PathResolver) Resolve(opts *Options) (*AgentInfo, error) { // If explicit agent ID is provided, use it if opts.AgentID != "" { return r.ResolveByID(opts.AgentID) } // For file mode, resolve from input file path if opts.InputMode == InputModeFile { if opts.Input == "" { return nil, fmt.Errorf("no agent ID or input file specified") } return r.ResolveFromPath(opts.Input) } // For message mode, try to resolve from current working directory return r.ResolveFromCwd() } // ResolveFromCwd resolves the agent from the current working directory // It looks for package.yao in the current directory or parent directories func (r *PathResolver) ResolveFromCwd() (*AgentInfo, error) { cwd, err := os.Getwd() if err != nil { return nil, fmt.Errorf("failed to get current working directory: %w", err) } info, err := r.ResolveFromPath(cwd) if err != nil { return nil, fmt.Errorf("no agent found in current directory. Use -n to specify agent explicitly") } return info, nil } // ResolveFromPath resolves the agent by traversing up from the input file path // It looks for package.yao in parent directories // If YAO_ROOT is set, it also considers paths relative to YAO_ROOT func (r *PathResolver) ResolveFromPath(inputPath string) (*AgentInfo, error) { // Get absolute path absPath, err := filepath.Abs(inputPath) if err != nil { return nil, fmt.Errorf("failed to get absolute path: %w", err) } // Start from the directory containing the input file dir := filepath.Dir(absPath) // Traverse up to find package.yao for { packagePath := filepath.Join(dir, "package.yao") if _, err := os.Stat(packagePath); err == nil { // Found package.yao return r.loadAgentFromPath(dir, packagePath) } // Move to parent directory parent := filepath.Dir(dir) if parent == dir { // Reached root, no package.yao found break } dir = parent } // If YAO_ROOT is set, try resolving relative to it yaoRoot := os.Getenv("YAO_ROOT") if yaoRoot != "" { // Try the input path relative to YAO_ROOT relPath := inputPath // If inputPath is absolute, try to make it relative if filepath.IsAbs(inputPath) { // Check if inputPath is under YAO_ROOT if rel, err := filepath.Rel(yaoRoot, inputPath); err == nil && !strings.HasPrefix(rel, "..") { relPath = rel } } // Traverse up from YAO_ROOT + relPath dir = filepath.Join(yaoRoot, filepath.Dir(relPath)) for { packagePath := filepath.Join(dir, "package.yao") if _, err := os.Stat(packagePath); err == nil { // Found package.yao return r.loadAgentFromPath(dir, packagePath) } // Move to parent directory, but don't go above YAO_ROOT parent := filepath.Dir(dir) if parent == dir || !strings.HasPrefix(parent, yaoRoot) { break } dir = parent } } return nil, fmt.Errorf("no package.yao found in path hierarchy of %s", inputPath) } // ResolveByID resolves an agent by its ID // This would integrate with the assistant loading system func (r *PathResolver) ResolveByID(agentID string) (*AgentInfo, error) { // This is a placeholder - actual implementation would use assistant.Get() // For now, return basic info return &AgentInfo{ ID: agentID, Name: agentID, }, nil } // loadAgentFromPath loads agent information from a package.yao file func (r *PathResolver) loadAgentFromPath(agentDir, packagePath string) (*AgentInfo, error) { // Read package.yao data, err := os.ReadFile(packagePath) if err != nil { return nil, fmt.Errorf("failed to read package.yao: %w", err) } // Parse package.yao var pkg PackageYao if err := jsoniter.Unmarshal(data, &pkg); err != nil { return nil, fmt.Errorf("failed to parse package.yao: %w", err) } // Derive agent ID from directory path agentID := deriveAgentID(agentDir) return &AgentInfo{ ID: agentID, Name: pkg.Name, Description: pkg.Description, Path: agentDir, Connector: pkg.Connector, Type: pkg.Type, }, nil } // PackageYao represents the structure of package.yao type PackageYao struct { Name string `json:"name"` Description string `json:"description,omitempty"` Connector string `json:"connector,omitempty"` Type string `json:"type,omitempty"` Uses map[string]interface{} `json:"uses,omitempty"` Options map[string]interface{} `json:"options,omitempty"` } // deriveAgentID derives an agent ID from the directory path // e.g., /app/assistants/workers/system/keyword -> workers.system.keyword func deriveAgentID(dir string) string { // Find "assistants" in path and use everything after it parts := strings.Split(filepath.ToSlash(dir), "/") // Look for "assistants" marker startIdx := -1 for i, part := range parts { if part == "assistants" { startIdx = i + 1 break } } if startIdx == -1 || startIdx >= len(parts) { // No "assistants" found, use the last directory name return filepath.Base(dir) } // Join remaining parts with dots return strings.Join(parts[startIdx:], ".") } // GetOutputFormat determines the output format from file extension func GetOutputFormat(outputPath string) OutputFormat { ext := strings.ToLower(filepath.Ext(outputPath)) switch ext { case ".json": return FormatJSON case ".html", ".htm": return FormatHTML case ".md", ".markdown": return FormatMarkdown default: return FormatJSON // Default to JSON } } // ValidateOptions validates test options func ValidateOptions(opts *Options) error { if opts.Input == "" { return fmt.Errorf("input is required (-i flag)") } // For file mode, check input file exists if opts.InputMode == InputModeFile { resolvedPath := ResolvePathWithYaoRoot(opts.Input) if _, err := os.Stat(resolvedPath); os.IsNotExist(err) { return fmt.Errorf("input file not found: %s", opts.Input) } } // Note: For message mode, agent can be resolved from cwd, so no validation here // The resolver will return an error if agent cannot be found // Validate timeout if opts.Timeout < 0 { return fmt.Errorf("timeout cannot be negative") } // Validate parallel if opts.Parallel < 0 { return fmt.Errorf("parallel cannot be negative") } return nil } // DefaultOptions returns options with default values func DefaultOptions() *Options { return &Options{ Timeout: 120 * time.Second, // 2 minutes default timeout Parallel: 1, Runs: 1, Verbose: false, FailFast: false, } } // DetectInputMode detects the input mode from the input string // Returns: // - InputModeScript: if input starts with "scripts." // - InputModeFile: if input ends with ".jsonl" or is an existing file // - InputModeMessage: otherwise (direct message mode) func DetectInputMode(input string) InputMode { // Check for script test prefix if strings.HasPrefix(input, "scripts.") { return InputModeScript } // If input ends with .jsonl or .json, treat as file if strings.HasSuffix(input, ".jsonl") || strings.HasSuffix(input, ".json") { return InputModeFile } // If input contains path separator, check if file exists if strings.Contains(input, string(filepath.Separator)) || strings.Contains(input, "/") { if _, err := os.Stat(input); err == nil { return InputModeFile } } // Otherwise treat as direct message return InputModeMessage } // MergeOptions merges user options with defaults func MergeOptions(opts *Options, defaults *Options) *Options { result := *defaults if opts.Input != "" { result.Input = opts.Input result.InputMode = DetectInputMode(opts.Input) } if opts.OutputFile != "" { result.OutputFile = opts.OutputFile } if opts.AgentID != "" { result.AgentID = opts.AgentID } if opts.Connector != "" { result.Connector = opts.Connector } if opts.UserID != "" { result.UserID = opts.UserID } if opts.TeamID != "" { result.TeamID = opts.TeamID } if opts.Locale != "" { result.Locale = opts.Locale } if opts.Timeout > 0 { result.Timeout = opts.Timeout } if opts.Parallel > 0 { result.Parallel = opts.Parallel } if opts.Runs > 0 { result.Runs = opts.Runs } if opts.ReporterID != "" { result.ReporterID = opts.ReporterID } if opts.ContextFile != "" { result.ContextFile = opts.ContextFile } if opts.Run != "" { result.Run = opts.Run } if opts.Verbose { result.Verbose = opts.Verbose } if opts.FailFast { result.FailFast = opts.FailFast } return &result } // GenerateDefaultOutputPath generates the default output path based on input file // Format: {input_directory}/output-{timestamp}.jsonl // Timestamp format: YYYYMMDDHHMMSS func GenerateDefaultOutputPath(inputPath string) string { // Resolve input path considering YAO_ROOT resolvedPath := ResolvePathWithYaoRoot(inputPath) dir := filepath.Dir(resolvedPath) timestamp := time.Now().Format("20060102150405") filename := fmt.Sprintf("output-%s.jsonl", timestamp) return filepath.Join(dir, filename) } // ResolveOutputPath resolves the output path based on input mode // - File mode: generate default path in same directory as input // - Message mode: return empty string (output to stdout) // If outputPath is explicitly specified, always use it func ResolveOutputPath(opts *Options) string { if opts.OutputFile != "" { return opts.OutputFile } // For file mode, generate default output path if opts.InputMode == InputModeFile { return GenerateDefaultOutputPath(opts.Input) } // For message mode, output to stdout (empty string) return "" } // CreateTestCaseFromMessage creates a single test case from a direct message func CreateTestCaseFromMessage(message string) *Case { return &Case{ ID: "T001", Input: message, } } // ResolvePathWithYaoRoot resolves a file path relative to current directory // No fallback to YAO_ROOT - paths are always resolved from current working directory func ResolvePathWithYaoRoot(path string) string { if filepath.IsAbs(path) { return path } // Try resolving relative to the application root first if application.App != nil { appRoot := application.App.Root() if appRoot != "" { candidate := filepath.Join(appRoot, path) if _, err := os.Stat(candidate); err == nil { return candidate } } } // Fallback: resolve relative to cwd absPath, err := filepath.Abs(path) if err != nil { return path } return absPath } ================================================ FILE: agent/test/runner.go ================================================ package test import ( stdContext "context" "fmt" "os" "path/filepath" "reflect" "regexp" "strings" "sync" "time" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/yao/agent/assistant" "github.com/yaoapp/yao/agent/context" ) // Executor executes test cases against an agent type Executor struct { opts *Options output *OutputWriter resolver Resolver loader Loader hookExecutor *HookExecutor agentPath string // Path to the agent being tested } // NewRunner creates a new test runner func NewRunner(opts *Options) *Executor { return &Executor{ opts: opts, output: NewOutputWriter(opts.Verbose), resolver: NewResolver(), loader: NewLoader(), hookExecutor: NewHookExecutor(opts.Verbose), } } // Run executes all test cases and returns a report func (r *Executor) Run() (*Report, error) { // For script test mode, use script runner if r.opts.InputMode == InputModeScript { return r.RunScriptTests() } // For direct message mode, use simplified output (development mode) if r.opts.InputMode == InputModeMessage { return r.RunDirect() } return r.RunTests() } // RunScriptTests executes script tests and returns a report func (r *Executor) RunScriptTests() (*Report, error) { scriptRunner := NewScriptRunner(r.opts) scriptReport, err := scriptRunner.Run() if err != nil { return nil, err } // Convert to standard report for unified output handling report := scriptReport.ToReport() // Write output if specified if r.opts.OutputFile != "" { err = r.writeOutput(report) if err != nil { r.output.Error("Failed to write output: %s", err.Error()) } else { r.output.OutputFile(r.opts.OutputFile) } } // Print final result r.output.FinalResult(!report.HasFailures()) return report, nil } // RunDirect executes a single direct message and outputs the result directly // This is optimized for development/debugging scenarios func (r *Executor) RunDirect() (*Report, error) { // Resolve agent agentInfo, err := r.resolver.Resolve(r.opts) if err != nil { return nil, fmt.Errorf("failed to resolve agent: %w", err) } // Get assistant ast, err := assistant.Get(agentInfo.ID) if err != nil { return nil, fmt.Errorf("failed to get assistant: %w", err) } // Create test case from message tc := CreateTestCaseFromMessage(r.opts.Input) // Create context chatID := GenerateChatID(tc.ID, 1) ctx := NewTestContextFromOptions(chatID, agentInfo.ID, r.opts, tc) defer ctx.Release() // Build context options opts := buildContextOptions(tc, r.opts) // Create timeout context timeout := tc.GetTimeout(r.opts.Timeout) timeoutCtx, cancel := stdContext.WithTimeout(ctx.Context, timeout) defer cancel() ctx.Context = timeoutCtx // Parse input to messages with file loading support inputOpts := r.getInputOptions() messages, err := tc.GetMessagesWithOptions(inputOpts) if err != nil { return nil, fmt.Errorf("failed to parse input: %w", err) } // Run the agent response, err := ast.Stream(ctx, messages, opts) // Check for timeout if timeoutCtx.Err() != nil { return nil, fmt.Errorf("timeout after %s", timeout) } // Check for error if err != nil { return nil, err } // Extract and print output directly output := extractOutput(response) r.output.DirectOutput(output) // Determine connector: user-specified > agent default connector := r.opts.Connector if connector == "" { connector = agentInfo.Connector } // Return minimal report (for exit code handling) return &Report{ Summary: &Summary{ Total: 1, Passed: 1, AgentID: agentInfo.ID, Connector: connector, }, }, nil } // RunTests executes test cases from file and generates a report func (r *Executor) RunTests() (*Report, error) { startTime := time.Now() // Print header r.output.Header("Agent Test") // Resolve agent agentInfo, err := r.resolver.Resolve(r.opts) if err != nil { return nil, fmt.Errorf("failed to resolve agent: %w", err) } r.output.Info("Agent: %s", agentInfo.ID) r.agentPath = agentInfo.Path // Store agent path for hook execution if r.opts.Connector != "" { r.output.Info("Connector: %s (override)", r.opts.Connector) } else if agentInfo.Connector != "" { r.output.Info("Connector: %s", agentInfo.Connector) } // Load test cases based on input source var testCases []*Case inputSource := ParseInputSource(r.opts.Input) switch inputSource.Type { case InputSourceAgent: // Generate test cases using agent r.output.Info("Generating test cases from agent: %s", inputSource.Value) targetInfo := &TargetAgentInfo{ ID: agentInfo.ID, Description: agentInfo.Description, } testCases, err = r.loader.LoadFromAgent(inputSource.Value, targetInfo, inputSource.Params) if err != nil { return nil, fmt.Errorf("failed to generate test cases: %w", err) } r.output.Info("Generated: %d test cases", len(testCases)) case InputSourceScript: // Generate test cases using script (if it's a generator script, not test script) // Note: scripts. prefix without "scripts:" is handled by RunScriptTests if strings.HasPrefix(r.opts.Input, "scripts:") { scriptRef := strings.TrimPrefix(r.opts.Input, "scripts:") r.output.Info("Generating test cases from script: %s", scriptRef) targetInfo := &TargetAgentInfo{ ID: agentInfo.ID, Description: agentInfo.Description, } testCases, err = r.loader.LoadFromScript(scriptRef, targetInfo) if err != nil { return nil, fmt.Errorf("failed to generate test cases from script: %w", err) } r.output.Info("Generated: %d test cases", len(testCases)) } else { // This is a test script (scripts.xxx format), handled by RunScriptTests return nil, fmt.Errorf("script test mode should be handled by RunScriptTests") } default: // File mode - load from JSONL testCases, err = r.loader.LoadFile(r.opts.Input) if err != nil { return nil, fmt.Errorf("failed to load test cases: %w", err) } r.output.Info("Input: %s (%d test cases)", r.opts.Input, len(testCases)) } // Handle dry-run mode - just output the generated test cases if r.opts.DryRun { r.output.Info("Dry-run mode: outputting generated test cases") return r.outputDryRun(testCases, agentInfo) } // Filter skipped tests activeTests := FilterSkipped(testCases) skippedCount := len(testCases) - len(activeTests) if skippedCount > 0 { r.output.Warning("Skipped: %d test cases", skippedCount) } // Filter by --run pattern if specified if r.opts.Run != "" { runPattern, err := regexp.Compile(r.opts.Run) if err != nil { return nil, fmt.Errorf("invalid --run pattern %q: %w", r.opts.Run, err) } activeTests = FilterByPattern(activeTests, runPattern) if len(activeTests) == 0 { return nil, fmt.Errorf("no test cases match pattern %q", r.opts.Run) } r.output.Info("Filter: %q (%d test cases match)", r.opts.Run, len(activeTests)) } // Load context config if specified if r.opts.ContextFile != "" { ctxConfig, err := LoadContextConfig(r.opts.ContextFile) if err != nil { return nil, fmt.Errorf("failed to load context file: %w", err) } r.opts.ContextData = ctxConfig r.output.Info("Context: %s", r.opts.ContextFile) } // Set options on hook executor (for context data access in hooks) r.hookExecutor.SetOptions(r.opts) // Print test info if r.opts.Runs > 1 { r.output.Info("Runs: %d per test case (stability analysis)", r.opts.Runs) } r.output.Info("Timeout: %s", r.opts.Timeout) if r.opts.Parallel > 1 { r.output.Info("Parallel: %d", r.opts.Parallel) } // Get assistant ast, err := assistant.Get(agentInfo.ID) if err != nil { return nil, fmt.Errorf("failed to get assistant: %w", err) } // Determine connector: user-specified > agent default connector := r.opts.Connector if connector == "" { connector = agentInfo.Connector } // Create report report := &Report{ Summary: &Summary{ Total: len(testCases), AgentID: agentInfo.ID, AgentPath: agentInfo.Path, Connector: connector, RunsPerCase: r.opts.Runs, }, Environment: NewEnvironment(r.opts.UserID, r.opts.TeamID), Metadata: &ReportMetadata{ StartedAt: startTime, InputFile: r.opts.Input, Options: r.opts, }, } // Execute global BeforeAll if specified var globalBeforeData interface{} if r.opts.BeforeAll != "" { r.output.Info("BeforeAll: %s", r.opts.BeforeAll) var err error globalBeforeData, err = r.hookExecutor.ExecuteBeforeAll(r.opts.BeforeAll, activeTests, agentInfo.Path) if err != nil { return nil, fmt.Errorf("beforeAll script failed: %w", err) } } // Ensure AfterAll runs even if tests fail defer func() { if r.opts.AfterAll != "" { r.output.Info("AfterAll: %s", r.opts.AfterAll) if err := r.hookExecutor.ExecuteAfterAll(r.opts.AfterAll, report.Results, globalBeforeData, agentInfo.Path); err != nil { r.output.Warning("afterAll script failed: %s", err.Error()) } } }() // Run tests r.output.SubHeader("Running Tests") if r.opts.Runs > 1 { // Stability testing mode report.StabilityResults = r.runStabilityTests(ast, activeTests, agentInfo.ID) r.calculateStabilitySummary(report) } else { // Single run mode report.Results = r.runSingleTests(ast, activeTests, agentInfo.ID) r.calculateSingleSummary(report) } // Add skipped count report.Summary.Skipped = skippedCount // Complete report report.Summary.DurationMs = time.Since(startTime).Milliseconds() report.Metadata.CompletedAt = time.Now() // Print summary r.output.Summary(report.Summary, time.Since(startTime)) // Write output if r.opts.OutputFile != "" { err = r.writeOutput(report) if err != nil { r.output.Error("Failed to write output: %s", err.Error()) } else { r.output.OutputFile(r.opts.OutputFile) } } // Print final result r.output.FinalResult(!report.HasFailures()) return report, nil } // runSingleTests runs each test case once func (r *Executor) runSingleTests(ast *assistant.Assistant, testCases []*Case, agentID string) []*Result { results := make([]*Result, 0, len(testCases)) if r.opts.Parallel > 1 { // Parallel execution results = r.runParallel(ast, testCases, agentID) } else { // Sequential execution for i, tc := range testCases { result := r.runSingleTest(ast, tc, agentID, 1) results = append(results, result) // Check fail-fast if r.opts.FailFast && result.Status != StatusPassed && result.Status != StatusSkipped { r.output.Warning("Stopping due to --fail-fast (failed at test %d/%d)", i+1, len(testCases)) break } } } return results } // runParallel runs tests in parallel func (r *Executor) runParallel(ast *assistant.Assistant, testCases []*Case, agentID string) []*Result { results := make([]*Result, len(testCases)) var wg sync.WaitGroup sem := make(chan struct{}, r.opts.Parallel) for i, tc := range testCases { wg.Add(1) go func(idx int, testCase *Case) { defer wg.Done() sem <- struct{}{} // Acquire defer func() { <-sem }() // Release results[idx] = r.runSingleTest(ast, testCase, agentID, 1) }(i, tc) } wg.Wait() return results } // runSingleTest runs a single test case func (r *Executor) runSingleTest(ast *assistant.Assistant, tc *Case, agentID string, runNum int) *Result { // Check if this is a dynamic mode test if tc.IsDynamicMode() { return r.runDynamicTest(ast, tc, agentID) } // Get input summary for display inputSummary := SummarizeInput(tc.Input, 50) r.output.TestStart(tc.ID, inputSummary, runNum) startTime := time.Now() // Create result result := &Result{ ID: tc.ID, Input: tc.Input, Expected: tc.Expected, Options: tc.Options, } // Execute before script if specified var beforeData interface{} if tc.Before != "" { var err error beforeData, err = r.hookExecutor.ExecuteBefore(tc.Before, tc, r.agentPath) if err != nil { result.Status = StatusError result.Error = fmt.Sprintf("before script failed: %s", err.Error()) result.DurationMs = time.Since(startTime).Milliseconds() r.output.TestResult(result.Status, time.Since(startTime)) r.output.TestError(result.Error) // Note: after script is NOT called when before fails return result } } // Ensure after script runs even if test fails (but only if before succeeded) defer func() { if tc.After != "" && (tc.Before == "" || beforeData != nil || result.Status != StatusError || !isBeforeError(result.Error)) { if err := r.hookExecutor.ExecuteAfter(tc.After, tc, result, beforeData, r.agentPath); err != nil { r.output.Warning("after script failed: %s", err.Error()) } } }() // Parse input to messages with file loading support // BaseDir is derived from the input file directory inputOpts := r.getInputOptions() messages, err := tc.GetMessagesWithOptions(inputOpts) if err != nil { result.Status = StatusError result.Error = fmt.Sprintf("failed to parse input: %s", err.Error()) result.DurationMs = time.Since(startTime).Milliseconds() r.output.TestResult(result.Status, time.Since(startTime)) r.output.TestError(result.Error) return result } // Create context chatID := GenerateChatID(tc.ID, runNum) ctx := NewTestContextFromOptions(chatID, agentID, r.opts, tc) defer ctx.Release() // Build context options from test case and runner options opts := buildContextOptions(tc, r.opts) // Create timeout context timeout := tc.GetTimeout(r.opts.Timeout) timeoutCtx, cancel := stdContext.WithTimeout(ctx.Context, timeout) defer cancel() ctx.Context = timeoutCtx // Run the test response, err := ast.Stream(ctx, messages, opts) duration := time.Since(startTime) result.DurationMs = duration.Milliseconds() // Check for timeout if timeoutCtx.Err() != nil { result.Status = StatusTimeout result.Error = fmt.Sprintf("timeout after %s", timeout) r.output.TestResult(result.Status, duration) r.output.TestError(result.Error) return result } // Check for error if err != nil { result.Status = StatusError result.Error = err.Error() r.output.TestResult(result.Status, duration) r.output.TestError(result.Error) return result } // Extract output result.Output = extractOutput(response) // Validate result using asserter (with response for tool_called assertions) asserter := NewAsserter().WithResponse(response) passed, errMsg := asserter.Validate(tc, result.Output) if passed { result.Status = StatusPassed } else { result.Status = StatusFailed result.Error = errMsg } r.output.TestResult(result.Status, duration) if result.Status == StatusFailed { r.output.TestError(result.Error) } r.output.TestOutput(fmt.Sprintf("%v", result.Output)) return result } // runDynamicTest runs a dynamic (simulator-driven) test case func (r *Executor) runDynamicTest(ast *assistant.Assistant, tc *Case, agentID string) *Result { // Output test start for dynamic mode r.output.DynamicTestStart(tc.ID, len(tc.Checkpoints)) startTime := time.Now() // Execute before script if specified var beforeData interface{} if tc.Before != "" { var err error beforeData, err = r.hookExecutor.ExecuteBefore(tc.Before, tc, r.agentPath) if err != nil { result := &Result{ ID: tc.ID, Status: StatusError, Error: fmt.Sprintf("before script failed: %s", err.Error()), DurationMs: time.Since(startTime).Milliseconds(), } r.output.TestResult(result.Status, time.Since(startTime)) r.output.TestError(result.Error) return result } } // Create dynamic runner and execute dynamicRunner := NewDynamicRunner(r.opts) dynamicResult := dynamicRunner.RunDynamic(ast, tc, agentID) // Convert to standard result result := dynamicResult.ToResult() // Execute after script if specified (before outputting result) if tc.After != "" && (tc.Before == "" || beforeData != nil || result.Status != StatusError || !isBeforeError(result.Error)) { if err := r.hookExecutor.ExecuteAfter(tc.After, tc, result, beforeData, r.agentPath); err != nil { r.output.Warning("after script failed: %s", err.Error()) } } // Output result duration := time.Duration(result.DurationMs) * time.Millisecond r.output.DynamicTestResult(result.Status, dynamicResult.TotalTurns, len(tc.Checkpoints), duration) if result.Error != "" { r.output.TestError(result.Error) } return result } // isBeforeError checks if the error message indicates a before script failure func isBeforeError(errMsg string) bool { return len(errMsg) > 0 && errMsg[:min(len(errMsg), 20)] == "before script failed" } // min returns the minimum of two integers func min(a, b int) int { if a < b { return a } return b } // runStabilityTests runs each test case multiple times for stability analysis func (r *Executor) runStabilityTests(ast *assistant.Assistant, testCases []*Case, agentID string) []*StabilityResult { results := make([]*StabilityResult, 0, len(testCases)) for _, tc := range testCases { sr := &StabilityResult{ ID: tc.ID, Input: tc.Input, Expected: tc.Expected, RunDetails: make([]*RunDetail, 0, r.opts.Runs), } // Run multiple times for run := 1; run <= r.opts.Runs; run++ { result := r.runSingleTest(ast, tc, agentID, run) rd := &RunDetail{ Run: run, Status: result.Status, DurationMs: result.DurationMs, Output: result.Output, Error: result.Error, } sr.RunDetails = append(sr.RunDetails, rd) } // Calculate stability metrics sr.CalculateStability() // Print stability result r.output.StabilityResult(sr) results = append(results, sr) // Check fail-fast if r.opts.FailFast && !sr.Stable { r.output.Warning("Stopping due to --fail-fast (unstable test: %s)", tc.ID) break } } return results } // calculateSingleSummary calculates summary for single run mode func (r *Executor) calculateSingleSummary(report *Report) { for _, result := range report.Results { switch result.Status { case StatusPassed: report.Summary.Passed++ case StatusFailed: report.Summary.Failed++ case StatusError: report.Summary.Errors++ case StatusTimeout: report.Summary.Timeouts++ } } } // calculateStabilitySummary calculates summary for stability mode func (r *Executor) calculateStabilitySummary(report *Report) { report.Summary.TotalRuns = len(report.StabilityResults) * r.opts.Runs var totalPassRate float64 for _, sr := range report.StabilityResults { if sr.Stable { report.Summary.StableCases++ report.Summary.Passed++ } else { report.Summary.UnstableCases++ report.Summary.Failed++ } totalPassRate += sr.PassRate } if len(report.StabilityResults) > 0 { report.Summary.OverallPassRate = totalPassRate / float64(len(report.StabilityResults)) } } // writeOutput writes the test report to the output file func (r *Executor) writeOutput(report *Report) error { file, err := os.Create(r.opts.OutputFile) if err != nil { return fmt.Errorf("failed to create output file: %w", err) } defer file.Close() // Get reporter based on -r flag or file extension reporter := GetReporterWithAgent(r.opts.ReporterID, r.opts.OutputFile, r.opts.Verbose) // If using agent reporter, set context if agentReporter, ok := reporter.(*AgentReporter); ok { // Create a context for the reporter agent call ctx := NewTestContext("reporter", r.opts.ReporterID, report.Environment) defer ctx.Release() agentReporter.SetContext(ctx) } // Write report using the reporter return reporter.Write(report, file) } // buildContextOptions builds context.Options from test case and runner options // Priority: test case options > runner options > defaults func buildContextOptions(tc *Case, runnerOpts *Options) *context.Options { opts := &context.Options{ Skip: &context.Skip{ History: true, // Default: skip history loading - input already contains full conversation }, } // Apply test case options if specified if tc.Options != nil { // Connector: test case > runner if tc.Options.Connector != "" { opts.Connector = tc.Options.Connector } // Mode if tc.Options.Mode != "" { opts.Mode = tc.Options.Mode } // DisableGlobalPrompts if tc.Options.DisableGlobalPrompts { opts.DisableGlobalPrompts = true } // Search (pointer to distinguish unset from false) if tc.Options.Search != nil { opts.Search = tc.Options.Search } // Metadata for hooks if tc.Options.Metadata != nil { opts.Metadata = tc.Options.Metadata } // Skip options from test case if tc.Options.Skip != nil { opts.Skip.Trace = tc.Options.Skip.Trace opts.Skip.Output = tc.Options.Skip.Output opts.Skip.Keyword = tc.Options.Skip.Keyword opts.Skip.Search = tc.Options.Skip.Search // Note: History defaults to true for tests } } // Runner connector override (highest priority) if runnerOpts != nil && runnerOpts.Connector != "" { opts.Connector = runnerOpts.Connector } return opts } // extractOutput extracts the output from the agent response // Priority: Next hook data > Completion content > Tool results message > nil func extractOutput(response *context.Response) interface{} { if response == nil { return nil } // Prefer Next hook data if available and non-empty // response.Next is already the Data value (not NextHookResponse struct) if response.Next != nil && !isEmptyValue(response.Next) { return response.Next } // Fall back to completion response if response.Completion != nil { // If content is non-empty, return it if response.Completion.Content != nil && !isEmptyValue(response.Completion.Content) { return response.Completion.Content } } // If no content but tools were executed, extract message from tool results // This handles the case where LLM calls tools but doesn't generate text if len(response.Tools) > 0 { return extractToolResultMessage(response.Tools) } return nil } // extractToolResultMessage extracts the message field from tool results // Returns the first non-empty message found, or a summary of tool calls func extractToolResultMessage(tools []context.ToolCallResponse) interface{} { if len(tools) == 0 { return nil } // Try to extract "message" field from tool results first for _, tool := range tools { if tool.Result != nil { // Try to get message from result map if resultMap, ok := tool.Result.(map[string]interface{}); ok { if msg, exists := resultMap["message"]; exists && msg != nil { if msgStr, ok := msg.(string); ok && msgStr != "" { return msgStr } } } } } // No message found, generate a summary of tool calls var summaries []string for _, tool := range tools { toolName := tool.Tool if toolName == "" { toolName = "unknown" } // Extract key info from result if possible if tool.Result != nil { if resultMap, ok := tool.Result.(map[string]interface{}); ok { // Try common result fields if action, ok := resultMap["action"].(string); ok { summaries = append(summaries, fmt.Sprintf("[%s: %s]", toolName, action)) continue } if success, ok := resultMap["success"].(bool); ok { status := "failed" if success { status = "success" } summaries = append(summaries, fmt.Sprintf("[%s: %s]", toolName, status)) continue } } } summaries = append(summaries, fmt.Sprintf("[%s]", toolName)) } if len(summaries) > 0 { return strings.Join(summaries, " ") } return nil } // isEmptyValue checks if a value is considered "empty" for output purposes func isEmptyValue(v interface{}) bool { if v == nil { return true } // Use reflection to check for typed nil (e.g., *NextHookResponse(nil)) rv := reflect.ValueOf(v) if rv.Kind() == reflect.Ptr && rv.IsNil() { return true } switch val := v.(type) { case string: return val == "" case map[string]interface{}: return len(val) == 0 case []interface{}: return len(val) == 0 case *context.NextHookResponse: // Check if NextHookResponse is effectively empty if val == nil { return true } return val.Data == nil && val.Delegate == nil } return false } // validateOutput validates the actual output against expected func validateOutput(actual, expected interface{}) bool { // Simple JSON comparison actualJSON, err1 := jsoniter.Marshal(actual) expectedJSON, err2 := jsoniter.Marshal(expected) if err1 != nil || err2 != nil { return false } return string(actualJSON) == string(expectedJSON) } // getInputOptions returns InputOptions based on the runner configuration // BaseDir is derived from the input file directory (for file mode) or current working directory func (r *Executor) getInputOptions() *InputOptions { opts := &InputOptions{} // For file mode, use the input file's directory as base if r.opts.InputMode == InputModeFile && r.opts.Input != "" { // Resolve path considering YAO_ROOT resolvedPath := ResolvePathWithYaoRoot(r.opts.Input) opts.BaseDir = filepath.Dir(resolvedPath) } // For message mode, BaseDir remains empty (uses current working directory) return opts } // outputDryRun outputs generated test cases without running them func (r *Executor) outputDryRun(testCases []*Case, agentInfo *AgentInfo) (*Report, error) { r.output.Info("Generated Test Cases:") // Output each test case as JSONL for _, tc := range testCases { data, err := jsoniter.Marshal(tc) if err != nil { r.output.Warning("Failed to marshal test case %s: %s", tc.ID, err.Error()) continue } fmt.Println(string(data)) } // Write to output file if specified if r.opts.OutputFile != "" { file, err := os.Create(r.opts.OutputFile) if err != nil { return nil, fmt.Errorf("failed to create output file: %w", err) } defer file.Close() for _, tc := range testCases { data, err := jsoniter.Marshal(tc) if err != nil { continue } file.WriteString(string(data) + "\n") } r.output.Info("Output written to: %s", r.opts.OutputFile) } // Return a minimal report connector := r.opts.Connector if connector == "" { connector = agentInfo.Connector } return &Report{ Summary: &Summary{ Total: len(testCases), AgentID: agentInfo.ID, Connector: connector, }, }, nil } ================================================ FILE: agent/test/runner_integration_test.go ================================================ package test_test import ( "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yaoapp/yao/agent" agenttest "github.com/yaoapp/yao/agent/test" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) // TestRunner_AgentDrivenInput tests the complete flow: // 1. Use generator-agent to generate test cases // 2. Run the generated tests against simple-greeting agent func TestRunner_AgentDrivenInput(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agents err := agent.Load(config.Conf) require.NoError(t, err, "Failed to load agents") // Test with agent-driven input opts := &agenttest.Options{ Input: "agents:tests.generator-agent?count=3", AgentID: "tests.simple-greeting", Verbose: true, InputMode: agenttest.InputModeFile, // Will be overridden by ParseInputSource } opts = agenttest.MergeOptions(opts, agenttest.DefaultOptions()) runner := agenttest.NewRunner(opts) report, err := runner.Run() require.NoError(t, err, "Runner should not return error") require.NotNil(t, report, "Report should not be nil") require.NotNil(t, report.Summary, "Summary should not be nil") // Verify report assert.Greater(t, report.Summary.Total, 0, "Should have at least one test case") t.Logf("Total: %d, Passed: %d, Failed: %d", report.Summary.Total, report.Summary.Passed, report.Summary.Failed) } // TestRunner_AgentDrivenInput_DryRun tests dry-run mode func TestRunner_AgentDrivenInput_DryRun(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agents err := agent.Load(config.Conf) require.NoError(t, err, "Failed to load agents") // Test with dry-run mode opts := &agenttest.Options{ Input: "agents:tests.generator-agent?count=2", AgentID: "tests.simple-greeting", DryRun: true, Verbose: true, } opts = agenttest.MergeOptions(opts, agenttest.DefaultOptions()) runner := agenttest.NewRunner(opts) report, err := runner.Run() require.NoError(t, err, "Dry-run should not return error") require.NotNil(t, report, "Report should not be nil") // In dry-run mode, tests are generated but not executed // So Passed and Failed should both be 0, but Total should have the count assert.Greater(t, report.Summary.Total, 0, "Should have generated test cases") t.Logf("Generated %d test cases in dry-run mode", report.Summary.Total) } // TestRunner_FileInput tests loading test cases from JSONL file func TestRunner_FileInput(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agents err := agent.Load(config.Conf) require.NoError(t, err, "Failed to load agents") // Create a temporary JSONL file with test cases // Use case-insensitive contains for robustness tmpDir := t.TempDir() inputFile := filepath.Join(tmpDir, "inputs.jsonl") testCases := `{"id": "greeting-hello", "input": "Hello", "assert": {"type": "regex", "value": "(?i)hello"}} {"id": "greeting-hi", "input": "Hi there", "assert": {"type": "regex", "value": "(?i)(hi|hello)"}} {"id": "greeting-morning", "input": "Good morning", "assert": {"type": "regex", "value": "(?i)(hello|morning|good)"}}` err = os.WriteFile(inputFile, []byte(testCases), 0644) require.NoError(t, err, "Failed to write test file") // Run tests from file opts := &agenttest.Options{ Input: inputFile, AgentID: "tests.simple-greeting", Verbose: true, InputMode: agenttest.InputModeFile, } opts = agenttest.MergeOptions(opts, agenttest.DefaultOptions()) runner := agenttest.NewRunner(opts) report, err := runner.Run() require.NoError(t, err, "Runner should not return error") require.NotNil(t, report, "Report should not be nil") // Verify report assert.Equal(t, 3, report.Summary.Total, "Should have 3 test cases") t.Logf("Total: %d, Passed: %d, Failed: %d", report.Summary.Total, report.Summary.Passed, report.Summary.Failed) // Check results for debugging if report.Results != nil { for _, r := range report.Results { t.Logf(" [%s] Status: %s, Output: %v", r.ID, r.Status, r.Output) } } } // TestRunner_DirectMessage tests direct message mode func TestRunner_DirectMessage(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agents err := agent.Load(config.Conf) require.NoError(t, err, "Failed to load agents") // Test with direct message opts := &agenttest.Options{ Input: "Hello, how are you?", AgentID: "tests.simple-greeting", Verbose: true, InputMode: agenttest.InputModeMessage, } opts = agenttest.MergeOptions(opts, agenttest.DefaultOptions()) runner := agenttest.NewRunner(opts) report, err := runner.Run() require.NoError(t, err, "Runner should not return error") require.NotNil(t, report, "Report should not be nil") // Direct message mode returns a minimal report assert.Equal(t, 1, report.Summary.Total, "Should have 1 test case") assert.Equal(t, 1, report.Summary.Passed, "Direct message should pass") } // TestRunner_WithBeforeAfter tests before/after hooks func TestRunner_WithBeforeAfter(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agents err := agent.Load(config.Conf) require.NoError(t, err, "Failed to load agents") // Create a temporary JSONL file with test cases that use hooks tmpDir := t.TempDir() inputFile := filepath.Join(tmpDir, "inputs.jsonl") // Note: hooks-test agent has env_test.ts with Before/After functions testCases := `{"id": "hook-test-1", "input": "Hello", "assert": {"type": "contains", "value": "hello"}, "before": "env_test.Before", "after": "env_test.After"}` err = os.WriteFile(inputFile, []byte(testCases), 0644) require.NoError(t, err, "Failed to write test file") // Run tests with hooks (using hooks-test agent which has the hook scripts) opts := &agenttest.Options{ Input: inputFile, AgentID: "tests.hooks-test", Verbose: true, InputMode: agenttest.InputModeFile, } opts = agenttest.MergeOptions(opts, agenttest.DefaultOptions()) runner := agenttest.NewRunner(opts) report, err := runner.Run() require.NoError(t, err, "Runner should not return error") require.NotNil(t, report, "Report should not be nil") t.Logf("Total: %d, Passed: %d, Failed: %d", report.Summary.Total, report.Summary.Passed, report.Summary.Failed) } // TestRunner_Parallel tests parallel execution func TestRunner_Parallel(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agents err := agent.Load(config.Conf) require.NoError(t, err, "Failed to load agents") // Create a temporary JSONL file with multiple test cases // Use regex for case-insensitive matching tmpDir := t.TempDir() inputFile := filepath.Join(tmpDir, "inputs.jsonl") testCases := `{"id": "parallel-1", "input": "Hello", "assert": {"type": "regex", "value": "(?i)(hello|hi)"}} {"id": "parallel-2", "input": "Hi", "assert": {"type": "regex", "value": "(?i)(hello|hi)"}} {"id": "parallel-3", "input": "Hey", "assert": {"type": "regex", "value": "(?i)(hello|hi|hey)"}} {"id": "parallel-4", "input": "Good day", "assert": {"type": "regex", "value": "(?i)(hello|good|day)"}}` err = os.WriteFile(inputFile, []byte(testCases), 0644) require.NoError(t, err, "Failed to write test file") // Run tests in parallel opts := &agenttest.Options{ Input: inputFile, AgentID: "tests.simple-greeting", Parallel: 2, // Run 2 tests in parallel Verbose: true, InputMode: agenttest.InputModeFile, } opts = agenttest.MergeOptions(opts, agenttest.DefaultOptions()) runner := agenttest.NewRunner(opts) report, err := runner.Run() require.NoError(t, err, "Runner should not return error") require.NotNil(t, report, "Report should not be nil") assert.Equal(t, 4, report.Summary.Total, "Should have 4 test cases") t.Logf("Total: %d, Passed: %d, Failed: %d (parallel: 2)", report.Summary.Total, report.Summary.Passed, report.Summary.Failed) } // TestRunner_FailFast tests fail-fast behavior func TestRunner_FailFast(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Load agents err := agent.Load(config.Conf) require.NoError(t, err, "Failed to load agents") // Create a temporary JSONL file with a failing test first tmpDir := t.TempDir() inputFile := filepath.Join(tmpDir, "inputs.jsonl") // First test will fail (expects "impossible" which won't be in response) testCases := `{"id": "fail-first", "input": "Hello", "assert": {"type": "contains", "value": "IMPOSSIBLE_STRING_12345"}} {"id": "should-skip", "input": "Hi", "assert": {"type": "contains", "value": "hi"}}` err = os.WriteFile(inputFile, []byte(testCases), 0644) require.NoError(t, err, "Failed to write test file") // Run tests with fail-fast opts := &agenttest.Options{ Input: inputFile, AgentID: "tests.simple-greeting", FailFast: true, Verbose: true, InputMode: agenttest.InputModeFile, } opts = agenttest.MergeOptions(opts, agenttest.DefaultOptions()) runner := agenttest.NewRunner(opts) report, err := runner.Run() require.NoError(t, err, "Runner should not return error (fail-fast is not an error)") require.NotNil(t, report, "Report should not be nil") // With fail-fast, only the first test should run assert.Equal(t, 1, report.Summary.Failed, "First test should fail") // The second test might not run due to fail-fast t.Logf("Total: %d, Passed: %d, Failed: %d (fail-fast enabled)", report.Summary.Total, report.Summary.Passed, report.Summary.Failed) } ================================================ FILE: agent/test/script.go ================================================ package test import ( "fmt" "path/filepath" "regexp" "strings" "time" "github.com/yaoapp/gou/application" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/gou/runtime/v8/bridge" "github.com/yaoapp/yao/agent/context" "rogchap.com/v8go" ) // ScriptRunner executes script tests type ScriptRunner struct { opts *Options output *OutputWriter } // NewScriptRunner creates a new script test runner func NewScriptRunner(opts *Options) *ScriptRunner { return &ScriptRunner{ opts: opts, output: NewOutputWriter(opts.Verbose), } } // ResolveScript resolves the script path from scripts.xxx.yyy or scripts.xxx.yyy.zzz format // // Resolution strategy: // 1. Find the assistant directory by detecting package.yao from longest to shortest path // 2. Remaining parts after the assistant boundary map to src/ subdirectories + module name // // Examples: // - scripts.expense.setup -> assistants/expense/src/setup_test.ts // - scripts.yao.keeper.seed -> assistants/yao/keeper/src/seed_test.ts // - scripts.yao.keeper.tests.seed -> assistants/yao/keeper/src/tests/seed_test.ts func ResolveScript(input string) (*ScriptInfo, error) { // Remove "scripts." prefix path := strings.TrimPrefix(input, "scripts.") // Split into parts: // "expense.setup" -> ["expense", "setup"] // "yao.keeper.tests.seed" -> ["yao", "keeper", "tests", "seed"] parts := strings.Split(path, ".") if len(parts) < 2 { return nil, fmt.Errorf("invalid script path: %s (expected format: scripts.assistant.module or scripts.assistant.sub.module)", input) } // Strategy: detect assistant boundary by looking for package.yao // Try from the longest possible assistant path down to the shortest var assistantDir, modulePath string assistantFound := false for i := len(parts) - 1; i >= 1; i-- { candidateDir := strings.Join(parts[:i], "/") for _, prefix := range []string{"assistants/", ""} { packagePath := filepath.Join(prefix+candidateDir, "package.yao") exists, err := application.App.Exists(packagePath) if err == nil && exists { assistantDir = candidateDir // Remaining parts form the module path (may include subdirectories) // e.g., parts[i:] = ["tests", "seed"] -> "tests/seed" modulePath = strings.Join(parts[i:], "/") assistantFound = true break } } if assistantFound { break } } // Fallback: original behavior — last part is module, rest is assistant dir if !assistantFound { assistantDir = strings.Join(parts[:len(parts)-1], "/") modulePath = parts[len(parts)-1] } // modulePath may contain subdirectories: "tests/seed" -> dir="tests", module="seed" moduleName := filepath.Base(modulePath) moduleSubDir := filepath.Dir(modulePath) if moduleSubDir == "." { moduleSubDir = "" } // Build candidate base paths basePaths := []string{ filepath.Join("assistants", assistantDir, "src", moduleSubDir), filepath.Join(assistantDir, "src", moduleSubDir), } var scriptPath, testPath string for _, basePath := range basePaths { // Check for TypeScript files first, then JavaScript for _, ext := range []string{".ts", ".js"} { candidateScript := filepath.Join(basePath, moduleName+ext) candidateTest := filepath.Join(basePath, moduleName+"_test"+ext) // Check if test file exists exists, err := application.App.Exists(candidateTest) if err == nil && exists { scriptPath = candidateScript testPath = candidateTest break } } if testPath != "" { break } } if testPath == "" { return nil, fmt.Errorf("test file not found for %s (tried: %s)", input, strings.Join(basePaths, ", ")) } return &ScriptInfo{ ID: input, Assistant: assistantDir, Module: moduleName, ScriptPath: scriptPath, TestPath: testPath, }, nil } // DiscoverTests finds all Test* functions in the script func DiscoverTests(scriptPath string) ([]*ScriptTestCase, error) { // Read the script file content, err := application.App.Read(scriptPath) if err != nil { return nil, fmt.Errorf("failed to read script: %w", err) } // Parse the script to find Test* functions // We use a simple regex-like approach to find function declarations tests := make([]*ScriptTestCase, 0) lines := strings.Split(string(content), "\n") for _, line := range lines { line = strings.TrimSpace(line) // Match function declarations: function TestXxx( or export function TestXxx( if strings.Contains(line, "function Test") { // Extract function name name := extractFunctionName(line) if name != "" && strings.HasPrefix(name, "Test") { tests = append(tests, &ScriptTestCase{ Name: name, Function: name, }) } } } return tests, nil } // extractFunctionName extracts the function name from a line func extractFunctionName(line string) string { // Remove "export" prefix if present line = strings.TrimPrefix(line, "export ") line = strings.TrimSpace(line) // Match "function Name(" if !strings.HasPrefix(line, "function ") { return "" } line = strings.TrimPrefix(line, "function ") // Find the opening parenthesis idx := strings.Index(line, "(") if idx == -1 { return "" } return strings.TrimSpace(line[:idx]) } // filterTests filters test cases by a regex pattern (similar to go test -run) func (r *ScriptRunner) filterTests(tests []*ScriptTestCase, pattern string) ([]*ScriptTestCase, error) { re, err := regexp.Compile(pattern) if err != nil { return nil, err } filtered := make([]*ScriptTestCase, 0) for _, tc := range tests { if re.MatchString(tc.Name) { filtered = append(filtered, tc) } } return filtered, nil } // Run executes all script tests and returns a report func (r *ScriptRunner) Run() (*ScriptTestReport, error) { startTime := time.Now() // Resolve script scriptInfo, err := ResolveScript(r.opts.Input) if err != nil { return nil, err } // Print header r.output.Header("Script Test") r.output.Info("Script: %s", scriptInfo.TestPath) // Discover tests tests, err := DiscoverTests(scriptInfo.TestPath) if err != nil { return nil, err } // Filter tests by -run pattern if specified if r.opts.Run != "" { tests, err = r.filterTests(tests, r.opts.Run) if err != nil { return nil, fmt.Errorf("invalid -run pattern: %w", err) } r.output.Info("Tests: %d functions (filtered by: %s)", len(tests), r.opts.Run) } else { r.output.Info("Tests: %d functions", len(tests)) } if len(tests) == 0 { r.output.Warning("No tests to run") } // Load context config if specified var ctxConfig *ContextConfig if r.opts.ContextFile != "" { var err error ctxConfig, err = LoadContextConfig(r.opts.ContextFile) if err != nil { return nil, fmt.Errorf("failed to load context file: %w", err) } r.output.Info("Context: %s", r.opts.ContextFile) } // Create environment with optional context config var env *Environment if ctxConfig != nil { env = NewEnvironmentWithContext(r.opts.UserID, r.opts.TeamID, ctxConfig) } else { env = NewEnvironment(r.opts.UserID, r.opts.TeamID) } r.output.Info("User: %s", env.UserID) r.output.Info("Team: %s", env.TeamID) // Load all scripts from src directory (including the test file) // This ensures imports can be resolved properly srcDir := filepath.Dir(scriptInfo.TestPath) loadedCount, err := r.loadAllScripts(srcDir) if err != nil { return nil, fmt.Errorf("failed to load scripts: %w", err) } r.output.Info("Loaded: %d scripts", loadedCount) // Create report report := &ScriptTestReport{ Type: "script_test", Script: scriptInfo.ID, ScriptPath: scriptInfo.TestPath, Summary: &ScriptTestSummary{Total: len(tests)}, Environment: env, Results: make([]*ScriptTestResult, 0, len(tests)), Metadata: &ScriptTestMetadata{ StartedAt: startTime, }, } // Run tests r.output.SubHeader("Running Tests") for _, tc := range tests { result := r.runScriptTest(tc, scriptInfo, env) report.Results = append(report.Results, result) // Update summary switch result.Status { case StatusPassed: report.Summary.Passed++ case StatusFailed, StatusError: // Both Failed and Error count as failures report.Summary.Failed++ case StatusSkipped: report.Summary.Skipped++ } // Check fail-fast (stop on both Failed and Error) if r.opts.FailFast && (result.Status == StatusFailed || result.Status == StatusError) { r.output.Warning("Stopping due to --fail-fast") break } } // Complete report report.Summary.DurationMs = time.Since(startTime).Milliseconds() report.Metadata.CompletedAt = time.Now() // Print summary r.output.ScriptTestSummary(report.Summary, time.Since(startTime)) return report, nil } // runScriptTest runs a single script test function func (r *ScriptRunner) runScriptTest(tc *ScriptTestCase, scriptInfo *ScriptInfo, env *Environment) *ScriptTestResult { r.output.TestStart(tc.Name, "", 1) startTime := time.Now() result := &ScriptTestResult{ Name: tc.Name, Status: StatusPassed, } // Create testing.T object testingT := NewTestingT(tc.Name) // Create agent context chatID := fmt.Sprintf("script-test-%s", tc.Name) agentCtx := NewTestContext(chatID, scriptInfo.Assistant, env) defer agentCtx.Release() // Execute the test function err := r.executeTestFunction(tc, scriptInfo, testingT, agentCtx) duration := time.Since(startTime) result.DurationMs = duration.Milliseconds() result.Logs = testingT.Logs() if err != nil { result.Status = StatusError result.Error = err.Error() r.output.TestResult(result.Status, duration) r.output.TestError(result.Error) return result } if testingT.Skipped() { result.Status = StatusSkipped r.output.TestResult(result.Status, duration) return result } if testingT.Failed() { result.Status = StatusFailed errors := testingT.Errors() if len(errors) > 0 { result.Error = errors[0] } result.Assertion = testingT.AssertionInfo() r.output.TestResult(result.Status, duration) r.output.TestError(result.Error) return result } r.output.TestResult(result.Status, duration) return result } // loadAllScripts loads all scripts from the src directory // This ensures that imports can be resolved properly func (r *ScriptRunner) loadAllScripts(srcDir string) (int, error) { count := 0 // Check if src directory exists exists, err := application.App.Exists(srcDir) if err != nil { return 0, err } if !exists { return 0, fmt.Errorf("src directory not found: %s", srcDir) } // Walk through src directory to find all script files exts := []string{"*.ts", "*.js"} err = application.App.Walk(srcDir, func(root, file string, isdir bool) error { if isdir { return nil } // Get relative path relPath := strings.TrimPrefix(file, root+"/") // Generate script ID from file path scriptID := generateTestScriptID(file, root) // Load the script _, err := v8.Load(file, scriptID) if err != nil { // Log warning but continue loading other scripts if r.opts.Verbose { r.output.Warning("Failed to load %s: %v", relPath, err) } return nil } count++ if r.opts.Verbose { r.output.Verbose("Loaded: %s", relPath) } return nil }, exts...) if err != nil { return count, fmt.Errorf("failed to walk src directory: %w", err) } return count, nil } // generateTestScriptID generates a script ID from file path for testing func generateTestScriptID(filePath string, srcDir string) string { // Normalize path separators filePath = filepath.ToSlash(filePath) srcDir = filepath.ToSlash(srcDir) // Remove src directory prefix relPath := strings.TrimPrefix(filePath, srcDir+"/") relPath = strings.TrimPrefix(relPath, "/") // Remove file extension relPath = strings.TrimSuffix(relPath, filepath.Ext(relPath)) // Replace path separators with dots and add test prefix scriptID := "test." + strings.ReplaceAll(relPath, "/", ".") return scriptID } // executeTestFunction executes a single test function using V8 func (r *ScriptRunner) executeTestFunction(tc *ScriptTestCase, scriptInfo *ScriptInfo, testingT *TestingT, agentCtx *context.Context) (execErr error) { // Recover from panics thrown by Process calls // Even if JavaScript try-catch catches the error, we want to fail the test defer func() { if r := recover(); r != nil { execErr = fmt.Errorf("panic in test function: %v", r) } }() // Get the test script (already loaded by loadAllScripts) testScriptID := generateTestScriptID(scriptInfo.TestPath, filepath.Dir(scriptInfo.TestPath)) script, ok := v8.Scripts[testScriptID] if !ok { return fmt.Errorf("test script not found: %s (id: %s)", scriptInfo.TestPath, testScriptID) } // Create a new script context scriptCtx, err := script.NewContext("", nil) if err != nil { return fmt.Errorf("failed to create script context: %w", err) } defer scriptCtx.Close() // Get the V8 context v8ctx := scriptCtx.Context // Set share data with authorized info for Process calls // This is needed because we call fn.Call directly instead of scriptCtx.Call var authorized map[string]interface{} if agentCtx.Authorized != nil { authorized = agentCtx.Authorized.AuthorizedToMap() } err = bridge.SetShareData(v8ctx, v8ctx.Global(), &bridge.Share{ Sid: "", Root: false, Global: nil, Authorized: authorized, }) if err != nil { return fmt.Errorf("failed to set share data: %w", err) } // Create testing.T JavaScript object testingTObj, err := NewTestingTObject(v8ctx, testingT) if err != nil { return fmt.Errorf("failed to create testing.T object: %w", err) } // Create agent context JavaScript object agentCtxObj, err := agentCtx.JsValue(v8ctx) if err != nil { return fmt.Errorf("failed to create agent context object: %w", err) } // Get the test function global := v8ctx.Global() fnValue, err := global.Get(tc.Function) if err != nil { return fmt.Errorf("failed to get test function %s: %w", tc.Function, err) } if !fnValue.IsFunction() { return fmt.Errorf("test function %s is not a function", tc.Function) } fn, err := fnValue.AsFunction() if err != nil { return fmt.Errorf("failed to convert to function: %w", err) } // Call the test function with (t, ctx) result, err := fn.Call(global, testingTObj, agentCtxObj) if err != nil { // Check if this is an assertion failure or a real error if testingT.Failed() { // Assertion failure - already recorded return nil } return fmt.Errorf("test function error: %w", err) } // Check if the result is a JavaScript Error (thrown by bridge.JsException) if result != nil && result.IsNativeError() { // Get error message from Error object if result.IsObject() { obj, err := result.AsObject() if err == nil { if msgVal, err := obj.Get("message"); err == nil && !msgVal.IsUndefined() { return fmt.Errorf("test threw exception: %s", msgVal.String()) } } } return fmt.Errorf("test threw exception: %s", result.String()) } return nil } // RegisterTestingGlobals registers testing-related global functions for V8 // This is called once during initialization func RegisterTestingGlobals() { v8.RegisterFunction("__testing_log", testingLogEmbed) } // testingLogEmbed provides a console.log-like function for tests func testingLogEmbed(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() parts := make([]string, len(args)) for i, arg := range args { goVal, err := bridge.GoValue(arg, info.Context()) if err != nil { parts[i] = arg.String() } else { parts[i] = fmt.Sprintf("%v", goVal) } } fmt.Println(strings.Join(parts, " ")) return v8go.Undefined(iso) }) } ================================================ FILE: agent/test/script_assert.go ================================================ package test import ( "encoding/json" "fmt" "reflect" "regexp" "strings" "sync" "github.com/yaoapp/gou/runtime/v8/bridge" "rogchap.com/v8go" ) // TestingT represents the testing object passed to test functions // It provides assertion methods and test control flow type TestingT struct { mu sync.Mutex name string failed bool skipped bool logs []string errors []string // Assertion failure details (for the first failure) assertionInfo *ScriptAssertionInfo } // NewTestingT creates a new TestingT instance func NewTestingT(name string) *TestingT { return &TestingT{ name: name, logs: make([]string, 0), errors: make([]string, 0), } } // Name returns the test name func (t *TestingT) Name() string { return t.name } // Failed returns whether the test has failed func (t *TestingT) Failed() bool { t.mu.Lock() defer t.mu.Unlock() return t.failed } // Skipped returns whether the test was skipped func (t *TestingT) Skipped() bool { t.mu.Lock() defer t.mu.Unlock() return t.skipped } // Logs returns all log messages func (t *TestingT) Logs() []string { t.mu.Lock() defer t.mu.Unlock() return append([]string{}, t.logs...) } // Errors returns all error messages func (t *TestingT) Errors() []string { t.mu.Lock() defer t.mu.Unlock() return append([]string{}, t.errors...) } // AssertionInfo returns the first assertion failure info func (t *TestingT) AssertionInfo() *ScriptAssertionInfo { t.mu.Lock() defer t.mu.Unlock() return t.assertionInfo } // log adds a log message func (t *TestingT) log(msg string) { t.mu.Lock() defer t.mu.Unlock() t.logs = append(t.logs, msg) } // fail marks the test as failed with an error message func (t *TestingT) fail(msg string, info *ScriptAssertionInfo) { t.mu.Lock() defer t.mu.Unlock() t.failed = true t.errors = append(t.errors, msg) if t.assertionInfo == nil && info != nil { t.assertionInfo = info } } // skip marks the test as skipped func (t *TestingT) skip(reason string) { t.mu.Lock() defer t.mu.Unlock() t.skipped = true if reason != "" { t.logs = append(t.logs, "SKIP: "+reason) } } // NewTestingTObject creates a JavaScript testing.T object for V8 func NewTestingTObject(v8ctx *v8go.Context, t *TestingT) (*v8go.Value, error) { iso := v8ctx.Isolate() // Create the main testing object testObj := v8go.NewObjectTemplate(iso) // Set name property testObj.Set("name", t.name) // Create assert object assertObj, err := newAssertObject(v8ctx, t) if err != nil { return nil, err } // Set methods testObj.Set("log", t.logMethod(iso)) testObj.Set("error", t.errorMethod(iso)) testObj.Set("skip", t.skipMethod(iso)) testObj.Set("fail", t.failMethod(iso)) testObj.Set("fatal", t.fatalMethod(iso)) // Create instance instance, err := testObj.NewInstance(v8ctx) if err != nil { return nil, err } obj, err := instance.Value.AsObject() if err != nil { return nil, err } // Set assert object obj.Set("assert", assertObj) // Set failed getter (dynamic) obj.Set("failed", t.failed) return instance.Value, nil } // logMethod implements t.log(...args) func (t *TestingT) logMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() parts := make([]string, len(args)) for i, arg := range args { parts[i] = arg.String() } t.log(strings.Join(parts, " ")) return v8go.Undefined(iso) }) } // errorMethod implements t.error(...args) func (t *TestingT) errorMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() parts := make([]string, len(args)) for i, arg := range args { parts[i] = arg.String() } msg := strings.Join(parts, " ") t.fail(msg, nil) return v8go.Undefined(iso) }) } // skipMethod implements t.skip(reason?) func (t *TestingT) skipMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { reason := "" if len(info.Args()) > 0 { reason = info.Args()[0].String() } t.skip(reason) return v8go.Undefined(iso) }) } // failMethod implements t.fail(reason?) func (t *TestingT) failMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { reason := "test failed" if len(info.Args()) > 0 { reason = info.Args()[0].String() } t.fail(reason, nil) return v8go.Undefined(iso) }) } // fatalMethod implements t.fatal(reason?) // Same as fail but intended to stop execution (in JS, this is handled by throwing) func (t *TestingT) fatalMethod(iso *v8go.Isolate) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() reason := "fatal error" if len(info.Args()) > 0 { reason = info.Args()[0].String() } t.fail(reason, nil) // Return exception to stop execution return bridge.JsException(v8ctx, reason) }) } // newAssertObject creates the assert object with all assertion methods func newAssertObject(v8ctx *v8go.Context, t *TestingT) (*v8go.Value, error) { iso := v8ctx.Isolate() assertObj := v8go.NewObjectTemplate(iso) // Boolean assertions assertObj.Set("True", assertTrueMethod(iso, t)) assertObj.Set("False", assertFalseMethod(iso, t)) // Equality assertions assertObj.Set("Equal", assertEqualMethod(iso, t)) assertObj.Set("NotEqual", assertNotEqualMethod(iso, t)) // Nil assertions assertObj.Set("Nil", assertNilMethod(iso, t)) assertObj.Set("NotNil", assertNotNilMethod(iso, t)) // String assertions assertObj.Set("Contains", assertContainsMethod(iso, t)) assertObj.Set("NotContains", assertNotContainsMethod(iso, t)) // Length assertion assertObj.Set("Len", assertLenMethod(iso, t)) // Comparison assertions assertObj.Set("Greater", assertGreaterMethod(iso, t)) assertObj.Set("GreaterOrEqual", assertGreaterOrEqualMethod(iso, t)) assertObj.Set("Less", assertLessMethod(iso, t)) assertObj.Set("LessOrEqual", assertLessOrEqualMethod(iso, t)) // Error assertions assertObj.Set("Error", assertErrorMethod(iso, t)) assertObj.Set("NoError", assertNoErrorMethod(iso, t)) // Panic assertions assertObj.Set("Panic", assertPanicMethod(iso, t)) assertObj.Set("NoPanic", assertNoPanicMethod(iso, t)) // Regex assertions assertObj.Set("Match", assertMatchMethod(iso, t)) assertObj.Set("NotMatch", assertNotMatchMethod(iso, t)) // Type assertion assertObj.Set("Type", assertTypeMethod(iso, t)) // JSON path assertion assertObj.Set("JSONPath", assertJSONPathMethod(iso, t)) // Agent-driven assertion assertObj.Set("Agent", assertAgentMethod(iso, t)) // Create instance instance, err := assertObj.NewInstance(v8ctx) if err != nil { return nil, err } return instance.Value, nil } // Helper function to get optional message argument func getMessage(args []*v8go.Value, startIdx int) string { if len(args) > startIdx && args[startIdx].IsString() { return args[startIdx].String() } return "" } // assertTrueMethod implements assert.True(value, message?) func assertTrueMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() if len(args) < 1 { t.fail("True requires a value argument", &ScriptAssertionInfo{Type: "True"}) return v8go.Undefined(iso) } value := args[0].Boolean() message := getMessage(args, 1) if !value { msg := "expected true, got false" if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "True", Expected: true, Actual: false, Message: message, }) } return v8go.Undefined(iso) }) } // assertFalseMethod implements assert.False(value, message?) func assertFalseMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() if len(args) < 1 { t.fail("False requires a value argument", &ScriptAssertionInfo{Type: "False"}) return v8go.Undefined(iso) } value := args[0].Boolean() message := getMessage(args, 1) if value { msg := "expected false, got true" if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "False", Expected: false, Actual: true, Message: message, }) } return v8go.Undefined(iso) }) } // assertEqualMethod implements assert.Equal(actual, expected, message?) func assertEqualMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 2 { t.fail("Equal requires actual and expected arguments", &ScriptAssertionInfo{Type: "Equal"}) return v8go.Undefined(iso) } actual, _ := bridge.GoValue(args[0], v8ctx) expected, _ := bridge.GoValue(args[1], v8ctx) message := getMessage(args, 2) if !deepEqual(actual, expected) { msg := fmt.Sprintf("expected %v, got %v", expected, actual) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "Equal", Expected: expected, Actual: actual, Message: message, }) } return v8go.Undefined(iso) }) } // assertNotEqualMethod implements assert.NotEqual(actual, expected, message?) func assertNotEqualMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 2 { t.fail("NotEqual requires actual and expected arguments", &ScriptAssertionInfo{Type: "NotEqual"}) return v8go.Undefined(iso) } actual, _ := bridge.GoValue(args[0], v8ctx) expected, _ := bridge.GoValue(args[1], v8ctx) message := getMessage(args, 2) if deepEqual(actual, expected) { msg := fmt.Sprintf("expected values to be different, both are %v", actual) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "NotEqual", Expected: expected, Actual: actual, Message: message, }) } return v8go.Undefined(iso) }) } // assertNilMethod implements assert.Nil(value, message?) func assertNilMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() if len(args) < 1 { t.fail("Nil requires a value argument", &ScriptAssertionInfo{Type: "Nil"}) return v8go.Undefined(iso) } isNil := args[0].IsNull() || args[0].IsUndefined() message := getMessage(args, 1) if !isNil { msg := fmt.Sprintf("expected nil, got %v", args[0].String()) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "Nil", Expected: nil, Actual: args[0].String(), Message: message, }) } return v8go.Undefined(iso) }) } // assertNotNilMethod implements assert.NotNil(value, message?) func assertNotNilMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() if len(args) < 1 { t.fail("NotNil requires a value argument", &ScriptAssertionInfo{Type: "NotNil"}) return v8go.Undefined(iso) } isNil := args[0].IsNull() || args[0].IsUndefined() message := getMessage(args, 1) if isNil { msg := "expected non-nil value, got nil" if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "NotNil", Actual: nil, Message: message, }) } return v8go.Undefined(iso) }) } // assertContainsMethod implements assert.Contains(str, substr, message?) func assertContainsMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() if len(args) < 2 { t.fail("Contains requires str and substr arguments", &ScriptAssertionInfo{Type: "Contains"}) return v8go.Undefined(iso) } str := args[0].String() substr := args[1].String() message := getMessage(args, 2) if !strings.Contains(str, substr) { msg := fmt.Sprintf("expected '%s' to contain '%s'", str, substr) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "Contains", Expected: substr, Actual: str, Message: message, }) } return v8go.Undefined(iso) }) } // assertNotContainsMethod implements assert.NotContains(str, substr, message?) func assertNotContainsMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() if len(args) < 2 { t.fail("NotContains requires str and substr arguments", &ScriptAssertionInfo{Type: "NotContains"}) return v8go.Undefined(iso) } str := args[0].String() substr := args[1].String() message := getMessage(args, 2) if strings.Contains(str, substr) { msg := fmt.Sprintf("expected '%s' to not contain '%s'", str, substr) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "NotContains", Expected: substr, Actual: str, Message: message, }) } return v8go.Undefined(iso) }) } // assertLenMethod implements assert.Len(value, length, message?) func assertLenMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 2 { t.fail("Len requires value and length arguments", &ScriptAssertionInfo{Type: "Len"}) return v8go.Undefined(iso) } value, _ := bridge.GoValue(args[0], v8ctx) expectedLen := int(args[1].Integer()) message := getMessage(args, 2) actualLen := getLength(value) if actualLen != expectedLen { msg := fmt.Sprintf("expected length %d, got %d", expectedLen, actualLen) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "Len", Expected: expectedLen, Actual: actualLen, Message: message, }) } return v8go.Undefined(iso) }) } // assertGreaterMethod implements assert.Greater(a, b, message?) func assertGreaterMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() if len(args) < 2 { t.fail("Greater requires two arguments", &ScriptAssertionInfo{Type: "Greater"}) return v8go.Undefined(iso) } a := args[0].Number() b := args[1].Number() message := getMessage(args, 2) if !(a > b) { msg := fmt.Sprintf("expected %v > %v", a, b) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "Greater", Expected: fmt.Sprintf("> %v", b), Actual: a, Message: message, }) } return v8go.Undefined(iso) }) } // assertGreaterOrEqualMethod implements assert.GreaterOrEqual(a, b, message?) func assertGreaterOrEqualMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() if len(args) < 2 { t.fail("GreaterOrEqual requires two arguments", &ScriptAssertionInfo{Type: "GreaterOrEqual"}) return v8go.Undefined(iso) } a := args[0].Number() b := args[1].Number() message := getMessage(args, 2) if !(a >= b) { msg := fmt.Sprintf("expected %v >= %v", a, b) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "GreaterOrEqual", Expected: fmt.Sprintf(">= %v", b), Actual: a, Message: message, }) } return v8go.Undefined(iso) }) } // assertLessMethod implements assert.Less(a, b, message?) func assertLessMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() if len(args) < 2 { t.fail("Less requires two arguments", &ScriptAssertionInfo{Type: "Less"}) return v8go.Undefined(iso) } a := args[0].Number() b := args[1].Number() message := getMessage(args, 2) if !(a < b) { msg := fmt.Sprintf("expected %v < %v", a, b) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "Less", Expected: fmt.Sprintf("< %v", b), Actual: a, Message: message, }) } return v8go.Undefined(iso) }) } // assertLessOrEqualMethod implements assert.LessOrEqual(a, b, message?) func assertLessOrEqualMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() if len(args) < 2 { t.fail("LessOrEqual requires two arguments", &ScriptAssertionInfo{Type: "LessOrEqual"}) return v8go.Undefined(iso) } a := args[0].Number() b := args[1].Number() message := getMessage(args, 2) if !(a <= b) { msg := fmt.Sprintf("expected %v <= %v", a, b) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "LessOrEqual", Expected: fmt.Sprintf("<= %v", b), Actual: a, Message: message, }) } return v8go.Undefined(iso) }) } // assertErrorMethod implements assert.Error(err, message?) func assertErrorMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() if len(args) < 1 { t.fail("Error requires an argument", &ScriptAssertionInfo{Type: "Error"}) return v8go.Undefined(iso) } // Check if it's null/undefined (no error) isError := !args[0].IsNull() && !args[0].IsUndefined() message := getMessage(args, 1) if !isError { msg := "expected an error, got nil" if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "Error", Actual: nil, Message: message, }) } return v8go.Undefined(iso) }) } // assertNoErrorMethod implements assert.NoError(err, message?) func assertNoErrorMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() if len(args) < 1 { t.fail("NoError requires an argument", &ScriptAssertionInfo{Type: "NoError"}) return v8go.Undefined(iso) } // Check if it's null/undefined (no error) isError := !args[0].IsNull() && !args[0].IsUndefined() message := getMessage(args, 1) if isError { msg := fmt.Sprintf("expected no error, got %v", args[0].String()) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "NoError", Actual: args[0].String(), Message: message, }) } return v8go.Undefined(iso) }) } // assertPanicMethod implements assert.Panic(fn, message?) func assertPanicMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 1 { t.fail("Panic requires a function argument", &ScriptAssertionInfo{Type: "Panic"}) return v8go.Undefined(iso) } if !args[0].IsFunction() { t.fail("Panic requires a function argument", &ScriptAssertionInfo{Type: "Panic"}) return v8go.Undefined(iso) } message := getMessage(args, 1) // Try to call the function and check if it throws fn, _ := args[0].AsFunction() _, err := fn.Call(v8ctx.Global()) if err == nil { msg := "expected function to panic, but it didn't" if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "Panic", Message: message, }) } return v8go.Undefined(iso) }) } // assertNoPanicMethod implements assert.NoPanic(fn, message?) func assertNoPanicMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 1 { t.fail("NoPanic requires a function argument", &ScriptAssertionInfo{Type: "NoPanic"}) return v8go.Undefined(iso) } if !args[0].IsFunction() { t.fail("NoPanic requires a function argument", &ScriptAssertionInfo{Type: "NoPanic"}) return v8go.Undefined(iso) } message := getMessage(args, 1) // Try to call the function and check if it throws fn, _ := args[0].AsFunction() _, err := fn.Call(v8ctx.Global()) if err != nil { msg := fmt.Sprintf("expected function not to panic, but got: %v", err) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "NoPanic", Actual: err.Error(), Message: message, }) } return v8go.Undefined(iso) }) } // assertMatchMethod implements assert.Match(value, pattern, message?) func assertMatchMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() if len(args) < 2 { t.fail("Match requires value and pattern arguments", &ScriptAssertionInfo{Type: "Match"}) return v8go.Undefined(iso) } value := args[0].String() pattern := args[1].String() message := getMessage(args, 2) re, err := regexp.Compile(pattern) if err != nil { t.fail(fmt.Sprintf("invalid regex pattern: %v", err), &ScriptAssertionInfo{ Type: "Match", Message: message, }) return v8go.Undefined(iso) } if !re.MatchString(value) { msg := fmt.Sprintf("expected '%s' to match pattern '%s'", value, pattern) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "Match", Expected: pattern, Actual: value, Message: message, }) } return v8go.Undefined(iso) }) } // assertNotMatchMethod implements assert.NotMatch(value, pattern, message?) func assertNotMatchMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() if len(args) < 2 { t.fail("NotMatch requires value and pattern arguments", &ScriptAssertionInfo{Type: "NotMatch"}) return v8go.Undefined(iso) } value := args[0].String() pattern := args[1].String() message := getMessage(args, 2) re, err := regexp.Compile(pattern) if err != nil { t.fail(fmt.Sprintf("invalid regex pattern: %v", err), &ScriptAssertionInfo{ Type: "NotMatch", Message: message, }) return v8go.Undefined(iso) } if re.MatchString(value) { msg := fmt.Sprintf("expected '%s' to not match pattern '%s'", value, pattern) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "NotMatch", Expected: pattern, Actual: value, Message: message, }) } return v8go.Undefined(iso) }) } // assertTypeMethod implements assert.Type(value, typeName, message?) func assertTypeMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { args := info.Args() if len(args) < 2 { t.fail("Type requires value and typeName arguments", &ScriptAssertionInfo{Type: "Type"}) return v8go.Undefined(iso) } value := args[0] expectedType := args[1].String() message := getMessage(args, 2) actualType := getJsType(value) if actualType != expectedType { msg := fmt.Sprintf("expected type '%s', got '%s'", expectedType, actualType) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "Type", Expected: expectedType, Actual: actualType, Message: message, }) } return v8go.Undefined(iso) }) } // assertJSONPathMethod implements assert.JSONPath(obj, path, expected, message?) func assertJSONPathMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 3 { t.fail("JSONPath requires obj, path, and expected arguments", &ScriptAssertionInfo{Type: "JSONPath"}) return v8go.Undefined(iso) } obj, _ := bridge.GoValue(args[0], v8ctx) path := args[1].String() expected, _ := bridge.GoValue(args[2], v8ctx) message := getMessage(args, 3) // Use the existing extractPath function from assert.go asserter := &Asserter{} actual := asserter.extractPath(obj, strings.TrimPrefix(path, "$.")) if !deepEqual(actual, expected) { msg := fmt.Sprintf("path '%s': expected %v, got %v", path, expected, actual) if message != "" { msg = message } t.fail(msg, &ScriptAssertionInfo{ Type: "JSONPath", Expected: expected, Actual: actual, Message: message, }) } return v8go.Undefined(iso) }) } // assertAgentMethod implements assert.Agent(response, agentID, options?) // Uses a validator agent to check the response // agentID is the direct agent ID (no "agents:" prefix needed) func assertAgentMethod(iso *v8go.Isolate, t *TestingT) *v8go.FunctionTemplate { return v8go.NewFunctionTemplate(iso, func(info *v8go.FunctionCallbackInfo) *v8go.Value { v8ctx := info.Context() args := info.Args() if len(args) < 2 { t.fail("Agent requires response and agentID arguments", &ScriptAssertionInfo{Type: "Agent"}) return v8go.Undefined(iso) } response, _ := bridge.GoValue(args[0], v8ctx) agentID := args[1].String() // Get options if provided var options map[string]interface{} if len(args) > 2 && args[2].IsObject() { optVal, _ := bridge.GoValue(args[2], v8ctx) options, _ = optVal.(map[string]interface{}) } // Build assertion with agents: prefix assertion := &Assertion{ Type: "agent", Use: "agents:" + agentID, } // Extract criteria and metadata from options if options != nil { if criteria, ok := options["criteria"]; ok { assertion.Value = criteria } if metadata, ok := options["metadata"].(map[string]interface{}); ok { assertion.Options = &AssertionOptions{Metadata: metadata} } if connector, ok := options["connector"].(string); ok { if assertion.Options == nil { assertion.Options = &AssertionOptions{} } assertion.Options.Connector = connector } } // Use the asserter to validate asserter := &Asserter{} result := asserter.assertAgent(assertion, response, nil) if !result.Passed { msg := result.Message if msg == "" { msg = "agent assertion failed" } t.fail(msg, &ScriptAssertionInfo{ Type: "Agent", Actual: response, Message: msg, }) } return v8go.Undefined(iso) }) } // Helper functions // deepEqual performs deep equality comparison func deepEqual(a, b interface{}) bool { // Handle nil cases if a == nil && b == nil { return true } if a == nil || b == nil { return false } // Try JSON comparison for complex types aJSON, errA := json.Marshal(a) bJSON, errB := json.Marshal(b) if errA == nil && errB == nil { return string(aJSON) == string(bJSON) } // Fall back to reflect.DeepEqual return reflect.DeepEqual(a, b) } // getLength returns the length of a value (array, string, map) func getLength(v interface{}) int { if v == nil { return 0 } switch val := v.(type) { case string: return len(val) case []interface{}: return len(val) case map[string]interface{}: return len(val) default: rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Slice, reflect.Array, reflect.Map, reflect.String: return rv.Len() } } return 0 } // getJsType returns the JavaScript type name of a value func getJsType(v *v8go.Value) string { if v.IsNull() { return "null" } if v.IsUndefined() { return "undefined" } if v.IsString() { return "string" } if v.IsNumber() { return "number" } if v.IsBoolean() { return "boolean" } if v.IsArray() { return "array" } if v.IsFunction() { return "function" } if v.IsObject() { return "object" } return "unknown" } ================================================ FILE: agent/test/script_hooks.go ================================================ package test import ( "fmt" "path/filepath" "strings" "github.com/yaoapp/gou/application" v8 "github.com/yaoapp/gou/runtime/v8" "github.com/yaoapp/gou/runtime/v8/bridge" "github.com/yaoapp/yao/agent/context" "rogchap.com/v8go" ) // HookExecutor executes before/after scripts from *_test.ts files // Scripts are loaded via V8 and executed directly, not via Process() type HookExecutor struct { verbose bool output *OutputWriter loadedDirs map[string]bool // Track which directories have been loaded agentContext *context.Context opts *Options // Test options (includes ContextData from --ctx) } // NewHookExecutor creates a new hook executor func NewHookExecutor(verbose bool) *HookExecutor { return &HookExecutor{ verbose: verbose, output: NewOutputWriter(verbose), loadedDirs: make(map[string]bool), } } // SetAgentContext sets the agent context for script execution func (h *HookExecutor) SetAgentContext(ctx *context.Context) { h.agentContext = ctx } // SetOptions sets the test options for hook execution func (h *HookExecutor) SetOptions(opts *Options) { h.opts = opts } // HookRef represents a parsed hook reference // Format: "src/env_test.ts:Before" or just "Before" (uses default test file) type HookRef struct { ScriptFile string // e.g., "env_test.ts" Function string // e.g., "Before" } // ParseHookRef parses a hook reference string // Formats: // - "Before" -> uses first *_test.ts file found // - "env_test.Before" -> uses src/env_test.ts // - "src/env_test.Before" -> uses src/env_test.ts func ParseHookRef(ref string) (*HookRef, error) { if ref == "" { return nil, fmt.Errorf("empty hook reference") } // Split by last dot to get function name lastDot := strings.LastIndex(ref, ".") if lastDot == -1 { // Just function name, will use default test file return &HookRef{ ScriptFile: "", // Will be resolved later Function: ref, }, nil } scriptPart := ref[:lastDot] funcName := ref[lastDot+1:] // Normalize script file name scriptFile := scriptPart if !strings.HasSuffix(scriptFile, "_test") { scriptFile += "_test" } scriptFile += ".ts" // Remove "src/" prefix if present scriptFile = strings.TrimPrefix(scriptFile, "src/") return &HookRef{ ScriptFile: scriptFile, Function: funcName, }, nil } // LoadTestScripts loads all *_test.ts scripts from the agent's src directory // Returns the script IDs that were loaded func (h *HookExecutor) LoadTestScripts(agentPath string) ([]string, error) { srcDir := filepath.Join(agentPath, "src") // Convert to relative path for application.App // application.App expects paths relative to YAO_ROOT relSrcDir := srcDir if application.App != nil { if rel, err := filepath.Rel(application.App.Root(), srcDir); err == nil { relSrcDir = rel } } // Check if already loaded (use absolute path as key) // No logging for already loaded - this is normal and happens frequently if h.loadedDirs[srcDir] { return nil, nil } // Check if src directory exists exists, err := application.App.Exists(relSrcDir) if err != nil { return nil, err } if !exists { return nil, nil // No src directory, not an error } var loadedScripts []string exts := []string{"*_test.ts", "*_test.js"} err = application.App.Walk(relSrcDir, func(root, file string, isdir bool) error { if isdir { return nil } // Only load *_test.ts/js files base := filepath.Base(file) if !strings.HasSuffix(base, "_test.ts") && !strings.HasSuffix(base, "_test.js") { return nil } // Generate script ID (use relative path for consistency) scriptID := generateHookScriptID(file, relSrcDir) // Load the script (file path from Walk is relative to App root) _, err := v8.Load(file, scriptID) if err != nil { if h.verbose { h.output.Warning("Failed to load hook script %s: %v", base, err) } return nil // Continue loading other scripts } loadedScripts = append(loadedScripts, scriptID) return nil }, exts...) if err != nil { return nil, fmt.Errorf("failed to walk src directory: %w", err) } h.loadedDirs[srcDir] = true // Log summary only once when scripts are first loaded if h.verbose && len(loadedScripts) > 0 { h.output.Verbose("Loaded %d hook scripts from %s", len(loadedScripts), relSrcDir) } return loadedScripts, nil } // generateHookScriptID generates a script ID for hook scripts // Example: assistants/test/src/env_test.ts -> hook.env_test func generateHookScriptID(filePath string, srcDir string) string { filePath = filepath.ToSlash(filePath) srcDir = filepath.ToSlash(srcDir) relPath := strings.TrimPrefix(filePath, srcDir+"/") relPath = strings.TrimPrefix(relPath, "/") relPath = strings.TrimSuffix(relPath, filepath.Ext(relPath)) return "hook." + strings.ReplaceAll(relPath, "/", ".") } // FindTestScript finds a loaded test script by pattern // If scriptFile is empty, returns the first *_test script found func (h *HookExecutor) FindTestScript(scriptFile string) (*v8.Script, string, error) { if scriptFile != "" { // Look for specific script scriptID := "hook." + strings.TrimSuffix(scriptFile, ".ts") scriptID = strings.TrimSuffix(scriptID, ".js") if script, ok := v8.Scripts[scriptID]; ok { return script, scriptID, nil } return nil, "", fmt.Errorf("hook script not found: %s (id: %s)", scriptFile, scriptID) } // Find first *_test script for id, script := range v8.Scripts { if strings.HasPrefix(id, "hook.") && strings.Contains(id, "_test") { return script, id, nil } } return nil, "", fmt.Errorf("no hook test script found") } // ExecuteBefore executes a Before function from a test script func (h *HookExecutor) ExecuteBefore(ref string, testCase *Case, agentPath string) (interface{}, error) { hookRef, err := ParseHookRef(ref) if err != nil { return nil, err } // Ensure scripts are loaded if _, err := h.LoadTestScripts(agentPath); err != nil { return nil, fmt.Errorf("failed to load test scripts: %w", err) } // Find the script script, scriptID, err := h.FindTestScript(hookRef.ScriptFile) if err != nil { return nil, err } if h.verbose { h.output.Verbose("Executing %s from %s", hookRef.Function, scriptID) } // Execute the function return h.executeHookFunction(script, hookRef.Function, testCase, nil, nil) } // ExecuteAfter executes an After function from a test script func (h *HookExecutor) ExecuteAfter(ref string, testCase *Case, result *Result, beforeData interface{}, agentPath string) error { hookRef, err := ParseHookRef(ref) if err != nil { return err } // Ensure scripts are loaded if _, err := h.LoadTestScripts(agentPath); err != nil { return fmt.Errorf("failed to load test scripts: %w", err) } // Find the script script, scriptID, err := h.FindTestScript(hookRef.ScriptFile) if err != nil { return err } if h.verbose { h.output.Verbose("Executing %s from %s", hookRef.Function, scriptID) } // Execute the function _, err = h.executeHookFunction(script, hookRef.Function, testCase, result, beforeData) return err } // ExecuteBeforeAll executes a BeforeAll function func (h *HookExecutor) ExecuteBeforeAll(ref string, testCases []*Case, agentPath string) (interface{}, error) { hookRef, err := ParseHookRef(ref) if err != nil { return nil, err } // Ensure scripts are loaded if _, err := h.LoadTestScripts(agentPath); err != nil { return nil, fmt.Errorf("failed to load test scripts: %w", err) } // Find the script script, scriptID, err := h.FindTestScript(hookRef.ScriptFile) if err != nil { return nil, err } if h.verbose { h.output.Verbose("Executing %s from %s", hookRef.Function, scriptID) } // Execute with test cases array return h.executeHookFunctionWithCases(script, hookRef.Function, testCases) } // ExecuteAfterAll executes an AfterAll function func (h *HookExecutor) ExecuteAfterAll(ref string, results []*Result, beforeData interface{}, agentPath string) error { hookRef, err := ParseHookRef(ref) if err != nil { return err } // Ensure scripts are loaded if _, err := h.LoadTestScripts(agentPath); err != nil { return fmt.Errorf("failed to load test scripts: %w", err) } // Find the script script, scriptID, err := h.FindTestScript(hookRef.ScriptFile) if err != nil { return err } if h.verbose { h.output.Verbose("Executing %s from %s", hookRef.Function, scriptID) } // Execute with results array _, err = h.executeHookFunctionWithResults(script, hookRef.Function, results, beforeData) return err } // executeHookFunction executes a hook function with test case context func (h *HookExecutor) executeHookFunction(script *v8.Script, funcName string, testCase *Case, result *Result, beforeData interface{}) (interface{}, error) { // Create script context scriptCtx, err := script.NewContext("", nil) if err != nil { return nil, fmt.Errorf("failed to create script context: %w", err) } defer scriptCtx.Close() v8ctx := scriptCtx.Context // Set share data if err := h.setShareData(v8ctx); err != nil { return nil, err } // Get the function global := v8ctx.Global() fnValue, err := global.Get(funcName) if err != nil { return nil, fmt.Errorf("failed to get function %s: %w", funcName, err) } if fnValue.IsUndefined() || fnValue.IsNull() { return nil, fmt.Errorf("function %s not defined", funcName) } if !fnValue.IsFunction() { return nil, fmt.Errorf("%s is not a function", funcName) } fn, err := fnValue.AsFunction() if err != nil { return nil, fmt.Errorf("failed to convert to function: %w", err) } // Build arguments args, err := h.buildHookArgs(v8ctx, testCase, result, beforeData) if err != nil { return nil, err } // Convert to v8go.Valuer slice for Call valuerArgs := make([]v8go.Valuer, len(args)) for i, arg := range args { valuerArgs[i] = arg } // Call the function jsResult, err := fn.Call(global, valuerArgs...) if err != nil { return nil, fmt.Errorf("hook function %s failed: %w", funcName, err) } // Convert result to Go value if jsResult == nil || jsResult.IsUndefined() || jsResult.IsNull() { return nil, nil } goResult, err := bridge.GoValue(jsResult, v8ctx) if err != nil { return nil, fmt.Errorf("failed to convert result: %w", err) } // Extract data field if present if resultMap, ok := goResult.(map[string]interface{}); ok { if data, exists := resultMap["data"]; exists { return data, nil } } return goResult, nil } // executeHookFunctionWithCases executes BeforeAll with test cases array func (h *HookExecutor) executeHookFunctionWithCases(script *v8.Script, funcName string, testCases []*Case) (interface{}, error) { scriptCtx, err := script.NewContext("", nil) if err != nil { return nil, fmt.Errorf("failed to create script context: %w", err) } defer scriptCtx.Close() v8ctx := scriptCtx.Context if err := h.setShareData(v8ctx); err != nil { return nil, err } global := v8ctx.Global() fnValue, err := global.Get(funcName) if err != nil { return nil, fmt.Errorf("failed to get function %s: %w", funcName, err) } if fnValue.IsUndefined() || fnValue.IsNull() { return nil, fmt.Errorf("function %s not defined", funcName) } if !fnValue.IsFunction() { return nil, fmt.Errorf("%s is not a function", funcName) } fn, err := fnValue.AsFunction() if err != nil { return nil, fmt.Errorf("failed to convert to function: %w", err) } // Build ctx argument ctxJS, err := h.buildCtxArg(v8ctx) if err != nil { return nil, err } // Convert test cases to JS array casesJS, err := h.testCasesToJS(v8ctx, testCases) if err != nil { return nil, err } jsResult, err := fn.Call(global, ctxJS, casesJS) if err != nil { return nil, fmt.Errorf("hook function %s failed: %w", funcName, err) } if jsResult == nil || jsResult.IsUndefined() || jsResult.IsNull() { return nil, nil } goResult, err := bridge.GoValue(jsResult, v8ctx) if err != nil { return nil, fmt.Errorf("failed to convert result: %w", err) } if resultMap, ok := goResult.(map[string]interface{}); ok { if data, exists := resultMap["data"]; exists { return data, nil } } return goResult, nil } // executeHookFunctionWithResults executes AfterAll with results array func (h *HookExecutor) executeHookFunctionWithResults(script *v8.Script, funcName string, results []*Result, beforeData interface{}) (interface{}, error) { scriptCtx, err := script.NewContext("", nil) if err != nil { return nil, fmt.Errorf("failed to create script context: %w", err) } defer scriptCtx.Close() v8ctx := scriptCtx.Context if err := h.setShareData(v8ctx); err != nil { return nil, err } global := v8ctx.Global() fnValue, err := global.Get(funcName) if err != nil { return nil, fmt.Errorf("failed to get function %s: %w", funcName, err) } if fnValue.IsUndefined() || fnValue.IsNull() { return nil, fmt.Errorf("function %s not defined", funcName) } if !fnValue.IsFunction() { return nil, fmt.Errorf("%s is not a function", funcName) } fn, err := fnValue.AsFunction() if err != nil { return nil, fmt.Errorf("failed to convert to function: %w", err) } // Build ctx argument ctxJS, err := h.buildCtxArg(v8ctx) if err != nil { return nil, err } // Convert results to JS array resultsJS, err := h.resultsToJS(v8ctx, results) if err != nil { return nil, err } // Convert beforeData to JS beforeDataJS, err := bridge.JsValue(v8ctx, beforeData) if err != nil { return nil, fmt.Errorf("failed to convert beforeData: %w", err) } jsResult, err := fn.Call(global, ctxJS, resultsJS, beforeDataJS) if err != nil { return nil, fmt.Errorf("hook function %s failed: %w", funcName, err) } if jsResult == nil || jsResult.IsUndefined() || jsResult.IsNull() { return nil, nil } goResult, err := bridge.GoValue(jsResult, v8ctx) if err != nil { return nil, fmt.Errorf("failed to convert result: %w", err) } return goResult, nil } // setShareData sets the share data for script execution func (h *HookExecutor) setShareData(v8ctx *v8go.Context) error { var authorized map[string]interface{} if h.agentContext != nil && h.agentContext.Authorized != nil { authorized = h.agentContext.Authorized.AuthorizedToMap() } return bridge.SetShareData(v8ctx, v8ctx.Global(), &bridge.Share{ Sid: "", Root: false, Global: nil, Authorized: authorized, }) } // buildCtxArg builds the context argument for hook functions func (h *HookExecutor) buildCtxArg(v8ctx *v8go.Context) (*v8go.Value, error) { ctxMap := map[string]interface{}{ "locale": "en", } // Use ContextData from --ctx flag if available if h.opts != nil && h.opts.ContextData != nil { cfg := h.opts.ContextData if cfg.Locale != "" { ctxMap["locale"] = cfg.Locale } if cfg.Authorized != nil { authorized := map[string]interface{}{} if cfg.Authorized.UserID != "" { authorized["user_id"] = cfg.Authorized.UserID } if cfg.Authorized.TeamID != "" { authorized["team_id"] = cfg.Authorized.TeamID } if cfg.Authorized.TenantID != "" { authorized["tenant_id"] = cfg.Authorized.TenantID } if cfg.Authorized.Sub != "" { authorized["sub"] = cfg.Authorized.Sub } ctxMap["authorized"] = authorized } if cfg.Metadata != nil { ctxMap["metadata"] = cfg.Metadata } } return bridge.JsValue(v8ctx, ctxMap) } // buildHookArgs builds the arguments for a hook function call // Arguments order: ctx, testCase, result (for After), beforeData (for After) func (h *HookExecutor) buildHookArgs(v8ctx *v8go.Context, testCase *Case, result *Result, beforeData interface{}) ([]*v8go.Value, error) { var args []*v8go.Value // Arg 1: ctx (context) - build from opts.ContextData if available ctxMap := map[string]interface{}{ "locale": "en", } // Use ContextData from --ctx flag if available if h.opts != nil && h.opts.ContextData != nil { cfg := h.opts.ContextData if cfg.Locale != "" { ctxMap["locale"] = cfg.Locale } if cfg.Authorized != nil { authorized := map[string]interface{}{} if cfg.Authorized.UserID != "" { authorized["user_id"] = cfg.Authorized.UserID } if cfg.Authorized.TeamID != "" { authorized["team_id"] = cfg.Authorized.TeamID } if cfg.Authorized.TenantID != "" { authorized["tenant_id"] = cfg.Authorized.TenantID } if cfg.Authorized.Sub != "" { authorized["sub"] = cfg.Authorized.Sub } ctxMap["authorized"] = authorized } if cfg.Metadata != nil { ctxMap["metadata"] = cfg.Metadata } } else if testCase != nil { // Fallback to test case fields if testCase.UserID != "" { ctxMap["user_id"] = testCase.UserID } if testCase.TeamID != "" { ctxMap["team_id"] = testCase.TeamID } // Build authorized info authorized := map[string]interface{}{} if testCase.UserID != "" { authorized["user_id"] = testCase.UserID } if testCase.TeamID != "" { authorized["team_id"] = testCase.TeamID } if len(authorized) > 0 { ctxMap["authorized"] = authorized } } ctxJS, err := bridge.JsValue(v8ctx, ctxMap) if err != nil { return nil, fmt.Errorf("failed to convert ctx: %w", err) } args = append(args, ctxJS) // Arg 2: testCase if testCase != nil { tcMap := map[string]interface{}{ "id": testCase.ID, "input": testCase.Input, } if testCase.Metadata != nil { tcMap["metadata"] = testCase.Metadata } if testCase.Assert != nil { tcMap["assert"] = testCase.Assert } // Include simulator options for dynamic tests if testCase.Simulator != nil { tcMap["simulator"] = testCase.Simulator } tcJS, err := bridge.JsValue(v8ctx, tcMap) if err != nil { return nil, fmt.Errorf("failed to convert testCase: %w", err) } args = append(args, tcJS) } else { // Pass empty object if no testCase emptyJS, _ := bridge.JsValue(v8ctx, map[string]interface{}{}) args = append(args, emptyJS) } // Arg 2: result (for After) if result != nil { resultMap := map[string]interface{}{ "id": result.ID, "status": string(result.Status), "duration_ms": result.DurationMs, } if result.Output != nil { resultMap["output"] = result.Output } if result.Error != "" { resultMap["error"] = result.Error } resultJS, err := bridge.JsValue(v8ctx, resultMap) if err != nil { return nil, fmt.Errorf("failed to convert result: %w", err) } args = append(args, resultJS) } // Arg 3: beforeData (for After) if beforeData != nil { beforeDataJS, err := bridge.JsValue(v8ctx, beforeData) if err != nil { return nil, fmt.Errorf("failed to convert beforeData: %w", err) } args = append(args, beforeDataJS) } return args, nil } // testCasesToJS converts test cases to a JS array func (h *HookExecutor) testCasesToJS(v8ctx *v8go.Context, testCases []*Case) (*v8go.Value, error) { cases := make([]map[string]interface{}, len(testCases)) for i, tc := range testCases { cases[i] = map[string]interface{}{ "id": tc.ID, "input": tc.Input, } if tc.Metadata != nil { cases[i]["metadata"] = tc.Metadata } } return bridge.JsValue(v8ctx, cases) } // resultsToJS converts results to a JS array func (h *HookExecutor) resultsToJS(v8ctx *v8go.Context, results []*Result) (*v8go.Value, error) { resultMaps := make([]map[string]interface{}, len(results)) for i, r := range results { resultMaps[i] = map[string]interface{}{ "id": r.ID, "status": string(r.Status), "duration_ms": r.DurationMs, } if r.Output != nil { resultMaps[i]["output"] = r.Output } if r.Error != "" { resultMaps[i]["error"] = r.Error } } return bridge.JsValue(v8ctx, resultMaps) } ================================================ FILE: agent/test/script_hooks_test.go ================================================ package test_test import ( "testing" "github.com/stretchr/testify/assert" v8 "github.com/yaoapp/gou/runtime/v8" agenttest "github.com/yaoapp/yao/agent/test" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) const hooksTestAgent = "assistants/tests/hooks-test" func TestParseHookRef(t *testing.T) { tests := []struct { name string input string wantFile string wantFunc string expectErr bool }{ { name: "function only", input: "Before", wantFile: "", wantFunc: "Before", }, { name: "with script file", input: "env_test.Before", wantFile: "env_test.ts", wantFunc: "Before", }, { name: "with src prefix", input: "src/env_test.Before", wantFile: "env_test.ts", wantFunc: "Before", }, { name: "nested path", input: "setup/db_test.Before", wantFile: "setup/db_test.ts", wantFunc: "Before", }, { name: "empty string", input: "", expectErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ref, err := agenttest.ParseHookRef(tt.input) if tt.expectErr { assert.Error(t, err) return } assert.NoError(t, err) assert.Equal(t, tt.wantFile, ref.ScriptFile) assert.Equal(t, tt.wantFunc, ref.Function) }) } } func TestHookExecutorLoadTestScripts(t *testing.T) { // Prepare test environment test.Prepare(t, config.Conf) defer test.Clean() // Load agent test scripts using the utility function scripts := test.LoadAgentTestScripts(t, hooksTestAgent) assert.NotEmpty(t, scripts, "Should load at least one test script") // Verify the script was loaded into V8 found := false for _, scriptID := range scripts { if _, ok := v8.Scripts[scriptID]; ok { found = true t.Logf("Loaded script: %s", scriptID) break } } assert.True(t, found, "At least one script should be loaded into V8") } func TestHookExecutorExecuteBefore(t *testing.T) { // Prepare test environment test.Prepare(t, config.Conf) defer test.Clean() // Load agent test scripts test.LoadAgentTestScripts(t, hooksTestAgent) executor := agenttest.NewHookExecutor(true) testCase := &agenttest.Case{ ID: "TEST001", Input: "Hello World", } // Execute Before hook beforeData, err := executor.ExecuteBefore("env_test.Before", testCase, hooksTestAgent) assert.NoError(t, err) assert.NotNil(t, beforeData) // Verify returned data dataMap, ok := beforeData.(map[string]interface{}) assert.True(t, ok, "beforeData should be a map") assert.Equal(t, "TEST001", dataMap["test_id"]) assert.NotEmpty(t, dataMap["mock_user_id"]) assert.NotEmpty(t, dataMap["mock_session_id"]) } func TestHookExecutorExecuteAfter(t *testing.T) { // Prepare test environment test.Prepare(t, config.Conf) defer test.Clean() // Load agent test scripts test.LoadAgentTestScripts(t, hooksTestAgent) executor := agenttest.NewHookExecutor(true) testCase := &agenttest.Case{ ID: "TEST002", Input: "Test input", } result := &agenttest.Result{ ID: "TEST002", Status: agenttest.StatusPassed, DurationMs: 100, } beforeData := map[string]interface{}{ "test_id": "TEST002", "mock_user_id": "user_TEST002_12345", "mock_session_id": "session_12345", } // Execute After hook err := executor.ExecuteAfter("env_test.After", testCase, result, beforeData, hooksTestAgent) assert.NoError(t, err) } func TestHookExecutorExecuteBeforeAll(t *testing.T) { // Prepare test environment test.Prepare(t, config.Conf) defer test.Clean() // Load agent test scripts test.LoadAgentTestScripts(t, hooksTestAgent) executor := agenttest.NewHookExecutor(true) testCases := []*agenttest.Case{ {ID: "T001", Input: "Test 1"}, {ID: "T002", Input: "Test 2"}, {ID: "T003", Input: "Test 3"}, } // Execute BeforeAll hook globalData, err := executor.ExecuteBeforeAll("env_test.BeforeAll", testCases, hooksTestAgent) assert.NoError(t, err) assert.NotNil(t, globalData) // Verify returned data dataMap, ok := globalData.(map[string]interface{}) assert.True(t, ok, "globalData should be a map") assert.NotEmpty(t, dataMap["suite_id"]) assert.Equal(t, float64(3), dataMap["test_count"]) // JSON numbers are float64 } func TestHookExecutorExecuteAfterAll(t *testing.T) { // Prepare test environment test.Prepare(t, config.Conf) defer test.Clean() // Load agent test scripts test.LoadAgentTestScripts(t, hooksTestAgent) executor := agenttest.NewHookExecutor(true) results := []*agenttest.Result{ {ID: "T001", Status: agenttest.StatusPassed, DurationMs: 100}, {ID: "T002", Status: agenttest.StatusFailed, DurationMs: 200, Error: "assertion failed"}, {ID: "T003", Status: agenttest.StatusPassed, DurationMs: 150}, } globalData := map[string]interface{}{ "suite_id": "suite_12345", "test_count": 3, } // Execute AfterAll hook err := executor.ExecuteAfterAll("env_test.AfterAll", results, globalData, hooksTestAgent) assert.NoError(t, err) } func TestHookExecutorFunctionNotFound(t *testing.T) { // Prepare test environment test.Prepare(t, config.Conf) defer test.Clean() // Load agent test scripts test.LoadAgentTestScripts(t, hooksTestAgent) executor := agenttest.NewHookExecutor(true) testCase := &agenttest.Case{ ID: "TEST001", Input: "Hello", } // Try to execute non-existent function _, err := executor.ExecuteBefore("env_test.NonExistent", testCase, hooksTestAgent) assert.Error(t, err) assert.Contains(t, err.Error(), "not defined") } func TestHookExecutorScriptNotFound(t *testing.T) { // Prepare test environment test.Prepare(t, config.Conf) defer test.Clean() // Load agent test scripts test.LoadAgentTestScripts(t, hooksTestAgent) executor := agenttest.NewHookExecutor(true) testCase := &agenttest.Case{ ID: "TEST001", Input: "Hello", } // Try to execute from non-existent script _, err := executor.ExecuteBefore("nonexistent_test.Before", testCase, hooksTestAgent) assert.Error(t, err) assert.Contains(t, err.Error(), "not found") } ================================================ FILE: agent/test/script_types.go ================================================ package test import "time" // ScriptInfo contains information about the script being tested type ScriptInfo struct { // ID is the script identifier (e.g., "scripts.expense.setup") ID string `json:"id"` // Assistant is the assistant directory name (e.g., "expense") Assistant string `json:"assistant"` // Module is the module name (e.g., "setup") Module string `json:"module"` // ScriptPath is the path to the main script file (e.g., "expense/src/setup.ts") ScriptPath string `json:"script_path"` // TestPath is the path to the test script file (e.g., "expense/src/setup_test.ts") TestPath string `json:"test_path"` } // ScriptTestCase represents a single script test function type ScriptTestCase struct { // Name is the test function name (e.g., "TestSystemReady") Name string `json:"name"` // Function is the full function reference Function string `json:"function"` } // ScriptTestResult represents the result of running a script test function type ScriptTestResult struct { // Name is the test function name Name string `json:"name"` // Status is the test execution status Status Status `json:"status"` // DurationMs is the execution duration in milliseconds DurationMs int64 `json:"duration_ms"` // Error contains the error message if the test failed Error string `json:"error,omitempty"` // Assertion contains assertion failure details Assertion *ScriptAssertionInfo `json:"assertion,omitempty"` // Logs contains log messages from the test Logs []string `json:"logs,omitempty"` } // ScriptAssertionInfo contains details about an assertion failure type ScriptAssertionInfo struct { // Type is the assertion type (e.g., "Equal", "True") Type string `json:"type"` // Expected is the expected value Expected interface{} `json:"expected,omitempty"` // Actual is the actual value Actual interface{} `json:"actual,omitempty"` // Message is the custom failure message Message string `json:"message,omitempty"` } // ScriptTestSummary contains aggregated statistics for script tests type ScriptTestSummary struct { // Total number of test functions Total int `json:"total"` // Passed number of test functions that passed Passed int `json:"passed"` // Failed number of test functions that failed Failed int `json:"failed"` // Skipped number of test functions that were skipped Skipped int `json:"skipped"` // DurationMs is the total execution duration in milliseconds DurationMs int64 `json:"duration_ms"` } // ScriptTestReport represents the complete script test report type ScriptTestReport struct { // Type indicates this is a script test report Type string `json:"type"` // "script_test" // Script is the script identifier (e.g., "scripts.expense.setup") Script string `json:"script"` // ScriptPath is the path to the test script file ScriptPath string `json:"script_path"` // Summary contains aggregated statistics Summary *ScriptTestSummary `json:"summary"` // Environment contains the test environment configuration Environment *Environment `json:"environment"` // Results contains individual test results Results []*ScriptTestResult `json:"results"` // Metadata contains additional report metadata Metadata *ScriptTestMetadata `json:"metadata"` } // ScriptTestMetadata contains metadata about the script test report type ScriptTestMetadata struct { // StartedAt is when the test run started StartedAt time.Time `json:"started_at"` // CompletedAt is when the test run completed CompletedAt time.Time `json:"completed_at"` // Version is the Yao version Version string `json:"version"` } // HasFailures returns true if there are any failed tests func (r *ScriptTestReport) HasFailures() bool { return r.Summary.Failed > 0 } // PassRate returns the pass rate as a percentage (0-100) func (r *ScriptTestReport) PassRate() float64 { if r.Summary.Total == 0 { return 0 } return float64(r.Summary.Passed) / float64(r.Summary.Total) * 100 } // ToReport converts ScriptTestReport to a standard Report for unified reporting func (r *ScriptTestReport) ToReport() *Report { return &Report{ Summary: &Summary{ Total: r.Summary.Total, Passed: r.Summary.Passed, Failed: r.Summary.Failed, Skipped: r.Summary.Skipped, DurationMs: r.Summary.DurationMs, AgentID: r.Script, AgentPath: r.ScriptPath, }, Environment: r.Environment, Results: r.toResults(), Metadata: &ReportMetadata{ StartedAt: r.Metadata.StartedAt, CompletedAt: r.Metadata.CompletedAt, Version: r.Metadata.Version, }, } } // toResults converts script test results to standard results func (r *ScriptTestReport) toResults() []*Result { results := make([]*Result, len(r.Results)) for i, sr := range r.Results { results[i] = &Result{ ID: sr.Name, Status: sr.Status, Input: sr.Name, DurationMs: sr.DurationMs, Error: sr.Error, } } return results } ================================================ FILE: agent/test/types.go ================================================ package test import ( "encoding/json" "fmt" "math" "os" "time" "github.com/yaoapp/yao/agent/context" ) // Status represents the status of a test case execution type Status string const ( // StatusPassed indicates the test passed StatusPassed Status = "passed" // StatusFailed indicates the test failed StatusFailed Status = "failed" // StatusSkipped indicates the test was skipped StatusSkipped Status = "skipped" // StatusError indicates a runtime error occurred StatusError Status = "error" // StatusTimeout indicates the test timed out StatusTimeout Status = "timeout" ) // OutputFormat represents the output format for test reports type OutputFormat string const ( // FormatJSON outputs JSON format (for CI integration) FormatJSON OutputFormat = "json" // FormatHTML outputs HTML format (for human review) FormatHTML OutputFormat = "html" // FormatMarkdown outputs Markdown format (for documentation) FormatMarkdown OutputFormat = "markdown" ) // StabilityClass represents the stability classification of a test case type StabilityClass string const ( // StabilityStable indicates 100% pass rate StabilityStable StabilityClass = "stable" // StabilityMostlyStable indicates 80-99% pass rate StabilityMostlyStable StabilityClass = "mostly_stable" // StabilityUnstable indicates 50-79% pass rate StabilityUnstable StabilityClass = "unstable" // StabilityHighlyUnstable indicates < 50% pass rate StabilityHighlyUnstable StabilityClass = "highly_unstable" ) // InputMode represents the input mode for test cases type InputMode string const ( // InputModeFile indicates input from a JSONL file InputModeFile InputMode = "file" // InputModeMessage indicates input from a direct message string InputModeMessage InputMode = "message" // InputModeScript indicates script test mode (testing agent handler scripts) InputModeScript InputMode = "script" ) // Options represents the configuration options for running tests type Options struct { // Input/Output // =============================== // Input is the input source: either a file path or a direct message Input string `json:"input"` // InputMode is the input mode (auto-detected from Input) InputMode InputMode `json:"input_mode"` // OutputFile is the path to write the test report // Format is determined by file extension (.json, .html, .md) OutputFile string `json:"output_file"` // Agent Selection // =============================== // AgentID is the explicit agent ID to test (optional) // If not set, agent is resolved from InputFile path AgentID string `json:"agent_id,omitempty"` // Connector overrides the agent's default connector (optional) Connector string `json:"connector,omitempty"` // Test Environment // =============================== // UserID is the test user ID (-u flag) UserID string `json:"user_id,omitempty"` // TeamID is the test team ID (-t flag) TeamID string `json:"team_id,omitempty"` // Locale is the locale for the test context (default: "en-us") Locale string `json:"locale,omitempty"` // ContextFile is the path to a JSON file containing custom context data (-ctx flag) // This allows full customization of authorized info, metadata, etc. ContextFile string `json:"context_file,omitempty"` // ContextData is the parsed context data from ContextFile // This is populated internally after loading the file ContextData *ContextConfig `json:"-"` // Execution // =============================== // Timeout is the default timeout for each test case // Can be overridden per test case Timeout time.Duration `json:"timeout,omitempty"` // Parallel is the number of tests to run in parallel // Default is 1 (sequential execution) Parallel int `json:"parallel,omitempty"` // Runs is the number of times to run each test case // Default is 1. When > 1, stability metrics are collected Runs int `json:"runs,omitempty"` // Reporting // =============================== // ReporterID is the reporter agent ID for custom report generation // If not set, default JSONL format is used ReporterID string `json:"reporter_id,omitempty"` // Behavior // =============================== // Verbose enables verbose output during test execution Verbose bool `json:"verbose,omitempty"` // FailFast stops execution on first failure FailFast bool `json:"fail_fast,omitempty"` // Run is a regex pattern to filter which tests to run (similar to go test -run) // Only tests matching the pattern will be executed // Example: "TestSystem" matches TestSystemReady, TestSystemError, etc. Run string `json:"run,omitempty"` // BeforeAll is the global before script (e.g., "scripts:tests.env.BeforeAll") // Called once before all test cases BeforeAll string `json:"before_all,omitempty"` // AfterAll is the global after script (e.g., "scripts:tests.env.AfterAll") // Called once after all test cases AfterAll string `json:"after_all,omitempty"` // DryRun generates test cases without running them // Useful for previewing agent-generated test cases DryRun bool `json:"dry_run,omitempty"` // Simulator is the default simulator agent ID for dynamic mode // Can be overridden per test case in JSONL Simulator string `json:"simulator,omitempty"` } // ContextConfig represents custom context configuration from JSON file // This allows full customization of the test context including authorized info type ContextConfig struct { // ChatID is the chat session identifier // Used to maintain session state across turns in dynamic tests ChatID string `json:"chat_id,omitempty"` // Authorized contains custom authorization data Authorized *AuthorizedConfig `json:"authorized,omitempty"` // Metadata contains custom metadata to pass to the context Metadata map[string]interface{} `json:"metadata,omitempty"` // Client contains custom client information Client *ClientConfig `json:"client,omitempty"` // Locale overrides the locale setting Locale string `json:"locale,omitempty"` // Referer overrides the referer setting Referer string `json:"referer,omitempty"` } // AuthorizedConfig represents custom authorization configuration // Matches the structure of types.AuthorizedInfo from openapi/oauth/types type AuthorizedConfig struct { // Sub is the subject identifier (JWT sub claim) Sub string `json:"sub,omitempty"` // ClientID is the OAuth client ID ClientID string `json:"client_id,omitempty"` // Scope is the access scope Scope string `json:"scope,omitempty"` // SessionID is the session identifier SessionID string `json:"session_id,omitempty"` // UserID is the user identifier UserID string `json:"user_id,omitempty"` // TeamID is the team identifier TeamID string `json:"team_id,omitempty"` // TenantID is the tenant identifier TenantID string `json:"tenant_id,omitempty"` // RememberMe is the remember me flag RememberMe bool `json:"remember_me,omitempty"` // Constraints contains data access constraints (set by ACL enforcement) Constraints *DataConstraintsConfig `json:"constraints,omitempty"` } // DataConstraintsConfig represents data access constraints // Matches the structure of types.DataConstraints from openapi/oauth/types type DataConstraintsConfig struct { // OwnerOnly - only access owner's data OwnerOnly bool `json:"owner_only,omitempty"` // CreatorOnly - only access creator's data CreatorOnly bool `json:"creator_only,omitempty"` // EditorOnly - only access editor's data EditorOnly bool `json:"editor_only,omitempty"` // TeamOnly - only access team's data (filter by team_id) TeamOnly bool `json:"team_only,omitempty"` // Extra contains user-defined constraints (department, region, etc.) Extra map[string]interface{} `json:"extra,omitempty"` } // ClientConfig represents custom client configuration type ClientConfig struct { // Type is the client type (e.g., "web", "mobile", "test") Type string `json:"type,omitempty"` // UserAgent is the client user agent string UserAgent string `json:"user_agent,omitempty"` // IP is the client IP address IP string `json:"ip,omitempty"` } // Environment configures the test execution context type Environment struct { // UserID is the user ID for authorized info (-u flag) UserID string `json:"user_id"` // TeamID is the team ID for authorized info (-t flag) TeamID string `json:"team_id"` // Locale is the locale (default: "en-us") Locale string `json:"locale"` // ClientType is the client type (default: "test") ClientType string `json:"client_type"` // ClientIP is the client IP (default: "127.0.0.1") ClientIP string `json:"client_ip"` // Referer is the request referer (default: "test") Referer string `json:"referer"` // Accept is the accept format (default: "standard") Accept string `json:"accept"` // ContextConfig contains custom context configuration (from -ctx flag) ContextConfig *ContextConfig `json:"-"` } // NewEnvironment creates a new test environment with defaults func NewEnvironment(userID, teamID string) *Environment { env := &Environment{ UserID: userID, TeamID: teamID, Locale: "en-us", ClientType: "test", ClientIP: "127.0.0.1", Referer: "test", Accept: "standard", } // Apply defaults if not set if env.UserID == "" { env.UserID = "test-user" } if env.TeamID == "" { env.TeamID = "test-team" } return env } // NewEnvironmentWithContext creates a new test environment with custom context config func NewEnvironmentWithContext(userID, teamID string, ctxConfig *ContextConfig) *Environment { env := NewEnvironment(userID, teamID) if ctxConfig == nil { return env } env.ContextConfig = ctxConfig // Override with context config values if ctxConfig.Locale != "" { env.Locale = ctxConfig.Locale } if ctxConfig.Referer != "" { env.Referer = ctxConfig.Referer } if ctxConfig.Client != nil { if ctxConfig.Client.Type != "" { env.ClientType = ctxConfig.Client.Type } if ctxConfig.Client.IP != "" { env.ClientIP = ctxConfig.Client.IP } } if ctxConfig.Authorized != nil { if ctxConfig.Authorized.UserID != "" { env.UserID = ctxConfig.Authorized.UserID } // TeamID takes precedence over TenantID for team override if ctxConfig.Authorized.TeamID != "" { env.TeamID = ctxConfig.Authorized.TeamID } else if ctxConfig.Authorized.TenantID != "" { env.TeamID = ctxConfig.Authorized.TenantID } } return env } // LoadContextConfig loads context configuration from a JSON file func LoadContextConfig(filePath string) (*ContextConfig, error) { resolvedPath := ResolvePathWithYaoRoot(filePath) data, err := os.ReadFile(resolvedPath) if err != nil { return nil, fmt.Errorf("failed to read context file: %w", err) } var config ContextConfig if err := json.Unmarshal(data, &config); err != nil { return nil, fmt.Errorf("failed to parse context file: %w", err) } return &config, nil } // Case represents a single test case loaded from JSONL type Case struct { // ID is the unique identifier for this test case (e.g., "T001") ID string `json:"id"` // Input is the test input, can be: // - string: simple text input // - map (Message): single message with role and content // - []map ([]Message): conversation history Input interface{} `json:"input"` // Expected is the expected output for validation (optional) // If set, the actual output will be compared against this Expected interface{} `json:"expected,omitempty"` // Assert defines custom assertion rules (optional) // If set, these rules will be used instead of simple expected comparison // Can be a single assertion or an array of assertions Assert interface{} `json:"assert,omitempty"` // Environment (per-test case, can be overridden by command line flags) // =============================== // UserID is the user ID for this test case (overridden by -u flag) UserID string `json:"user,omitempty"` // TeamID is the team ID for this test case (overridden by -t flag) TeamID string `json:"team,omitempty"` // Metadata contains additional metadata for the test case // This is passed to ctx.Metadata and can be used by Create Hook Metadata map[string]interface{} `json:"metadata,omitempty"` // Options contains context options for this test case // Supports: connector, skip (history, trace, output, keyword, search), mode Options *CaseOptions `json:"options,omitempty"` // Skip indicates whether to skip this test case Skip bool `json:"skip,omitempty"` // Timeout overrides the default timeout for this test case // Format: "30s", "1m", "2m30s" Timeout string `json:"timeout,omitempty"` // Before script function (e.g., "scripts:tests.env.Before") // Called before the test case runs, returns data passed to After Before string `json:"before,omitempty"` // After script function (e.g., "scripts:tests.env.After") // Called after the test case completes (pass or fail) After string `json:"after,omitempty"` // Dynamic Mode Fields // =============================== // Simulator configures the user simulator for dynamic testing // When set, the test runs in dynamic mode with multi-turn conversation Simulator *Simulator `json:"simulator,omitempty"` // Checkpoints define validation points for dynamic testing // Each checkpoint is checked after every agent response Checkpoints []*Checkpoint `json:"checkpoints,omitempty"` // MaxTurns is the maximum number of conversation turns (default: 20) MaxTurns int `json:"max_turns,omitempty"` } // Simulator configures the user simulator for dynamic testing type Simulator struct { // Use is the simulator agent ID (no prefix needed) Use string `json:"use"` // Options for the simulator agent Options *SimulatorOptions `json:"options,omitempty"` } // SimulatorOptions configures simulator behavior type SimulatorOptions struct { // Metadata passed to the simulator agent // Common fields: persona, goal, style Metadata map[string]interface{} `json:"metadata,omitempty"` // Connector overrides the simulator's default connector Connector string `json:"connector,omitempty"` } // Checkpoint defines a validation point in dynamic testing type Checkpoint struct { // ID is the unique identifier for this checkpoint ID string `json:"id"` // Description is a human-readable description Description string `json:"description,omitempty"` // Assert defines the assertion to validate // Same format as Case.Assert Assert interface{} `json:"assert"` // After specifies checkpoint IDs that must be reached before this one // Used to enforce ordering (e.g., "ask_type" must come before "confirm") After []string `json:"after,omitempty"` // Required indicates if this checkpoint must be reached (default: true) // Optional checkpoints don't cause test failure if not reached Required *bool `json:"required,omitempty"` } // CaseOptions represents per-test-case context options // Maps to context.Options fields type CaseOptions struct { // Connector overrides the agent's default connector Connector string `json:"connector,omitempty"` // Skip configuration Skip *CaseSkipOptions `json:"skip,omitempty"` // DisableGlobalPrompts temporarily disables global prompts for this request DisableGlobalPrompts bool `json:"disable_global_prompts,omitempty"` // Search mode, default is true (use pointer to distinguish unset from false) Search *bool `json:"search,omitempty"` // Mode is the agent mode (default: "chat") Mode string `json:"mode,omitempty"` // Metadata for passing custom data to hooks (e.g., scenario selection) Metadata map[string]interface{} `json:"metadata,omitempty"` } // CaseSkipOptions represents skip configuration for a test case // Maps to context.Skip fields type CaseSkipOptions struct { History bool `json:"history,omitempty"` // Skip history loading Trace bool `json:"trace,omitempty"` // Skip trace logging Output bool `json:"output,omitempty"` // Skip output to client Keyword bool `json:"keyword,omitempty"` // Skip keyword extraction Search bool `json:"search,omitempty"` // Skip auto search } // Assertion represents a single assertion rule type Assertion struct { // Type is the assertion type: // - "equals": exact match (default if expected is set) // - "contains": output contains the expected string/value // - "not_contains": output does not contain the string/value // - "json_path": extract value using JSON path and compare // - "regex": match output against regex pattern // - "script": run a custom assertion script // - "type": check output type (string, object, array, number, boolean) // - "schema": validate against JSON schema // - "agent": use an agent to validate the response Type string `json:"type"` // Value is the expected value or pattern (depends on type) Value interface{} `json:"value,omitempty"` // Path is the JSON path for json_path assertions (e.g., "$.need_search") Path string `json:"path,omitempty"` // Script is the assertion script name for script assertions // The script receives (output, input, expected) and returns {pass: bool, message: string} Script string `json:"script,omitempty"` // Use specifies the agent/script for validation // For agent assertions: "agents:tests.validator-agent" (with prefix) // For script assertions: "scripts:tests.validate" (with prefix) Use string `json:"use,omitempty"` // Options for agent-driven assertions (aligned with context.Options) Options *AssertionOptions `json:"options,omitempty"` // Message is a custom failure message Message string `json:"message,omitempty"` // Negate inverts the assertion result Negate bool `json:"negate,omitempty"` } // AssertionOptions for agent-driven assertions type AssertionOptions struct { // Connector overrides the agent's default connector Connector string `json:"connector,omitempty"` // Metadata contains custom data passed to the validator agent Metadata map[string]interface{} `json:"metadata,omitempty"` } // AssertionResult represents the result of an assertion type AssertionResult struct { // Passed indicates whether the assertion passed Passed bool `json:"passed"` // Message describes the assertion result Message string `json:"message,omitempty"` // Assertion is the original assertion that was evaluated Assertion *Assertion `json:"assertion,omitempty"` // Actual is the actual value that was compared Actual interface{} `json:"actual,omitempty"` // Expected is the expected value Expected interface{} `json:"expected,omitempty"` } // GetEnvironment returns the effective test environment for this test case // Priority: command line flags > context config > test case fields > defaults func (tc *Case) GetEnvironment(opts *Options) *Environment { // Start with context config if available, otherwise use defaults var env *Environment if opts != nil && opts.ContextData != nil { env = NewEnvironmentWithContext("", "", opts.ContextData) } else { env = NewEnvironment("", "") } // Apply test case specific values if tc.UserID != "" { env.UserID = tc.UserID } if tc.TeamID != "" { env.TeamID = tc.TeamID } // Apply command line overrides (highest priority) if opts != nil { if opts.UserID != "" { env.UserID = opts.UserID } if opts.TeamID != "" { env.TeamID = opts.TeamID } if opts.Locale != "" { env.Locale = opts.Locale } } return env } // GetMessages converts the Input to a slice of context.Message // This handles all input formats: string, Message, []Message func (tc *Case) GetMessages() ([]context.Message, error) { return ParseInput(tc.Input) } // GetMessagesWithOptions converts the Input to a slice of context.Message with options // This handles all input formats: string, Message, []Message // It also processes file:// references in content parts func (tc *Case) GetMessagesWithOptions(opts *InputOptions) ([]context.Message, error) { return ParseInputWithOptions(tc.Input, opts) } // GetTimeout returns the timeout duration for this test case // Returns the override timeout if set, otherwise returns the default func (tc *Case) GetTimeout(defaultTimeout time.Duration) time.Duration { if tc.Timeout == "" { return defaultTimeout } d, err := time.ParseDuration(tc.Timeout) if err != nil { return defaultTimeout } return d } // Result represents the result of running a single test case type Result struct { // ID is the test case identifier ID string `json:"id"` // Status is the test execution status Status Status `json:"status"` // Input is the original test input (for reference in reports) Input interface{} `json:"input"` // Output is the actual output from the agent Output interface{} `json:"output,omitempty"` // Expected is the expected output (if specified in test case) Expected interface{} `json:"expected,omitempty"` // DurationMs is the execution duration in milliseconds DurationMs int64 `json:"duration_ms"` // Error contains the error message if status is failed/error/timeout Error string `json:"error,omitempty"` // Options contains the context options used for this test case Options *CaseOptions `json:"options,omitempty"` // Metadata contains additional result metadata Metadata map[string]interface{} `json:"metadata,omitempty"` } // RunDetail represents the result of a single run in stability testing type RunDetail struct { // Run is the run number (1-based) Run int `json:"run"` // Status is the execution status for this run Status Status `json:"status"` // DurationMs is the execution duration in milliseconds DurationMs int64 `json:"duration_ms"` // Output is the output from this run Output interface{} `json:"output,omitempty"` // Error contains the error message if this run failed Error string `json:"error,omitempty"` } // StabilityResult represents the stability analysis result for a test case type StabilityResult struct { // ID is the test case identifier ID string `json:"id"` // Input is the original test input Input interface{} `json:"input"` // Expected is the expected output (if specified) Expected interface{} `json:"expected,omitempty"` // Runs is the total number of runs Runs int `json:"runs"` // Passed is the number of runs that passed Passed int `json:"passed"` // Failed is the number of runs that failed Failed int `json:"failed"` // PassRate is the pass rate percentage (0-100) PassRate float64 `json:"pass_rate"` // Consistency is a measure of output consistency (0-1) // 1.0 means all outputs are identical, lower values indicate variation Consistency float64 `json:"consistency"` // Stable indicates whether the test is considered stable Stable bool `json:"stable"` // StabilityClass is the stability classification StabilityClass StabilityClass `json:"stability_class"` // Timing statistics AvgDurationMs float64 `json:"avg_duration_ms"` MinDurationMs int64 `json:"min_duration_ms"` MaxDurationMs int64 `json:"max_duration_ms"` StdDeviationMs float64 `json:"std_deviation_ms"` // RunDetails contains details for each run RunDetails []*RunDetail `json:"run_details"` } // CalculateStability calculates stability metrics from run details func (sr *StabilityResult) CalculateStability() { if len(sr.RunDetails) == 0 { return } sr.Runs = len(sr.RunDetails) sr.Passed = 0 sr.Failed = 0 var totalDuration int64 sr.MinDurationMs = math.MaxInt64 sr.MaxDurationMs = 0 for _, rd := range sr.RunDetails { if rd.Status == StatusPassed { sr.Passed++ } else { sr.Failed++ } totalDuration += rd.DurationMs if rd.DurationMs < sr.MinDurationMs { sr.MinDurationMs = rd.DurationMs } if rd.DurationMs > sr.MaxDurationMs { sr.MaxDurationMs = rd.DurationMs } } // Calculate pass rate sr.PassRate = float64(sr.Passed) / float64(sr.Runs) * 100 // Calculate average duration sr.AvgDurationMs = float64(totalDuration) / float64(sr.Runs) // Calculate standard deviation var sumSquares float64 for _, rd := range sr.RunDetails { diff := float64(rd.DurationMs) - sr.AvgDurationMs sumSquares += diff * diff } sr.StdDeviationMs = math.Sqrt(sumSquares / float64(sr.Runs)) // Determine stability classification sr.StabilityClass = ClassifyStability(sr.PassRate) sr.Stable = sr.PassRate == 100 // Calculate consistency (simplified: based on pass rate) sr.Consistency = sr.PassRate / 100 } // ClassifyStability returns the stability classification based on pass rate func ClassifyStability(passRate float64) StabilityClass { switch { case passRate == 100: return StabilityStable case passRate >= 80: return StabilityMostlyStable case passRate >= 50: return StabilityUnstable default: return StabilityHighlyUnstable } } // Summary contains aggregated statistics for the test run type Summary struct { // Total number of test cases Total int `json:"total"` // Passed number of test cases that passed Passed int `json:"passed"` // Failed number of test cases that failed Failed int `json:"failed"` // Skipped number of test cases that were skipped Skipped int `json:"skipped"` // Errors number of test cases with runtime errors Errors int `json:"errors"` // Timeouts number of test cases that timed out Timeouts int `json:"timeouts"` // DurationMs is the total execution duration in milliseconds DurationMs int64 `json:"duration_ms"` // AgentID is the ID of the agent being tested AgentID string `json:"agent_id"` // AgentPath is the file path of the agent (for path-based resolution) AgentPath string `json:"agent_path,omitempty"` // Connector is the connector used for the test Connector string `json:"connector"` // Stability metrics (when Runs > 1) // =============================== // RunsPerCase is the number of runs per test case RunsPerCase int `json:"runs_per_case,omitempty"` // TotalRuns is the total number of runs (Total * RunsPerCase) TotalRuns int `json:"total_runs,omitempty"` // OverallPassRate is the overall pass rate percentage OverallPassRate float64 `json:"overall_pass_rate,omitempty"` // StableCases is the number of cases with 100% pass rate StableCases int `json:"stable_cases,omitempty"` // UnstableCases is the number of cases with < 100% pass rate UnstableCases int `json:"unstable_cases,omitempty"` } // Report represents the complete test report type Report struct { // Summary contains aggregated statistics Summary *Summary `json:"summary"` // Environment contains the test environment configuration Environment *Environment `json:"environment,omitempty"` // Results contains individual test results (for single run) Results []*Result `json:"results,omitempty"` // StabilityResults contains stability analysis results (for multiple runs) StabilityResults []*StabilityResult `json:"stability_results,omitempty"` // Metadata contains additional report metadata Metadata *ReportMetadata `json:"metadata"` } // ReportMetadata contains metadata about the test report type ReportMetadata struct { // StartedAt is when the test run started StartedAt time.Time `json:"started_at"` // CompletedAt is when the test run completed CompletedAt time.Time `json:"completed_at"` // Version is the Yao version Version string `json:"version"` // InputFile is the path to the input file InputFile string `json:"input_file"` // OutputFile is the path to the output file OutputFile string `json:"output_file"` // Options contains the test options used Options *Options `json:"options,omitempty"` } // HasFailures returns true if there are any failed, error, or timeout tests func (r *Report) HasFailures() bool { return r.Summary.Failed > 0 || r.Summary.Errors > 0 || r.Summary.Timeouts > 0 } // PassRate returns the pass rate as a percentage (0-100) func (r *Report) PassRate() float64 { if r.Summary.Total == 0 { return 0 } return float64(r.Summary.Passed) / float64(r.Summary.Total) * 100 } // IsStabilityTest returns true if this is a stability test (multiple runs) func (r *Report) IsStabilityTest() bool { return r.Summary.RunsPerCase > 1 } // AgentInfo contains information about the agent being tested type AgentInfo struct { // ID is the agent identifier ID string `json:"id"` // Name is the human-readable name Name string `json:"name"` // Description is the agent description Description string `json:"description,omitempty"` // Path is the file system path to the agent Path string `json:"path"` // Connector is the default connector Connector string `json:"connector"` // Type is the agent type (e.g., "worker", "assistant") Type string `json:"type,omitempty"` } // ReporterInput is the input passed to a custom reporter agent type ReporterInput struct { // Report is the test report to format Report *Report `json:"report"` // Format is the desired output format Format string `json:"format"` // Options contains additional formatting options Options *ReporterOptions `json:"options,omitempty"` } // ReporterOptions contains options for custom reporter agents type ReporterOptions struct { // Verbose includes detailed output in the report Verbose bool `json:"verbose,omitempty"` // IncludeOutputs includes full outputs in the report IncludeOutputs bool `json:"include_outputs,omitempty"` // IncludeInputs includes full inputs in the report IncludeInputs bool `json:"include_inputs,omitempty"` // MaxOutputLength limits the output length in the report MaxOutputLength int `json:"max_output_length,omitempty"` // Theme is the report theme (for HTML reports) Theme string `json:"theme,omitempty"` // Title is the report title Title string `json:"title,omitempty"` } ================================================ FILE: agent/testutils/testutils.go ================================================ package testutils import ( "testing" _ "github.com/yaoapp/gou/encoding" "github.com/yaoapp/gou/model" "github.com/yaoapp/gou/query" "github.com/yaoapp/gou/query/gou" _ "github.com/yaoapp/gou/text" "github.com/yaoapp/xun/capsule" "github.com/yaoapp/yao/agent" "github.com/yaoapp/yao/agent/caller" "github.com/yaoapp/yao/agent/llm" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/kb" "github.com/yaoapp/yao/test" // Import assistant to trigger init() which registers AgentGetterFunc _ "github.com/yaoapp/yao/agent/assistant" ) // Prepare prepare the test environment with optional V8 mode configuration // Usage: // // testutils.Prepare(t) // standard mode (default) // testutils.Prepare(t, test.PrepareOption{V8Mode: "performance"}) // performance mode for benchmarks func Prepare(t *testing.T, opts ...interface{}) { test.Prepare(t, config.Conf, opts...) // Load KB (required for agent KB features) _, err := kb.Load(config.Conf) if err != nil { t.Fatal(err) } // Load agent err = agent.Load(config.Conf) if err != nil { t.Fatal(err) } // Ensure JSAPI factories are registered (may be called multiple times, idempotent) // This is needed because Go's init() order is not guaranteed across packages caller.SetJSAPIFactory() llm.SetJSAPIFactory() // Register default query engine (required for DB search) // capsule.Global is initialized by test.Prepare if _, has := query.Engines["default"]; !has && capsule.Global != nil { query.Register("default", &gou.Query{ Query: capsule.Query(), GetTableName: func(s string) string { if mod, has := model.Models[s]; has { return mod.MetaData.Table.Name } return s }, AESKey: config.Conf.DB.AESKey, }) } } // Clean clean the test environment func Clean(t *testing.T) { test.Clean() } ================================================ FILE: agent/types/dsl.go ================================================ package types import "github.com/yaoapp/gou/store" // GetCacheStore get the cache store func (dsl *DSL) GetCacheStore() (store.Store, error) { if dsl.Cache == "" { return store.Get("__yao.agent.cache") } return store.Get(dsl.Cache) } ================================================ FILE: agent/types/types.go ================================================ package types import ( "github.com/yaoapp/yao/agent/assistant" searchTypes "github.com/yaoapp/yao/agent/search/types" store "github.com/yaoapp/yao/agent/store/types" ) // DSL AI assistant type DSL struct { // Agent Global Settings // =============================== Uses *Uses `json:"uses,omitempty" yaml:"uses,omitempty"` // Which assistant to use default, title, prompt StoreSetting store.Setting `json:"store" yaml:"store"` // The store setting of the assistant Cache string `json:"cache" yaml:"cache"` // The cache store of the assistant, if not set, default is "__yao.agent.cache" // System Agents Connector Settings // =============================== // System configures connectors for system agents (__yao.keyword, __yao.querydsl, __yao.title, __yao.prompt) // Each agent can have its own connector, or use the default // If not set, fallback to the first connector that supports the required capabilities System *System `json:"system,omitempty" yaml:"system,omitempty"` // Global External Settings // =============================== KB *store.KBSetting `json:"kb,omitempty" yaml:"kb,omitempty"` // The knowledge base configuration loaded from agent/kb.yml Search *searchTypes.Config `json:"search,omitempty" yaml:"search,omitempty"` // The search configuration loaded from agent/search.yao // Internal // =============================== // ID string `json:"-" yaml:"-"` // The id of the instance Assistant assistant.API `json:"-" yaml:"-"` // The default assistant Store store.Store `json:"-" yaml:"-"` // The store of the assistant GlobalPrompts []store.Prompt `json:"-" yaml:"-"` // Global prompts loaded from agent/prompts.yml } // Uses the default assistant settings // =============================== type Uses struct { Default string `json:"default,omitempty" yaml:"default,omitempty"` // The default assistant to use Title string `json:"title,omitempty" yaml:"title,omitempty"` // The assistant for generating the topic title. Prompt string `json:"prompt,omitempty" yaml:"prompt,omitempty"` // The assistant for generating the prompt. RobotPrompt string `json:"robot_prompt,omitempty" yaml:"robot_prompt,omitempty"` // The assistant for generating Robot's system prompt (responsibilities description). Vision string `json:"vision,omitempty" yaml:"vision,omitempty"` // The assistant for generating the image/video description, if the assistant enable the vision and model not support vision, use the vision model to describe the image/video, and return the messages with the image/video's description. Format: "agent" or "mcp:mcp_server_id" Audio string `json:"audio,omitempty" yaml:"audio,omitempty"` // The assistant for processing audio (speech-to-text, text-to-speech). If the model doesn't support audio, use this to convert audio to text. Format: "agent" or "mcp:mcp_server_id" Search string `json:"search,omitempty" yaml:"search,omitempty"` // The assistant for searching the knowledge, global web search. If not set, and the assistant enable the knowledge, it will search the result from the knowledge automatically. Fetch string `json:"fetch,omitempty" yaml:"fetch,omitempty"` // The assistant for fetching the http/https/ftp/sftp/etc. file, and return the file's content. if not set, use the http process to fetch the file. // Search-related processing tools (NLP) Web string `json:"web,omitempty" yaml:"web,omitempty"` // Web search handler: "builtin", "", "mcp:." Keyword string `json:"keyword,omitempty" yaml:"keyword,omitempty"` // Keyword extraction: "builtin", "", "mcp:." QueryDSL string `json:"querydsl,omitempty" yaml:"querydsl,omitempty"` // QueryDSL generation: "builtin", "", "mcp:." Rerank string `json:"rerank,omitempty" yaml:"rerank,omitempty"` // Result reranking: "builtin", "", "mcp:." } // System configures connectors for system agents // =============================== type System struct { Default string `json:"default,omitempty" yaml:"default,omitempty"` // Default connector for all system agents Keyword string `json:"keyword,omitempty" yaml:"keyword,omitempty"` // Connector for __yao.keyword agent QueryDSL string `json:"querydsl,omitempty" yaml:"querydsl,omitempty"` // Connector for __yao.querydsl agent Title string `json:"title,omitempty" yaml:"title,omitempty"` // Connector for __yao.title agent Prompt string `json:"prompt,omitempty" yaml:"prompt,omitempty"` // Connector for __yao.prompt agent RobotPrompt string `json:"robot_prompt,omitempty" yaml:"robot_prompt,omitempty"` // Connector for __yao.robot_prompt agent NeedSearch string `json:"needsearch,omitempty" yaml:"needsearch,omitempty"` // Connector for __yao.needsearch agent Entity string `json:"entity,omitempty" yaml:"entity,omitempty"` // Connector for __yao.entity agent } // Mention Structure // =============================== type Mention struct { ID string `json:"id"` Name string `json:"name"` Avatar string `json:"avatar,omitempty"` Type string `json:"type,omitempty"` } ================================================ FILE: aigc/aigc.go ================================================ package aigc import ( "fmt" "strings" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/process" "github.com/yaoapp/kun/exception" "github.com/yaoapp/yao/openai" ) // Autopilots the loaded autopilots var Autopilots = []string{} // AIGCs the loaded AIGCs var AIGCs = map[string]*DSL{} // Select select the AIGC func Select(id string) (*DSL, error) { if AIGCs[id] == nil { return nil, fmt.Errorf("aigc %s not found", id) } return AIGCs[id], nil } // Call the AIGC func (ai *DSL) Call(content string, user string, option map[string]interface{}) (interface{}, *exception.Exception) { messages := []map[string]interface{}{} for _, prompt := range ai.Prompts { message := map[string]interface{}{"role": prompt.Role, "content": prompt.Content} if prompt.Name != "" { message["name"] = prompt.Name } messages = append(messages, message) } // add the user message message := map[string]interface{}{"role": "user", "content": content} if user != "" { message["user"] = user } messages = append(messages, message) bytes, err := jsoniter.Marshal(messages) if err != nil { return nil, exception.New(err.Error(), 400) } token, err := ai.AI.Tiktoken(string(bytes)) if err != nil { return nil, exception.New(err.Error(), 400) } if token > ai.AI.MaxToken() { return nil, exception.New("token limit exceeded", 400) } // call the AI res, ex := ai.AI.ChatCompletions(messages, option, nil) if ex != nil { return nil, ex } resText, ex := ai.AI.GetContent(res) if ex != nil { return nil, ex } if ai.Process == "" { return resText, nil } var param interface{} = resText if ai.Optional.JSON { err = jsoniter.Unmarshal([]byte(resText), ¶m) if err != nil { return nil, exception.New("%s parse error: %s", 400, resText, err.Error()) } } p, err := process.Of(ai.Process, param) if err != nil { return nil, exception.New(err.Error(), 400) } resProcess, err := p.Exec() if err != nil { return nil, exception.New(err.Error(), 500) } return resProcess, nil } // NewAI create a new AI func (ai *DSL) newAI() (AI, error) { if ai.Connector == "" || strings.HasPrefix(ai.Connector, "moapi") { model := "gpt-3.5-turbo" if strings.HasPrefix(ai.Connector, "moapi:") { model = strings.TrimPrefix(ai.Connector, "moapi:") } mo, err := openai.NewMoapi(model) if err != nil { return nil, err } return mo, nil } conn, err := connector.Select(ai.Connector) if err != nil { return nil, err } if conn.Is(connector.OPENAI) { return openai.New(ai.Connector) } return nil, fmt.Errorf("%s connector %s not support, should be a openai", ai.ID, ai.Connector) } ================================================ FILE: aigc/aigc_test.go ================================================ package aigc import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func TestCall(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() prepare(t) aigc, err := Select("translate") if err != nil { t.Fatal(err) } content, ex := aigc.Call("你好哇", "", nil) if ex != nil { t.Fatal(ex.Message) } assert.Contains(t, content, "Hello") } func TestCallWithProcess(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() prepare(t) aigc, err := Select("draw") if err != nil { t.Fatal(err) } args, ex := aigc.Call("帮我画一只小白兔,要有白色的耳朵. 画布高度 256,宽度 256", "", nil) if ex != nil { t.Fatal(ex.Message) } data, ok := args.(map[string]interface{}) if !ok { t.Fatal("args is not map[string]interface{}") } assert.Equal(t, float64(256), data["height"]) assert.Equal(t, float64(256), data["width"]) } func prepare(t *testing.T) { err := Load(config.Conf) if err != nil { t.Fatal(err) } } ================================================ FILE: aigc/load.go ================================================ package aigc import ( "fmt" "strings" "github.com/yaoapp/gou/application" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/share" ) // Load load AIGC func Load(cfg config.Config) error { // Ignore if the aigcs directory does not exist exists, err := application.App.Exists("aigcs") if err != nil { return err } if !exists { return nil } exts := []string{"*.ai.yml", "*.ai.yaml"} messages := []string{} err = application.App.Walk("aigcs", func(root, file string, isdir bool) error { if isdir { return nil } id := aigcID(root, file) _, err := LoadFile(file, id) if err != nil { messages = append(messages, err.Error()) } return nil }, exts...) if err != nil { return err } if len(messages) > 0 { return fmt.Errorf("%s", strings.Join(messages, ";\n")) } return nil } // aigcID parses AIGC ID from file path // Special handling for .ai.yml and .ai.yaml extensions // e.g., "aigcs/translate.ai.yml" -> "translate" func aigcID(root, file string) string { id := share.ID(root, file) // Remove "_ai" suffix caused by .ai.yml/.ai.yaml extension // share.ID treats .yml/.yaml as single extension, so "translate.ai.yml" becomes "translate_ai" id = strings.TrimSuffix(id, "_ai") return id } // LoadFile load AIGC by file func LoadFile(file string, id string) (*DSL, error) { data, err := application.App.Read(file) if err != nil { return nil, err } return LoadSource(data, file, id) } // LoadSource load AIGC func LoadSource(data []byte, file, id string) (*DSL, error) { dsl := DSL{ ID: id, Optional: Optional{ Autopilot: false, JSON: false, }, } err := application.Parse(file, data, &dsl) if err != nil { return nil, err } if dsl.Prompts == nil || len(dsl.Prompts) == 0 { return nil, fmt.Errorf("%s prompts is required", id) } // create AI interface dsl.AI, err = dsl.newAI() if err != nil { return nil, err } // add to autopilots if dsl.Optional.Autopilot { Autopilots = append(Autopilots, id) } // add to AIGCs AIGCs[id] = &dsl return AIGCs[id], nil } ================================================ FILE: aigc/load_test.go ================================================ package aigc import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func TestLoad(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() Load(config.Conf) check(t) } func check(t *testing.T) { ids := map[string]bool{} for id := range AIGCs { ids[id] = true } assert.True(t, ids["translate"]) assert.True(t, ids["draw"]) assert.GreaterOrEqual(t, len(Autopilots), 2) } ================================================ FILE: aigc/process.go ================================================ package aigc import ( "github.com/yaoapp/gou/process" "github.com/yaoapp/kun/exception" ) func init() { process.Register("aigcs", processAigcs) } // processScripts func processAigcs(process *process.Process) interface{} { process.ValidateArgNums(1) aigc, err := Select(process.ID) if err != nil { exception.New("aigcs.%s not loaded", 404, process.ID).Throw() return nil } content := process.ArgsString(0) user := "" var option map[string]interface{} = nil if process.NumOfArgs() > 1 { user = process.ArgsString(1) } if process.NumOfArgs() > 2 { option = process.ArgsMap(2) } res, ex := aigc.Call(content, user, option) if ex != nil { ex.Throw() } return res } ================================================ FILE: aigc/process_test.go ================================================ package aigc import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/process" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func TestProcessAigcs(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() prepare(t) args := []interface{}{"你好"} res := process.New("aigcs.translate", args...).Run() assert.Contains(t, res, "Hello") } ================================================ FILE: aigc/types.go ================================================ package aigc import ( "context" "github.com/yaoapp/kun/exception" ) // DSL the connector DSL type DSL struct { ID string `json:"-" yaml:"-"` Name string `json:"name,omitempty"` Connector string `json:"connector,omitempty"` Process string `json:"process,omitempty"` Prompts []Prompt `json:"prompts"` Optional Optional `json:"optional,omitempty"` AI AI `json:"-" yaml:"-"` } // Prompt a prompt type Prompt struct { Role string `json:"role"` Content string `json:"content"` Name string `json:"name,omitempty"` } // Optional optional type Optional struct { Autopilot bool `json:"autopilot,omitempty"` JSON bool `json:"json,omitempty"` } // AI the AI interface type AI interface { ChatCompletions(messages []map[string]interface{}, option map[string]interface{}, cb func(data []byte) int) (interface{}, *exception.Exception) ChatCompletionsWith(ctx context.Context, messages []map[string]interface{}, option map[string]interface{}, cb func(data []byte) int) (interface{}, *exception.Exception) GetContent(response interface{}) (string, *exception.Exception) Embeddings(input interface{}, user string) (interface{}, *exception.Exception) Tiktoken(input string) (int, error) MaxToken() int } ================================================ FILE: api/README.md ================================================ # API ================================================ FILE: api/api.go ================================================ package api import ( "fmt" "strings" "github.com/yaoapp/gou/api" "github.com/yaoapp/gou/application" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/share" ) // Load apis func Load(cfg config.Config) error { messages := []string{} // Ignore if the apis directory does not exist exists, err := application.App.Exists("apis") if err != nil { return err } if !exists { return nil } exts := []string{"*.http.yao", "*.http.json", "*.http.jsonc"} err = application.App.Walk("apis", func(root, file string, isdir bool) error { if isdir { return nil } _, err := api.Load(file, share.ID(root, file)) if err != nil { messages = append(messages, err.Error()) } return err }, exts...) if len(messages) > 0 { return fmt.Errorf("%s", strings.Join(messages, ";\n")) } return err } ================================================ FILE: api/api_test.go ================================================ package api import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/api" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func TestLoad(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() Load(config.Conf) check(t) } func check(t *testing.T) { ids := map[string]bool{} for id := range api.APIs { ids[id] = true } assert.True(t, ids["user"]) assert.True(t, ids["user.pet"]) // assert.True(t, ids["xiang.import"]) // will be removed in the future // assert.True(t, ids["xiang.storage"]) // will be removed in the future // wskeys := []string{} // for key := range websocket.Upgraders { // wskeys = append(wskeys, key) // } // assert.Equal(t, 5, len(keys)) // assert.Equal(t, 1, len(wskeys)) } ================================================ FILE: assert/asserter.go ================================================ package assert import ( "fmt" "regexp" "strings" "github.com/yaoapp/gou/text" ) // Asserter handles assertions/validations type Asserter struct { agentValidator AgentValidator scriptRunner ScriptRunner } // New creates a new Asserter func New() *Asserter { return &Asserter{} } // WithAgentValidator sets the agent validator for agent-type assertions func (a *Asserter) WithAgentValidator(v AgentValidator) *Asserter { a.agentValidator = v return a } // WithScriptRunner sets the script runner for script-type assertions func (a *Asserter) WithScriptRunner(r ScriptRunner) *Asserter { a.scriptRunner = r return a } // Validate validates output against a list of assertions // Returns (passed, error message) func (a *Asserter) Validate(assertions []*Assertion, output interface{}) (bool, string) { if len(assertions) == 0 { return true, "" } var failures []string for _, assertion := range assertions { result := a.Evaluate(assertion, output, nil) if !result.Passed { msg := result.Message if assertion.Message != "" { msg = assertion.Message } failures = append(failures, msg) } } if len(failures) > 0 { return false, strings.Join(failures, "; ") } return true, "" } // ValidateWithDetails validates output and returns detailed results func (a *Asserter) ValidateWithDetails(assertions []*Assertion, output interface{}) *Result { if len(assertions) == 0 { return &Result{Passed: true} } if len(assertions) == 1 { return a.Evaluate(assertions[0], output, nil) } var failures []string for _, assertion := range assertions { result := a.Evaluate(assertion, output, nil) if !result.Passed { msg := result.Message if assertion.Message != "" { msg = assertion.Message } failures = append(failures, msg) } } if len(failures) > 0 { return &Result{ Passed: false, Message: strings.Join(failures, "; "), } } return &Result{Passed: true} } // Evaluate evaluates a single assertion func (a *Asserter) Evaluate(assertion *Assertion, output, input interface{}) *Result { result := &Result{ Assertion: assertion, Expected: assertion.Value, } switch assertion.Type { case "equals", "": result = a.assertEquals(assertion, output) case "contains": result = a.assertContains(assertion, output) case "not_contains": result = a.assertNotContains(assertion, output) case "json_path": result = a.assertJSONPath(assertion, output) case "regex": result = a.assertRegex(assertion, output) case "type": result = a.assertType(assertion, output) case "script": result = a.assertScript(assertion, output, input) case "agent": result = a.assertAgent(assertion, output, input) default: result.Passed = false result.Message = fmt.Sprintf("unknown assertion type: %s", assertion.Type) } // Apply negate if assertion.Negate { result.Passed = !result.Passed if result.Passed { result.Message = "negated assertion passed" } else { result.Message = "negated: " + result.Message } } return result } // assertEquals checks for exact equality func (a *Asserter) assertEquals(assertion *Assertion, output interface{}) *Result { result := &Result{ Assertion: assertion, Actual: output, Expected: assertion.Value, } if ValidateOutput(output, assertion.Value) { result.Passed = true result.Message = "values are equal" } else { result.Passed = false result.Message = fmt.Sprintf("expected %v, got %v", assertion.Value, output) } return result } // assertContains checks if output contains the expected value func (a *Asserter) assertContains(assertion *Assertion, output interface{}) *Result { result := &Result{ Assertion: assertion, Actual: output, Expected: assertion.Value, } outputStr := ToString(output) expectedStr := ToString(assertion.Value) if strings.Contains(outputStr, expectedStr) { result.Passed = true result.Message = fmt.Sprintf("output contains '%s'", expectedStr) } else { result.Passed = false result.Message = fmt.Sprintf("output does not contain '%s'", expectedStr) } return result } // assertNotContains checks if output does not contain the expected value func (a *Asserter) assertNotContains(assertion *Assertion, output interface{}) *Result { result := a.assertContains(assertion, output) result.Passed = !result.Passed if result.Passed { result.Message = fmt.Sprintf("output does not contain '%s'", ToString(assertion.Value)) } else { result.Message = fmt.Sprintf("output should not contain '%s'", ToString(assertion.Value)) } return result } // assertJSONPath extracts a value using JSON path and compares func (a *Asserter) assertJSONPath(assertion *Assertion, output interface{}) *Result { result := &Result{ Assertion: assertion, Expected: assertion.Value, } // Convert output to JSON if needed var jsonData interface{} switch v := output.(type) { case string: extracted := text.ExtractJSON(v) if extracted != nil { jsonData = extracted } else { result.Passed = false result.Message = fmt.Sprintf("output is not valid JSON: %s", TruncateOutput(v, 100)) return result } case map[string]interface{}, []interface{}: jsonData = v default: result.Passed = false result.Message = fmt.Sprintf("output is not a JSON object or array, got: %T", output) return result } // Extract value using path path := strings.TrimPrefix(assertion.Path, "$.") actual := ExtractPath(jsonData, path) result.Actual = actual // Compare if ValidateOutput(actual, assertion.Value) { result.Passed = true result.Message = fmt.Sprintf("path '%s' equals expected value", assertion.Path) return result } // IN semantics: if expected is array, check if actual matches any element if expectedArr, ok := assertion.Value.([]interface{}); ok { if _, actualIsArr := actual.([]interface{}); !actualIsArr { for _, expectedItem := range expectedArr { if ValidateOutput(actual, expectedItem) { result.Passed = true result.Message = fmt.Sprintf("path '%s' equals one of expected values", assertion.Path) return result } } } } result.Passed = false result.Message = fmt.Sprintf("path '%s': expected %v, got %v", assertion.Path, assertion.Value, actual) return result } // assertRegex checks if output matches a regex pattern func (a *Asserter) assertRegex(assertion *Assertion, output interface{}) *Result { result := &Result{ Assertion: assertion, Actual: output, Expected: assertion.Value, } pattern, ok := assertion.Value.(string) if !ok { result.Passed = false result.Message = "regex pattern must be a string" return result } re, err := regexp.Compile(pattern) if err != nil { result.Passed = false result.Message = fmt.Sprintf("invalid regex pattern: %s", err.Error()) return result } outputStr := ToString(output) if re.MatchString(outputStr) { result.Passed = true result.Message = fmt.Sprintf("output matches pattern '%s'", pattern) } else { result.Passed = false result.Message = fmt.Sprintf("output does not match pattern '%s'", pattern) } return result } // assertType checks the type of the output (or a nested field if path is specified) func (a *Asserter) assertType(assertion *Assertion, output interface{}) *Result { result := &Result{ Assertion: assertion, Expected: assertion.Value, } expectedType, ok := assertion.Value.(string) if !ok { result.Passed = false result.Message = "type assertion value must be a string" return result } // If path is specified, extract the value first var valueToCheck interface{} = output if assertion.Path != "" { // Convert output to JSON if needed var jsonData interface{} switch v := output.(type) { case string: extracted := text.ExtractJSON(v) if extracted != nil { jsonData = extracted } else { result.Passed = false result.Message = fmt.Sprintf("output is not valid JSON: %s", TruncateOutput(v, 100)) return result } case map[string]interface{}, []interface{}: jsonData = v default: result.Passed = false result.Message = fmt.Sprintf("output is not a JSON object or array, got: %T", output) return result } // Extract value using path path := strings.TrimPrefix(assertion.Path, "$.") valueToCheck = ExtractPath(jsonData, path) if valueToCheck == nil { result.Passed = false result.Actual = nil result.Message = fmt.Sprintf("path '%s' not found in output", assertion.Path) return result } } result.Actual = valueToCheck actualType := GetType(valueToCheck) if actualType == expectedType { result.Passed = true if assertion.Path != "" { result.Message = fmt.Sprintf("path '%s' is of type '%s'", assertion.Path, expectedType) } else { result.Message = fmt.Sprintf("output is of type '%s'", expectedType) } } else { result.Passed = false if assertion.Path != "" { result.Message = fmt.Sprintf("path '%s': expected type '%s', got '%s'", assertion.Path, expectedType, actualType) } else { result.Message = fmt.Sprintf("expected type '%s', got '%s'", expectedType, actualType) } } return result } // assertScript runs a custom assertion script func (a *Asserter) assertScript(assertion *Assertion, output, input interface{}) *Result { result := &Result{ Assertion: assertion, Actual: output, } if a.scriptRunner == nil { result.Passed = false result.Message = "script assertions require a ScriptRunner to be configured" return result } scriptName := assertion.Script if scriptName == "" { result.Passed = false result.Message = "script assertion requires a script name" return result } passed, message, err := a.scriptRunner.Run(scriptName, output, input, assertion.Value) if err != nil { result.Passed = false result.Message = fmt.Sprintf("script execution failed: %s", err.Error()) return result } result.Passed = passed result.Message = message return result } // assertAgent uses an agent to validate the output func (a *Asserter) assertAgent(assertion *Assertion, output, input interface{}) *Result { result := &Result{ Assertion: assertion, Actual: output, } if a.agentValidator == nil { result.Passed = false result.Message = "agent assertions require an AgentValidator to be configured" return result } // Parse use field: "agents:validator" if !strings.HasPrefix(assertion.Use, "agents:") { result.Passed = false result.Message = "agent assertion requires 'use' field with 'agents:' prefix" return result } agentID := strings.TrimPrefix(assertion.Use, "agents:") return a.agentValidator.Validate(agentID, output, input, assertion.Value, assertion.Options) } // ParseAssertions parses assertion definitions into Assertion objects func ParseAssertions(input interface{}) []*Assertion { if input == nil { return nil } var assertions []*Assertion switch v := input.(type) { case map[string]interface{}: assertion := mapToAssertion(v) if assertion != nil { assertions = append(assertions, assertion) } case []interface{}: for _, item := range v { if m, ok := item.(map[string]interface{}); ok { assertion := mapToAssertion(m) if assertion != nil { assertions = append(assertions, assertion) } } } case string: assertions = append(assertions, &Assertion{Type: v}) } return assertions } // mapToAssertion converts a map to an Assertion func mapToAssertion(m map[string]interface{}) *Assertion { assertion := &Assertion{} if t, ok := m["type"].(string); ok { assertion.Type = t } if v, ok := m["value"]; ok { assertion.Value = v } if p, ok := m["path"].(string); ok { assertion.Path = p } if s, ok := m["script"].(string); ok { assertion.Script = s } if u, ok := m["use"].(string); ok { assertion.Use = u } if msg, ok := m["message"].(string); ok { assertion.Message = msg } if n, ok := m["negate"].(bool); ok { assertion.Negate = n } if opts, ok := m["options"].(map[string]interface{}); ok { assertion.Options = &AssertionOptions{} if c, ok := opts["connector"].(string); ok { assertion.Options.Connector = c } if meta, ok := opts["metadata"].(map[string]interface{}); ok { assertion.Options.Metadata = meta } } return assertion } ================================================ FILE: assert/asserter_test.go ================================================ package assert import ( "errors" "testing" ) func TestAsserterEquals(t *testing.T) { a := New() tests := []struct { name string value interface{} output interface{} expected bool }{ {"string match", "hello", "hello", true}, {"string mismatch", "hello", "world", false}, {"number match", 42, 42, true}, {"number mismatch", 42, 43, false}, {"bool match", true, true, true}, {"bool mismatch", true, false, false}, {"map match", map[string]interface{}{"a": 1}, map[string]interface{}{"a": 1}, true}, {"map mismatch", map[string]interface{}{"a": 1}, map[string]interface{}{"a": 2}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assertion := &Assertion{ Type: "equals", Value: tt.value, } result := a.Evaluate(assertion, tt.output, nil) if result.Passed != tt.expected { t.Errorf("expected passed=%v, got passed=%v", tt.expected, result.Passed) } }) } } func TestAsserterContains(t *testing.T) { a := New() tests := []struct { name string value string output string expected bool }{ {"contains substring", "world", "hello world", true}, {"does not contain", "foo", "hello world", false}, {"exact match", "hello", "hello", true}, {"empty string", "", "hello", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assertion := &Assertion{ Type: "contains", Value: tt.value, } result := a.Evaluate(assertion, tt.output, nil) if result.Passed != tt.expected { t.Errorf("expected passed=%v, got passed=%v", tt.expected, result.Passed) } }) } } func TestAsserterNotContains(t *testing.T) { a := New() tests := []struct { name string value string output string expected bool }{ {"does not contain", "foo", "hello world", true}, {"contains substring", "world", "hello world", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assertion := &Assertion{ Type: "not_contains", Value: tt.value, } result := a.Evaluate(assertion, tt.output, nil) if result.Passed != tt.expected { t.Errorf("expected passed=%v, got passed=%v", tt.expected, result.Passed) } }) } } func TestAsserterJSONPath(t *testing.T) { a := New() output := map[string]interface{}{ "name": "test", "count": 42, "nested": map[string]interface{}{ "value": "deep", }, "items": []interface{}{"a", "b", "c"}, } tests := []struct { name string path string value interface{} expected bool }{ {"simple field", "name", "test", true}, {"number field", "count", float64(42), true}, {"nested field", "nested.value", "deep", true}, {"array index", "items[0]", "a", true}, {"array index 2", "items[2]", "c", true}, {"wrong value", "name", "wrong", false}, {"non-existent path", "missing", nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assertion := &Assertion{ Type: "json_path", Path: tt.path, Value: tt.value, } result := a.Evaluate(assertion, output, nil) if result.Passed != tt.expected { t.Errorf("expected passed=%v, got passed=%v, message=%s", tt.expected, result.Passed, result.Message) } }) } } func TestAsserterRegex(t *testing.T) { a := New() tests := []struct { name string pattern string output string expected bool }{ {"simple match", "hello", "hello world", true}, {"regex pattern", "^\\d+$", "12345", true}, {"regex no match", "^\\d+$", "abc", false}, {"email pattern", `\w+@\w+\.\w+`, "test@example.com", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assertion := &Assertion{ Type: "regex", Value: tt.pattern, } result := a.Evaluate(assertion, tt.output, nil) if result.Passed != tt.expected { t.Errorf("expected passed=%v, got passed=%v", tt.expected, result.Passed) } }) } } func TestAsserterType(t *testing.T) { a := New() tests := []struct { name string expectedType string output interface{} expected bool }{ {"string type", "string", "hello", true}, {"number type", "number", 42, true}, {"number type float", "number", 3.14, true}, {"boolean type", "boolean", true, true}, {"array type", "array", []interface{}{1, 2, 3}, true}, {"object type", "object", map[string]interface{}{"a": 1}, true}, {"null type", "null", nil, true}, {"wrong type", "string", 42, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assertion := &Assertion{ Type: "type", Value: tt.expectedType, } result := a.Evaluate(assertion, tt.output, nil) if result.Passed != tt.expected { t.Errorf("expected passed=%v, got passed=%v", tt.expected, result.Passed) } }) } } func TestAsserterTypeWithPath(t *testing.T) { a := New() // Test data with nested structure output := map[string]interface{}{ "name": "test", "count": float64(42), "items": []interface{}{"a", "b", "c"}, "enabled": true, "nested": map[string]interface{}{ "value": "nested_value", }, } tests := []struct { name string path string expectedType string expected bool }{ {"string field", "name", "string", true}, {"number field", "count", "number", true}, {"array field", "items", "array", true}, {"boolean field", "enabled", "boolean", true}, {"object field", "nested", "object", true}, {"nested string field", "nested.value", "string", true}, {"wrong type for field", "name", "number", false}, {"non-existent path", "missing", "string", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assertion := &Assertion{ Type: "type", Path: tt.path, Value: tt.expectedType, } result := a.Evaluate(assertion, output, nil) if result.Passed != tt.expected { t.Errorf("expected passed=%v, got passed=%v, message=%s", tt.expected, result.Passed, result.Message) } }) } } func TestAsserterNegate(t *testing.T) { a := New() // Test negation assertion := &Assertion{ Type: "equals", Value: "hello", Negate: true, } // Should fail because "hello" == "hello", but negate inverts it result := a.Evaluate(assertion, "hello", nil) if result.Passed { t.Error("negated equals should fail when values match") } // Should pass because "hello" != "world", and negate inverts it result = a.Evaluate(assertion, "world", nil) if !result.Passed { t.Error("negated equals should pass when values don't match") } } func TestAsserterValidate(t *testing.T) { a := New() assertions := []*Assertion{ {Type: "type", Value: "object"}, {Type: "json_path", Path: "name", Value: "test"}, {Type: "json_path", Path: "count", Value: float64(42)}, } output := map[string]interface{}{ "name": "test", "count": 42, } passed, message := a.Validate(assertions, output) if !passed { t.Errorf("validation should pass, got message: %s", message) } // Test with failing assertion assertions = append(assertions, &Assertion{ Type: "json_path", Path: "name", Value: "wrong", }) passed, message = a.Validate(assertions, output) if passed { t.Error("validation should fail with wrong value") } } func TestParseAssertions(t *testing.T) { // Test map input input := map[string]interface{}{ "type": "contains", "value": "hello", } assertions := ParseAssertions(input) if len(assertions) != 1 { t.Errorf("expected 1 assertion, got %d", len(assertions)) } if assertions[0].Type != "contains" { t.Errorf("expected type 'contains', got '%s'", assertions[0].Type) } // Test array input input2 := []interface{}{ map[string]interface{}{"type": "equals", "value": 1}, map[string]interface{}{"type": "contains", "value": "test"}, } assertions = ParseAssertions(input2) if len(assertions) != 2 { t.Errorf("expected 2 assertions, got %d", len(assertions)) } // Test string input assertions = ParseAssertions("contains") if len(assertions) != 1 { t.Errorf("expected 1 assertion, got %d", len(assertions)) } if assertions[0].Type != "contains" { t.Errorf("expected type 'contains', got '%s'", assertions[0].Type) } } func TestExtractPath(t *testing.T) { data := map[string]interface{}{ "name": "test", "nested": map[string]interface{}{ "value": "deep", }, "items": []interface{}{ map[string]interface{}{"id": 1}, map[string]interface{}{"id": 2}, }, } tests := []struct { path string expected interface{} }{ {"name", "test"}, {"nested.value", "deep"}, {"items[0].id", 1}, {"items[1].id", 2}, {"missing", nil}, {"nested.missing", nil}, } for _, tt := range tests { t.Run(tt.path, func(t *testing.T) { result := ExtractPath(data, tt.path) if !ValidateOutput(result, tt.expected) { t.Errorf("path '%s': expected %v, got %v", tt.path, tt.expected, result) } }) } } // ============================================================================ // Additional tests for improved coverage // ============================================================================ // Mock implementations for testing type mockScriptRunner struct { passed bool message string err error } func (m *mockScriptRunner) Run(scriptName string, output, input, expected interface{}) (bool, string, error) { return m.passed, m.message, m.err } type mockAgentValidator struct { result *Result } func (m *mockAgentValidator) Validate(agentID string, output, input, criteria interface{}, options *AssertionOptions) *Result { return m.result } // Test WithAgentValidator and WithScriptRunner func TestAsserterConfiguration(t *testing.T) { a := New() // Test chaining mockAgent := &mockAgentValidator{} mockScript := &mockScriptRunner{} result := a.WithAgentValidator(mockAgent).WithScriptRunner(mockScript) if result != a { t.Error("WithAgentValidator should return the same asserter for chaining") } if a.agentValidator != mockAgent { t.Error("agentValidator should be set") } if a.scriptRunner != mockScript { t.Error("scriptRunner should be set") } } // Test ValidateWithDetails func TestAsserterValidateWithDetails(t *testing.T) { a := New() t.Run("empty assertions", func(t *testing.T) { result := a.ValidateWithDetails([]*Assertion{}, "output") if !result.Passed { t.Error("empty assertions should pass") } }) t.Run("single assertion pass", func(t *testing.T) { result := a.ValidateWithDetails([]*Assertion{ {Type: "equals", Value: "hello"}, }, "hello") if !result.Passed { t.Error("single matching assertion should pass") } }) t.Run("single assertion fail", func(t *testing.T) { result := a.ValidateWithDetails([]*Assertion{ {Type: "equals", Value: "hello"}, }, "world") if result.Passed { t.Error("single non-matching assertion should fail") } }) t.Run("multiple assertions with custom message", func(t *testing.T) { result := a.ValidateWithDetails([]*Assertion{ {Type: "equals", Value: "hello"}, {Type: "contains", Value: "world", Message: "custom failure message"}, }, "hello") if result.Passed { t.Error("should fail when one assertion fails") } if result.Message != "custom failure message" { t.Errorf("should use custom message, got: %s", result.Message) } }) t.Run("multiple assertions all pass", func(t *testing.T) { result := a.ValidateWithDetails([]*Assertion{ {Type: "contains", Value: "hello"}, {Type: "contains", Value: "world"}, }, "hello world") if !result.Passed { t.Error("all matching assertions should pass") } }) } // Test unknown assertion type func TestAsserterUnknownType(t *testing.T) { a := New() assertion := &Assertion{ Type: "unknown_type", Value: "test", } result := a.Evaluate(assertion, "test", nil) if result.Passed { t.Error("unknown assertion type should fail") } if result.Message != "unknown assertion type: unknown_type" { t.Errorf("unexpected message: %s", result.Message) } } // Test default type (empty string = equals) func TestAsserterDefaultType(t *testing.T) { a := New() assertion := &Assertion{ Type: "", // empty = equals Value: "hello", } result := a.Evaluate(assertion, "hello", nil) if !result.Passed { t.Error("empty type should default to equals") } } // Test assertScript func TestAsserterScript(t *testing.T) { t.Run("no script runner configured", func(t *testing.T) { a := New() assertion := &Assertion{ Type: "script", Script: "test.script", } result := a.Evaluate(assertion, "output", nil) if result.Passed { t.Error("should fail without script runner") } if result.Message != "script assertions require a ScriptRunner to be configured" { t.Errorf("unexpected message: %s", result.Message) } }) t.Run("empty script name", func(t *testing.T) { a := New().WithScriptRunner(&mockScriptRunner{}) assertion := &Assertion{ Type: "script", Script: "", } result := a.Evaluate(assertion, "output", nil) if result.Passed { t.Error("should fail with empty script name") } if result.Message != "script assertion requires a script name" { t.Errorf("unexpected message: %s", result.Message) } }) t.Run("script execution error", func(t *testing.T) { a := New().WithScriptRunner(&mockScriptRunner{ err: errors.New("execution failed"), }) assertion := &Assertion{ Type: "script", Script: "test.script", } result := a.Evaluate(assertion, "output", nil) if result.Passed { t.Error("should fail on script error") } if result.Message != "script execution failed: execution failed" { t.Errorf("unexpected message: %s", result.Message) } }) t.Run("script passes", func(t *testing.T) { a := New().WithScriptRunner(&mockScriptRunner{ passed: true, message: "script passed", }) assertion := &Assertion{ Type: "script", Script: "test.script", } result := a.Evaluate(assertion, "output", nil) if !result.Passed { t.Error("should pass when script passes") } if result.Message != "script passed" { t.Errorf("unexpected message: %s", result.Message) } }) t.Run("script fails", func(t *testing.T) { a := New().WithScriptRunner(&mockScriptRunner{ passed: false, message: "validation failed", }) assertion := &Assertion{ Type: "script", Script: "test.script", } result := a.Evaluate(assertion, "output", nil) if result.Passed { t.Error("should fail when script fails") } }) } // Test assertAgent func TestAsserterAgent(t *testing.T) { t.Run("no agent validator configured", func(t *testing.T) { a := New() assertion := &Assertion{ Type: "agent", Use: "agents:validator", } result := a.Evaluate(assertion, "output", nil) if result.Passed { t.Error("should fail without agent validator") } if result.Message != "agent assertions require an AgentValidator to be configured" { t.Errorf("unexpected message: %s", result.Message) } }) t.Run("invalid use field format", func(t *testing.T) { a := New().WithAgentValidator(&mockAgentValidator{}) assertion := &Assertion{ Type: "agent", Use: "invalid_format", } result := a.Evaluate(assertion, "output", nil) if result.Passed { t.Error("should fail with invalid use format") } if result.Message != "agent assertion requires 'use' field with 'agents:' prefix" { t.Errorf("unexpected message: %s", result.Message) } }) t.Run("agent validation passes", func(t *testing.T) { a := New().WithAgentValidator(&mockAgentValidator{ result: &Result{Passed: true, Message: "agent validated"}, }) assertion := &Assertion{ Type: "agent", Use: "agents:validator", } result := a.Evaluate(assertion, "output", nil) if !result.Passed { t.Error("should pass when agent validates") } }) t.Run("agent validation fails", func(t *testing.T) { a := New().WithAgentValidator(&mockAgentValidator{ result: &Result{Passed: false, Message: "agent rejected"}, }) assertion := &Assertion{ Type: "agent", Use: "agents:validator", } result := a.Evaluate(assertion, "output", nil) if result.Passed { t.Error("should fail when agent rejects") } }) } // Test assertJSONPath edge cases func TestAsserterJSONPathEdgeCases(t *testing.T) { a := New() t.Run("string output with valid JSON", func(t *testing.T) { assertion := &Assertion{ Type: "json_path", Path: "name", Value: "test", } result := a.Evaluate(assertion, `{"name": "test"}`, nil) if !result.Passed { t.Errorf("should pass with valid JSON string, message: %s", result.Message) } }) t.Run("string output with invalid JSON", func(t *testing.T) { assertion := &Assertion{ Type: "json_path", Path: "name", Value: "test", } result := a.Evaluate(assertion, "not json", nil) if result.Passed { t.Error("should fail with invalid JSON string") } }) t.Run("non-JSON output type", func(t *testing.T) { assertion := &Assertion{ Type: "json_path", Path: "name", Value: "test", } result := a.Evaluate(assertion, 12345, nil) if result.Passed { t.Error("should fail with non-JSON type") } }) t.Run("IN semantics with array expected", func(t *testing.T) { assertion := &Assertion{ Type: "json_path", Path: "status", Value: []interface{}{"active", "pending", "completed"}, } output := map[string]interface{}{"status": "pending"} result := a.Evaluate(assertion, output, nil) if !result.Passed { t.Errorf("should pass with IN semantics, message: %s", result.Message) } }) t.Run("IN semantics no match", func(t *testing.T) { assertion := &Assertion{ Type: "json_path", Path: "status", Value: []interface{}{"active", "completed"}, } output := map[string]interface{}{"status": "pending"} result := a.Evaluate(assertion, output, nil) if result.Passed { t.Error("should fail when value not in expected array") } }) t.Run("path with $. prefix", func(t *testing.T) { assertion := &Assertion{ Type: "json_path", Path: "$.name", Value: "test", } output := map[string]interface{}{"name": "test"} result := a.Evaluate(assertion, output, nil) if !result.Passed { t.Errorf("should handle $. prefix, message: %s", result.Message) } }) t.Run("array output", func(t *testing.T) { assertion := &Assertion{ Type: "json_path", Path: "[0]", Value: "first", } output := []interface{}{"first", "second"} result := a.Evaluate(assertion, output, nil) if !result.Passed { t.Errorf("should work with array output, message: %s", result.Message) } }) } // Test assertRegex edge cases func TestAsserterRegexEdgeCases(t *testing.T) { a := New() t.Run("non-string pattern", func(t *testing.T) { assertion := &Assertion{ Type: "regex", Value: 12345, // not a string } result := a.Evaluate(assertion, "test", nil) if result.Passed { t.Error("should fail with non-string pattern") } if result.Message != "regex pattern must be a string" { t.Errorf("unexpected message: %s", result.Message) } }) t.Run("invalid regex pattern", func(t *testing.T) { assertion := &Assertion{ Type: "regex", Value: "[invalid", } result := a.Evaluate(assertion, "test", nil) if result.Passed { t.Error("should fail with invalid regex") } }) } // Test assertType edge cases func TestAsserterTypeEdgeCases(t *testing.T) { a := New() t.Run("non-string type value", func(t *testing.T) { assertion := &Assertion{ Type: "type", Value: 12345, // not a string } result := a.Evaluate(assertion, "test", nil) if result.Passed { t.Error("should fail with non-string type value") } if result.Message != "type assertion value must be a string" { t.Errorf("unexpected message: %s", result.Message) } }) t.Run("type with path from JSON string", func(t *testing.T) { assertion := &Assertion{ Type: "type", Path: "items", Value: "array", } result := a.Evaluate(assertion, `{"items": [1, 2, 3]}`, nil) if !result.Passed { t.Errorf("should pass with JSON string input, message: %s", result.Message) } }) t.Run("type with path from invalid JSON string", func(t *testing.T) { assertion := &Assertion{ Type: "type", Path: "items", Value: "array", } result := a.Evaluate(assertion, "not json", nil) if result.Passed { t.Error("should fail with invalid JSON string") } }) t.Run("type with path from non-JSON type", func(t *testing.T) { assertion := &Assertion{ Type: "type", Path: "items", Value: "array", } result := a.Evaluate(assertion, 12345, nil) if result.Passed { t.Error("should fail with non-JSON type") } }) t.Run("type with path from array", func(t *testing.T) { assertion := &Assertion{ Type: "type", Path: "[0]", Value: "string", } result := a.Evaluate(assertion, []interface{}{"hello"}, nil) if !result.Passed { t.Errorf("should work with array, message: %s", result.Message) } }) } // Test Validate with custom message func TestAsserterValidateWithCustomMessage(t *testing.T) { a := New() assertions := []*Assertion{ {Type: "equals", Value: "expected", Message: "custom failure"}, } passed, message := a.Validate(assertions, "actual") if passed { t.Error("should fail") } if message != "custom failure" { t.Errorf("should use custom message, got: %s", message) } } // Test ParseAssertions edge cases func TestParseAssertionsEdgeCases(t *testing.T) { t.Run("nil input", func(t *testing.T) { result := ParseAssertions(nil) if result != nil { t.Error("nil input should return nil") } }) t.Run("array with non-map items", func(t *testing.T) { input := []interface{}{ "string item", map[string]interface{}{"type": "equals"}, } result := ParseAssertions(input) if len(result) != 1 { t.Errorf("should only parse map items, got %d", len(result)) } }) t.Run("map with all fields", func(t *testing.T) { input := map[string]interface{}{ "type": "agent", "value": "criteria", "path": "$.field", "script": "test.script", "use": "agents:validator", "message": "custom message", "negate": true, "options": map[string]interface{}{ "connector": "openai", "metadata": map[string]interface{}{"key": "value"}, }, } result := ParseAssertions(input) if len(result) != 1 { t.Fatalf("expected 1 assertion, got %d", len(result)) } a := result[0] if a.Type != "agent" { t.Errorf("type mismatch: %s", a.Type) } if a.Path != "$.field" { t.Errorf("path mismatch: %s", a.Path) } if a.Script != "test.script" { t.Errorf("script mismatch: %s", a.Script) } if a.Use != "agents:validator" { t.Errorf("use mismatch: %s", a.Use) } if a.Message != "custom message" { t.Errorf("message mismatch: %s", a.Message) } if !a.Negate { t.Error("negate should be true") } if a.Options == nil { t.Fatal("options should not be nil") } if a.Options.Connector != "openai" { t.Errorf("connector mismatch: %s", a.Options.Connector) } if a.Options.Metadata["key"] != "value" { t.Error("metadata mismatch") } }) } // Test helper functions func TestToString(t *testing.T) { tests := []struct { name string input interface{} expected string }{ {"nil", nil, ""}, {"string", "hello", "hello"}, {"bytes", []byte("hello"), "hello"}, {"number", 42, "42"}, {"map", map[string]interface{}{"a": 1}, `{"a":1}`}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := ToString(tt.input) if result != tt.expected { t.Errorf("expected %s, got %s", tt.expected, result) } }) } } func TestGetType(t *testing.T) { tests := []struct { name string input interface{} expected string }{ {"nil", nil, "null"}, {"string", "hello", "string"}, {"float64", float64(3.14), "number"}, {"float32", float32(3.14), "number"}, {"int", 42, "number"}, {"int64", int64(42), "number"}, {"int32", int32(42), "number"}, {"bool", true, "boolean"}, {"array", []interface{}{1, 2}, "array"}, {"object", map[string]interface{}{"a": 1}, "object"}, {"other", struct{}{}, "struct {}"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := GetType(tt.input) if result != tt.expected { t.Errorf("expected %s, got %s", tt.expected, result) } }) } } func TestTruncateOutput(t *testing.T) { tests := []struct { name string input interface{} maxLen int expected string }{ {"nil", nil, 10, ""}, {"short string", "hello", 10, "hello"}, {"long string", "hello world", 5, "hello..."}, {"object", map[string]interface{}{"a": 1}, 100, `{"a":1}`}, {"long object", map[string]interface{}{"key": "value"}, 5, `{"key...`}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := TruncateOutput(tt.input, tt.maxLen) if result != tt.expected { t.Errorf("expected %s, got %s", tt.expected, result) } }) } } func TestExtractJSON(t *testing.T) { // Test basic JSON extraction result := ExtractJSON(`{"name": "test"}`) if result == nil { t.Error("should extract JSON") } if m, ok := result.(map[string]interface{}); ok { if m["name"] != "test" { t.Error("should extract correct value") } } else { t.Error("should return map") } } func TestExtractPathEdgeCases(t *testing.T) { t.Run("invalid array index", func(t *testing.T) { data := map[string]interface{}{ "items": []interface{}{"a", "b"}, } result := ExtractPath(data, "items[abc]") if result != nil { t.Error("invalid index should return nil") } }) t.Run("array index on non-array", func(t *testing.T) { data := map[string]interface{}{ "name": "test", } result := ExtractPath(data, "name[0]") if result != nil { t.Error("array index on non-array should return nil") } }) t.Run("negative array index", func(t *testing.T) { data := map[string]interface{}{ "items": []interface{}{"a", "b"}, } result := ExtractPath(data, "items[-1]") if result != nil { t.Error("negative index should return nil") } }) t.Run("out of bounds array index", func(t *testing.T) { data := map[string]interface{}{ "items": []interface{}{"a", "b"}, } result := ExtractPath(data, "items[99]") if result != nil { t.Error("out of bounds index should return nil") } }) t.Run("field access on non-map", func(t *testing.T) { data := map[string]interface{}{ "name": "test", } result := ExtractPath(data, "name.field") if result != nil { t.Error("field access on non-map should return nil") } }) t.Run("empty path segment", func(t *testing.T) { data := map[string]interface{}{ "name": "test", } result := ExtractPath(data, ".name") if result != "test" { t.Errorf("should handle leading dot, got: %v", result) } }) } func TestValidateOutputEdgeCases(t *testing.T) { // Test with unmarshalable types (channels, functions) ch := make(chan int) result := ValidateOutput(ch, ch) if result { t.Error("unmarshalable types should return false") } } ================================================ FILE: assert/helpers.go ================================================ package assert import ( "encoding/json" "fmt" "strconv" "strings" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/text" ) // ValidateOutput compares two values for equality using JSON serialization func ValidateOutput(actual, expected interface{}) bool { actualJSON, err1 := jsoniter.Marshal(actual) expectedJSON, err2 := jsoniter.Marshal(expected) if err1 != nil || err2 != nil { return false } return string(actualJSON) == string(expectedJSON) } // ToString converts a value to string for comparison func ToString(v interface{}) string { if v == nil { return "" } switch val := v.(type) { case string: return val case []byte: return string(val) default: b, err := json.Marshal(v) if err != nil { return fmt.Sprintf("%v", v) } return string(b) } } // GetType returns the type name of a value func GetType(v interface{}) string { if v == nil { return "null" } switch v.(type) { case string: return "string" case float64, float32, int, int64, int32: return "number" case bool: return "boolean" case []interface{}: return "array" case map[string]interface{}: return "object" default: return fmt.Sprintf("%T", v) } } // ExtractPath extracts a value from JSON using dot-notation path with array index support // Supports: "field", "field.nested", "field[0]", "field[0].nested", "field.nested[0].value" func ExtractPath(data interface{}, path string) interface{} { current := data segments := ParsePathSegments(path) for _, segment := range segments { if segment == "" { continue } // Check if this is an array index like "[0]" if strings.HasPrefix(segment, "[") && strings.HasSuffix(segment, "]") { indexStr := segment[1 : len(segment)-1] index, err := strconv.Atoi(indexStr) if err != nil { return nil } arr, ok := current.([]interface{}) if !ok { return nil } if index < 0 || index >= len(arr) { return nil } current = arr[index] } else { // Regular field access switch v := current.(type) { case map[string]interface{}: current = v[segment] default: return nil } } } return current } // ParsePathSegments splits a path like "wheres[0].like" into ["wheres", "[0]", "like"] func ParsePathSegments(path string) []string { var segments []string var current strings.Builder for i := 0; i < len(path); i++ { ch := path[i] switch ch { case '.': if current.Len() > 0 { segments = append(segments, current.String()) current.Reset() } case '[': if current.Len() > 0 { segments = append(segments, current.String()) current.Reset() } j := i + 1 for j < len(path) && path[j] != ']' { j++ } if j < len(path) { segments = append(segments, path[i:j+1]) i = j } default: current.WriteByte(ch) } } if current.Len() > 0 { segments = append(segments, current.String()) } return segments } // TruncateOutput truncates output for error messages func TruncateOutput(output interface{}, maxLen int) string { var s string switch v := output.(type) { case string: s = v case nil: return "" default: bytes, err := jsoniter.Marshal(v) if err != nil { s = fmt.Sprintf("%v", v) } else { s = string(bytes) } } if len(s) > maxLen { return s[:maxLen] + "..." } return s } // ExtractJSON extracts JSON from text (handles markdown code blocks, etc.) func ExtractJSON(content string) interface{} { return text.ExtractJSON(content) } ================================================ FILE: assert/types.go ================================================ // Package assert provides a universal assertion/validation library for Yao. // It can be used by agent/robot, flow, pipe, widget, and other modules. // // Design: // - Independent implementation (no dependency on agent/test) // - Supports both rule-based and semantic validation // - Extensible through interfaces (AgentValidator, ScriptRunner) package assert // Assertion represents a single assertion rule type Assertion struct { // Type is the assertion type: // - "equals": exact match (default if expected is set) // - "contains": output contains the expected string/value // - "not_contains": output does not contain the string/value // - "json_path": extract value using JSON path and compare // - "regex": match output against regex pattern // - "type": check output type (string, object, array, number, boolean) // - "script": run a custom assertion script (requires ScriptRunner) // - "agent": use an agent to validate (requires AgentValidator) Type string `json:"type"` // Value is the expected value or pattern (depends on type) Value interface{} `json:"value,omitempty"` // Path is the JSON path for json_path assertions (e.g., "$.count", "items[0].name") Path string `json:"path,omitempty"` // Script is the script/process name for script assertions Script string `json:"script,omitempty"` // Use specifies the agent for validation (e.g., "agents:validator") Use string `json:"use,omitempty"` // Options for agent-driven assertions Options *AssertionOptions `json:"options,omitempty"` // Message is a custom failure message Message string `json:"message,omitempty"` // Negate inverts the assertion result Negate bool `json:"negate,omitempty"` } // AssertionOptions for agent-driven assertions type AssertionOptions struct { // Connector overrides the agent's default connector Connector string `json:"connector,omitempty"` // Metadata contains custom data passed to the validator Metadata map[string]interface{} `json:"metadata,omitempty"` } // Result represents the result of an assertion type Result struct { // Passed indicates whether the assertion passed Passed bool `json:"passed"` // Message describes the assertion result Message string `json:"message,omitempty"` // Assertion is the original assertion that was evaluated Assertion *Assertion `json:"assertion,omitempty"` // Actual is the actual value that was compared Actual interface{} `json:"actual,omitempty"` // Expected is the expected value Expected interface{} `json:"expected,omitempty"` } // AgentValidator is an interface for agent-based validation // Implementations should call an AI agent to perform semantic validation type AgentValidator interface { // Validate validates output using an agent // agentID: the agent identifier (e.g., "validator") // output: the output to validate // input: the original input (for context) // criteria: validation criteria from assertion.Value // options: assertion options Validate(agentID string, output, input, criteria interface{}, options *AssertionOptions) *Result } // ScriptRunner is an interface for running assertion scripts // Implementations should call a Yao process to perform validation type ScriptRunner interface { // Run runs an assertion script // scriptName: the script/process name // output: the output to validate // input: the original input // expected: the expected value from assertion.Value // Returns (passed, message, error) Run(scriptName string, output, input, expected interface{}) (bool, string, error) } ================================================ FILE: attachment/README.md ================================================ # Attachment Package A comprehensive file upload package for Go that supports chunked uploads, file format validation, compression, and multiple storage backends. ## Features - **Multiple Storage Backends**: Local filesystem and S3-compatible storage - **Chunked Upload Support**: Handle large files with standard HTTP Content-Range headers - **File Deduplication**: Content-based fingerprinting to avoid duplicate uploads - **File Compression**: - Gzip compression for any file type - Image compression with configurable size limits - **File Validation**: - File size limits - MIME type and extension validation - Wildcard pattern support (e.g., `image/*`, `text/*`) - **Flexible File Organization**: Hierarchical storage with multi-level group organization - **Multiple Read Methods**: Stream, bytes, and base64 encoding - **Global Manager Registry**: Support for registering and accessing managers globally - **Upload Status Tracking**: Track upload progress with status field - **Content Synchronization**: Support for synchronized uploads with Content-Sync header ## Installation ```bash go get github.com/yaoapp/yao/neo/attachment ``` ## Quick Start ### Basic Usage ```go package main import ( "context" "strings" "mime/multipart" "github.com/yaoapp/yao/neo/attachment" ) func main() { // Create a manager with default settings manager, err := attachment.RegisterDefault("uploads") if err != nil { panic(err) } // Or create a custom manager customManager, err := attachment.New(attachment.ManagerOption{ Driver: "local", MaxSize: "20M", ChunkSize: "2M", AllowedTypes: []string{"text/*", "image/*", ".pdf"}, Options: map[string]interface{}{ "path": "/var/uploads", }, }) if err != nil { panic(err) } // Upload a file content := "Hello, World!" fileHeader := &attachment.FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "hello.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := attachment.UploadOption{ Groups: []string{"user123", "chat456"}, // Multi-level groups (e.g., user, chat, knowledge, etc.) OriginalFilename: "my_document.txt", // Preserve original filename } file, err := manager.Upload(context.Background(), fileHeader, strings.NewReader(content), option) if err != nil { panic(err) } // Check upload status if file.Status == "uploaded" { fmt.Printf("File uploaded successfully: %s\n", file.ID) } // Read the file back data, err := manager.Read(context.Background(), file.ID) if err != nil { panic(err) } println(string(data)) // Output: Hello, World! } ``` ### Storage Backends #### Local Storage ```go manager, err := attachment.New(attachment.ManagerOption{ Driver: "local", MaxSize: "20M", Options: map[string]interface{}{ "path": "/var/uploads", "base_url": "https://example.com/files", }, }) ``` #### S3 Storage ```go manager, err := attachment.New(attachment.ManagerOption{ Driver: "s3", MaxSize: "100M", Options: map[string]interface{}{ "endpoint": "https://s3.amazonaws.com", "region": "us-east-1", "key": "your-access-key", "secret": "your-secret-key", "bucket": "your-bucket-name", "prefix": "attachments/", }, }) ``` ### Chunked Upload For large files, you can upload in chunks using standard HTTP Content-Range headers: ```go // Upload chunks totalSize := int64(1024000) // 1MB file chunkSize := int64(1024) // 1KB chunks uid := "unique-file-id-123" for start := int64(0); start < totalSize; start += chunkSize { end := start + chunkSize - 1 if end >= totalSize { end = totalSize - 1 } chunkData := make([]byte, end-start+1) // ... fill chunkData with actual data ... chunkHeader := &attachment.FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "large_file.zip", Size: end - start + 1, Header: make(map[string][]string), }, } chunkHeader.Header.Set("Content-Type", "application/zip") chunkHeader.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, totalSize)) chunkHeader.Header.Set("Content-Uid", uid) file, err := manager.Upload(ctx, chunkHeader, bytes.NewReader(chunkData), option) if err != nil { return err } // File is complete when the last chunk is uploaded if chunkHeader.Complete() { fmt.Printf("Upload complete: %s\n", file.ID) break } } ``` ### Compression #### Gzip Compression ```go option := attachment.UploadOption{ Gzip: true, // Enable gzip compression } file, err := manager.Upload(ctx, fileHeader, reader, option) ``` #### Image Compression ```go option := attachment.UploadOption{ CompressImage: true, CompressSize: 1920, // Max dimension in pixels (default: 1920) } file, err := manager.Upload(ctx, imageHeader, imageReader, option) ``` ### Multi-level Groups The `Groups` field supports hierarchical file organization: ```go // Single level grouping option := attachment.UploadOption{ Groups: []string{"users"}, } // Multi-level grouping option := attachment.UploadOption{ Groups: []string{"users", "user123", "chats", "chat456"}, } // Knowledge base organization option := attachment.UploadOption{ Groups: []string{"knowledge", "documents", "technical"}, } ``` This creates nested directory structures for better organization and access control. ### File Validation #### Size Limits ```go manager, err := attachment.New(attachment.ManagerOption{ MaxSize: "20M", // Maximum file size // Supports: B, K, M, G (e.g., "1024B", "2K", "10M", "1G") }) ``` #### Type Validation ```go manager, err := attachment.New(attachment.ManagerOption{ AllowedTypes: []string{ "text/*", // All text types "image/*", // All image types "application/pdf", // Specific MIME type ".txt", // File extension ".jpg", // File extension }, }) ``` ### Reading Files #### Stream Reading ```go response, err := manager.Download(ctx, fileID) if err != nil { return err } defer response.Reader.Close() // Use response.Reader as io.ReadCloser // response.ContentType contains the MIME type // response.Extension contains the file extension ``` #### Read as Bytes ```go data, err := manager.Read(ctx, fileID) if err != nil { return err } // data is []byte ``` #### Read as Base64 ```go base64Data, err := manager.ReadBase64(ctx, fileID) if err != nil { return err } // base64Data is string ``` ### Global Managers You can register managers globally for easy access: ```go // Register default manager with sensible defaults attachment.RegisterDefault("main") // Register custom managers attachment.Register("local", "local", attachment.ManagerOption{ Driver: "local", Options: map[string]interface{}{ "path": "/var/uploads", }, }) attachment.Register("s3", "s3", attachment.ManagerOption{ Driver: "s3", Options: map[string]interface{}{ "bucket": "my-bucket", "key": "access-key", "secret": "secret-key", }, }) // Use global managers localManager := attachment.Managers["local"] s3Manager := attachment.Managers["s3"] defaultManager := attachment.Managers["main"] ``` ## File Organization Files are organized in a hierarchical structure: ``` attachments/ ├── 20240101/ # Date (YYYYMMDD) │ └── user123/ # First level group (optional) │ └── chat456/ # Second level group (optional) │ └── knowledge/ # Additional group levels (optional) │ └── ab/ # First 2 chars of hash │ └── cd/ # Next 2 chars of hash │ └── abcdef12.txt # Hash + extension ``` The file ID generation includes: - Date prefix for organization - Multi-level groups for access control and organization - Content hash for deduplication - Original file extension ## API Reference ### Manager #### `New(option ManagerOption) (*Manager, error)` Creates a new attachment manager. #### `Register(name string, driver string, option ManagerOption) (*Manager, error)` Registers a global attachment manager. #### `Upload(ctx context.Context, fileheader *FileHeader, reader io.Reader, option UploadOption) (*File, error)` Uploads a file (supports chunked upload). #### `Download(ctx context.Context, fileID string) (*FileResponse, error)` Downloads a file as a stream. #### `Read(ctx context.Context, fileID string) ([]byte, error)` Reads a file as bytes. #### `ReadBase64(ctx context.Context, fileID string) (string, error)` Reads a file as base64 encoded string. ### Storage Interface All storage backends implement the following interface: ```go type Storage interface { Upload(ctx context.Context, fileID string, reader io.Reader, contentType string) (string, error) UploadChunk(ctx context.Context, fileID string, chunkIndex int, reader io.Reader, contentType string) error MergeChunks(ctx context.Context, fileID string, totalChunks int) error Download(ctx context.Context, fileID string) (io.ReadCloser, string, error) Reader(ctx context.Context, fileID string) (io.ReadCloser, error) URL(ctx context.Context, fileID string) string Exists(ctx context.Context, fileID string) bool Delete(ctx context.Context, fileID string) error } ``` ### Types #### `ManagerOption` Configuration for creating a manager: - `Driver`: "local" or "s3" - `MaxSize`: Maximum file size (e.g., "20M") - `ChunkSize`: Chunk size for uploads (e.g., "2M") - `AllowedTypes`: Array of allowed MIME types/extensions - `Options`: Driver-specific options #### `UploadOption` Options for file upload: - `CompressImage`: Enable image compression - `CompressSize`: Maximum image dimension (default: 1920) - `Gzip`: Enable gzip compression - `Groups`: Multi-level group identifiers for hierarchical file organization (e.g., []string{"user123", "chat456", "knowledge"}) - `OriginalFilename`: Original filename to preserve (avoids encoding issues) #### `File` Uploaded file information: - `ID`: Unique file identifier - `Filename`: Original filename - `ContentType`: MIME type - `Bytes`: File size - `CreatedAt`: Upload timestamp - `Status`: Upload status ("uploading", "uploaded", "indexing", "indexed", "upload_failed", "index_failed") #### `FileResponse` Download response: - `Reader`: io.ReadCloser for file content - `ContentType`: MIME type - `Extension`: File extension ## Chunked Upload Details The package supports chunked uploads using standard HTTP headers: - `Content-Range`: Specifies byte range (e.g., "bytes 0-1023/2048") - `Content-Uid`: Unique identifier for the file being uploaded ### Chunk Index Calculation The package uses a standard chunk size (1024 bytes by default) to calculate chunk indices consistently. This ensures proper chunk ordering during merge operations. ### Content Type Preservation For chunked uploads, the content type is preserved from the first chunk and applied to the final merged file, ensuring proper MIME type handling across all storage backends. ## Error Handling The package returns descriptive errors for common issues: - File size exceeds limit - Unsupported file type - Storage backend errors - Invalid chunk information - Missing required configuration ## Testing Run the tests: ```bash # Run all tests go test ./... # Run with S3 credentials (optional) export S3_ACCESS_KEY="your-key" export S3_SECRET_KEY="your-secret" export S3_BUCKET="your-bucket" export S3_API="https://your-s3-endpoint" go test ./... ``` The package includes comprehensive tests for: - Basic file upload/download - Chunked uploads with content type preservation - Compression (gzip and image) - File validation (size, type, wildcards) - Multiple storage backends (local and S3) - Error handling and edge cases ### Test Coverage - **Manager Tests**: Upload, download, validation, compression - **Local Storage Tests**: File operations, chunked uploads, directory management - **S3 Storage Tests**: S3 operations, chunked uploads, presigned URLs (requires credentials) ## Performance Considerations - **Chunked Uploads**: Use appropriate chunk sizes (1-5MB) for optimal performance - **Image Compression**: Automatically resizes large images to reduce storage costs - **Gzip Compression**: Reduces storage size for text-based files - **Content Type Detection**: Efficient MIME type detection and preservation ## Security Features - **File Type Validation**: Prevents upload of unauthorized file types - **Size Limits**: Configurable file size restrictions - **Path Sanitization**: Secure file path generation - **Access Control**: Multi-level hierarchical file organization ## License This package is part of the Yao project and follows the same license terms. ### File Deduplication with Fingerprints The package supports file deduplication using content fingerprints: ```go // Set a content fingerprint to enable deduplication fileHeader := &attachment.FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "document.pdf", Size: fileSize, Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "application/pdf") fileHeader.Header.Set("Content-Fingerprint", "sha256:abcdef123456") // Content-based hash file, err := manager.Upload(ctx, fileHeader, reader, option) ``` ### Content Synchronization For synchronized uploads across multiple clients: ```go // Enable content synchronization fileHeader.Header.Set("Content-Sync", "true") // Each client can upload the same content with the same fingerprint // The system will deduplicate based on the content fingerprint ``` ### Chunked Upload with Enhanced Headers For large files, you can upload in chunks using standard HTTP Content-Range headers with additional metadata: ```go // Upload chunks with unique identifier and fingerprint totalSize := int64(1024000) // 1MB file chunkSize := int64(1024) // 1KB chunks uid := "unique-file-id-123" fingerprint := "sha256:content-hash-here" for start := int64(0); start < totalSize; start += chunkSize { end := start + chunkSize - 1 if end >= totalSize { end = totalSize - 1 } chunkData := make([]byte, end-start+1) // ... fill chunkData with actual data ... chunkHeader := &attachment.FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "large_file.zip", Size: end - start + 1, Header: make(map[string][]string), }, } chunkHeader.Header.Set("Content-Type", "application/zip") chunkHeader.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, totalSize)) chunkHeader.Header.Set("Content-Uid", uid) chunkHeader.Header.Set("Content-Fingerprint", fingerprint) chunkHeader.Header.Set("Content-Sync", "true") // Enable synchronization option := attachment.UploadOption{ Groups: []string{"user123", "chat456"}, // Multi-level groups OriginalFilename: "my_large_file.zip", // Preserve original name } file, err := manager.Upload(ctx, chunkHeader, bytes.NewReader(chunkData), option) if err != nil { return err } // Check if upload is complete if file.Status == "uploaded" { fmt.Printf("Upload complete: %s\n", file.ID) break } else if file.Status == "uploading" { fmt.Printf("Chunk uploaded, progress: %d/%d\n", chunkHeader.GetChunkSize(), chunkHeader.GetTotalSize()) } } ``` ### FileHeader Methods The `FileHeader` type provides several utility methods: ```go // Get unique identifier for chunked uploads uid := fileHeader.UID() // Get content fingerprint for deduplication fingerprint := fileHeader.Fingerprint() // Get byte range for chunked uploads rangeHeader := fileHeader.Range() // Check if synchronization is enabled isSync := fileHeader.Sync() // Check if this is a chunked upload isChunk := fileHeader.IsChunk() // Check if upload is complete (for chunked uploads) isComplete := fileHeader.Complete() // Get detailed chunk information start, end, total, err := fileHeader.GetChunkInfo() // Get total file size (for chunked uploads) totalSize := fileHeader.GetTotalSize() // Get current chunk size chunkSize := fileHeader.GetChunkSize() ``` ## File Headers and Metadata The package supports several HTTP headers for enhanced functionality: - `Content-Range`: Standard HTTP range header for chunked uploads (e.g., "bytes 0-1023/2048") - `Content-Uid`: Unique identifier for file uploads (for deduplication and tracking) - `Content-Fingerprint`: Content-based hash for deduplication (e.g., "sha256:abc123") - `Content-Sync`: Enable synchronized uploads across multiple clients ("true"/"false") ### Header Processing When processing uploads, headers can be extracted from both HTTP request headers and multipart file headers: ```go // Extract headers from HTTP request and file headers header := attachment.GetHeader(requestHeader, fileHeader, fileSize) // The resulting FileHeader will contain merged headers from both sources uid := header.UID() fingerprint := header.Fingerprint() isSync := header.Sync() ``` ## Upload Status Tracking Files have a status field that tracks the upload lifecycle: - `"uploading"`: File upload is in progress (for chunked uploads) - `"uploaded"`: File has been successfully uploaded - `"indexing"`: File is being processed for search indexing - `"indexed"`: File has been indexed and is fully processed - `"upload_failed"`: Upload failed due to an error - `"index_failed"`: Indexing failed but file is still accessible ```go file, err := manager.Upload(ctx, fileHeader, reader, option) if err != nil { return err } switch file.Status { case "uploading": fmt.Println("Upload in progress...") case "uploaded": fmt.Println("Upload completed successfully") case "upload_failed": fmt.Println("Upload failed") } ``` ## Text Content Storage The attachment package supports storing parsed text content extracted from files (e.g., from PDFs, Word documents, or image OCR). This is useful for building search indexes or providing text-based previews. The system automatically maintains two versions of the text content: - **Full content** (`content`): Complete text, stored as longText (up to 4GB) - **Preview** (`content_preview`): First 2000 characters, stored as text for quick access ### Saving Parsed Text Content Use `SaveText` to store the extracted text content. It automatically saves both full content and preview: ```go // Upload a PDF file file, err := manager.Upload(ctx, fileHeader, reader, option) if err != nil { return err } // Extract text from the PDF (using your preferred library) parsedText := extractTextFromPDF(file.ID) // Save the parsed text (automatically saves both full and preview) err = manager.SaveText(ctx, file.ID, parsedText) if err != nil { return fmt.Errorf("failed to save text content: %w", err) } ``` ### Retrieving Parsed Text Content Use `GetText` to retrieve text content. By default, it returns the preview for better performance: ```go // Get preview (first 2000 characters) - Fast, suitable for UI display preview, err := manager.GetText(ctx, file.ID) if err != nil { return fmt.Errorf("failed to get preview: %w", err) } if preview == "" { fmt.Println("No text content available for this file") } else { fmt.Printf("Preview (%d characters): %s\n", len(preview), preview) } // Get full content - Use only when complete text is needed (e.g., for indexing) fullText, err := manager.GetText(ctx, file.ID, true) if err != nil { return fmt.Errorf("failed to get full text: %w", err) } fmt.Printf("Full content (%d characters)\n", len(fullText)) ``` ### Performance Optimization The text content fields are optimized for different use cases: | Field | Size Limit | Use Case | Performance | |-------|------------|----------|-------------| | `content_preview` | 2000 chars | Quick preview, UI display, snippets | ⚡ Very Fast | | `content` | 4GB | Full text search, complete content | 🐌 Slow for large files | **Best Practices:** 1. Use preview by default: `GetText(ctx, fileID)` 2. Only request full content when necessary: `GetText(ctx, fileID, true)` 3. Both fields are excluded from `List()` by default for optimal performance 4. Preview uses character (rune) count, not bytes, for proper UTF-8 handling ### Example: Complete Text Processing Workflow ```go // 1. Upload file file, err := manager.Upload(ctx, fileHeader, reader, option) if err != nil { return err } // 2. Process file based on content type var parsedText string switch { case strings.HasPrefix(file.ContentType, "image/"): // Use OCR to extract text from image parsedText, err = performOCR(file.ID) case file.ContentType == "application/pdf": // Extract text from PDF parsedText, err = extractPDFText(file.ID) case strings.Contains(file.ContentType, "wordprocessingml"): // Extract text from Word document parsedText, err = extractWordText(file.ID) } if err != nil { return fmt.Errorf("failed to extract text: %w", err) } // 3. Save the extracted text if parsedText != "" { err = manager.SaveText(ctx, file.ID, parsedText) if err != nil { return fmt.Errorf("failed to save text: %w", err) } } // 4. Later, retrieve the text for search or display savedText, err := manager.GetText(ctx, file.ID) if err != nil { return err } fmt.Printf("Retrieved text: %s\n", savedText) ``` ### Text Content Features - **Dual Storage**: Automatically maintains both full content and preview (2000 chars) - **Size Limits**: - Preview: 2000 characters (text type) - Full content: Up to 4GB (longText type) - **Smart Retrieval**: Returns preview by default, full content on demand - **Update**: Text content can be updated at any time using `SaveText` - **Clear**: Set text to empty string to clear both fields - **UTF-8 Safe**: Preview uses character (rune) count, not bytes, ensuring proper multi-byte character handling - **Performance**: Both `content` and `content_preview` fields are excluded by default in `List()` and `Info()` operations to avoid loading text data. Use `GetText()` to explicitly retrieve text content when needed #### `RegisterDefault(name string) (*Manager, error)` Registers a default attachment manager with sensible defaults for common file types. ## Process API The attachment package provides a set of Yao Process APIs for file management with built-in permission support. ### Available Processes | Process | Description | |---------|-------------| | `attachment.Save` | Save a file from base64 data URI | | `attachment.Read` | Read file content as base64 data URI | | `attachment.Info` | Get file metadata | | `attachment.List` | List files with pagination and filtering | | `attachment.Delete` | Delete a file | | `attachment.Exists` | Check if file exists | | `attachment.URL` | Get file URL | | `attachment.SaveText` | Save parsed text content for a file | | `attachment.GetText` | Get parsed text content for a file | ### Permission Model The Process API integrates with Yao's `process.Authorized` mechanism: - **Authorized Info**: Reads `UserID`, `TeamID`, `TenantID` from `process.Authorized` (set by OAuth guard) - **Auto Permission Storage**: On save, automatically stores `__yao_created_by`, `__yao_team_id`, `__yao_tenant_id` from `process.Authorized` - **Data Constraints**: Respects `Constraints.OwnerOnly` and `Constraints.TeamOnly` from ACL enforcement - **Owner Access**: When `OwnerOnly` is set, only file creator (`__yao_created_by`) can access their files - **Team Access**: When `TeamOnly` is set, team members can access files with `share: "team"` - **Public Access**: Files with `public: true` are readable by everyone regardless of constraints - **No Constraints**: If no constraints are set, all authenticated users can access all files ### Usage Examples #### JavaScript (Yao Scripts) ```javascript // Save a file from base64 data URI const file = Process("attachment.Save", "default", "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUA...", "photo.png", { share: "team" } ); console.log("Saved file ID:", file.file_id); // Save text file const textFile = Process("attachment.Save", "default", "data:text/plain;base64,SGVsbG8gV29ybGQh", "hello.txt" ); // Read file content as data URI const dataURI = Process("attachment.Read", "default", file.file_id); // Returns: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUA..." // Get file info const info = Process("attachment.Info", "default", file.file_id); // List files with pagination const result = Process("attachment.List", "default", { page: 1, page_size: 20, filters: { status: "uploaded", content_type: "image/*" }, order_by: "created_at desc" }); // Check if file exists const exists = Process("attachment.Exists", "default", file.file_id); // Get file URL const url = Process("attachment.URL", "default", file.file_id); // Save parsed text content (e.g., OCR result, PDF text) Process("attachment.SaveText", "default", file.file_id, "Extracted text content..."); // Get text content (preview by default) const preview = Process("attachment.GetText", "default", file.file_id); // Get full text content const fullText = Process("attachment.GetText", "default", file.file_id, true); // Delete file Process("attachment.Delete", "default", file.file_id); ``` #### Flow DSL ```json { "name": "Save Image", "nodes": [ { "name": "save", "process": "attachment.Save", "args": [ "default", "{{$in.dataURI}}", "{{$in.filename}}", { "share": "team" } ] } ], "output": "{{$res.save}}" } ``` ### Process Reference #### `attachment.Save` Save a file from base64 data URI. Automatically parses content type from data URI header and stores permission fields from `process.Authorized`. **Arguments:** 1. `uploaderID` (string) - The uploader/manager ID 2. `content` (string) - Base64 data URI (e.g., `"data:image/png;base64,xxxx"`) or plain base64 3. `filename` (string, optional) - Original filename (auto-generated if not provided) 4. `option` (map, optional) - Upload options: - `groups` ([]string) - Directory groups for organization - `gzip` (bool) - Enable gzip compression - `compress_image` (bool) - Enable image compression - `compress_size` (int) - Target image size in pixels - `public` (bool) - Make file publicly accessible - `share` (string) - Share scope: "private" or "team" **Returns:** `*File` - Saved file information **Example:** ```javascript // With data URI (auto-detect content type) Process("attachment.Save", "default", "data:image/png;base64,iVBORw0KGgo...", "photo.png") // With plain base64 (defaults to application/octet-stream) Process("attachment.Save", "default", "SGVsbG8gV29ybGQh", "hello.txt") // With options Process("attachment.Save", "default", "data:application/pdf;base64,...", "doc.pdf", { groups: ["documents"], share: "team", public: false }) ``` --- #### `attachment.Read` Read file content as base64 data URI. **Arguments:** 1. `uploaderID` (string) - The uploader/manager ID 2. `fileID` (string) - The file ID **Returns:** `string` - Base64 data URI (e.g., `"data:image/png;base64,xxxx"`) **Example:** ```javascript const dataURI = Process("attachment.Read", "default", "abc123") // Returns: "data:image/png;base64,iVBORw0KGgo..." ``` --- #### `attachment.Info` Get file metadata. **Arguments:** 1. `uploaderID` (string) - The uploader/manager ID 2. `fileID` (string) - The file ID **Returns:** `*File` - File metadata --- #### `attachment.List` List files with pagination and filtering. **Arguments:** 1. `uploaderID` (string) - The uploader/manager ID 2. `option` (map, optional) - List options: - `page` (int) - Page number (default: 1) - `page_size` (int) - Items per page (default: 20, max: 100) - `filters` (map) - Filter conditions (e.g., `{"status": "uploaded"}`) - `order_by` (string) - Sort order (e.g., "created_at desc") - `select` ([]string) - Fields to return **Returns:** `*ListResult` - Paginated file list --- #### `attachment.Delete` Delete a file. Requires write permission (owner only). **Arguments:** 1. `uploaderID` (string) - The uploader/manager ID 2. `fileID` (string) - The file ID **Returns:** `bool` - Success status --- #### `attachment.Exists` Check if a file exists. **Arguments:** 1. `uploaderID` (string) - The uploader/manager ID 2. `fileID` (string) - The file ID **Returns:** `bool` - Whether file exists --- #### `attachment.URL` Get the URL of a file. **Arguments:** 1. `uploaderID` (string) - The uploader/manager ID 2. `fileID` (string) - The file ID **Returns:** `string` - File URL --- #### `attachment.SaveText` Save parsed text content for a file (e.g., OCR result, PDF extracted text). **Arguments:** 1. `uploaderID` (string) - The uploader/manager ID 2. `fileID` (string) - The file ID 3. `text` (string) - Text content to save **Returns:** `bool` - Success status --- #### `attachment.GetText` Get parsed text content for a file. **Arguments:** 1. `uploaderID` (string) - The uploader/manager ID 2. `fileID` (string) - The file ID 3. `fullContent` (bool, optional) - Whether to get full content (default: false, returns preview) **Returns:** `string` - Text content ================================================ FILE: attachment/compresses.go ================================================ package attachment import ( "bytes" "fmt" "image" "image/jpeg" "image/png" "io" ) // CompressImage compresses the image while maintaining aspect ratio func CompressImage(reader io.Reader, contentType string, maxSize int) ([]byte, error) { // Read all data first data, err := io.ReadAll(reader) if err != nil { return nil, fmt.Errorf("failed to read image data: %w", err) } // Decode image img, _, err := image.Decode(bytes.NewReader(data)) if err != nil { return nil, fmt.Errorf("failed to decode image: %w", err) } // Calculate new dimensions bounds := img.Bounds() width := bounds.Dx() height := bounds.Dy() var newWidth, newHeight int if width > height { if width > maxSize { newWidth = maxSize newHeight = int(float64(height) * (float64(maxSize) / float64(width))) } else { return data, nil // No need to resize, return original data } } else { if height > maxSize { newHeight = maxSize newWidth = int(float64(width) * (float64(maxSize) / float64(height))) } else { return data, nil // No need to resize, return original data } } // Create new image with new dimensions newImg := image.NewRGBA(image.Rect(0, 0, newWidth, newHeight)) // Scale the image using bilinear interpolation for y := 0; y < newHeight; y++ { for x := 0; x < newWidth; x++ { srcX := float64(x) * float64(width) / float64(newWidth) srcY := float64(y) * float64(height) / float64(newHeight) newImg.Set(x, y, img.At(int(srcX), int(srcY))) } } // Encode image var buf bytes.Buffer switch contentType { case "image/jpeg": err = jpeg.Encode(&buf, newImg, &jpeg.Options{Quality: 85}) case "image/png": err = png.Encode(&buf, newImg) default: return data, nil // Unsupported format, return original data } if err != nil { return nil, fmt.Errorf("failed to encode image: %w", err) } return buf.Bytes(), nil } ================================================ FILE: attachment/convert.go ================================================ package attachment import ( "fmt" "strings" ) // toBool converts various types to boolean func toBool(v interface{}) bool { if v == nil { return false } switch val := v.(type) { case bool: return val case int: return val != 0 case int64: return val != 0 case uint8: // MySQL tinyint(1) return val != 0 case float64: return val != 0 case string: normalized := strings.ToLower(strings.TrimSpace(val)) switch normalized { case "true", "1", "enabled", "yes", "on": return true default: return false } default: return false } } // toString converts various types to string func toString(v interface{}) string { if v == nil { return "" } switch val := v.(type) { case string: return val case int: return fmt.Sprintf("%d", val) case int64: return fmt.Sprintf("%d", val) case float64: return fmt.Sprintf("%.0f", val) case bool: if val { return "true" } return "false" default: return fmt.Sprintf("%v", val) } } ================================================ FILE: attachment/example_usage.go ================================================ package attachment import ( "bytes" "context" "fmt" "log" "mime/multipart" "strings" "github.com/yaoapp/yao/attachment/s3" ) // ExampleUsage demonstrates how to use the attachment package func ExampleUsage() { // 1. Create a local storage manager localManager, err := New(ManagerOption{ Driver: "local", MaxSize: "20M", ChunkSize: "2M", AllowedTypes: []string{"text/*", "image/*", "application/pdf", ".txt", ".jpg", ".png", ".pdf"}, Options: map[string]interface{}{ "path": "/var/uploads", "base_url": "https://example.com/files", }, }) if err != nil { log.Fatalf("Failed to create local manager: %v", err) } // 2. Create an S3 storage manager s3Manager, err := New(ManagerOption{ Driver: "s3", MaxSize: "100M", ChunkSize: "5M", AllowedTypes: []string{"*"}, // Allow all types Options: map[string]interface{}{ "endpoint": "https://s3.amazonaws.com", "region": "us-east-1", "key": "your-access-key", "secret": "your-secret-key", "bucket": "your-bucket-name", "prefix": "attachments/", }, }) if err != nil { log.Printf("Failed to create S3 manager (this is expected without credentials): %v", err) } else { fmt.Printf("Created S3 manager successfully\n") // Demonstrate S3 manager usage if credentials are available if s3Manager != nil { fmt.Printf("S3 manager is ready for use with bucket: %s\n", s3Manager.storage.(*s3.Storage).Bucket) } } // 3. Register managers globally _, err = Register("local", "local", ManagerOption{ Driver: "local", MaxSize: "20M", AllowedTypes: []string{"text/*", "image/*"}, Options: map[string]interface{}{ "path": "/var/uploads", }, }) if err != nil { log.Printf("Failed to register local manager: %v", err) } // Try to register S3 manager (will fail without credentials) _, err = Register("s3", "s3", ManagerOption{ Driver: "s3", MaxSize: "100M", Options: map[string]interface{}{ "bucket": "your-bucket", "key": "your-key", "secret": "your-secret", }, }) if err != nil { log.Printf("Failed to register S3 manager (expected without credentials): %v", err) } ctx := context.Background() // 4. Example: Simple file upload content := "Hello, World! This is a test file with some content to demonstrate the attachment package." fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "hello.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") uploadOption := UploadOption{ Groups: []string{"user123"}, Gzip: false, // No compression for small text files } file, err := localManager.Upload(ctx, fileHeader, strings.NewReader(content), uploadOption) if err != nil { log.Fatalf("Failed to upload file: %v", err) } fmt.Printf("Uploaded file: %s (ID: %s, Size: %d bytes)\n", file.Filename, file.ID, file.Bytes) // 5. Example: File upload with gzip compression largeContent := strings.Repeat("This is a large text file that benefits from compression. ", 100) gzipFileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "large_text.txt", Size: int64(len(largeContent)), Header: make(map[string][]string), }, } gzipFileHeader.Header.Set("Content-Type", "text/plain") gzipOption := UploadOption{ Groups: []string{"user123"}, Gzip: true, // Enable compression } gzipFile, err := localManager.Upload(ctx, gzipFileHeader, strings.NewReader(largeContent), gzipOption) if err != nil { log.Fatalf("Failed to upload gzipped file: %v", err) } fmt.Printf("Uploaded compressed file: %s (ID: %s)\n", gzipFile.Filename, gzipFile.ID) // 6. Example: Image upload with compression imageUploadOption := UploadOption{ Groups: []string{"user123"}, CompressImage: true, CompressSize: 1920, // Resize to max 1920px Gzip: false, } // Simulate image upload (you would get this from multipart form) imageHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "photo.jpg", Size: 1024000, // 1MB Header: make(map[string][]string), }, } imageHeader.Header.Set("Content-Type", "image/jpeg") fmt.Printf("Image upload option configured: compress=%v, size=%d\n", imageUploadOption.CompressImage, imageUploadOption.CompressSize) // 6.5. Example: Multi-level groups fmt.Println("\n--- Multi-level Groups Examples ---") // Single level grouping singleGroupOption := UploadOption{ Groups: []string{"knowledge"}, } singleGroupHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "knowledge_doc.txt", Size: int64(len("Knowledge base document")), Header: make(map[string][]string), }, } singleGroupHeader.Header.Set("Content-Type", "text/plain") singleFile, err := localManager.Upload(ctx, singleGroupHeader, strings.NewReader("Knowledge base document"), singleGroupOption) if err != nil { log.Printf("Failed to upload single group file: %v", err) } else { fmt.Printf("Single group file uploaded: %s (ID: %s)\n", singleFile.Filename, singleFile.ID) } // Multi-level grouping multiGroupOption := UploadOption{ Groups: []string{"users", "user123", "chats", "chat456", "documents"}, } multiGroupHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "chat_document.txt", Size: int64(len("Document in user chat")), Header: make(map[string][]string), }, } multiGroupHeader.Header.Set("Content-Type", "text/plain") multiFile, err := localManager.Upload(ctx, multiGroupHeader, strings.NewReader("Document in user chat"), multiGroupOption) if err != nil { log.Printf("Failed to upload multi-group file: %v", err) } else { fmt.Printf("Multi-level group file uploaded: %s (ID: %s)\n", multiFile.Filename, multiFile.ID) fmt.Printf("File path includes hierarchy: users/user123/chats/chat456/documents\n") } // Knowledge base organization knowledgeOption := UploadOption{ Groups: []string{"knowledge", "technical", "api", "documentation"}, } knowledgeHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "api_guide.md", Size: int64(len("# API Documentation\n\nThis is technical documentation.")), Header: make(map[string][]string), }, } knowledgeHeader.Header.Set("Content-Type", "text/markdown") knowledgeFile, err := localManager.Upload(ctx, knowledgeHeader, strings.NewReader("# API Documentation\n\nThis is technical documentation."), knowledgeOption) if err != nil { log.Printf("Failed to upload knowledge file: %v", err) } else { fmt.Printf("Knowledge base file uploaded: %s (ID: %s)\n", knowledgeFile.Filename, knowledgeFile.ID) fmt.Printf("Organized in: knowledge/technical/api/documentation\n") } // 7. Example: Chunked upload largeContent = strings.Repeat("This is a large file content that will be uploaded in chunks. ", 1000) chunkSize := 1024 totalSize := len(largeContent) uid := "unique-large-file-123" fmt.Printf("Starting chunked upload: total size=%d, chunk size=%d\n", totalSize, chunkSize) var lastFile *File chunkCount := 0 // Split into chunks and upload for i := 0; i < totalSize; i += chunkSize { end := i + chunkSize if end > totalSize { end = totalSize } chunk := largeContent[i:end] chunkHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "large_file.txt", Size: int64(len(chunk)), Header: make(map[string][]string), }, } chunkHeader.Header.Set("Content-Type", "text/plain") chunkHeader.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", i, end-1, totalSize)) chunkHeader.Header.Set("Content-Uid", uid) chunkOption := UploadOption{ Groups: []string{"user123"}, Gzip: true, // Compress chunks } chunkFile, err := localManager.Upload(ctx, chunkHeader, strings.NewReader(chunk), chunkOption) if err != nil { log.Fatalf("Failed to upload chunk %d: %v", chunkCount, err) } chunkCount++ lastFile = chunkFile // Check if this is the last chunk if chunkHeader.Complete() { fmt.Printf("Uploaded large file in %d chunks: %s (ID: %s)\n", chunkCount, chunkFile.Filename, chunkFile.ID) break } } // 8. Example: Download and read files if file != nil { // Download as stream response, err := localManager.Download(ctx, file.ID) if err != nil { log.Fatalf("Failed to download file: %v", err) } defer response.Reader.Close() fmt.Printf("Downloaded file content type: %s, extension: %s\n", response.ContentType, response.Extension) // Read as bytes data, err := localManager.Read(ctx, file.ID) if err != nil { log.Fatalf("Failed to read file: %v", err) } fmt.Printf("File content length: %d bytes\n", len(data)) if len(data) < 100 { fmt.Printf("File content: %s\n", string(data)) } else { fmt.Printf("File content preview: %s...\n", string(data[:100])) } // Read as base64 base64Data, err := localManager.ReadBase64(ctx, file.ID) if err != nil { log.Fatalf("Failed to read file as base64: %v", err) } fmt.Printf("File as base64 (first 50 chars): %s...\n", base64Data[:min(50, len(base64Data))]) } // 9. Example: Read chunked file if lastFile != nil { chunkData, err := localManager.Read(ctx, lastFile.ID) if err != nil { log.Fatalf("Failed to read chunked file: %v", err) } // Since the chunks were compressed, we need to decompress decompressed, err := Gunzip(chunkData) if err != nil { log.Fatalf("Failed to decompress chunked file: %v", err) } fmt.Printf("Chunked file content length: %d bytes (decompressed)\n", len(decompressed)) if len(decompressed) < 200 { fmt.Printf("Chunked file content: %s\n", string(decompressed)) } else { fmt.Printf("Chunked file content preview: %s...\n", string(decompressed[:200])) } } // 10. Example: Using global managers globalManager := Managers["local"] if globalManager != nil { fmt.Println("Using global manager for local storage") // Test a simple upload with global manager testContent := "Test content using global manager" testHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "global_test.txt", Size: int64(len(testContent)), Header: make(map[string][]string), }, } testHeader.Header.Set("Content-Type", "text/plain") testFile, err := globalManager.Upload(ctx, testHeader, strings.NewReader(testContent), UploadOption{ Groups: []string{"global_user"}, }) if err != nil { log.Printf("Failed to upload with global manager: %v", err) } else { fmt.Printf("Global manager upload successful: %s\n", testFile.ID) } } // 11. Example: File validation fmt.Println("\n--- File Validation Examples ---") // Test file size validation tooLargeContent := strings.Repeat("x", 25*1024*1024) // 25MB, exceeds 20MB limit largeFileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "too_large.txt", Size: int64(len(tooLargeContent)), Header: make(map[string][]string), }, } largeFileHeader.Header.Set("Content-Type", "text/plain") _, err = localManager.Upload(ctx, largeFileHeader, strings.NewReader(tooLargeContent), UploadOption{}) if err != nil { fmt.Printf("Expected error for large file: %v\n", err) } // Test file type validation invalidFileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "script.exe", Size: 1024, Header: make(map[string][]string), }, } invalidFileHeader.Header.Set("Content-Type", "application/x-executable") _, err = localManager.Upload(ctx, invalidFileHeader, strings.NewReader("fake exe content"), UploadOption{}) if err != nil { fmt.Printf("Expected error for invalid file type: %v\n", err) } fmt.Println("\n--- Example Usage Complete ---") } // ExampleChunkedUpload demonstrates how to handle chunked uploads properly func ExampleChunkedUpload(manager *Manager, filename string, totalSize int64, contentType string) error { ctx := context.Background() chunkSize := int64(1024 * 1024) // 1MB chunks uid := "unique-file-" + filename fmt.Printf("Starting chunked upload: file=%s, total=%d bytes, chunks=%d\n", filename, totalSize, (totalSize+chunkSize-1)/chunkSize) for offset := int64(0); offset < totalSize; offset += chunkSize { end := offset + chunkSize - 1 if end >= totalSize { end = totalSize - 1 } chunkSize := end - offset + 1 // Create chunk header chunkHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: filename, Size: chunkSize, Header: make(map[string][]string), }, } chunkHeader.Header.Set("Content-Type", contentType) chunkHeader.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", offset, end, totalSize)) chunkHeader.Header.Set("Content-Uid", uid) // In real usage, you would read the actual chunk data from the source chunkData := make([]byte, chunkSize) // Fill with sample data for demonstration for i := range chunkData { chunkData[i] = byte('A' + (i % 26)) } option := UploadOption{ Groups: []string{"user123"}, Gzip: false, // Disable compression for this example } file, err := manager.Upload(ctx, chunkHeader, bytes.NewReader(chunkData), option) if err != nil { return fmt.Errorf("failed to upload chunk at offset %d: %w", offset, err) } fmt.Printf("Uploaded chunk %d-%d/%d\n", offset, end, totalSize) // Check if this was the last chunk if chunkHeader.Complete() { fmt.Printf("File upload completed: %s (ID: %s)\n", file.Filename, file.ID) // Verify the uploaded file data, err := manager.Read(ctx, file.ID) if err != nil { return fmt.Errorf("failed to read uploaded file: %w", err) } if int64(len(data)) != totalSize { return fmt.Errorf("uploaded file size mismatch: expected %d, got %d", totalSize, len(data)) } fmt.Printf("File verification successful: %d bytes\n", len(data)) break } } return nil } // ExampleS3Upload demonstrates S3-specific features func ExampleS3Upload() { // This example requires actual S3 credentials s3Manager, err := New(ManagerOption{ Driver: "s3", MaxSize: "50M", Options: map[string]interface{}{ "endpoint": "https://s3.amazonaws.com", "region": "us-east-1", "key": "your-access-key", "secret": "your-secret-key", "bucket": "your-bucket", "prefix": "test-uploads/", }, }) if err != nil { log.Printf("S3 manager creation failed (expected without credentials): %v", err) return } ctx := context.Background() content := "Test content for S3 upload" fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "s3_test.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") file, err := s3Manager.Upload(ctx, fileHeader, strings.NewReader(content), UploadOption{ Groups: []string{"s3_user"}, }) if err != nil { log.Printf("S3 upload failed: %v", err) return } fmt.Printf("S3 upload successful: %s\n", file.ID) // Get presigned URL url := s3Manager.storage.URL(ctx, file.ID) fmt.Printf("Presigned URL: %s\n", url) } // Helper function for min func min(a, b int) int { if a < b { return a } return b } ================================================ FILE: attachment/fileheader.go ================================================ package attachment import ( "strconv" "strings" ) // UID is the uid of the file, it is the unique identifier of the file func (fileheader *FileHeader) UID() string { return fileheader.Header.Get("Content-Uid") } // Fingerprint is the fingerprint of the file, it is the fingerprint of the file func (fileheader *FileHeader) Fingerprint() string { return fileheader.Header.Get("Content-Fingerprint") } // Range is the range of the file, it is the start and end of the file func (fileheader *FileHeader) Range() string { return fileheader.Header.Get("Content-Range") } // Sync is the sync of the file, it is the sync of the file func (fileheader *FileHeader) Sync() bool { return fileheader.Header.Get("Content-Sync") == "true" } // IsChunk is the chunk of the file, it is the chunk of the file func (fileheader *FileHeader) IsChunk() bool { return fileheader.Range() != "" } // Complete checks if the chunk upload is completed // For non-chunk files, it returns true // For chunk files, it parses the Content-Range header to determine if this is the last chunk func (fileheader *FileHeader) Complete() bool { if !fileheader.IsChunk() { return true } // Parse Content-Range header: "bytes start-end/total" rangeHeader := fileheader.Range() if rangeHeader == "" { return false } // Remove "bytes " prefix rangeStr := strings.TrimPrefix(rangeHeader, "bytes ") // Split by "/" parts := strings.Split(rangeStr, "/") if len(parts) != 2 { return false } // Parse total size total, err := strconv.ParseInt(parts[1], 10, 64) if err != nil { return false } // Parse range "start-end" rangeParts := strings.Split(parts[0], "-") if len(rangeParts) != 2 { return false } end, err := strconv.ParseInt(rangeParts[1], 10, 64) if err != nil { return false } // Check if this is the last chunk: end + 1 == total return end+1 == total } // GetChunkInfo returns the chunk information parsed from Content-Range header // Returns start, end, total, and error func (fileheader *FileHeader) GetChunkInfo() (start, end, total int64, err error) { if !fileheader.IsChunk() { return 0, 0, 0, nil } rangeHeader := fileheader.Range() if rangeHeader == "" { return 0, 0, 0, nil } // Remove "bytes " prefix rangeStr := strings.TrimPrefix(rangeHeader, "bytes ") // Split by "/" parts := strings.Split(rangeStr, "/") if len(parts) != 2 { return 0, 0, 0, nil } // Parse total size total, err = strconv.ParseInt(parts[1], 10, 64) if err != nil { return 0, 0, 0, err } // Parse range "start-end" rangeParts := strings.Split(parts[0], "-") if len(rangeParts) != 2 { return 0, 0, 0, nil } start, err = strconv.ParseInt(rangeParts[0], 10, 64) if err != nil { return 0, 0, 0, err } end, err = strconv.ParseInt(rangeParts[1], 10, 64) if err != nil { return 0, 0, 0, err } return start, end, total, nil } // GetTotalSize returns the total file size from Content-Range header func (fileheader *FileHeader) GetTotalSize() int64 { _, _, total, err := fileheader.GetChunkInfo() if err != nil { return 0 } return total } // GetChunkSize returns the current chunk size func (fileheader *FileHeader) GetChunkSize() int64 { start, end, _, err := fileheader.GetChunkInfo() if err != nil { return 0 } return end - start + 1 } ================================================ FILE: attachment/gzip.go ================================================ package attachment import ( "bytes" "compress/gzip" "fmt" "io" "os" ) // GzipCompressor supports chunked compression for Gzip type GzipCompressor struct { writer *gzip.Writer buffer *bytes.Buffer file *os.File // optional file handle } // NewGzipCompressor creates a new Gzip compressor func NewGzipCompressor() *GzipCompressor { buf := &bytes.Buffer{} gz := gzip.NewWriter(buf) return &GzipCompressor{ writer: gz, buffer: buf, file: nil, } } // NewGzipCompressorFromFile creates a Gzip compressor from file, supports streaming read func NewGzipCompressorFromFile(filePath string) (*GzipCompressor, error) { file, err := os.Open(filePath) if err != nil { return nil, fmt.Errorf("failed to open file %s: %w", filePath, err) } buf := &bytes.Buffer{} gz := gzip.NewWriter(buf) return &GzipCompressor{ writer: gz, buffer: buf, file: file, }, nil } // ReadChunk reads a chunk of specified size from file and compresses it func (gc *GzipCompressor) ReadChunk(chunkSize int) (bool, error) { if gc.file == nil { return false, fmt.Errorf("no file associated with this compressor") } chunk := make([]byte, chunkSize) n, err := gc.file.Read(chunk) if err != nil && err != io.EOF { return false, fmt.Errorf("failed to read from file: %w", err) } if n > 0 { if err := gc.Write(chunk[:n]); err != nil { return false, err } } // return whether there is more data return err != io.EOF, nil } // CompressFileInChunks compresses the entire file in chunks func (gc *GzipCompressor) CompressFileInChunks(chunkSize int) error { if gc.file == nil { return fmt.Errorf("no file associated with this compressor") } for { hasMore, err := gc.ReadChunk(chunkSize) if err != nil { return err } if !hasMore { break } } return nil } // Write writes data for compression (supports chunked writing) func (gc *GzipCompressor) Write(data []byte) error { _, err := gc.writer.Write(data) if err != nil { return fmt.Errorf("failed to write data to gzip: %w", err) } return nil } // Flush flushes the buffer but does not close the compressor func (gc *GzipCompressor) Flush() error { return gc.writer.Flush() } // Close closes the compressor and returns the final compressed data func (gc *GzipCompressor) Close() ([]byte, error) { err := gc.writer.Close() if err != nil { return nil, fmt.Errorf("failed to close gzip writer: %w", err) } // if there is an associated file, close it too if gc.file != nil { gc.file.Close() gc.file = nil } return gc.buffer.Bytes(), nil } // GetCompressedData gets the current compressed data (without closing the compressor) func (gc *GzipCompressor) GetCompressedData() []byte { // flush the buffer first gc.writer.Flush() return gc.buffer.Bytes() } // Reset resets the compressor for reuse func (gc *GzipCompressor) Reset() { gc.buffer.Reset() gc.writer.Reset(gc.buffer) if gc.file != nil { gc.file.Close() gc.file = nil } } // Gzip compresses data in one go func Gzip(data []byte) ([]byte, error) { var buf bytes.Buffer gz := gzip.NewWriter(&buf) _, err := gz.Write(data) if err != nil { return nil, fmt.Errorf("failed to gzip data: %w", err) } err = gz.Close() if err != nil { return nil, fmt.Errorf("failed to close gzip writer: %w", err) } return buf.Bytes(), nil } // GzipChunks compresses multiple data chunks func GzipChunks(chunks [][]byte) ([]byte, error) { compressor := NewGzipCompressor() for _, chunk := range chunks { if err := compressor.Write(chunk); err != nil { return nil, err } } return compressor.Close() } // GzipFromReader compresses data from Reader stream func GzipFromReader(reader io.Reader) ([]byte, error) { var buf bytes.Buffer gz := gzip.NewWriter(&buf) _, err := io.Copy(gz, reader) if err != nil { return nil, fmt.Errorf("failed to copy data to gzip: %w", err) } err = gz.Close() if err != nil { return nil, fmt.Errorf("failed to close gzip writer: %w", err) } return buf.Bytes(), nil } // Gunzip decompresses gzip data func Gunzip(data []byte) ([]byte, error) { reader, err := gzip.NewReader(bytes.NewReader(data)) if err != nil { return nil, fmt.Errorf("failed to create gzip reader: %w", err) } defer reader.Close() var buf bytes.Buffer _, err = io.Copy(&buf, reader) if err != nil { return nil, fmt.Errorf("failed to decompress data: %w", err) } return buf.Bytes(), nil } // GzipToWriter writes compressed data to Writer func GzipToWriter(data []byte, writer io.Writer) error { gz := gzip.NewWriter(writer) defer gz.Close() _, err := gz.Write(data) if err != nil { return fmt.Errorf("failed to write gzip data: %w", err) } return nil } // GzipFromReaderToWriter reads data from Reader and writes compressed data to Writer func GzipFromReaderToWriter(reader io.Reader, writer io.Writer) error { gz := gzip.NewWriter(writer) defer gz.Close() _, err := io.Copy(gz, reader) if err != nil { return fmt.Errorf("failed to copy and compress data: %w", err) } return nil } // GzipFile compresses entire file (loads into memory at once) func GzipFile(filePath string) ([]byte, error) { data, err := os.ReadFile(filePath) if err != nil { return nil, fmt.Errorf("failed to read file %s: %w", filePath, err) } return Gzip(data) } // GzipFileInChunks compresses file in chunks (memory friendly) func GzipFileInChunks(filePath string, chunkSize int) ([]byte, error) { compressor, err := NewGzipCompressorFromFile(filePath) if err != nil { return nil, err } defer compressor.Close() err = compressor.CompressFileInChunks(chunkSize) if err != nil { return nil, err } return compressor.Close() } // GzipFileToFile compresses file and saves to another file func GzipFileToFile(srcPath, dstPath string, chunkSize int) error { srcFile, err := os.Open(srcPath) if err != nil { return fmt.Errorf("failed to open source file %s: %w", srcPath, err) } defer srcFile.Close() dstFile, err := os.Create(dstPath) if err != nil { return fmt.Errorf("failed to create destination file %s: %w", dstPath, err) } defer dstFile.Close() return GzipFromReaderToWriter(srcFile, dstFile) } // GzipFileStream compresses file in streaming mode, returns Reader interface func GzipFileStream(filePath string) (io.Reader, error) { file, err := os.Open(filePath) if err != nil { return nil, fmt.Errorf("failed to open file %s: %w", filePath, err) } pr, pw := io.Pipe() go func() { defer pw.Close() defer file.Close() gz := gzip.NewWriter(pw) defer gz.Close() _, err := io.Copy(gz, file) if err != nil { pw.CloseWithError(err) } }() return pr, nil } ================================================ FILE: attachment/load.go ================================================ package attachment import ( "fmt" "strings" "github.com/yaoapp/gou/application" "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/data" "github.com/yaoapp/yao/share" ) // SystemUploaders system uploaders var systemUploaders = map[string]string{ "__yao.attachment": "yao/uploaders/attachment.local.yao", } // Load load uploaders func Load(cfg config.Config) error { // Register attachment processes Init() messages := []string{} // Load system uploaders err := loadSystemUploaders(cfg) if err != nil { return err } // Load filesystem uploaders exts := []string{"*.s3.yao", "*.local.yao", "*.s3.json", "*.local.json", "*.s3.jsonc", "*.local.jsonc"} err = application.App.Walk("uploaders", func(root, file string, isdir bool) error { if isdir { return nil } // Skip if not uploader file if !isUploaderFile(file) { return nil } err := loadUploaderFile(root, file, cfg) if err != nil { messages = append(messages, err.Error()) } return err }, exts...) if len(messages) > 0 { for _, message := range messages { log.Error("Load filesystem uploaders error: %s", message) } return fmt.Errorf("%s", strings.Join(messages, ";\n")) } return nil } // loadSystemUploaders load system uploaders func loadSystemUploaders(cfg config.Config) error { for id, path := range systemUploaders { content, err := data.Read(path) if err != nil { return err } // Parse uploader config var option ManagerOption err = application.Parse(path, content, &option) if err != nil { return err } // Replace environment variables and paths option.ReplaceEnv(cfg.DataRoot) // Register the uploader manager _, err = Register(id, option.Driver, option) if err != nil { log.Error("register system uploader %s error: %s", id, err.Error()) return err } log.Info("loaded system uploader: %s (%s)", id, option.Label) } return nil } // loadUploaderFile load a single uploader file func loadUploaderFile(root, file string, cfg config.Config) error { // Generate uploader ID id := share.ID(root, file) // Read file content content, err := application.App.Read(file) if err != nil { return fmt.Errorf("failed to read uploader file %s: %v", file, err) } // Parse uploader config var option ManagerOption err = application.Parse(file, content, &option) if err != nil { return fmt.Errorf("failed to parse uploader file %s: %v", file, err) } // Validate driver consistency between filename and config filenameDriver := extractDriverFromFilename(file) if filenameDriver != "" && option.Driver != "" && filenameDriver != option.Driver { log.Warn("Driver mismatch in uploader file %s: filename suggests '%s' but config has '%s'", file, filenameDriver, option.Driver) } // Replace environment variables and paths option.ReplaceEnv(cfg.DataRoot) // Register the uploader manager _, err = Register(id, option.Driver, option) if err != nil { log.Error("register uploader %s error: %s", id, err.Error()) return fmt.Errorf("failed to register uploader %s: %v", id, err) } log.Info("loaded uploader: %s (%s)", id, option.Label) return nil } // isUploaderFile checks if the file is an uploader configuration file func isUploaderFile(filename string) bool { // Accept files with specific driver patterns: *.s3.yao, *.local.yao, etc. lower := strings.ToLower(filename) return strings.HasSuffix(lower, ".s3.yao") || strings.HasSuffix(lower, ".local.yao") || strings.HasSuffix(lower, ".s3.json") || strings.HasSuffix(lower, ".local.json") || strings.HasSuffix(lower, ".s3.jsonc") || strings.HasSuffix(lower, ".local.jsonc") } // extractDriverFromFilename extracts the driver name from filename (e.g., "test.s3.yao" -> "s3") func extractDriverFromFilename(filename string) string { lower := strings.ToLower(filename) // Extract driver from patterns like "*.s3.yao", "*.local.json", etc. if strings.Contains(lower, ".s3.") { return "s3" } else if strings.Contains(lower, ".local.") { return "local" } return "" } ================================================ FILE: attachment/load_test.go ================================================ package attachment import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func TestLoad(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() err := Load(config.Conf) assert.NoError(t, err) check(t) } func check(t *testing.T) { // Check that managers are loaded assert.NotEmpty(t, Managers, "Managers should not be empty after loading") // Check system uploader _, exists := Managers["__yao.attachment"] assert.True(t, exists, "System uploader __yao.attachment should be loaded") // Check test app uploaders (must exist) // These are the uploaders in yao-dev-app/uploaders/ _, hasData := Managers["data"] _, hasTest := Managers["test"] // Both test uploaders should be loaded assert.True(t, hasData, "Test uploader 'data' should be loaded from data.local.yao") assert.True(t, hasTest, "Test uploader 'test' should be loaded from test.s3.yao") // Log all loaded managers for debugging t.Logf("Loaded managers: %v", getManagerNames()) } // getManagerNames returns a slice of manager names for testing func getManagerNames() []string { names := make([]string, 0, len(Managers)) for name := range Managers { names = append(names, name) } return names } ================================================ FILE: attachment/local/storage.go ================================================ package local import ( "compress/gzip" "context" "crypto/sha256" "fmt" "io" "mime" "net/http" "os" "path/filepath" "strings" "time" ) // MaxImageSize maximum image size (1920x1080) const MaxImageSize = 1920 // Storage the local storage driver type Storage struct { Path string `json:"path" yaml:"path"` Compression bool `json:"compression" yaml:"compression"` BaseURL string `json:"base_url" yaml:"base_url"` PreviewURL func(fileID string) string `json:"-" yaml:"-"` } // New create a new local storage func New(options map[string]interface{}) (*Storage, error) { storage := &Storage{ Compression: true, } if path, ok := options["path"].(string); ok { storage.Path = path } if compression, ok := options["compression"].(bool); ok { storage.Compression = compression } if baseURL, ok := options["base_url"].(string); ok { storage.BaseURL = baseURL } if previewURL, ok := options["preview_url"].(func(string) string); ok { storage.PreviewURL = previewURL } if storage.Path == "" { return nil, fmt.Errorf("path is required") } // Ensure the base path exists if err := os.MkdirAll(storage.Path, 0755); err != nil { return nil, fmt.Errorf("failed to create base path: %w", err) } return storage, nil } // Upload upload file to local storage func (storage *Storage) Upload(ctx context.Context, path string, reader io.Reader, contentType string) (string, error) { fullPath := filepath.Join(storage.Path, path) // Create directory if not exists dir := filepath.Dir(fullPath) if err := os.MkdirAll(dir, 0755); err != nil { return "", err } // Create and write file file, err := os.Create(fullPath) if err != nil { return "", err } defer file.Close() _, err = io.Copy(file, reader) if err != nil { return "", err } return path, nil } // UploadChunk uploads a chunk of a file func (storage *Storage) UploadChunk(ctx context.Context, path string, chunkIndex int, reader io.Reader, contentType string) error { // Create chunks directory chunksDir := filepath.Join(storage.Path, ".chunks", path) if err := os.MkdirAll(chunksDir, 0755); err != nil { return err } // Write chunk file chunkPath := filepath.Join(chunksDir, fmt.Sprintf("chunk_%d", chunkIndex)) file, err := os.Create(chunkPath) if err != nil { return err } defer file.Close() _, err = io.Copy(file, reader) return err } // MergeChunks merges all chunks into the final file func (storage *Storage) MergeChunks(ctx context.Context, path string, totalChunks int) error { chunksDir := filepath.Join(storage.Path, ".chunks", path) finalPath := filepath.Join(storage.Path, path) // Create directory for final file dir := filepath.Dir(finalPath) if err := os.MkdirAll(dir, 0755); err != nil { return err } // Create final file finalFile, err := os.Create(finalPath) if err != nil { return err } defer finalFile.Close() // Read and merge chunks in order for i := 0; i < totalChunks; i++ { chunkPath := filepath.Join(chunksDir, fmt.Sprintf("chunk_%d", i)) chunkFile, err := os.Open(chunkPath) if err != nil { return fmt.Errorf("failed to read chunk %d: %w", i, err) } _, err = io.Copy(finalFile, chunkFile) chunkFile.Close() if err != nil { return fmt.Errorf("failed to copy chunk %d: %w", i, err) } } // Clean up chunks directory os.RemoveAll(chunksDir) return nil } // Reader read file from local storage func (storage *Storage) Reader(ctx context.Context, path string) (io.ReadCloser, error) { fullpath := filepath.Join(storage.Path, path) reader, err := os.Open(fullpath) if err != nil { return nil, err } // If the file is a gzip file, decompress it if strings.HasSuffix(path, ".gz") { reader, err := gzip.NewReader(reader) if err != nil { return nil, err } return reader, nil } return reader, nil } // Download download file from local storage func (storage *Storage) Download(ctx context.Context, path string) (io.ReadCloser, string, error) { fullPath := filepath.Join(storage.Path, path) reader, err := os.Open(fullPath) if err != nil { return nil, "", err } // Try to detect content type from file extension contentType := "application/octet-stream" ext := filepath.Ext(strings.TrimSuffix(path, ".gz")) switch strings.ToLower(ext) { case ".txt": contentType = "text/plain" case ".html": contentType = "text/html" case ".css": contentType = "text/css" case ".js": contentType = "application/javascript" case ".json": contentType = "application/json" case ".jpg", ".jpeg": contentType = "image/jpeg" case ".png": contentType = "image/png" case ".gif": contentType = "image/gif" case ".pdf": contentType = "application/pdf" case ".mp4": contentType = "video/mp4" case ".mp3": contentType = "audio/mpeg" case ".wav": contentType = "audio/wav" case ".ogg": contentType = "audio/ogg" case ".webm": contentType = "video/webm" case ".webp": contentType = "image/webp" case ".zip": } // If the file is a gzip file, decompress it if strings.HasSuffix(path, ".gz") { reader, err := gzip.NewReader(reader) if err != nil { return nil, "", err } return reader, contentType, nil } return reader, contentType, nil } // URL get file url func (storage *Storage) URL(ctx context.Context, path string) string { if storage.PreviewURL != nil { return storage.PreviewURL(path) } if storage.BaseURL != "" { return fmt.Sprintf("%s/%s", strings.TrimRight(storage.BaseURL, "/"), path) } return fmt.Sprintf("%s/%s", storage.Path, path) } // GetContent gets file content as bytes func (storage *Storage) GetContent(ctx context.Context, path string) ([]byte, error) { reader, err := storage.Reader(ctx, path) if err != nil { return nil, err } defer reader.Close() return io.ReadAll(reader) } // Exists checks if a file exists func (storage *Storage) Exists(ctx context.Context, path string) bool { fullpath := filepath.Join(storage.Path, path) _, err := os.Stat(fullpath) return err == nil } // Delete deletes a file func (storage *Storage) Delete(ctx context.Context, path string) error { fullpath := filepath.Join(storage.Path, path) return os.Remove(fullpath) } func (storage *Storage) makeID(filename string, ext string) string { date := time.Now().Format("20060102") hash := fmt.Sprintf("%x", sha256.Sum256([]byte(filename)))[:8] name := strings.TrimSuffix(filepath.Base(filename), ext) return fmt.Sprintf("%s/%s-%s%s", date, name, hash, ext) } // LocalPath returns the absolute path of the file and its content type func (storage *Storage) LocalPath(ctx context.Context, path string) (string, string, error) { fullPath := filepath.Join(storage.Path, path) // Check if file exists if _, err := os.Stat(fullPath); os.IsNotExist(err) { return "", "", fmt.Errorf("file not found: %s", path) } // For gzipped files, we need to detect the original content type, not the gzip wrapper var contentType string var err error if strings.HasSuffix(path, ".gz") { // For gzipped files, detect content type of the decompressed content originalPath := strings.TrimSuffix(path, ".gz") ext := filepath.Ext(originalPath) // First try to detect by original file extension contentType, err = detectContentTypeFromExtension(ext) if err != nil || contentType == "application/octet-stream" { // Fallback: decompress and detect from content contentType, err = detectContentTypeFromGzippedFile(fullPath) if err != nil { return "", "", fmt.Errorf("failed to detect content type from gzipped file: %w", err) } } } else { // Regular file content type detection contentType, err = detectContentType(fullPath) if err != nil { return "", "", fmt.Errorf("failed to detect content type: %w", err) } } // Return absolute path absPath, err := filepath.Abs(fullPath) if err != nil { return "", "", fmt.Errorf("failed to get absolute path: %w", err) } return absPath, contentType, nil } // detectContentType detects content type based on file extension and content func detectContentType(filePath string) (string, error) { // First try to detect by file extension ext := strings.ToLower(filepath.Ext(filePath)) // Common file extensions mapping switch ext { case ".txt": return "text/plain", nil case ".html", ".htm": return "text/html", nil case ".css": return "text/css", nil case ".js": return "application/javascript", nil case ".json": return "application/json", nil case ".xml": return "application/xml", nil case ".jpg", ".jpeg": return "image/jpeg", nil case ".png": return "image/png", nil case ".gif": return "image/gif", nil case ".webp": return "image/webp", nil case ".svg": return "image/svg+xml", nil case ".pdf": return "application/pdf", nil case ".doc": return "application/msword", nil case ".docx": return "application/vnd.openxmlformats-officedocument.wordprocessingml.document", nil case ".xls": return "application/vnd.ms-excel", nil case ".xlsx": return "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", nil case ".ppt": return "application/vnd.ms-powerpoint", nil case ".pptx": return "application/vnd.openxmlformats-officedocument.presentationml.presentation", nil case ".zip": return "application/zip", nil case ".tar": return "application/x-tar", nil case ".gz": return "application/gzip", nil case ".mp3": return "audio/mpeg", nil case ".wav": return "audio/wav", nil case ".m4a": return "audio/mp4", nil case ".ogg": return "audio/ogg", nil case ".mp4": return "video/mp4", nil case ".avi": return "video/x-msvideo", nil case ".mov": return "video/quicktime", nil case ".webm": return "video/webm", nil case ".md", ".mdx": return "text/markdown", nil case ".yao": return "application/yao", nil case ".csv": return "text/csv", nil } // Try to detect by MIME package if contentType := mime.TypeByExtension(ext); contentType != "" { return contentType, nil } // Fallback: detect by reading file content file, err := os.Open(filePath) if err != nil { return "application/octet-stream", nil // Default fallback } defer file.Close() // Read first 512 bytes for content detection buffer := make([]byte, 512) n, err := file.Read(buffer) if err != nil && err != io.EOF { return "application/octet-stream", nil } // Use http.DetectContentType to detect based on content contentType := http.DetectContentType(buffer[:n]) return contentType, nil } // detectContentTypeFromExtension detects content type based only on file extension func detectContentTypeFromExtension(ext string) (string, error) { ext = strings.ToLower(ext) // Common file extensions mapping switch ext { case ".txt": return "text/plain", nil case ".html", ".htm": return "text/html", nil case ".css": return "text/css", nil case ".js": return "application/javascript", nil case ".json": return "application/json", nil case ".xml": return "application/xml", nil case ".jpg", ".jpeg": return "image/jpeg", nil case ".png": return "image/png", nil case ".gif": return "image/gif", nil case ".webp": return "image/webp", nil case ".svg": return "image/svg+xml", nil case ".pdf": return "application/pdf", nil case ".doc": return "application/msword", nil case ".docx": return "application/vnd.openxmlformats-officedocument.wordprocessingml.document", nil case ".xls": return "application/vnd.ms-excel", nil case ".xlsx": return "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", nil case ".ppt": return "application/vnd.ms-powerpoint", nil case ".pptx": return "application/vnd.openxmlformats-officedocument.presentationml.presentation", nil case ".zip": return "application/zip", nil case ".tar": return "application/x-tar", nil case ".mp3": return "audio/mpeg", nil case ".wav": return "audio/wav", nil case ".m4a": return "audio/mp4", nil case ".ogg": return "audio/ogg", nil case ".mp4": return "video/mp4", nil case ".avi": return "video/x-msvideo", nil case ".mov": return "video/quicktime", nil case ".webm": return "video/webm", nil case ".md", ".mdx": return "text/markdown", nil case ".yao": return "application/yao", nil case ".csv": return "text/csv", nil } // Try to detect by MIME package if contentType := mime.TypeByExtension(ext); contentType != "" { return contentType, nil } // Return default if not found return "application/octet-stream", nil } // detectContentTypeFromGzippedFile detects content type by decompressing and reading gzipped file func detectContentTypeFromGzippedFile(gzippedFilePath string) (string, error) { file, err := os.Open(gzippedFilePath) if err != nil { return "", err } defer file.Close() // Create gzip reader gzipReader, err := gzip.NewReader(file) if err != nil { return "", err } defer gzipReader.Close() // Read first 512 bytes of decompressed content buffer := make([]byte, 512) n, err := gzipReader.Read(buffer) if err != nil && err != io.EOF { return "", err } // Use http.DetectContentType to detect based on decompressed content contentType := http.DetectContentType(buffer[:n]) return contentType, nil } ================================================ FILE: attachment/local/storage_test.go ================================================ package local import ( "bytes" "context" "image" "image/png" "io" "os" "path/filepath" "testing" "github.com/google/uuid" "github.com/stretchr/testify/assert" ) // generateTestFileName generates a unique test filename with the given prefix and extension func generateTestFileName(prefix, ext string) string { return prefix + "-" + uuid.New().String() + ext } func TestLocalStorage(t *testing.T) { // Create a temporary directory for testing tempDir, err := os.MkdirTemp("", "local_storage_test") assert.NoError(t, err) defer os.RemoveAll(tempDir) testPath := filepath.Join(tempDir, "test_storage") t.Run("Create Storage", func(t *testing.T) { storage, err := New(map[string]interface{}{ "path": testPath, "compression": true, }) assert.NoError(t, err) assert.NotNil(t, storage) assert.Equal(t, testPath, storage.Path) assert.True(t, storage.Compression) }) t.Run("Upload and Download", func(t *testing.T) { storage, err := New(map[string]interface{}{ "path": testPath, "compression": true, }) assert.NoError(t, err) content := []byte("test content") reader := bytes.NewReader(content) fileID := generateTestFileName("upload-download", ".txt") _, err = storage.Upload(context.Background(), fileID, reader, "text/plain") assert.NoError(t, err) assert.NotEmpty(t, fileID) // Download reader2, contentType, err := storage.Download(context.Background(), fileID) assert.NoError(t, err) assert.Contains(t, contentType, "text/plain") downloaded, err := io.ReadAll(reader2) assert.NoError(t, err) assert.Equal(t, content, downloaded) }) t.Run("Upload and Download Image with Compression", func(t *testing.T) { storage, err := New(map[string]interface{}{ "path": testPath, "compression": true, }) assert.NoError(t, err) // Create a test image (100x100 pixels - smaller for faster testing) img := image.NewRGBA(image.Rect(0, 0, 100, 100)) var buf bytes.Buffer err = png.Encode(&buf, img) assert.NoError(t, err) // Upload reader := bytes.NewReader(buf.Bytes()) fileID := generateTestFileName("image-with-compression", ".png") _, err = storage.Upload(context.Background(), fileID, reader, "image/png") assert.NoError(t, err) assert.NotEmpty(t, fileID) // Download and verify reader2, contentType, err := storage.Download(context.Background(), fileID) assert.NoError(t, err) assert.Equal(t, "image/png", contentType) downloaded, err := io.ReadAll(reader2) assert.NoError(t, err) // Decode the downloaded image downloadedImg, _, err := image.Decode(bytes.NewReader(downloaded)) assert.NoError(t, err) // Verify image was processed bounds := downloadedImg.Bounds() assert.True(t, bounds.Dx() > 0) assert.True(t, bounds.Dy() > 0) }) t.Run("Upload Image without Compression", func(t *testing.T) { storage, err := New(map[string]interface{}{ "path": testPath, "compression": false, }) assert.NoError(t, err) // Create a test image (100x100 pixels) img := image.NewRGBA(image.Rect(0, 0, 100, 100)) var buf bytes.Buffer err = png.Encode(&buf, img) assert.NoError(t, err) // Upload reader := bytes.NewReader(buf.Bytes()) fileID := generateTestFileName("image-without-compression", ".png") _, err = storage.Upload(context.Background(), fileID, reader, "image/png") assert.NoError(t, err) assert.NotEmpty(t, fileID) // Download and verify reader2, contentType, err := storage.Download(context.Background(), fileID) assert.NoError(t, err) assert.Equal(t, "image/png", contentType) downloaded, err := io.ReadAll(reader2) assert.NoError(t, err) // Decode the downloaded image downloadedImg, _, err := image.Decode(bytes.NewReader(downloaded)) assert.NoError(t, err) // Verify dimensions are unchanged bounds := downloadedImg.Bounds() assert.Equal(t, 100, bounds.Dx()) assert.Equal(t, 100, bounds.Dy()) }) t.Run("URL Generation", func(t *testing.T) { storage, err := New(map[string]interface{}{ "path": testPath, "compression": true, }) assert.NoError(t, err) fileID := "20240101/test-12345678.txt" url := storage.URL(context.Background(), fileID) expected := filepath.Join(testPath, fileID) assert.Equal(t, expected, url) }) t.Run("Download Non-existent File", func(t *testing.T) { storage, err := New(map[string]interface{}{ "path": testPath, "compression": true, }) assert.NoError(t, err) _, _, err = storage.Download(context.Background(), "non-existent.txt") assert.Error(t, err) }) t.Run("Chunked Upload", func(t *testing.T) { storage, err := New(map[string]interface{}{ "path": testPath, }) assert.NoError(t, err) fileID := "test-chunked.txt" content1 := []byte("chunk1") content2 := []byte("chunk2") // Upload chunks err = storage.UploadChunk(context.Background(), fileID, 0, bytes.NewReader(content1), "text/plain") assert.NoError(t, err) err = storage.UploadChunk(context.Background(), fileID, 1, bytes.NewReader(content2), "text/plain") assert.NoError(t, err) // Merge chunks err = storage.MergeChunks(context.Background(), fileID, 2) assert.NoError(t, err) // Download and verify reader, contentType, err := storage.Download(context.Background(), fileID) assert.NoError(t, err) assert.Equal(t, "text/plain", contentType) downloaded, err := io.ReadAll(reader) assert.NoError(t, err) assert.Equal(t, append(content1, content2...), downloaded) }) t.Run("File Operations", func(t *testing.T) { storage, err := New(map[string]interface{}{ "path": testPath, }) assert.NoError(t, err) fileID := "test-ops.txt" content := []byte("test content") // Upload file _, err = storage.Upload(context.Background(), fileID, bytes.NewReader(content), "text/plain") assert.NoError(t, err) // Check if file exists exists := storage.Exists(context.Background(), fileID) assert.True(t, exists) // Read file reader, err := storage.Reader(context.Background(), fileID) assert.NoError(t, err) defer reader.Close() data, err := io.ReadAll(reader) assert.NoError(t, err) assert.Equal(t, content, data) // Get file content directly directContent, err := storage.GetContent(context.Background(), fileID) assert.NoError(t, err) assert.Equal(t, content, directContent) // Delete file err = storage.Delete(context.Background(), fileID) assert.NoError(t, err) // Check if file no longer exists exists = storage.Exists(context.Background(), fileID) assert.False(t, exists) }) t.Run("LocalPath", func(t *testing.T) { storage, err := New(map[string]interface{}{ "path": testPath, }) assert.NoError(t, err) // Test different file types to verify content type detection testFiles := []struct { ext string content []byte contentType string expectedCT string }{ {".txt", []byte("Hello World"), "text/plain", "text/plain"}, {".json", []byte(`{"key": "value"}`), "application/json", "application/json"}, {".html", []byte("Test"), "text/html", "text/html"}, {".csv", []byte("col1,col2\nval1,val2"), "text/csv", "text/csv"}, {".md", []byte("# Markdown Content"), "text/markdown", "text/markdown"}, {".yao", []byte("yao file content"), "application/yao", "application/yao"}, } for _, tf := range testFiles { // Generate unique filename with UUID to avoid conflicts fileName := generateTestFileName("localpath-test", tf.ext) // Upload file _, err = storage.Upload(context.Background(), fileName, bytes.NewReader(tf.content), tf.contentType) assert.NoError(t, err, "Failed to upload %s", fileName) // Get local path and content type localPath, detectedCT, err := storage.LocalPath(context.Background(), fileName) assert.NoError(t, err, "Failed to get local path for %s", fileName) assert.NotEmpty(t, localPath, "Local path should not be empty for %s", fileName) assert.Equal(t, tf.expectedCT, detectedCT, "Content type mismatch for %s", fileName) // Verify the path is absolute assert.True(t, filepath.IsAbs(localPath), "Path should be absolute for %s", fileName) // Verify the file exists at the returned path _, err = os.Stat(localPath) assert.NoError(t, err, "File should exist at local path for %s", fileName) // Verify file content fileContent, err := os.ReadFile(localPath) assert.NoError(t, err, "Failed to read file at local path for %s", fileName) assert.Equal(t, tf.content, fileContent, "File content mismatch for %s", fileName) } }) t.Run("LocalPath_NonExistentFile", func(t *testing.T) { storage, err := New(map[string]interface{}{ "path": testPath, }) assert.NoError(t, err) // Test with non-existent file _, _, err = storage.LocalPath(context.Background(), "non-existent.txt") assert.Error(t, err) assert.Contains(t, err.Error(), "file not found") }) t.Run("LocalPath_ContentDetection", func(t *testing.T) { storage, err := New(map[string]interface{}{ "path": testPath, }) assert.NoError(t, err) // Upload a file without extension but with recognizable content htmlContent := []byte("Test

Hello

") _, err = storage.Upload(context.Background(), "noext", bytes.NewReader(htmlContent), "application/octet-stream") assert.NoError(t, err) // Get local path - should detect HTML content type localPath, contentType, err := storage.LocalPath(context.Background(), "noext") assert.NoError(t, err) assert.NotEmpty(t, localPath) // Content detection should identify this as HTML assert.Equal(t, "text/html; charset=utf-8", contentType) }) } ================================================ FILE: attachment/manager.go ================================================ package attachment import ( "bytes" "context" "crypto/md5" "encoding/base64" "encoding/hex" "fmt" "io" "mime" "mime/multipart" "net/http" "net/textproto" "os" "path/filepath" "reflect" "strconv" "strings" "sync" "time" "github.com/yaoapp/gou/fs" "github.com/yaoapp/gou/model" "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/attachment/local" "github.com/yaoapp/yao/attachment/s3" "github.com/yaoapp/yao/config" ) // Ensure Manager implements FileManager interface var _ FileManager = (*Manager)(nil) // Managers the managers var Managers = map[string]*Manager{} var uploadChunks = sync.Map{} // UploadChunk is the chunk data type UploadChunk struct { Last int Total int64 Chunksize int64 TotalChunks int64 // Cache metadata from first chunk to avoid inconsistencies ContentType string Filename string UserPath string CompressImage bool CompressSize int } // Parse parses an attachment wrapper string and returns uploader name and file ID // Format: __:// // Example: __yao.attachment://ccd472d11feb96e03a3fc468f494045c // Returns (uploader, fileID, isWrapper) func Parse(value string) (string, string, bool) { if !strings.HasPrefix(value, "__") { return "", value, false } // Exclude common protocols (ftp, http, https, etc.) excludedProtocols := []string{"__ftp://", "__http://", "__https://", "__ws://", "__wss://"} for _, protocol := range excludedProtocols { if strings.HasPrefix(value, protocol) { return "", value, false } } // Split by :// parts := strings.SplitN(value, "://", 2) if len(parts) != 2 { return "", value, false } uploader := parts[0] // Keep the __ prefix as it's part of the manager name fileID := parts[1] return uploader, fileID, true } // Base64 processes a wrapper value and converts it to Base64 if it's an attachment wrapper // If the value is not a wrapper, it returns the original value // Special case: if value looks like a file path, it will try to read from fs data // Optional parameter dataURI: if true, returns data URI format (data:image/png;base64,...) func Base64(ctx context.Context, value string, dataURI ...bool) string { useDataURI := false if len(dataURI) > 0 { useDataURI = dataURI[0] } uploader, fileID, isWrapper := Parse(value) if !isWrapper { // Try to read as file path from fs data if base64Data := readFilePathAsBase64(value, useDataURI); base64Data != "" { return base64Data } return value } // Get the manager manager, exists := Managers[uploader] if !exists { return value } // Get file info to determine content type var contentType string if useDataURI { fileInfo, err := manager.Info(ctx, fileID) if err == nil && fileInfo != nil { contentType = fileInfo.ContentType } } // Read the file as Base64 base64Data, err := manager.ReadBase64(ctx, fileID) if err != nil { return value } // Return with data URI prefix if requested if useDataURI && contentType != "" { return fmt.Sprintf("data:%s;base64,%s", contentType, base64Data) } return base64Data } // readFilePathAsBase64 reads a file from fs data and returns Base64 encoded content // Returns empty string if file doesn't exist or can't be read // If dataURI is true, returns data URI format with mime type detection func readFilePathAsBase64(path string, dataURI bool) string { // Check if path looks like a file path (contains / or \) if !strings.Contains(path, "/") && !strings.Contains(path, "\\") { return "" } // Try to get fs data dataFS, err := fs.Get("data") if err != nil || dataFS == nil { return "" } // Check if file exists exists, err := dataFS.Exists(path) if err != nil || !exists { return "" } // Read file content content, err := dataFS.ReadFile(path) if err != nil { return "" } // Encode to Base64 base64Str := base64.StdEncoding.EncodeToString(content) // Return with data URI prefix if requested if dataURI { // Detect content type from file extension or content contentType := detectContentType(path, content) if contentType != "" { return fmt.Sprintf("data:%s;base64,%s", contentType, base64Str) } } return base64Str } // detectContentType detects the MIME type from file path and content func detectContentType(path string, content []byte) string { // First try to get from file extension ext := filepath.Ext(path) if ext != "" { mimeType := mime.TypeByExtension(ext) if mimeType != "" { return mimeType } } // Fallback to detecting from content (first 512 bytes) if len(content) > 0 { detectSize := len(content) if detectSize > 512 { detectSize = 512 } return http.DetectContentType(content[:detectSize]) } return "" } // GetHeader gets the header from the file header and request header func GetHeader(requestHeader http.Header, fileHeader textproto.MIMEHeader, size int64) *FileHeader { // Convert the header to a FileHeader header := &FileHeader{FileHeader: &multipart.FileHeader{Header: make(map[string][]string), Size: size}} for key, values := range fileHeader { for _, value := range values { header.Header.Set(key, value) } } // Set Content-Sync, Content-Uid, Content-Range if requestHeader.Get("Content-Sync") != "" { header.Header.Set("Content-Sync", requestHeader.Get("Content-Sync")) } // Set Content-Uid if requestHeader.Get("Content-Uid") != "" { header.Header.Set("Content-Uid", requestHeader.Get("Content-Uid")) } // Set Content-Range if requestHeader.Get("Content-Range") != "" { header.Header.Set("Content-Range", requestHeader.Get("Content-Range")) } return header } // Register registers a global attachment manager func Register(name string, driver string, option ManagerOption) (*Manager, error) { // Create a new manager manager, err := New(option) if err != nil { return nil, err } // Set the manager name manager.Name = name // Register the manager Managers[name] = manager return manager, nil } // RegisterDefault registers a default attachment manager func RegisterDefault(name string) (*Manager, error) { option := ManagerOption{ Driver: "local", Options: map[string]interface{}{"path": filepath.Join(config.Conf.DataRoot, name)}, MaxSize: "50M", ChunkSize: "2M", AllowedTypes: []string{ "text/*", "image/*", "video/*", "audio/*", "application/x-zip-compressed", "application/x-tar", "application/x-gzip", "application/yao", "application/zip", "application/pdf", "application/json", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", "application/vnd.openxmlformats-officedocument.presentationml.presentation", "application/vnd.openxmlformats-officedocument.presentationml.slideshow", ".md", ".txt", ".csv", ".xls", ".xlsx", ".ppt", ".pptx", ".doc", ".docx", ".mdx", ".m4a", ".mp3", ".mp4", ".wav", ".webm", ".yao", }, } return Register(name, option.Driver, option) } // ReplaceEnv replaces the environment variables in the options func (option *ManagerOption) ReplaceEnv(root string) { if option.Options != nil { // Replace the environment variables in the options for k, v := range option.Options { if iv, ok := v.(string); ok { if strings.HasPrefix(iv, "$ENV.") { iv = os.ExpandEnv(fmt.Sprintf("${%s}", strings.TrimPrefix(iv, "$ENV."))) option.Options[k] = iv } // Path if k == "path" { iv = strings.TrimPrefix(iv, "/") option.Options[k] = filepath.Join(root, iv) } } } } } // New creates a new attachment manager func New(option ManagerOption) (*Manager, error) { manager := &Manager{ ManagerOption: option, allowedTypes: allowedType{mapping: make(map[string]bool), wildcards: []string{}, }} switch strings.ToLower(option.Driver) { case "local": storage, err := local.New(option.Options) if err != nil { return nil, err } manager.storage = storage case "s3": storage, err := s3.New(option.Options) if err != nil { return nil, err } manager.storage = storage default: return nil, fmt.Errorf("driver %s does not support", option.Driver) } // Max size if option.MaxSize != "" { maxsize, err := getSize(option.MaxSize) if err != nil { return nil, err } manager.maxsize = maxsize } // Chunk size if option.ChunkSize != "" { chunsize, err := getSize(option.ChunkSize) if err != nil { return nil, err } manager.chunsize = chunsize } // init allowedTypes if len(option.AllowedTypes) > 0 { for _, t := range option.AllowedTypes { t = strings.TrimSpace(t) if strings.HasSuffix(t, "*") { manager.allowedTypes.wildcards = append(manager.allowedTypes.wildcards, t) continue } manager.allowedTypes.mapping[t] = true } } return manager, nil } // LocalPath gets the local path of the file func (manager Manager) LocalPath(ctx context.Context, fileID string) (string, string, error) { // Get the real storage path from database storagePath, err := manager.getStoragePathFromDatabase(ctx, fileID) if err != nil { return "", "", err } // Call the storage implementation return manager.storage.LocalPath(ctx, storagePath) } // Upload uploads a file, Content-Sync must be true for chunked upload func (manager Manager) Upload(ctx context.Context, fileheader *FileHeader, reader io.Reader, option UploadOption) (*File, error) { file, err := manager.makeFile(fileheader, option) if err != nil { return nil, err } // Handle chunked upload if fileheader.IsChunk() { start, end, total, err := fileheader.GetChunkInfo() if err != nil { return nil, fmt.Errorf("invalid chunk info: %w", err) } // Store the chunk info chunkIndex := 0 if start == 0 { chunksize := end - start + 1 totalChunks := (total + chunksize - 1) / chunksize uploadChunks.LoadOrStore(file.ID, &UploadChunk{ Last: chunkIndex, Total: total, Chunksize: chunksize, TotalChunks: totalChunks, // Cache metadata from first chunk ContentType: file.ContentType, Filename: file.Filename, UserPath: file.UserPath, CompressImage: option.CompressImage, CompressSize: option.CompressSize, }) } // Update the chunk index v, ok := uploadChunks.Load(file.ID) if !ok { return nil, fmt.Errorf("chunk data not found") } chunkdata := v.(*UploadChunk) // Update the chunk index if start != 0 { chunkIndex = chunkdata.Last + 1 chunkdata.Last = chunkIndex uploadChunks.Store(file.ID, chunkdata) // For non-first chunks, use cached metadata from first chunk file.ContentType = chunkdata.ContentType file.Filename = chunkdata.Filename file.UserPath = chunkdata.UserPath } // Apply gzip compression if requested if option.Gzip { compressed, err := GzipFromReader(reader) if err != nil { return nil, fmt.Errorf("failed to gzip chunk: %w", err) } reader = bytes.NewReader(compressed) } // Upload chunk using the storage path err = manager.storage.UploadChunk(ctx, file.Path, chunkIndex, reader, file.ContentType) if err != nil { return nil, err } // Save to database on first chunk only if start == 0 { file.Status = "uploading" err = manager.saveFileToDatabase(ctx, file, file.Path, option) if err != nil { return nil, fmt.Errorf("failed to create database record for chunked upload: %w", err) } } // Fix the file size, the file size is the sum of all chunks file.Bytes = chunkIndex * int(chunkdata.Chunksize) file.Status = "uploading" // If this is the last chunk, merge all chunks if fileheader.Complete() { err = manager.storage.MergeChunks(ctx, file.Path, int(chunkdata.TotalChunks)) if err != nil { return nil, err } // Set initial file size from chunks file.Bytes = int(chunkdata.Total) // Apply image compression if requested and it's the final file // Use cached compress options from first chunk if chunkdata.CompressImage && strings.HasPrefix(file.ContentType, "image/") { // Create a temporary option with cached compress size compressOption := UploadOption{ CompressSize: chunkdata.CompressSize, } compressedBytes, err := manager.compressStoredImageAndGetSize(ctx, file, compressOption) if err != nil { return nil, err } // Update file size to compressed size file.Bytes = compressedBytes } // Remove the chunk data uploadChunks.Delete(file.ID) // Update status to uploaded file.Status = "uploaded" // Update only bytes and status for the last chunk err = manager.saveFileToDatabase(ctx, file, file.Path, option) if err != nil { return nil, fmt.Errorf("failed to update chunked file status: %w", err) } } return file, nil } // Handle single file upload var finalReader io.Reader = reader // Apply gzip compression if requested if option.Gzip { compressed, err := GzipFromReader(reader) if err != nil { return nil, fmt.Errorf("failed to gzip file: %w", err) } finalReader = bytes.NewReader(compressed) } // Apply image compression if requested if option.CompressImage && strings.HasPrefix(file.ContentType, "image/") { size := option.CompressSize if size == 0 { size = 1920 } // Read original data for fallback var originalData []byte var err error // If gzip was applied, we need to decompress first if option.Gzip { data, err := io.ReadAll(finalReader) if err != nil { return nil, err } decompressed, err := Gunzip(data) if err != nil { return nil, err } originalData = decompressed finalReader = bytes.NewReader(decompressed) } else { originalData, err = io.ReadAll(finalReader) if err != nil { return nil, err } finalReader = bytes.NewReader(originalData) } // Try to compress the image with failback mechanism compressed, err := CompressImage(finalReader, file.ContentType, size) if err != nil { // Log the error and use original file as fallback log.Warn("Failed to compress image (content-type: %s, file: %s): %v. Using original file.", file.ContentType, file.Filename, err) // Use original data compressed = originalData } // Re-apply gzip if it was requested if option.Gzip { gzipped, err := Gzip(compressed) if err != nil { return nil, err } finalReader = bytes.NewReader(gzipped) } else { finalReader = bytes.NewReader(compressed) } } // Upload the file to storage using the generated storage path actualStoragePath, err := manager.storage.Upload(ctx, file.Path, finalReader, file.ContentType) if err != nil { return nil, err } // Update the actual storage path if storage returns a different path if actualStoragePath != "" && actualStoragePath != file.Path { file.Path = actualStoragePath } // Update the file status file.Status = "uploaded" // Save file information to database err = manager.saveFileToDatabase(ctx, file, file.Path, option) if err != nil { return nil, fmt.Errorf("failed to save file to database: %w", err) } return file, nil } // compressStoredImageAndGetSize compresses the stored image and returns the compressed size func (manager Manager) compressStoredImageAndGetSize(ctx context.Context, file *File, option UploadOption) (int, error) { // Download the stored file using storage path reader, err := manager.storage.Reader(ctx, file.Path) if err != nil { return 0, err } defer reader.Close() size := option.CompressSize if size == 0 { size = 1920 } // Read original data for fallback originalData, err := io.ReadAll(reader) if err != nil { return 0, err } // Try to compress the image with failback mechanism compressed, err := CompressImage(bytes.NewReader(originalData), file.ContentType, size) if err != nil { // Log the error and keep original file log.Warn("Failed to compress stored image (content-type: %s, file: %s): %v. Keeping original file.", file.ContentType, file.Filename, err) // File is already stored (merged chunks), just return original size return len(originalData), nil } // Re-upload the compressed image using storage path _, err = manager.storage.Upload(ctx, file.Path, bytes.NewReader(compressed), file.ContentType) if err != nil { return 0, err } // Return the compressed size return len(compressed), nil } // Download downloads a file func (manager Manager) Download(ctx context.Context, fileID string) (*FileResponse, error) { // Get real storage path from database storagePath, err := manager.getStoragePathFromDatabase(ctx, fileID) if err != nil { return nil, err } reader, contentType, err := manager.storage.Download(ctx, storagePath) if err != nil { return nil, err } extension := filepath.Ext(storagePath) if extension == "" { // Try to get extension from content type extensions, err := mime.ExtensionsByType(contentType) if err == nil && len(extensions) > 0 { extension = extensions[0] } } return &FileResponse{ Reader: reader, ContentType: contentType, Extension: extension, }, nil } // Read reads a file and returns the content as bytes func (manager Manager) Read(ctx context.Context, fileID string) ([]byte, error) { // Get file info from database to check if it's gzipped file, err := manager.getFileFromDatabase(ctx, fileID) if err != nil { return nil, err } reader, err := manager.storage.Reader(ctx, file.Path) if err != nil { return nil, err } defer reader.Close() data, err := io.ReadAll(reader) if err != nil { return nil, err } // Storage layer already handles gzip decompression for .gz files // No need to decompress again at Manager level return data, nil } // ReadBase64 reads a file and returns the content as base64 encoded string func (manager Manager) ReadBase64(ctx context.Context, fileID string) (string, error) { data, err := manager.Read(ctx, fileID) if err != nil { return "", err } return base64.StdEncoding.EncodeToString(data), nil } // Info retrieves complete file information from database by file ID func (manager Manager) Info(ctx context.Context, fileID string) (*File, error) { return manager.getFileFromDatabase(ctx, fileID) } // List retrieves files from database with pagination and filtering func (manager Manager) List(ctx context.Context, option ListOption) (*ListResult, error) { m := model.Select("__yao.attachment") // Set default values page := option.Page if page <= 0 { page = 1 } pageSize := option.PageSize if pageSize <= 0 { pageSize = 20 } // Build query parameters queryParam := model.QueryParam{} // Add select fields if len(option.Select) > 0 { queryParam.Select = make([]interface{}, 0, len(option.Select)) for _, field := range option.Select { queryParam.Select = append(queryParam.Select, field) } } else { // Default: exclude the 'content' field (which may contain large text data) // Only include it if explicitly requested in Select queryParam.Select = []interface{}{ "id", "file_id", "uploader", "content_type", "name", "url", "description", "type", "user_path", "path", "groups", "gzip", "bytes", "status", "progress", "error", "preset", "public", "share", "created_at", "updated_at", "deleted_at", "__yao_created_by", "__yao_updated_by", "__yao_team_id", "__yao_tenant_id", } } // Add filters if len(option.Filters) > 0 { queryParam.Wheres = make([]model.QueryWhere, 0, len(option.Filters)) for field, value := range option.Filters { where := model.QueryWhere{ Column: field, Value: value, } // Handle special operators for wildcard matching if strValue, ok := value.(string); ok { if strings.Contains(strValue, "*") { // Wildcard matching for LIKE queries where.OP = "like" where.Value = strings.ReplaceAll(strValue, "*", "%") } } queryParam.Wheres = append(queryParam.Wheres, where) } } // Add advanced where clauses (for permission filtering, etc.) if len(option.Wheres) > 0 { if queryParam.Wheres == nil { queryParam.Wheres = make([]model.QueryWhere, 0, len(option.Wheres)) } queryParam.Wheres = append(queryParam.Wheres, option.Wheres...) } // Add ordering if option.OrderBy != "" { // Parse order by string like "created_at desc" or "name asc" parts := strings.Fields(option.OrderBy) if len(parts) >= 1 { orderField := parts[0] orderDirection := "asc" if len(parts) >= 2 { orderDirection = strings.ToLower(parts[1]) } queryParam.Orders = []model.QueryOrder{ { Column: orderField, Option: orderDirection, }, } } } else { // Default order by created_at desc queryParam.Orders = []model.QueryOrder{ { Column: "created_at", Option: "desc", }, } } // Use model's built-in Paginate method result, err := m.Paginate(queryParam, page, pageSize) if err != nil { return nil, fmt.Errorf("failed to paginate files: %w", err) } // Extract pagination info from result total := int64(0) if totalInterface, ok := result["total"]; ok { if totalInt, ok := totalInterface.(int); ok { total = int64(totalInt) } else if totalInt64, ok := totalInterface.(int64); ok { total = totalInt64 } } // Extract data from result - handle maps.MapStrAny type var records []map[string]interface{} if dataInterface, ok := result["data"]; ok { // The data is of type []maps.MapStrAny, need to convert if dataSlice, ok := dataInterface.([]interface{}); ok { records = make([]map[string]interface{}, len(dataSlice)) for i, item := range dataSlice { if record, ok := item.(map[string]interface{}); ok { records[i] = record } } } else { // Try to handle it as the actual type returned by gou using reflection dataValue := reflect.ValueOf(dataInterface) if dataValue.Kind() == reflect.Slice { length := dataValue.Len() records = make([]map[string]interface{}, length) for i := 0; i < length; i++ { item := dataValue.Index(i).Interface() // Convert the item to map[string]interface{} using reflection if itemValue := reflect.ValueOf(item); itemValue.Kind() == reflect.Map { record := make(map[string]interface{}) for _, key := range itemValue.MapKeys() { if keyStr := key.String(); keyStr != "" { record[keyStr] = itemValue.MapIndex(key).Interface() } } records[i] = record } } } } } // Convert records to File structs files := make([]*File, 0, len(records)) for _, record := range records { file := &File{} // Map required fields if fileID, ok := record["file_id"].(string); ok { file.ID = fileID } if name, ok := record["name"].(string); ok { file.Filename = name } if contentType, ok := record["content_type"].(string); ok { file.ContentType = contentType } if status, ok := record["status"].(string); ok { file.Status = status } // Map optional fields if userPath, ok := record["user_path"].(string); ok { file.UserPath = userPath } if path, ok := record["path"].(string); ok { file.Path = path } if bytes, ok := record["bytes"].(int64); ok { file.Bytes = int(bytes) } else if bytesInt, ok := record["bytes"].(int); ok { file.Bytes = bytesInt } if createdAt, ok := record["created_at"].(int64); ok { file.CreatedAt = int(createdAt) } else if createdAtInt, ok := record["created_at"].(int); ok { file.CreatedAt = createdAtInt } else { // Fallback to current time if not available file.CreatedAt = int(time.Now().Unix()) } files = append(files, file) } // Calculate total pages totalPages := int((total + int64(pageSize) - 1) / int64(pageSize)) return &ListResult{ Files: files, Total: total, Page: page, PageSize: pageSize, TotalPages: totalPages, }, nil } // validate validates the file and option func (manager Manager) makeFile(file *FileHeader, option UploadOption) (*File, error) { // Validate max size if manager.maxsize > 0 && file.Size > manager.maxsize { return nil, fmt.Errorf("file size %d exceeds the maximum size of %d", file.Size, manager.maxsize) } // Use original filename if provided, otherwise use the file header filename filename := file.Filename userPath := option.OriginalFilename if userPath != "" { // If user provided a path, extract just the filename for the filename field filename = filepath.Base(userPath) } extension := filepath.Ext(filename) // Get the content type // For chunked uploads, file.Header may have incorrect content-type (e.g., application/octet-stream for Blob) // Try to detect from filename extension first, then fallback to header contentType := file.Header.Get("Content-Type") if extension != "" { // Try to get content type from extension detectedType := mime.TypeByExtension(extension) if detectedType != "" { // If detected type is not the generic octet-stream, use it // This handles chunked uploads where the header has incorrect type if detectedType != "application/octet-stream" || contentType == "application/octet-stream" { contentType = detectedType } } } // Get the extension from the content type if not available from filename if extension == "" { // Special handling for common types switch contentType { case "text/plain": extension = ".txt" case "image/jpeg": extension = ".jpg" case "image/png": extension = ".png" case "application/pdf": extension = ".pdf" default: extensions, err := mime.ExtensionsByType(contentType) if err == nil && len(extensions) > 0 { // For text/plain, prefer .txt over .conf if contentType == "text/plain" { for _, ext := range extensions { if ext == ".txt" { extension = ext break } } if extension == "" { extension = ".txt" } } else { extension = extensions[0] } } } } // Validate allowed types if !manager.allowed(contentType, extension) { return nil, fmt.Errorf("%s type %s is not allowed", filename, contentType) } // Generate file ID and storage path using the new approach id, storagePath, err := manager.generateFilePaths(file, extension, option) if err != nil { return nil, err } // Set the path: use userPath if provided, otherwise use filename filePath := userPath if filePath == "" { filePath = filename } return &File{ ID: id, UserPath: userPath, // Keep user's original input exactly as provided Path: storagePath, // Complete storage path: Groups + filename Filename: filename, // Use just the filename (extracted from path or header) ContentType: contentType, Bytes: int(file.Size), CreatedAt: int(time.Now().Unix()), Status: "uploading", }, nil } func (manager Manager) allowed(contentType string, extension string) bool { // text/*, image/*, audio/*, video/*, application/yao-*, ... for _, t := range manager.allowedTypes.wildcards { prefix := strings.TrimSuffix(t, "*") if strings.HasPrefix(contentType, prefix) { return true } } // Accepted types if _, ok := manager.allowedTypes.mapping[contentType]; ok { return true } // Accepted extensions if _, ok := manager.allowedTypes.mapping[extension]; ok { return true } // Not allowed return false } // generateFileID generates file ID and storage path based on Groups and filename func (manager Manager) generateFilePaths(file *FileHeader, extension string, option UploadOption) (fileID string, storagePath string, err error) { // 1. Get the filename var filename string if file.Fingerprint() != "" { filename = file.Fingerprint() } else if file.IsChunk() { filename = file.UID() } else { // Generate unique filename to avoid conflicts var originalName string if option.OriginalFilename != "" { originalName = filepath.Base(option.OriginalFilename) } else { originalName = file.Filename } // Extract extension from original filename ext := filepath.Ext(originalName) if ext == "" && extension != "" { ext = extension } // Generate unique filename: MD5 hash of original name + timestamp + extension nameHash := generateID(originalName + fmt.Sprintf("%d", time.Now().UnixNano())) filename = nameHash[:16] + ext // Use first 16 chars of hash + extension } // 2. Build complete storage path: Groups + filename pathParts := []string{} // Add groups to path if len(option.Groups) > 0 { pathParts = append(pathParts, option.Groups...) } // Add filename pathParts = append(pathParts, filename) // Join to create complete storage path storagePath = strings.Join(pathParts, "/") // 3. Validate the storage path if !isValidPath(storagePath) { return "", "", fmt.Errorf("invalid storage path: %s", storagePath) } // 4. Generate ID as alias of the storage path (for security) fileID = generateID(storagePath) // 5. Add gzip extension to storage path if needed (not to fileID) if option.Gzip { storagePath = storagePath + ".gz" } return fileID, storagePath, nil } // generateID generates a URL-safe ID based on the storage path func generateID(storagePath string) string { hash := md5.Sum([]byte(storagePath)) return hex.EncodeToString(hash[:]) } // isValidPath checks if a file path is valid func isValidPath(path string) bool { if path == "" { return false } // Check for invalid characters that could cause issues invalidChars := []string{"../", "..\\", "\\", "//"} for _, invalid := range invalidChars { if strings.Contains(path, invalid) { return false } } return true } // getSize converts the size to bytes func getSize(size string) (int64, error) { if size == "" || size == "0" { return 0, fmt.Errorf("size is empty") } unit := strings.ToUpper(size[len(size)-1:]) str := size[:len(size)-1] if unit != "B" && unit != "K" && unit != "M" && unit != "G" { unit = "B" str = size } value, err := strconv.ParseInt(str, 10, 64) if err != nil { return 0, fmt.Errorf("invalid size: %s %s", size, err) } switch unit { case "B": return value, nil case "K": return value * 1024, nil case "M": return value * 1024 * 1024, nil case "G": return value * 1024 * 1024 * 1024, nil } return 0, fmt.Errorf("invalid size: %s", size) } // Exists checks if a file exists in storage func (manager Manager) Exists(ctx context.Context, fileID string) bool { // Check if file exists in database first storagePath, err := manager.getStoragePathFromDatabase(ctx, fileID) if err != nil { return false } // Then check if it exists in storage return manager.storage.Exists(ctx, storagePath) } // Delete deletes a file from storage func (manager Manager) Delete(ctx context.Context, fileID string) error { // Get real storage path from database storagePath, err := manager.getStoragePathFromDatabase(ctx, fileID) if err != nil { return err } // Delete from storage err = manager.storage.Delete(ctx, storagePath) if err != nil { return err } // Delete from database m := model.Select("__yao.attachment") _, err = m.DeleteWhere(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "file_id", Value: fileID}, }, }) if err != nil { return fmt.Errorf("failed to delete from database: %w", err) } return nil } // saveFileToDatabase saves file information to the database // For chunked uploads, it only updates bytes/status/progress if record exists func (manager Manager) saveFileToDatabase(ctx context.Context, file *File, storagePath string, option UploadOption) error { m := model.Select("__yao.attachment") // Check if record exists first records, err := m.Get(model.QueryParam{ Select: []interface{}{"file_id"}, Wheres: []model.QueryWhere{ {Column: "file_id", Value: file.ID}, }, }) if err != nil { return fmt.Errorf("failed to check existing record: %w", err) } if len(records) > 0 { // Record exists - this is a chunked upload update // Only update bytes, status, and progress (don't overwrite metadata) updateData := map[string]interface{}{ "bytes": int64(file.Bytes), "status": file.Status, } _, err = m.UpdateWhere(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "file_id", Value: file.ID}, }, }, updateData) return err } // Record doesn't exist - create new record with full metadata // Set default value for share if empty share := option.Share if share == "" { share = "private" } // Prepare data for database data := map[string]interface{}{ "file_id": file.ID, "uploader": manager.Name, "content_type": file.ContentType, "name": file.Filename, "user_path": option.OriginalFilename, "path": storagePath, "bytes": int64(file.Bytes), "status": file.Status, "gzip": option.Gzip, "groups": option.Groups, "public": option.Public, "share": share, } // Add Yao permission fields if provided if option.YaoCreatedBy != "" { data["__yao_created_by"] = option.YaoCreatedBy } if option.YaoUpdatedBy != "" { data["__yao_updated_by"] = option.YaoUpdatedBy } if option.YaoTeamID != "" { data["__yao_team_id"] = option.YaoTeamID } if option.YaoTenantID != "" { data["__yao_tenant_id"] = option.YaoTenantID } // Create new record _, err = m.Create(data) return err } // getFileFromDatabase retrieves file information from database by file_id func (manager Manager) getFileFromDatabase(ctx context.Context, fileID string) (*File, error) { m := model.Select("__yao.attachment") records, err := m.Get(model.QueryParam{ Select: []interface{}{ "file_id", "name", "content_type", "status", "user_path", "path", "bytes", "public", "share", "__yao_created_by", "__yao_team_id", "__yao_tenant_id", }, Wheres: []model.QueryWhere{ {Column: "file_id", Value: fileID}, }, Limit: 1, }) if err != nil { return nil, fmt.Errorf("failed to query file: %w", err) } if len(records) == 0 { return nil, fmt.Errorf("file not found") } record := records[0] // Convert database record to File struct file := &File{ ID: record["file_id"].(string), Filename: record["name"].(string), ContentType: record["content_type"].(string), Status: record["status"].(string), CreatedAt: int(time.Now().Unix()), // TODO: get from database } // Handle optional fields if userPath, ok := record["user_path"].(string); ok { file.UserPath = userPath } if path, ok := record["path"].(string); ok { file.Path = path } if bytes, ok := record["bytes"].(int64); ok { file.Bytes = int(bytes) } // Handle permission fields with safe conversion file.Public = toBool(record["public"]) file.Share = toString(record["share"]) file.YaoCreatedBy = toString(record["__yao_created_by"]) file.YaoTeamID = toString(record["__yao_team_id"]) file.YaoTenantID = toString(record["__yao_tenant_id"]) return file, nil } // getStoragePathFromDatabase retrieves the real storage path for a file_id func (manager Manager) getStoragePathFromDatabase(ctx context.Context, fileID string) (string, error) { m := model.Select("__yao.attachment") records, err := m.Get(model.QueryParam{ Select: []interface{}{"path"}, Wheres: []model.QueryWhere{ {Column: "file_id", Value: fileID}, }, }) if err != nil { return "", fmt.Errorf("failed to query database: %w", err) } if len(records) == 0 { return "", fmt.Errorf("file not found: %s", fileID) } if path, ok := records[0]["path"].(string); ok && path != "" { return path, nil } return "", fmt.Errorf("invalid storage path for file ID: %s", fileID) } // GetText retrieves the parsed text content for a file by its ID // By default, returns the preview (first 2000 characters) from 'content_preview' field // Set fullContent to true to retrieve the complete text from 'content' field func (manager Manager) GetText(ctx context.Context, fileID string, fullContent ...bool) (string, error) { m := model.Select("__yao.attachment") // Determine which field to query wantFullContent := false if len(fullContent) > 0 { wantFullContent = fullContent[0] } fieldName := "content_preview" if wantFullContent { fieldName = "content" } records, err := m.Get(model.QueryParam{ Select: []interface{}{fieldName}, Wheres: []model.QueryWhere{ {Column: "file_id", Value: fileID}, }, Limit: 1, }) if err != nil { return "", fmt.Errorf("failed to query text content: %w", err) } if len(records) == 0 { return "", fmt.Errorf("file not found: %s", fileID) } // Handle content field - it may be nil, string, or other types if content, ok := records[0][fieldName].(string); ok { return content, nil } // If content is nil or not a string, return empty string return "", nil } // SaveText saves the parsed text content for a file by its ID // Automatically saves both full content and preview (first 2000 characters) // Updates both 'content' and 'content_preview' fields in the attachment record func (manager Manager) SaveText(ctx context.Context, fileID string, text string) error { m := model.Select("__yao.attachment") // Check if record exists first records, err := m.Get(model.QueryParam{ Select: []interface{}{"file_id"}, Wheres: []model.QueryWhere{ {Column: "file_id", Value: fileID}, }, Limit: 1, }) if err != nil { return fmt.Errorf("failed to check file existence: %w", err) } if len(records) == 0 { return fmt.Errorf("file not found: %s", fileID) } // Create preview: first 2000 characters (or runes for proper UTF-8 handling) preview := text const maxPreviewLength = 2000 if len([]rune(text)) > maxPreviewLength { preview = string([]rune(text)[:maxPreviewLength]) } // Update both content and content_preview fields updateData := map[string]interface{}{ "content": text, "content_preview": preview, } _, err = m.UpdateWhere(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "file_id", Value: fileID}, }, }, updateData) if err != nil { return fmt.Errorf("failed to save text content: %w", err) } return nil } ================================================ FILE: attachment/manager_test.go ================================================ package attachment import ( "bytes" "context" "encoding/base64" "fmt" "mime/multipart" "os" "path/filepath" "strings" "testing" "time" "github.com/yaoapp/gou/model" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func TestMain(m *testing.M) { // Run tests code := m.Run() os.Exit(code) } func TestManagerUpload(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create a local storage manager manager, err := New(ManagerOption{ Driver: "local", MaxSize: "10M", ChunkSize: "2M", AllowedTypes: []string{"text/*", "image/*", ".txt", ".jpg", ".png"}, Options: map[string]interface{}{ "path": "/tmp/test_attachments", }, }) if err != nil { t.Fatalf("Failed to create manager: %v", err) } // Test simple text file upload t.Run("SimpleTextUpload", func(t *testing.T) { content := "Hello, World!" reader := strings.NewReader(content) // Create a mock file header fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "test.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{Groups: []string{"user123"}} file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload file: %v", err) } if file.Filename != "test.txt" { t.Errorf("Expected filename 'test.txt', got '%s'", file.Filename) } // Content type may include charset if !strings.HasPrefix(file.ContentType, "text/plain") { t.Errorf("Expected content type 'text/plain', got '%s'", file.ContentType) } // Test download response, err := manager.Download(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to download file: %v", err) } defer response.Reader.Close() downloadedContent, err := manager.Read(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to read file: %v", err) } if string(downloadedContent) != content { t.Errorf("Expected content '%s', got '%s'", content, string(downloadedContent)) } // Test ReadBase64 base64Content, err := manager.ReadBase64(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to read file as base64: %v", err) } expectedBase64 := base64.StdEncoding.EncodeToString([]byte(content)) if base64Content != expectedBase64 { t.Errorf("Expected base64 '%s', got '%s'", expectedBase64, base64Content) } }) // Test gzip compression t.Run("GzipUpload", func(t *testing.T) { content := "This is a test file that will be compressed with gzip." reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "test_gzip.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{ Gzip: true, Groups: []string{"user123"}, } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload gzipped file: %v", err) } // The stored file should be compressed, but when we read it back, // we should get the original content (if the storage handles decompression) downloadedContent, err := manager.Read(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to read gzipped file: %v", err) } if string(downloadedContent) != content { t.Errorf("Expected content '%s', got '%s'", content, string(downloadedContent)) } }) // Test chunked upload t.Run("ChunkedUpload", func(t *testing.T) { content := "This is a large file that will be uploaded in chunks. " + strings.Repeat("Lorem ipsum dolor sit amet, consectetur adipiscing elit. ", 100) chunkSize := 1024 totalSize := len(content) var lastFile *File for start := 0; start < totalSize; start += chunkSize { end := start + chunkSize - 1 if end >= totalSize { end = totalSize - 1 } chunk := []byte(content[start : end+1]) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "large_file.txt", Size: int64(len(chunk)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") fileHeader.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, totalSize)) fileHeader.Header.Set("Content-Uid", "unique-file-id-123") option := UploadOption{Groups: []string{"user123"}} file, err := manager.Upload(context.Background(), fileHeader, bytes.NewReader(chunk), option) if err != nil { t.Fatalf("Failed to upload chunk starting at %d: %v", start, err) } lastFile = file } // After uploading all chunks, read the complete file if lastFile != nil { downloadedContent, err := manager.Read(context.Background(), lastFile.ID) if err != nil { t.Fatalf("Failed to read chunked file: %v", err) } if string(downloadedContent) != content { t.Errorf("Chunked upload content mismatch. Expected length %d, got %d", len(content), len(downloadedContent)) } } }) } func TestManagerMultiLevelGroups(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create a local storage manager manager, err := New(ManagerOption{ Driver: "local", MaxSize: "10M", AllowedTypes: []string{"text/*", "image/*"}, Options: map[string]interface{}{ "path": "/tmp/test_attachments", }, }) if err != nil { t.Fatalf("Failed to create manager: %v", err) } // Test multi-level groups t.Run("MultiLevelGroups", func(t *testing.T) { content := "Test content for multi-level groups" reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "multilevel.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") // Test with multi-level groups option := UploadOption{ Groups: []string{"users", "user123", "chats", "chat456", "documents"}, } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload file with multi-level groups: %v", err) } // File ID should be 32 character hex (MD5 hash) if len(file.ID) != 32 { t.Errorf("File ID should be 32 characters: %s (length %d)", file.ID, len(file.ID)) } // Check that it's all lowercase hex for _, r := range file.ID { if !((r >= '0' && r <= '9') || (r >= 'a' && r <= 'f')) { t.Errorf("File ID contains non-hex character: %c", r) } } // Test download downloadedContent, err := manager.Read(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to read file with multi-level groups: %v", err) } if string(downloadedContent) != content { t.Errorf("Content mismatch for multi-level groups file") } }) // Test single group (backward compatibility) t.Run("SingleGroup", func(t *testing.T) { content := "Test content for single group" reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "single.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{ Groups: []string{"knowledge"}, } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload file with single group: %v", err) } // File ID should be 32 character hex (MD5 hash) if len(file.ID) != 32 { t.Errorf("File ID should be 32 characters: %s (length %d)", file.ID, len(file.ID)) } // Check that it's all lowercase hex for _, r := range file.ID { if !((r >= '0' && r <= '9') || (r >= 'a' && r <= 'f')) { t.Errorf("File ID contains non-hex character: %c", r) } } }) // Test empty groups (no grouping) t.Run("EmptyGroups", func(t *testing.T) { content := "Test content without groups" reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "nogroup.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{ Groups: []string{}, // Empty groups } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload file without groups: %v", err) } // Should still work and create valid file ID if file.ID == "" { t.Error("File ID should not be empty") } }) } func TestManagerValidation(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager, err := New(ManagerOption{ Driver: "local", MaxSize: "1K", // Very small max size for testing AllowedTypes: []string{"text/plain"}, Options: map[string]interface{}{ "path": "/tmp/test_attachments", }, }) if err != nil { t.Fatalf("Failed to create manager: %v", err) } // Test file size validation t.Run("FileSizeValidation", func(t *testing.T) { content := strings.Repeat("a", 2048) // 2KB, exceeds 1KB limit reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "large.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{} _, err := manager.Upload(context.Background(), fileHeader, reader, option) if err == nil { t.Error("Expected error for file size exceeding limit") } }) // Test file type validation t.Run("FileTypeValidation", func(t *testing.T) { content := "test" reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "test.jpg", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "image/jpeg") // Not allowed option := UploadOption{} _, err := manager.Upload(context.Background(), fileHeader, reader, option) if err == nil { t.Error("Expected error for disallowed file type") } }) } func TestManagerName(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Create a manager with a specific name managerName := "test-manager" manager, err := RegisterDefault(managerName) if err != nil { t.Fatalf("Failed to register manager: %v", err) } // Verify the manager name is set correctly if manager.Name != managerName { t.Errorf("Expected manager name '%s', got '%s'", managerName, manager.Name) } // Upload a file to verify the manager name is saved to database content := "Test file content" reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "test-manager-name.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{ Groups: []string{"test"}, } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload file: %v", err) } // Query database directly to verify manager name is stored m := model.Select("__yao.attachment") records, err := m.Get(model.QueryParam{ Select: []interface{}{"uploader"}, Wheres: []model.QueryWhere{ {Column: "file_id", Value: file.ID}, }, }) if err != nil { t.Fatalf("Failed to query database: %v", err) } if len(records) == 0 { t.Fatal("No record found in database") } storedManagerName, ok := records[0]["uploader"].(string) if !ok { t.Fatal("Uploader field is not a string") } if storedManagerName != managerName { t.Errorf("Expected stored uploader name '%s', got '%s'", managerName, storedManagerName) } } func TestUniqueFilenameGeneration(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager, err := RegisterDefault("test") if err != nil { t.Fatalf("Failed to register manager: %v", err) } // Upload two files with the same filename content1 := "First file content" content2 := "Second file content" // First file reader1 := strings.NewReader(content1) fileHeader1 := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "duplicate.txt", // Same filename Size: int64(len(content1)), Header: make(map[string][]string), }, } fileHeader1.Header.Set("Content-Type", "text/plain") // Second file reader2 := strings.NewReader(content2) fileHeader2 := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "duplicate.txt", // Same filename Size: int64(len(content2)), Header: make(map[string][]string), }, } fileHeader2.Header.Set("Content-Type", "text/plain") option := UploadOption{ Groups: []string{"test"}, } // Upload first file file1, err := manager.Upload(context.Background(), fileHeader1, reader1, option) if err != nil { t.Fatalf("Failed to upload first file: %v", err) } // Sleep a bit to ensure different timestamps time.Sleep(time.Millisecond) // Upload second file file2, err := manager.Upload(context.Background(), fileHeader2, reader2, option) if err != nil { t.Fatalf("Failed to upload second file: %v", err) } // Verify files have different IDs if file1.ID == file2.ID { t.Error("Files with same original name should have different IDs") } // Verify files have different storage paths if file1.Path == file2.Path { t.Error("Files with same original name should have different storage paths") } // Verify both files can be read independently data1, err := manager.Read(context.Background(), file1.ID) if err != nil { t.Fatalf("Failed to read first file: %v", err) } data2, err := manager.Read(context.Background(), file2.ID) if err != nil { t.Fatalf("Failed to read second file: %v", err) } if string(data1) != content1 { t.Errorf("First file content mismatch. Expected: %s, Got: %s", content1, string(data1)) } if string(data2) != content2 { t.Errorf("Second file content mismatch. Expected: %s, Got: %s", content2, string(data2)) } t.Logf("File 1 - ID: %s, Path: %s", file1.ID, file1.Path) t.Logf("File 2 - ID: %s, Path: %s", file2.ID, file2.Path) } func TestInfo(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager, err := RegisterDefault("test") if err != nil { t.Fatalf("Failed to register manager: %v", err) } // Upload a test file content := "Test file for info retrieval" reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "info-test.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{ Groups: []string{"info", "test"}, OriginalFilename: "original-info-test.txt", Public: false, Share: "private", Gzip: false, } uploadedFile, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload file: %v", err) } // Test the Info method fileInfo, err := manager.Info(context.Background(), uploadedFile.ID) if err != nil { t.Fatalf("Failed to get file info: %v", err) } // Verify file information if fileInfo.ID != uploadedFile.ID { t.Errorf("Expected file ID %s, got %s", uploadedFile.ID, fileInfo.ID) } if fileInfo.Filename != uploadedFile.Filename { t.Errorf("Expected filename %s, got %s", uploadedFile.Filename, fileInfo.Filename) } // Content type may include charset if !strings.HasPrefix(fileInfo.ContentType, "text/plain") { t.Errorf("Expected content type 'text/plain', got %s", fileInfo.ContentType) } if fileInfo.Status != "uploaded" { t.Errorf("Expected status 'uploaded', got %s", fileInfo.Status) } if fileInfo.UserPath != option.OriginalFilename { t.Errorf("Expected user path %s, got %s", option.OriginalFilename, fileInfo.UserPath) } if fileInfo.Path != uploadedFile.Path { t.Errorf("Expected path %s, got %s", uploadedFile.Path, fileInfo.Path) } // Test with non-existent file ID _, err = manager.Info(context.Background(), "non-existent-id") if err == nil { t.Error("Expected error for non-existent file ID, got nil") } t.Logf("Retrieved file info - ID: %s, Path: %s, UserPath: %s", fileInfo.ID, fileInfo.Path, fileInfo.UserPath) } func TestList(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Use unique manager name for test isolation managerName := fmt.Sprintf("test-list-%d", time.Now().UnixNano()) manager, err := RegisterDefault(managerName) if err != nil { t.Fatalf("Failed to register manager: %v", err) } // Clean up existing records first m := model.Select("__yao.attachment") _, err = m.DeleteWhere(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "uploader", Value: managerName}, }, }) if err != nil { t.Logf("Warning: Failed to clean up existing records: %v", err) } // Upload multiple test files testFiles := []struct { filename string content string contentType string groups []string }{ {"test1.txt", "Content of test file 1", "text/plain", []string{"group1"}}, {"test2.txt", "Content of test file 2", "text/plain", []string{"group1"}}, {"image1.jpg", "Image content 1", "image/jpeg", []string{"group2", "images"}}, {"doc1.pdf", "PDF content", "application/pdf", []string{"group2", "docs"}}, {"test3.txt", "Content of test file 3", "text/plain", []string{"group1"}}, } uploadedFiles := make([]*File, 0, len(testFiles)) for _, tf := range testFiles { reader := strings.NewReader(tf.content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: tf.filename, Size: int64(len(tf.content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", tf.contentType) option := UploadOption{ Groups: tf.groups, } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload file %s: %v", tf.filename, err) } uploadedFiles = append(uploadedFiles, file) } // Test basic listing (no filters, default pagination) t.Run("BasicList", func(t *testing.T) { result, err := manager.List(context.Background(), ListOption{ Filters: map[string]interface{}{ "uploader": managerName, }, }) if err != nil { t.Fatalf("Failed to list files: %v", err) } if len(result.Files) != len(testFiles) { t.Errorf("Expected %d files, got %d", len(testFiles), len(result.Files)) } if result.Total != int64(len(testFiles)) { t.Errorf("Expected total %d, got %d", len(testFiles), result.Total) } if result.Page != 1 { t.Errorf("Expected page 1, got %d", result.Page) } if result.PageSize != 20 { t.Errorf("Expected page size 20, got %d", result.PageSize) } }) // Test pagination t.Run("Pagination", func(t *testing.T) { result, err := manager.List(context.Background(), ListOption{ Page: 1, PageSize: 2, Filters: map[string]interface{}{ "uploader": managerName, }, }) if err != nil { t.Fatalf("Failed to list files with pagination: %v", err) } if len(result.Files) != 2 { t.Errorf("Expected 2 files, got %d", len(result.Files)) } if result.Total != int64(len(testFiles)) { t.Errorf("Expected total %d, got %d", len(testFiles), result.Total) } if result.Page != 1 { t.Errorf("Expected page 1, got %d", result.Page) } if result.PageSize != 2 { t.Errorf("Expected page size 2, got %d", result.PageSize) } if result.TotalPages != 3 { // 5 files / 2 per page = 3 pages t.Errorf("Expected 3 total pages, got %d", result.TotalPages) } }) // Test filtering by content type t.Run("FilterByContentType", func(t *testing.T) { result, err := manager.List(context.Background(), ListOption{ Wheres: []model.QueryWhere{ {Column: "uploader", Value: managerName}, {Column: "content_type", Value: "text/plain%", OP: "like"}, }, }) if err != nil { t.Fatalf("Failed to list files with content type filter: %v", err) } expectedCount := 3 // test1.txt, test2.txt, test3.txt if len(result.Files) != expectedCount { t.Errorf("Expected %d text files, got %d", expectedCount, len(result.Files)) } // Verify all returned files are text/plain (may include charset) for _, file := range result.Files { if !strings.HasPrefix(file.ContentType, "text/plain") { t.Errorf("Expected content type 'text/plain', got '%s'", file.ContentType) } } }) // Test wildcard filtering t.Run("WildcardFilter", func(t *testing.T) { result, err := manager.List(context.Background(), ListOption{ Filters: map[string]interface{}{ "uploader": managerName, "content_type": "image/*", }, }) if err != nil { t.Fatalf("Failed to list files with wildcard filter: %v", err) } expectedCount := 1 // image1.jpg if len(result.Files) != expectedCount { t.Errorf("Expected %d image files, got %d", expectedCount, len(result.Files)) } }) // Test ordering t.Run("OrderBy", func(t *testing.T) { result, err := manager.List(context.Background(), ListOption{ OrderBy: "name asc", Filters: map[string]interface{}{ "uploader": managerName, }, }) if err != nil { t.Fatalf("Failed to list files with ordering: %v", err) } if len(result.Files) != len(testFiles) { t.Errorf("Expected %d files, got %d", len(testFiles), len(result.Files)) } // Files should be ordered by name ascending // Note: The actual filenames are generated, so we just check that they're sorted for i := 1; i < len(result.Files); i++ { if result.Files[i-1].Filename > result.Files[i].Filename { t.Errorf("Files are not sorted by name ascending") break } } }) // Test field selection t.Run("SelectFields", func(t *testing.T) { result, err := manager.List(context.Background(), ListOption{ Select: []string{"file_id", "name", "content_type"}, Filters: map[string]interface{}{ "uploader": managerName, }, }) if err != nil { t.Fatalf("Failed to list files with field selection: %v", err) } if len(result.Files) != len(testFiles) { t.Errorf("Expected %d files, got %d", len(testFiles), len(result.Files)) } // Verify selected fields are populated for _, file := range result.Files { if file.ID == "" { t.Error("Expected file_id to be populated") } if file.Filename == "" { t.Error("Expected filename to be populated") } if file.ContentType == "" { t.Error("Expected content_type to be populated") } } }) t.Logf("Successfully tested list functionality with %d files", len(uploadedFiles)) } func TestManagerLocalPath(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Test with local storage t.Run("LocalStorage", func(t *testing.T) { // Create a local storage manager manager, err := New(ManagerOption{ Driver: "local", MaxSize: "10M", AllowedTypes: []string{"text/*", "image/*", "application/*", ".txt", ".json", ".html", ".csv", ".yao"}, Options: map[string]interface{}{ "path": "/tmp/test_localpath_attachments", }, }) if err != nil { t.Fatalf("Failed to create local manager: %v", err) } manager.Name = "localpath-test" // Test different file types testFiles := []struct { filename string content string contentType string expectedCT string }{ {"test.txt", "Hello LocalPath", "text/plain", "text/plain"}, {"test.json", `{"localpath": "test"}`, "application/json", "application/json"}, {"test.html", "LocalPath Test", "text/html", "text/html"}, {"test.csv", "col1,col2\nlocalpath,test", "text/csv", "text/csv"}, {"test.yao", "localpath yao content", "application/yao", "application/yao"}, } for _, tf := range testFiles { // Upload file reader := strings.NewReader(tf.content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: tf.filename, Size: int64(len(tf.content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", tf.contentType) option := UploadOption{ Groups: []string{"localpath", "test"}, OriginalFilename: tf.filename, } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload file %s: %v", tf.filename, err) } // Test LocalPath localPath, detectedCT, err := manager.LocalPath(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to get local path for %s: %v", tf.filename, err) } // Verify path is absolute if !filepath.IsAbs(localPath) { t.Errorf("Expected absolute path for %s, got: %s", tf.filename, localPath) } // Verify content type if detectedCT != tf.expectedCT { t.Errorf("Expected content type %s for %s, got: %s", tf.expectedCT, tf.filename, detectedCT) } // Verify file exists if _, err := os.Stat(localPath); os.IsNotExist(err) { t.Errorf("File should exist at local path %s for %s", localPath, tf.filename) } // Verify file content fileContent, err := os.ReadFile(localPath) if err != nil { t.Fatalf("Failed to read file at local path for %s: %v", tf.filename, err) } if string(fileContent) != tf.content { t.Errorf("File content mismatch for %s. Expected: %s, Got: %s", tf.filename, tf.content, string(fileContent)) } t.Logf("File %s - ID: %s, LocalPath: %s, ContentType: %s", tf.filename, file.ID, localPath, detectedCT) } }) // Test with gzipped files in local storage t.Run("LocalStorage_Gzipped", func(t *testing.T) { manager, err := New(ManagerOption{ Driver: "local", MaxSize: "10M", AllowedTypes: []string{"text/*"}, Options: map[string]interface{}{ "path": "/tmp/test_localpath_gzip_attachments", }, }) if err != nil { t.Fatalf("Failed to create local manager: %v", err) } manager.Name = "localpath-gzip-test" content := "This content will be gzipped" reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "gzipped.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{ Groups: []string{"gzip", "test"}, OriginalFilename: "gzipped.txt", Gzip: true, // Enable gzip compression } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload gzipped file: %v", err) } // Test LocalPath - should get decompressed content localPath, contentType, err := manager.LocalPath(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to get local path for gzipped file: %v", err) } // Verify content type if contentType != "text/plain" { t.Errorf("Expected content type text/plain, got: %s", contentType) } // For gzipped files in local storage, the storage path ends with .gz // but the content should be accessible normally through Read methods fileContent, err := manager.Read(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to read gzipped file: %v", err) } if string(fileContent) != content { t.Errorf("Gzipped file content mismatch. Expected: %s, Got: %s", content, string(fileContent)) } t.Logf("Gzipped file - ID: %s, LocalPath: %s, ContentType: %s", file.ID, localPath, contentType) }) } func TestManagerLocalPath_NonExistentFile(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager, err := New(ManagerOption{ Driver: "local", AllowedTypes: []string{"text/*"}, Options: map[string]interface{}{ "path": "/tmp/test_localpath_nonexistent", }, }) if err != nil { t.Fatalf("Failed to create manager: %v", err) } manager.Name = "nonexistent-test" // Test with non-existent file ID _, _, err = manager.LocalPath(context.Background(), "non-existent-file-id") if err == nil { t.Error("Expected error for non-existent file ID") } // Should contain "file not found" in the error chain if !strings.Contains(err.Error(), "file not found") { t.Errorf("Expected 'file not found' in error message, got: %s", err.Error()) } } func TestPublicAndShareFields(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Force re-migrate the attachment table to ensure schema is up to date m := model.Select("__yao.attachment") if m != nil { // Drop and recreate table to get latest schema err := m.DropTable() if err != nil { t.Logf("Warning: failed to drop table: %v", err) } err = m.Migrate(false) if err != nil { t.Fatalf("Failed to migrate table: %v", err) } } manager, err := RegisterDefault("test-public-share") if err != nil { t.Fatalf("Failed to register manager: %v", err) } // Test 1: Upload with public=true and share=team t.Run("PublicTeamShare", func(t *testing.T) { content := "Public team shared file" reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "public-team.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{ Groups: []string{"test"}, OriginalFilename: "public-team.txt", Public: true, Share: "team", } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload public team file: %v", err) } // Verify in database m := model.Select("__yao.attachment") records, err := m.Get(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "file_id", Value: file.ID}, }, }) if err != nil { t.Fatalf("Failed to query database: %v", err) } if len(records) == 0 { t.Fatal("No record found in database") } // Debug: print all fields t.Logf("Record fields: %+v", records[0]) publicValue := toBool(records[0]["public"]) if !publicValue { t.Errorf("Expected public to be true, got: %v (type: %T)", records[0]["public"], records[0]["public"]) } shareValue := toString(records[0]["share"]) if shareValue != "team" { t.Errorf("Expected share to be 'team', got: %v (type: %T)", records[0]["share"], records[0]["share"]) } }) // Test 2: Upload with public=false and share=private (default) t.Run("PrivateShare", func(t *testing.T) { content := "Private file" reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "private.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{ Groups: []string{"test"}, OriginalFilename: "private.txt", Public: false, Share: "private", } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload private file: %v", err) } // Verify in database m := model.Select("__yao.attachment") records, err := m.Get(model.QueryParam{ Select: []interface{}{"public", "share"}, Wheres: []model.QueryWhere{ {Column: "file_id", Value: file.ID}, }, }) if err != nil { t.Fatalf("Failed to query database: %v", err) } if len(records) == 0 { t.Fatal("No record found in database") } publicValue := toBool(records[0]["public"]) if publicValue { t.Errorf("Expected public to be false, got: %v", records[0]["public"]) } shareValue := toString(records[0]["share"]) if shareValue != "private" { t.Errorf("Expected share to be 'private', got: %v", records[0]["share"]) } }) // Test 3: Upload without specifying share (should default to private) t.Run("DefaultSharePrivate", func(t *testing.T) { content := "Default share file" reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "default-share.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{ Groups: []string{"test"}, OriginalFilename: "default-share.txt", Public: false, // Share not specified, should default to "private" } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload file with default share: %v", err) } // Verify in database m := model.Select("__yao.attachment") records, err := m.Get(model.QueryParam{ Select: []interface{}{"share"}, Wheres: []model.QueryWhere{ {Column: "file_id", Value: file.ID}, }, }) if err != nil { t.Fatalf("Failed to query database: %v", err) } if len(records) == 0 { t.Fatal("No record found in database") } shareValue := toString(records[0]["share"]) if shareValue != "private" { t.Errorf("Expected default share to be 'private', got: %v", records[0]["share"]) } }) } func TestYaoPermissionFields(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Force re-migrate the attachment table to ensure schema is up to date m := model.Select("__yao.attachment") if m != nil { // Drop and recreate table to get latest schema err := m.DropTable() if err != nil { t.Logf("Warning: failed to drop table: %v", err) } err = m.Migrate(false) if err != nil { t.Fatalf("Failed to migrate table: %v", err) } } manager, err := RegisterDefault("test-yao-permission") if err != nil { t.Fatalf("Failed to register manager: %v", err) } // Test 1: Upload with all Yao permission fields t.Run("AllYaoFields", func(t *testing.T) { content := "File with all Yao permission fields" reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "yao-all-fields.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{ Groups: []string{"test"}, OriginalFilename: "yao-all-fields.txt", YaoCreatedBy: "user123", YaoUpdatedBy: "user123", YaoTeamID: "team456", YaoTenantID: "tenant789", } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload file with Yao fields: %v", err) } // Verify in database m := model.Select("__yao.attachment") records, err := m.Get(model.QueryParam{ Select: []interface{}{"__yao_created_by", "__yao_updated_by", "__yao_team_id", "__yao_tenant_id"}, Wheres: []model.QueryWhere{ {Column: "file_id", Value: file.ID}, }, }) if err != nil { t.Fatalf("Failed to query database: %v", err) } if len(records) == 0 { t.Fatal("No record found in database") } // Verify __yao_created_by createdBy := toString(records[0]["__yao_created_by"]) if createdBy != "user123" { t.Errorf("Expected __yao_created_by to be 'user123', got: %v", records[0]["__yao_created_by"]) } // Verify __yao_updated_by updatedBy := toString(records[0]["__yao_updated_by"]) if updatedBy != "user123" { t.Errorf("Expected __yao_updated_by to be 'user123', got: %v", records[0]["__yao_updated_by"]) } // Verify __yao_team_id teamID := toString(records[0]["__yao_team_id"]) if teamID != "team456" { t.Errorf("Expected __yao_team_id to be 'team456', got: %v", records[0]["__yao_team_id"]) } // Verify __yao_tenant_id tenantID := toString(records[0]["__yao_tenant_id"]) if tenantID != "tenant789" { t.Errorf("Expected __yao_tenant_id to be 'tenant789', got: %v", records[0]["__yao_tenant_id"]) } }) // Test 2: Upload with partial Yao fields (only team and tenant) t.Run("PartialYaoFields", func(t *testing.T) { content := "File with partial Yao fields" reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "yao-partial-fields.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{ Groups: []string{"test"}, OriginalFilename: "yao-partial-fields.txt", YaoTeamID: "team999", YaoTenantID: "tenant888", // YaoCreatedBy and YaoUpdatedBy not specified } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload file with partial Yao fields: %v", err) } // Verify in database m := model.Select("__yao.attachment") records, err := m.Get(model.QueryParam{ Select: []interface{}{"__yao_team_id", "__yao_tenant_id"}, Wheres: []model.QueryWhere{ {Column: "file_id", Value: file.ID}, }, }) if err != nil { t.Fatalf("Failed to query database: %v", err) } if len(records) == 0 { t.Fatal("No record found in database") } // Verify __yao_team_id teamID := toString(records[0]["__yao_team_id"]) if teamID != "team999" { t.Errorf("Expected __yao_team_id to be 'team999', got: %v", records[0]["__yao_team_id"]) } // Verify __yao_tenant_id tenantID := toString(records[0]["__yao_tenant_id"]) if tenantID != "tenant888" { t.Errorf("Expected __yao_tenant_id to be 'tenant888', got: %v", records[0]["__yao_tenant_id"]) } }) // Test 3: Upload without Yao fields (should be null/empty in database) t.Run("NoYaoFields", func(t *testing.T) { content := "File without Yao fields" reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "yao-no-fields.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{ Groups: []string{"test"}, OriginalFilename: "yao-no-fields.txt", // No Yao fields specified } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload file without Yao fields: %v", err) } // Should succeed without errors if file.ID == "" { t.Error("File ID should not be empty") } t.Logf("Successfully uploaded file without Yao fields - ID: %s", file.ID) }) } func TestManagerLocalPath_ValidationFlow(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager, err := New(ManagerOption{ Driver: "local", AllowedTypes: []string{"text/*"}, Options: map[string]interface{}{ "path": "/tmp/test_localpath_validation", }, }) if err != nil { t.Fatalf("Failed to create manager: %v", err) } manager.Name = "validation-test" // Upload a file content := "Validation flow test content" reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "validation.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{ Groups: []string{"validation"}, OriginalFilename: "original-validation.txt", } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload file: %v", err) } // Test complete flow: Upload -> LocalPath -> Verify -> Delete t.Run("CompleteFlow", func(t *testing.T) { // Get local path localPath, contentType, err := manager.LocalPath(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to get local path: %v", err) } // Verify all properties if !filepath.IsAbs(localPath) { t.Error("Path should be absolute") } if contentType != "text/plain" { t.Errorf("Expected content type text/plain, got: %s", contentType) } // Verify file exists stat, err := os.Stat(localPath) if err != nil { t.Fatalf("File should exist at local path: %v", err) } if stat.Size() != int64(len(content)) { t.Errorf("File size mismatch. Expected: %d, Got: %d", len(content), stat.Size()) } // Verify file content matches fileContent, err := os.ReadFile(localPath) if err != nil { t.Fatalf("Failed to read file: %v", err) } if string(fileContent) != content { t.Errorf("Content mismatch. Expected: %s, Got: %s", content, string(fileContent)) } // Verify through manager's Read method as well managerContent, err := manager.Read(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to read through manager: %v", err) } if string(managerContent) != content { t.Errorf("Manager read content mismatch. Expected: %s, Got: %s", content, string(managerContent)) } t.Logf("Validation complete - LocalPath: %s, Size: %d bytes, ContentType: %s", localPath, stat.Size(), contentType) }) // Clean up err = manager.Delete(context.Background(), file.ID) if err != nil { t.Logf("Warning: Failed to delete test file: %v", err) } } func TestGetTextAndSaveText(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() manager, err := RegisterDefault("test-text-content") if err != nil { t.Fatalf("Failed to register manager: %v", err) } // Upload a test file content := "This is a test file for text content storage" reader := strings.NewReader(content) fileHeader := &FileHeader{ FileHeader: &multipart.FileHeader{ Filename: "test-text.txt", Size: int64(len(content)), Header: make(map[string][]string), }, } fileHeader.Header.Set("Content-Type", "text/plain") option := UploadOption{ Groups: []string{"test"}, OriginalFilename: "test-text.txt", } file, err := manager.Upload(context.Background(), fileHeader, reader, option) if err != nil { t.Fatalf("Failed to upload file: %v", err) } // Test 1: GetText on file without saved text (should return empty) t.Run("GetTextEmpty", func(t *testing.T) { text, err := manager.GetText(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to get text: %v", err) } if text != "" { t.Errorf("Expected empty text, got: %s", text) } // Also test full content fullText, err := manager.GetText(context.Background(), file.ID, true) if err != nil { t.Fatalf("Failed to get full text: %v", err) } if fullText != "" { t.Errorf("Expected empty full text, got: %s", fullText) } }) // Test 2: SaveText and verify t.Run("SaveTextAndVerify", func(t *testing.T) { parsedText := "This is the parsed text content from the file. It could be extracted from PDF, Word, or image OCR." err := manager.SaveText(context.Background(), file.ID, parsedText) if err != nil { t.Fatalf("Failed to save text: %v", err) } // Retrieve the saved text retrievedText, err := manager.GetText(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to get saved text: %v", err) } if retrievedText != parsedText { t.Errorf("Text mismatch. Expected: %s, Got: %s", parsedText, retrievedText) } t.Logf("Successfully saved and retrieved text content (%d characters)", len(retrievedText)) }) // Test 3: Update existing text t.Run("UpdateText", func(t *testing.T) { updatedText := "This is the updated parsed text content with additional information." err := manager.SaveText(context.Background(), file.ID, updatedText) if err != nil { t.Fatalf("Failed to update text: %v", err) } retrievedText, err := manager.GetText(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to get updated text: %v", err) } if retrievedText != updatedText { t.Errorf("Updated text mismatch. Expected: %s, Got: %s", updatedText, retrievedText) } }) // Test 4: Save long text content and verify preview vs full content t.Run("SaveLongText", func(t *testing.T) { // Generate a large text content (10KB) longText := strings.Repeat("This is a long text content that simulates parsing from a large document like PDF or Word. ", 100) err := manager.SaveText(context.Background(), file.ID, longText) if err != nil { t.Fatalf("Failed to save long text: %v", err) } // Get preview (default, should be limited to 2000 characters) previewText, err := manager.GetText(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to get preview text: %v", err) } // Preview should be exactly 2000 characters (runes) previewRunes := []rune(previewText) if len(previewRunes) != 2000 { t.Errorf("Preview length mismatch. Expected: 2000 runes, Got: %d runes", len(previewRunes)) } // Get full content fullText, err := manager.GetText(context.Background(), file.ID, true) if err != nil { t.Fatalf("Failed to get full text: %v", err) } if fullText != longText { t.Errorf("Full text mismatch. Expected length: %d, Got: %d", len(longText), len(fullText)) } t.Logf("Successfully saved long text - Preview: %d chars, Full: %d chars", len(previewText), len(fullText)) }) // Test 5: Test UTF-8 character handling in preview t.Run("UTF8PreviewHandling", func(t *testing.T) { // Create text with multi-byte UTF-8 characters (Chinese, emoji, etc.) // Each Chinese character is 3 bytes, emoji is 4 bytes chineseText := strings.Repeat("这是一个测试文本,包含中文字符。", 150) // Should exceed 2000 chars emojiText := strings.Repeat("Hello 👋 World 🌍 ", 150) // Test with Chinese text err := manager.SaveText(context.Background(), file.ID, chineseText) if err != nil { t.Fatalf("Failed to save Chinese text: %v", err) } previewChinese, err := manager.GetText(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to get Chinese preview: %v", err) } // Should be exactly 2000 runes (characters), not bytes if len([]rune(previewChinese)) != 2000 { t.Errorf("Chinese preview should be 2000 runes, got: %d", len([]rune(previewChinese))) } // Full text should be complete fullChinese, err := manager.GetText(context.Background(), file.ID, true) if err != nil { t.Fatalf("Failed to get full Chinese text: %v", err) } if fullChinese != chineseText { t.Errorf("Chinese text mismatch") } // Test with emoji text err = manager.SaveText(context.Background(), file.ID, emojiText) if err != nil { t.Fatalf("Failed to save emoji text: %v", err) } previewEmoji, err := manager.GetText(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to get emoji preview: %v", err) } if len([]rune(previewEmoji)) != 2000 { t.Errorf("Emoji preview should be 2000 runes, got: %d", len([]rune(previewEmoji))) } t.Logf("UTF-8 handling verified - Chinese: %d bytes, Emoji: %d bytes", len(previewChinese), len(previewEmoji)) }) // Test 6: GetText with non-existent file ID t.Run("GetTextNonExistent", func(t *testing.T) { _, err := manager.GetText(context.Background(), "non-existent-id") if err == nil { t.Error("Expected error for non-existent file ID") } if !strings.Contains(err.Error(), "file not found") { t.Errorf("Expected 'file not found' error, got: %s", err.Error()) } }) // Test 7: SaveText with non-existent file ID t.Run("SaveTextNonExistent", func(t *testing.T) { err := manager.SaveText(context.Background(), "non-existent-id", "some text") if err == nil { t.Error("Expected error for non-existent file ID") } if !strings.Contains(err.Error(), "file not found") { t.Errorf("Expected 'file not found' error, got: %s", err.Error()) } }) // Test 8: Save empty text (clear content) t.Run("SaveEmptyText", func(t *testing.T) { err := manager.SaveText(context.Background(), file.ID, "") if err != nil { t.Fatalf("Failed to save empty text: %v", err) } retrievedText, err := manager.GetText(context.Background(), file.ID) if err != nil { t.Fatalf("Failed to get empty text: %v", err) } if retrievedText != "" { t.Errorf("Expected empty text, got: %s", retrievedText) } }) // Test 9: Verify List doesn't include content fields by default t.Run("ListExcludesContentByDefault", func(t *testing.T) { // Save some text content testText := "This text should not appear in list results by default" err := manager.SaveText(context.Background(), file.ID, testText) if err != nil { t.Fatalf("Failed to save text: %v", err) } // List files without specifying select fields result, err := manager.List(context.Background(), ListOption{ Filters: map[string]interface{}{ "file_id": file.ID, }, }) if err != nil { t.Fatalf("Failed to list files: %v", err) } if len(result.Files) == 0 { t.Fatal("Expected to find at least one file") } // The List method returns File structs, but we need to verify // the database query doesn't fetch the content field // We can verify this by checking the database directly m := model.Select("__yao.attachment") records, err := m.Get(model.QueryParam{ Wheres: []model.QueryWhere{ {Column: "file_id", Value: file.ID}, }, }) if err != nil { t.Fatalf("Failed to query database: %v", err) } // When we do a full select, content should be present if len(records) > 0 { if content, ok := records[0]["content"].(string); ok && content == testText { t.Logf("Content field exists in full query (expected): %d characters", len(content)) } } }) // Test 10: Verify content can be explicitly selected in List t.Run("ListIncludesContentWhenExplicitlySelected", func(t *testing.T) { // Save some text content testText := "This text SHOULD appear when explicitly selected" err := manager.SaveText(context.Background(), file.ID, testText) if err != nil { t.Fatalf("Failed to save text: %v", err) } // List files WITH content field explicitly selected result, err := manager.List(context.Background(), ListOption{ Select: []string{"file_id", "name", "content"}, Filters: map[string]interface{}{ "file_id": file.ID, }, }) if err != nil { t.Fatalf("Failed to list files with content: %v", err) } if len(result.Files) == 0 { t.Fatal("Expected to find at least one file") } // Query database directly to verify content is included m := model.Select("__yao.attachment") records, err := m.Get(model.QueryParam{ Select: []interface{}{"file_id", "name", "content"}, Wheres: []model.QueryWhere{ {Column: "file_id", Value: file.ID}, }, }) if err != nil { t.Fatalf("Failed to query database: %v", err) } if len(records) == 0 { t.Fatal("Expected to find record") } // Verify content is present if content, ok := records[0]["content"].(string); ok { if content != testText { t.Errorf("Expected content '%s', got '%s'", testText, content) } t.Logf("Content field correctly included when explicitly selected: %d characters", len(content)) } else { t.Error("Content field not found when explicitly selected") } }) // Clean up err = manager.Delete(context.Background(), file.ID) if err != nil { t.Logf("Warning: Failed to delete test file: %v", err) } } ================================================ FILE: attachment/process.go ================================================ package attachment import ( "archive/zip" "bytes" "context" "encoding/base64" "fmt" "mime" "mime/multipart" "net/textproto" "path/filepath" "strings" "github.com/yaoapp/gou/model" "github.com/yaoapp/gou/process" "github.com/yaoapp/kun/any" "github.com/yaoapp/kun/maps" ) // Init registers all attachment processes func Init() { process.RegisterGroup("attachment", map[string]process.Handler{ "Save": processSave, "Read": processRead, "Info": processInfo, "List": processList, "Delete": processDelete, "Exists": processExists, "URL": processURL, "SaveText": processSaveText, "GetText": processGetText, "Zip": processZip, }) } // processSave saves a file from base64 data URI // Args: // - uploaderID: string - the uploader/manager ID // - content: string - base64 data URI (e.g., "data:image/png;base64,xxxx") or plain base64 // - filename: string (optional) - original filename // - option: map (optional) - upload options (groups, gzip, compress_image, public, share) // // Returns: *File - uploaded file info // // Example: // // Process("attachment.Save", "default", "data:image/png;base64,iVBORw0KGgo...", "photo.png") // Process("attachment.Save", "default", "data:text/plain;base64,SGVsbG8=", "hello.txt", {"share": "team"}) func processSave(p *process.Process) interface{} { p.ValidateArgNums(2) uploaderID := p.ArgsString(0) content := p.ArgsString(1) // Get manager manager, exists := Managers[uploaderID] if !exists { return fmt.Errorf("uploader not found: %s", uploaderID) } // Parse data URI and decode content contentType, data, err := parseDataURI(content) if err != nil { return fmt.Errorf("failed to parse content: %v", err) } // Get filename from args or generate from content type filename := "" if p.NumOfArgs() > 2 { filename = p.ArgsString(2) } if filename == "" { filename = generateFilename(contentType) } // Create file header header := createFileHeader(filename, contentType, int64(len(data))) // Create upload options option := createUploadOption(p, filename) // Upload ctx := context.Background() file, err := manager.Upload(ctx, header, strings.NewReader(string(data)), option) if err != nil { return fmt.Errorf("failed to save file: %v", err) } return file } // processRead reads file content as base64 data URI // Args: // - uploaderID: string - the uploader/manager ID // - fileID: string - the file ID // // Returns: string - base64 data URI (e.g., "data:image/png;base64,xxxx") // // Example: // // const dataURI = Process("attachment.Read", "default", "abc123") func processRead(p *process.Process) interface{} { p.ValidateArgNums(2) uploaderID := p.ArgsString(0) fileID := p.ArgsString(1) manager, exists := Managers[uploaderID] if !exists { return fmt.Errorf("uploader not found: %s", uploaderID) } ctx := context.Background() // Get file info for content type and permission check fileInfo, err := manager.Info(ctx, fileID) if err != nil { return fmt.Errorf("file not found: %v", err) } // Check permission if err := checkFilePermission(p, fileInfo, true); err != nil { return err } // Read content as base64 base64Data, err := manager.ReadBase64(ctx, fileID) if err != nil { return fmt.Errorf("failed to read file: %v", err) } // Return as data URI return fmt.Sprintf("data:%s;base64,%s", fileInfo.ContentType, base64Data) } // processInfo gets file information // Args: // - uploaderID: string - the uploader/manager ID // - fileID: string - the file ID // // Returns: *File - file info func processInfo(p *process.Process) interface{} { p.ValidateArgNums(2) uploaderID := p.ArgsString(0) fileID := p.ArgsString(1) manager, exists := Managers[uploaderID] if !exists { return fmt.Errorf("uploader not found: %s", uploaderID) } ctx := context.Background() fileInfo, err := manager.Info(ctx, fileID) if err != nil { return fmt.Errorf("file not found: %v", err) } // Check permission if err := checkFilePermission(p, fileInfo, true); err != nil { return err } return fileInfo } // processList lists files with pagination and filtering // Args: // - uploaderID: string - the uploader/manager ID // - option: map (optional) - list options (page, page_size, filters, order_by, select) // // Returns: *ListResult - paginated file list func processList(p *process.Process) interface{} { p.ValidateArgNums(1) uploaderID := p.ArgsString(0) manager, exists := Managers[uploaderID] if !exists { return fmt.Errorf("uploader not found: %s", uploaderID) } // Parse list options listOption := ListOption{ Page: 1, PageSize: 20, } if p.NumOfArgs() > 1 { optionRaw := p.ArgsMap(1) option := maps.MapOf(optionRaw).Dot() if page := any.Of(option.Get("page")).CInt(); page > 0 { listOption.Page = page } if pageSize := any.Of(option.Get("page_size")).CInt(); pageSize > 0 && pageSize <= 100 { listOption.PageSize = pageSize } if filters, ok := option.Get("filters").(map[string]interface{}); ok { listOption.Filters = filters } if orderBy, ok := option.Get("order_by").(string); ok { listOption.OrderBy = orderBy } if selectFields, ok := option.Get("select").([]interface{}); ok { for _, field := range selectFields { if f, ok := field.(string); ok { listOption.Select = append(listOption.Select, f) } } } } // Always filter by uploader if listOption.Filters == nil { listOption.Filters = make(map[string]interface{}) } listOption.Filters["uploader"] = uploaderID // Add permission-based filtering listOption.Wheres = append(listOption.Wheres, model.QueryWhere{ Column: "uploader", Value: uploaderID, }) listOption.Wheres = append(listOption.Wheres, buildPermissionWheres(p)...) ctx := context.Background() result, err := manager.List(ctx, listOption) if err != nil { return fmt.Errorf("failed to list files: %v", err) } return result } // processDelete deletes a file // Args: // - uploaderID: string - the uploader/manager ID // - fileID: string - the file ID // // Returns: bool - success func processDelete(p *process.Process) interface{} { p.ValidateArgNums(2) uploaderID := p.ArgsString(0) fileID := p.ArgsString(1) manager, exists := Managers[uploaderID] if !exists { return fmt.Errorf("uploader not found: %s", uploaderID) } ctx := context.Background() // Get file info first fileInfo, err := manager.Info(ctx, fileID) if err != nil { return fmt.Errorf("file not found: %v", err) } // Check write permission if err := checkFilePermission(p, fileInfo, false); err != nil { return err } // Delete file if err := manager.Delete(ctx, fileID); err != nil { return fmt.Errorf("failed to delete file: %v", err) } return true } // processExists checks if file exists // Args: // - uploaderID: string - the uploader/manager ID // - fileID: string - the file ID // // Returns: bool func processExists(p *process.Process) interface{} { p.ValidateArgNums(2) uploaderID := p.ArgsString(0) fileID := p.ArgsString(1) manager, exists := Managers[uploaderID] if !exists { return fmt.Errorf("uploader not found: %s", uploaderID) } ctx := context.Background() return manager.Exists(ctx, fileID) } // processURL gets file URL // Args: // - uploaderID: string - the uploader/manager ID // - fileID: string - the file ID // // Returns: string - file URL func processURL(p *process.Process) interface{} { p.ValidateArgNums(2) uploaderID := p.ArgsString(0) fileID := p.ArgsString(1) manager, exists := Managers[uploaderID] if !exists { return fmt.Errorf("uploader not found: %s", uploaderID) } ctx := context.Background() // Get file info for permission check fileInfo, err := manager.Info(ctx, fileID) if err != nil { return fmt.Errorf("file not found: %v", err) } // Check permission if err := checkFilePermission(p, fileInfo, true); err != nil { return err } return manager.storage.URL(ctx, fileID) } // processSaveText saves parsed text content for a file // Args: // - uploaderID: string - the uploader/manager ID // - fileID: string - the file ID // - text: string - the text content to save // // Returns: bool - success func processSaveText(p *process.Process) interface{} { p.ValidateArgNums(3) uploaderID := p.ArgsString(0) fileID := p.ArgsString(1) text := p.ArgsString(2) manager, exists := Managers[uploaderID] if !exists { return fmt.Errorf("uploader not found: %s", uploaderID) } ctx := context.Background() // Get file info first to check write permission fileInfo, err := manager.Info(ctx, fileID) if err != nil { return fmt.Errorf("file not found: %v", err) } // Check write permission if err := checkFilePermission(p, fileInfo, false); err != nil { return err } if err := manager.SaveText(ctx, fileID, text); err != nil { return fmt.Errorf("failed to save text: %v", err) } return true } // processGetText gets parsed text content for a file // Args: // - uploaderID: string - the uploader/manager ID // - fileID: string - the file ID // - fullContent: bool (optional) - whether to get full content (default: false, returns preview) // // Returns: string - text content func processGetText(p *process.Process) interface{} { p.ValidateArgNums(2) uploaderID := p.ArgsString(0) fileID := p.ArgsString(1) fullContent := false if p.NumOfArgs() > 2 { fullContent = p.ArgsBool(2) } manager, exists := Managers[uploaderID] if !exists { return fmt.Errorf("uploader not found: %s", uploaderID) } ctx := context.Background() // Get file info for permission check fileInfo, err := manager.Info(ctx, fileID) if err != nil { return fmt.Errorf("file not found: %v", err) } // Check permission if err := checkFilePermission(p, fileInfo, true); err != nil { return err } text, err := manager.GetText(ctx, fileID, fullContent) if err != nil { return fmt.Errorf("failed to get text: %v", err) } return text } // ============ Helper Functions ============ // parseDataURI parses content as either: // 1. Data URI format: data:image/png;base64,xxxxx (decoded from base64) // 2. Plain text: stored as-is with text/plain content type // // Returns content type, data bytes, and error func parseDataURI(content string) (string, []byte, error) { // Handle data URI format: data:image/png;base64,xxxxx if strings.HasPrefix(content, "data:") { // Split by comma to get the data part parts := strings.SplitN(content, ",", 2) if len(parts) != 2 { return "", nil, fmt.Errorf("invalid data URI format") } // Parse the header: data:image/png;base64 header := parts[0] base64Content := parts[1] // Extract content type from header contentType := "application/octet-stream" header = strings.TrimPrefix(header, "data:") headerParts := strings.Split(header, ";") if len(headerParts) > 0 && headerParts[0] != "" { contentType = headerParts[0] } // Decode base64 data, err := base64.StdEncoding.DecodeString(base64Content) if err != nil { return "", nil, fmt.Errorf("failed to decode base64: %v", err) } return contentType, data, nil } // Plain text content - store as-is return "text/plain", []byte(content), nil } // generateFilename generates a filename based on content type func generateFilename(contentType string) string { // Get extension from content type exts, err := mime.ExtensionsByType(contentType) if err == nil && len(exts) > 0 { return "file" + exts[0] } // Fallback for common types switch contentType { case "image/png": return "file.png" case "image/jpeg": return "file.jpg" case "image/gif": return "file.gif" case "image/webp": return "file.webp" case "text/plain": return "file.txt" case "application/pdf": return "file.pdf" case "application/json": return "file.json" default: return "file.bin" } } // createUploadOption creates UploadOption from process args func createUploadOption(p *process.Process, filename string) UploadOption { option := UploadOption{ OriginalFilename: filename, } // Parse option from fourth argument if provided if p.NumOfArgs() > 3 { optionRaw := p.ArgsMap(3) optionMap := maps.MapOf(optionRaw).Dot() // Groups if groups, ok := optionMap.Get("groups").([]interface{}); ok { for _, g := range groups { if gs, ok := g.(string); ok { option.Groups = append(option.Groups, gs) } } } else if groupsStr, ok := optionMap.Get("groups").(string); ok { option.Groups = strings.Split(groupsStr, ",") for i := range option.Groups { option.Groups[i] = strings.TrimSpace(option.Groups[i]) } } // Gzip if gzip, ok := optionMap.Get("gzip").(bool); ok { option.Gzip = gzip } // Compress image if compress, ok := optionMap.Get("compress_image").(bool); ok { option.CompressImage = compress } if size := any.Of(optionMap.Get("compress_size")).CInt(); size > 0 { option.CompressSize = size } // Public/Share if public, ok := optionMap.Get("public").(bool); ok { option.Public = public } if share, ok := optionMap.Get("share").(string); ok { option.Share = share } } // Set permission fields from process.Authorized if p.Authorized != nil { option.YaoCreatedBy = p.Authorized.UserID option.YaoTeamID = p.Authorized.TeamID option.YaoTenantID = p.Authorized.TenantID } return option } // createFileHeader creates a FileHeader from parameters func createFileHeader(filename, contentType string, size int64) *FileHeader { header := &multipart.FileHeader{ Filename: filename, Size: size, Header: make(textproto.MIMEHeader), } header.Header.Set("Content-Type", contentType) // Set extension from filename if ext := filepath.Ext(filename); ext != "" { header.Header.Set("Content-Extension", ext) } return &FileHeader{FileHeader: header} } // checkFilePermission checks if user has permission to access the file // readable: true for read permission, false for write permission func checkFilePermission(p *process.Process, fileInfo *File, readable bool) error { auth := p.Authorized // No auth info - allow access (for non-authenticated operations) if auth == nil { return nil } // No constraints - allow access if !auth.Constraints.TeamOnly && !auth.Constraints.OwnerOnly { return nil } // Public files are readable by everyone if readable && fileInfo.Public { return nil } // Combined Team and Owner permission validation if auth.Constraints.TeamOnly && auth.Constraints.OwnerOnly { if fileInfo.YaoCreatedBy == auth.UserID && fileInfo.YaoTeamID == auth.TeamID { return nil } } // Owner only permission validation if auth.Constraints.OwnerOnly { if fileInfo.YaoCreatedBy != "" && fileInfo.YaoCreatedBy == auth.UserID { return nil } } // Team only permission validation if auth.Constraints.TeamOnly { switch fileInfo.Share { case "team": if fileInfo.YaoTeamID == auth.TeamID { return nil } case "private": if fileInfo.YaoCreatedBy == auth.UserID { return nil } } } return fmt.Errorf("forbidden: no permission to access file") } // buildPermissionWheres builds where clauses for permission filtering func buildPermissionWheres(p *process.Process) []model.QueryWhere { auth := p.Authorized if auth == nil { return nil } // No constraints - no additional filtering needed if !auth.Constraints.TeamOnly && !auth.Constraints.OwnerOnly { return nil } var wheres []model.QueryWhere // Team only - User can access: // 1. Public files (public = true) // 2. Files in their team where: // - They uploaded the file (__yao_created_by matches) // - OR the file is shared with team (share = "team") if auth.Constraints.TeamOnly && auth.TeamID != "" { wheres = append(wheres, model.QueryWhere{ Wheres: []model.QueryWhere{ {Column: "public", Value: true, Method: "orwhere"}, {Wheres: []model.QueryWhere{ {Column: "__yao_team_id", Value: auth.TeamID}, {Wheres: []model.QueryWhere{ {Column: "__yao_created_by", Value: auth.UserID}, {Column: "share", Value: "team", Method: "orwhere"}, }}, }, Method: "orwhere"}, }, }) return wheres } // Owner only - User can access: // 1. Public files (public = true) // 2. Files they uploaded where: // - __yao_team_id is null (not team files) // - __yao_created_by matches their user ID if auth.Constraints.OwnerOnly && auth.UserID != "" { wheres = append(wheres, model.QueryWhere{ Wheres: []model.QueryWhere{ {Column: "public", Value: true, Method: "orwhere"}, {Wheres: []model.QueryWhere{ {Column: "__yao_team_id", OP: "null"}, {Column: "__yao_created_by", Value: auth.UserID}, }, Method: "orwhere"}, }, }) return wheres } return wheres } // processZip packages multiple attachment files into a single zip archive // and uploads it back to the same uploader storage. // // Args: // - uploaderID: string - the uploader/manager ID (e.g. "__yao.attachment") // - fileIDs: []string - list of file IDs to package // - zipFilename: string - output zip filename (e.g. "my-notes.zip") // - option: map (optional) - upload options (groups, gzip, public, share) // // Returns: *File - uploaded zip file info (file_id, filename, bytes, etc.) // // Permission: checks read permission on each source file via p.Authorized. // The uploaded zip inherits permission fields (created_by, team_id, tenant_id) // from the process context automatically via createUploadOption. // // Example: // // Process("attachment.Zip", "__yao.attachment", ["abc123", "def456"], "archive.zip") // Process("attachment.Zip", "__yao.attachment", ["abc123"], "archive.zip", {"public": true}) func processZip(p *process.Process) interface{} { p.ValidateArgNums(3) uploaderID := p.ArgsString(0) fileIDs := toStringSlice(p.Args[1]) zipFilename := p.ArgsString(2) if len(fileIDs) == 0 { return fmt.Errorf("attachment.Zip: fileIDs is empty") } if zipFilename == "" { zipFilename = "archive.zip" } // Ensure .zip extension if !strings.HasSuffix(strings.ToLower(zipFilename), ".zip") { zipFilename += ".zip" } // Get manager manager, exists := Managers[uploaderID] if !exists { return fmt.Errorf("uploader not found: %s", uploaderID) } ctx := context.Background() // Build zip in memory buf := new(bytes.Buffer) zipWriter := zip.NewWriter(buf) usedNames := map[string]int{} // for deduplication for _, fid := range fileIDs { // Get file info fileInfo, err := manager.Info(ctx, fid) if err != nil { return fmt.Errorf("attachment.Zip: failed to get file info for %s: %v", fid, err) } // Check read permission if err := checkFilePermission(p, fileInfo, true); err != nil { return fmt.Errorf("attachment.Zip: permission denied for file %s: %v", fid, err) } // Read file content data, err := manager.Read(ctx, fid) if err != nil { return fmt.Errorf("attachment.Zip: failed to read file %s: %v", fid, err) } // Deduplicate filename name := dedupFilename(fileInfo.Filename, usedNames) // Write to zip w, err := zipWriter.Create(name) if err != nil { return fmt.Errorf("attachment.Zip: failed to create zip entry %s: %v", name, err) } if _, err := w.Write(data); err != nil { return fmt.Errorf("attachment.Zip: failed to write zip entry %s: %v", name, err) } } if err := zipWriter.Close(); err != nil { return fmt.Errorf("attachment.Zip: failed to close zip writer: %v", err) } // Upload the zip file header := createFileHeader(zipFilename, "application/zip", int64(buf.Len())) option := createUploadOption(p, zipFilename) file, err := manager.Upload(ctx, header, buf, option) if err != nil { return fmt.Errorf("attachment.Zip: failed to upload zip: %v", err) } return file } // toStringSlice converts []interface{} to []string func toStringSlice(v interface{}) []string { switch val := v.(type) { case []string: return val case []interface{}: result := make([]string, 0, len(val)) for _, item := range val { if s, ok := item.(string); ok { result = append(result, s) } } return result default: return nil } } // dedupFilename ensures unique filenames within a zip archive. // If "photo.png" already exists, returns "photo(1).png", "photo(2).png", etc. func dedupFilename(name string, used map[string]int) string { if name == "" { name = "file" } lower := strings.ToLower(name) count, exists := used[lower] if !exists { used[lower] = 1 return name } // Generate deduplicated name ext := filepath.Ext(name) base := strings.TrimSuffix(name, ext) used[lower] = count + 1 return fmt.Sprintf("%s(%d)%s", base, count, ext) } ================================================ FILE: attachment/process_test.go ================================================ package attachment import ( "encoding/base64" "fmt" "strings" "testing" "time" "github.com/yaoapp/gou/process" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func TestProcessSave(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Register default uploader for testing manager, err := RegisterDefault("data.local") if err != nil { t.Fatalf("Failed to register manager: %v", err) } _ = manager // Test 1: Save with data URI format t.Run("SaveWithDataURI", func(t *testing.T) { content := "Hello, World!" base64Content := base64.StdEncoding.EncodeToString([]byte(content)) dataURI := fmt.Sprintf("data:text/plain;base64,%s", base64Content) p := process.New("attachment.Save", "data.local", dataURI, "hello.txt") result := processSave(p) if err, ok := result.(error); ok { t.Fatalf("Failed to save file: %v", err) } file, ok := result.(*File) if !ok { t.Fatalf("Expected *File, got %T", result) } if file.ID == "" { t.Error("File ID should not be empty") } if file.Filename != "hello.txt" { t.Errorf("Expected filename 'hello.txt', got '%s'", file.Filename) } if !strings.HasPrefix(file.ContentType, "text/plain") { t.Errorf("Expected content type 'text/plain', got '%s'", file.ContentType) } t.Logf("Saved file - ID: %s, Filename: %s, ContentType: %s", file.ID, file.Filename, file.ContentType) }) // Test 2: Save with plain base64 (no data URI header) - use text/plain to pass allowed types t.Run("SaveWithPlainBase64", func(t *testing.T) { content := "Plain base64 content" base64Content := base64.StdEncoding.EncodeToString([]byte(content)) // Without data URI header, we need to provide a filename with allowed extension // or use data URI format. Let's test with text file extension. dataURI := fmt.Sprintf("data:text/plain;base64,%s", base64Content) p := process.New("attachment.Save", "data.local", dataURI, "plain.txt") result := processSave(p) if err, ok := result.(error); ok { t.Fatalf("Failed to save file: %v", err) } file, ok := result.(*File) if !ok { t.Fatalf("Expected *File, got %T", result) } if file.ID == "" { t.Error("File ID should not be empty") } // With data URI, content type should be text/plain if !strings.HasPrefix(file.ContentType, "text/plain") { t.Errorf("Expected content type 'text/plain', got '%s'", file.ContentType) } }) // Test 3: Save image with data URI t.Run("SaveImageDataURI", func(t *testing.T) { // Minimal valid PNG (1x1 pixel transparent PNG) pngBase64 := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" dataURI := fmt.Sprintf("data:image/png;base64,%s", pngBase64) p := process.New("attachment.Save", "data.local", dataURI, "pixel.png") result := processSave(p) if err, ok := result.(error); ok { t.Fatalf("Failed to save image: %v", err) } file, ok := result.(*File) if !ok { t.Fatalf("Expected *File, got %T", result) } if file.ContentType != "image/png" { t.Errorf("Expected content type 'image/png', got '%s'", file.ContentType) } }) // Test 4: Save with options - verify via Info since File struct may not have all fields t.Run("SaveWithOptions", func(t *testing.T) { content := "Content with options" base64Content := base64.StdEncoding.EncodeToString([]byte(content)) dataURI := fmt.Sprintf("data:text/plain;base64,%s", base64Content) options := map[string]interface{}{ "groups": []interface{}{"test", "unit"}, "public": true, "share": "team", } p := process.New("attachment.Save", "data.local", dataURI, "options.txt", options) result := processSave(p) if err, ok := result.(error); ok { t.Fatalf("Failed to save file with options: %v", err) } file, ok := result.(*File) if !ok { t.Fatalf("Expected *File, got %T", result) } // File should be saved successfully if file.ID == "" { t.Error("File ID should not be empty") } // Get info to verify public and share fields infoP := process.New("attachment.Info", "data.local", file.ID) infoResult := processInfo(infoP) info, ok := infoResult.(*File) if !ok { t.Fatalf("Failed to get file info: %v", infoResult) } if !info.Public { t.Error("Expected file to be public") } if info.Share != "team" { t.Errorf("Expected share 'team', got '%s'", info.Share) } t.Logf("Saved file with options - ID: %s, Public: %v, Share: %s", file.ID, info.Public, info.Share) }) // Test 5: Save without filename (auto-generate) t.Run("SaveWithoutFilename", func(t *testing.T) { content := "Auto filename content" base64Content := base64.StdEncoding.EncodeToString([]byte(content)) dataURI := fmt.Sprintf("data:application/json;base64,%s", base64Content) p := process.New("attachment.Save", "data.local", dataURI) result := processSave(p) if err, ok := result.(error); ok { t.Fatalf("Failed to save file: %v", err) } file, ok := result.(*File) if !ok { t.Fatalf("Expected *File, got %T", result) } // Should auto-generate a filename if file.Filename == "" { t.Error("Filename should not be empty") } t.Logf("Auto-generated filename: %s", file.Filename) }) // Test 6: Save with invalid uploader t.Run("SaveWithInvalidUploader", func(t *testing.T) { content := "Test content" base64Content := base64.StdEncoding.EncodeToString([]byte(content)) dataURI := fmt.Sprintf("data:text/plain;base64,%s", base64Content) p := process.New("attachment.Save", "non-existent-uploader", dataURI, "test.txt") result := processSave(p) err, ok := result.(error) if !ok { t.Fatal("Expected error for non-existent uploader") } if !strings.Contains(err.Error(), "uploader not found") { t.Errorf("Expected 'uploader not found' error, got: %s", err.Error()) } }) // Test 7: Save with invalid base64 t.Run("SaveWithInvalidBase64", func(t *testing.T) { invalidDataURI := "data:text/plain;base64,not-valid-base64!!!" p := process.New("attachment.Save", "data.local", invalidDataURI, "invalid.txt") result := processSave(p) _, ok := result.(error) if !ok { t.Fatal("Expected error for invalid base64") } }) // Test 8: Save plain text directly (no data URI) t.Run("SavePlainText", func(t *testing.T) { content := "This is plain text content without data URI encoding." p := process.New("attachment.Save", "data.local", content, "plain-text.txt") result := processSave(p) if err, ok := result.(error); ok { t.Fatalf("Failed to save plain text: %v", err) } file, ok := result.(*File) if !ok { t.Fatalf("Expected *File, got %T", result) } if file.ID == "" { t.Error("File ID should not be empty") } // Content type should be text/plain for plain text if !strings.HasPrefix(file.ContentType, "text/plain") { t.Errorf("Expected content type 'text/plain', got '%s'", file.ContentType) } // Read back and verify content readP := process.New("attachment.Read", "data.local", file.ID) readResult := processRead(readP) dataURI, ok := readResult.(string) if !ok { t.Fatalf("Expected string, got %T: %v", readResult, readResult) } // Decode from data URI parts := strings.SplitN(dataURI, ",", 2) if len(parts) != 2 { t.Fatalf("Invalid data URI format") } decoded, err := base64.StdEncoding.DecodeString(parts[1]) if err != nil { t.Fatalf("Failed to decode: %v", err) } if string(decoded) != content { t.Errorf("Content mismatch: expected %q, got %q", content, string(decoded)) } }) // Test 9: Save Chinese text directly (UTF-8) t.Run("SaveChineseText", func(t *testing.T) { content := "这是一段中文内容,测试UTF-8编码。\n第二行内容。" p := process.New("attachment.Save", "data.local", content, "chinese.txt") result := processSave(p) if err, ok := result.(error); ok { t.Fatalf("Failed to save Chinese text: %v", err) } file, ok := result.(*File) if !ok { t.Fatalf("Expected *File, got %T", result) } // Read back and verify content readP := process.New("attachment.Read", "data.local", file.ID) readResult := processRead(readP) dataURI, ok := readResult.(string) if !ok { t.Fatalf("Expected string, got %T: %v", readResult, readResult) } // Decode from data URI parts := strings.SplitN(dataURI, ",", 2) if len(parts) != 2 { t.Fatalf("Invalid data URI format") } decoded, err := base64.StdEncoding.DecodeString(parts[1]) if err != nil { t.Fatalf("Failed to decode: %v", err) } if string(decoded) != content { t.Errorf("Chinese content mismatch: expected %q, got %q", content, string(decoded)) } }) } func TestProcessRead(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Register default uploader for testing _, err := RegisterDefault("data.local") if err != nil { t.Fatalf("Failed to register manager: %v", err) } // First, save a file to read content := "Content to read" base64Content := base64.StdEncoding.EncodeToString([]byte(content)) dataURI := fmt.Sprintf("data:text/plain;base64,%s", base64Content) saveP := process.New("attachment.Save", "data.local", dataURI, "read-test.txt") saveResult := processSave(saveP) file, ok := saveResult.(*File) if !ok { t.Fatalf("Failed to save test file: %v", saveResult) } // Test 1: Read file as data URI t.Run("ReadAsDataURI", func(t *testing.T) { p := process.New("attachment.Read", "data.local", file.ID) result := processRead(p) if err, ok := result.(error); ok { t.Fatalf("Failed to read file: %v", err) } resultDataURI, ok := result.(string) if !ok { t.Fatalf("Expected string, got %T", result) } // Should return data URI format if !strings.HasPrefix(resultDataURI, "data:text/plain") { t.Errorf("Expected data URI starting with 'data:text/plain', got: %s", resultDataURI[:50]) } if !strings.Contains(resultDataURI, ";base64,") { t.Error("Expected data URI to contain ';base64,'") } // Decode and verify content parts := strings.SplitN(resultDataURI, ",", 2) if len(parts) != 2 { t.Fatal("Invalid data URI format") } decodedContent, err := base64.StdEncoding.DecodeString(parts[1]) if err != nil { t.Fatalf("Failed to decode base64: %v", err) } if string(decodedContent) != content { t.Errorf("Content mismatch. Expected: %s, Got: %s", content, string(decodedContent)) } t.Logf("Read file successfully - Data URI length: %d", len(resultDataURI)) }) // Test 2: Read non-existent file t.Run("ReadNonExistent", func(t *testing.T) { p := process.New("attachment.Read", "data.local", "non-existent-file-id") result := processRead(p) err, ok := result.(error) if !ok { t.Fatal("Expected error for non-existent file") } if !strings.Contains(err.Error(), "file not found") { t.Errorf("Expected 'file not found' error, got: %s", err.Error()) } }) // Test 3: Read with invalid uploader t.Run("ReadWithInvalidUploader", func(t *testing.T) { p := process.New("attachment.Read", "non-existent-uploader", file.ID) result := processRead(p) err, ok := result.(error) if !ok { t.Fatal("Expected error for non-existent uploader") } if !strings.Contains(err.Error(), "uploader not found") { t.Errorf("Expected 'uploader not found' error, got: %s", err.Error()) } }) } func TestProcessInfo(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Register default uploader for testing _, err := RegisterDefault("data.local") if err != nil { t.Fatalf("Failed to register manager: %v", err) } // Save a file with options content := "Info test content" base64Content := base64.StdEncoding.EncodeToString([]byte(content)) dataURI := fmt.Sprintf("data:text/plain;base64,%s", base64Content) options := map[string]interface{}{ "groups": []interface{}{"info", "test"}, "public": true, "share": "team", } saveP := process.New("attachment.Save", "data.local", dataURI, "info-test.txt", options) saveResult := processSave(saveP) file, ok := saveResult.(*File) if !ok { t.Fatalf("Failed to save test file: %v", saveResult) } // Test 1: Get file info t.Run("GetFileInfo", func(t *testing.T) { p := process.New("attachment.Info", "data.local", file.ID) result := processInfo(p) if err, ok := result.(error); ok { t.Fatalf("Failed to get file info: %v", err) } info, ok := result.(*File) if !ok { t.Fatalf("Expected *File, got %T", result) } if info.ID != file.ID { t.Errorf("Expected ID %s, got %s", file.ID, info.ID) } if info.Filename != file.Filename { t.Errorf("Expected filename %s, got %s", file.Filename, info.Filename) } if !strings.HasPrefix(info.ContentType, "text/plain") { t.Errorf("Expected content type 'text/plain', got %s", info.ContentType) } if !info.Public { t.Error("Expected file to be public") } if info.Share != "team" { t.Errorf("Expected share 'team', got %s", info.Share) } t.Logf("File info - ID: %s, Filename: %s, Bytes: %d, Public: %v, Share: %s", info.ID, info.Filename, info.Bytes, info.Public, info.Share) }) // Test 2: Get info for non-existent file t.Run("GetInfoNonExistent", func(t *testing.T) { p := process.New("attachment.Info", "data.local", "non-existent-file-id") result := processInfo(p) err, ok := result.(error) if !ok { t.Fatal("Expected error for non-existent file") } if !strings.Contains(err.Error(), "file not found") { t.Errorf("Expected 'file not found' error, got: %s", err.Error()) } }) } func TestProcessList(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Use unique manager name for test isolation managerName := fmt.Sprintf("data.local.list.%d", time.Now().UnixNano()) _, err := RegisterDefault(managerName) if err != nil { t.Fatalf("Failed to register manager: %v", err) } // Upload multiple test files testFiles := []struct { content string filename string contentType string }{ {"File 1 content", "file1.txt", "text/plain"}, {"File 2 content", "file2.txt", "text/plain"}, {"File 3 content", "file3.txt", "text/plain"}, {`{"key": "value"}`, "data.json", "application/json"}, {"CSV,Data\n1,2", "data.csv", "text/csv"}, } uploadedIDs := make([]string, 0, len(testFiles)) for _, tf := range testFiles { base64Content := base64.StdEncoding.EncodeToString([]byte(tf.content)) dataURI := fmt.Sprintf("data:%s;base64,%s", tf.contentType, base64Content) p := process.New("attachment.Save", managerName, dataURI, tf.filename) result := processSave(p) file, ok := result.(*File) if !ok { t.Fatalf("Failed to save file %s: %v", tf.filename, result) } uploadedIDs = append(uploadedIDs, file.ID) } // Test 1: Basic list t.Run("BasicList", func(t *testing.T) { p := process.New("attachment.List", managerName) result := processList(p) if err, ok := result.(error); ok { t.Fatalf("Failed to list files: %v", err) } listResult, ok := result.(*ListResult) if !ok { t.Fatalf("Expected *ListResult, got %T", result) } if len(listResult.Files) != len(testFiles) { t.Errorf("Expected %d files, got %d", len(testFiles), len(listResult.Files)) } if listResult.Total != int64(len(testFiles)) { t.Errorf("Expected total %d, got %d", len(testFiles), listResult.Total) } t.Logf("List result - Total: %d, Page: %d, PageSize: %d", listResult.Total, listResult.Page, listResult.PageSize) }) // Test 2: List with pagination t.Run("ListWithPagination", func(t *testing.T) { options := map[string]interface{}{ "page": 1, "page_size": 2, } p := process.New("attachment.List", managerName, options) result := processList(p) if err, ok := result.(error); ok { t.Fatalf("Failed to list files with pagination: %v", err) } listResult, ok := result.(*ListResult) if !ok { t.Fatalf("Expected *ListResult, got %T", result) } if len(listResult.Files) != 2 { t.Errorf("Expected 2 files, got %d", len(listResult.Files)) } if listResult.PageSize != 2 { t.Errorf("Expected page size 2, got %d", listResult.PageSize) } if listResult.TotalPages != 3 { // 5 files / 2 per page = 3 pages t.Errorf("Expected 3 total pages, got %d", listResult.TotalPages) } }) // Test 3: List with filters - use content_type wildcard t.Run("ListWithFilters", func(t *testing.T) { options := map[string]interface{}{ "filters": map[string]interface{}{ "content_type": "text/*", }, } p := process.New("attachment.List", managerName, options) result := processList(p) if err, ok := result.(error); ok { t.Fatalf("Failed to list files with filters: %v", err) } listResult, ok := result.(*ListResult) if !ok { t.Fatalf("Expected *ListResult, got %T", result) } // Should find text/plain and text/csv files // Note: The filter implementation may vary, so we just check the call succeeds t.Logf("List with content_type filter - Total: %d files", listResult.Total) }) // Test 4: List with ordering t.Run("ListWithOrdering", func(t *testing.T) { options := map[string]interface{}{ "order_by": "name asc", } p := process.New("attachment.List", managerName, options) result := processList(p) if err, ok := result.(error); ok { t.Fatalf("Failed to list files with ordering: %v", err) } listResult, ok := result.(*ListResult) if !ok { t.Fatalf("Expected *ListResult, got %T", result) } // Verify files are sorted for i := 1; i < len(listResult.Files); i++ { if listResult.Files[i-1].Filename > listResult.Files[i].Filename { t.Errorf("Files are not sorted ascending by name") break } } }) } func TestProcessDelete(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Register default uploader for testing _, err := RegisterDefault("data.local") if err != nil { t.Fatalf("Failed to register manager: %v", err) } // Save a file to delete content := "Content to delete" base64Content := base64.StdEncoding.EncodeToString([]byte(content)) dataURI := fmt.Sprintf("data:text/plain;base64,%s", base64Content) saveP := process.New("attachment.Save", "data.local", dataURI, "delete-test.txt") saveResult := processSave(saveP) file, ok := saveResult.(*File) if !ok { t.Fatalf("Failed to save test file: %v", saveResult) } // Test 1: Delete existing file t.Run("DeleteExistingFile", func(t *testing.T) { // Verify file exists first existsP := process.New("attachment.Exists", "data.local", file.ID) existsResult := processExists(existsP) if exists, ok := existsResult.(bool); !ok || !exists { t.Fatal("File should exist before deletion") } // Delete the file p := process.New("attachment.Delete", "data.local", file.ID) result := processDelete(p) if err, ok := result.(error); ok { t.Fatalf("Failed to delete file: %v", err) } success, ok := result.(bool) if !ok || !success { t.Errorf("Expected true, got %v", result) } // Verify file no longer exists existsP2 := process.New("attachment.Exists", "data.local", file.ID) existsResult2 := processExists(existsP2) if exists, ok := existsResult2.(bool); ok && exists { t.Error("File should not exist after deletion") } }) // Test 2: Delete non-existent file t.Run("DeleteNonExistent", func(t *testing.T) { p := process.New("attachment.Delete", "data.local", "non-existent-file-id") result := processDelete(p) err, ok := result.(error) if !ok { t.Fatal("Expected error for non-existent file") } if !strings.Contains(err.Error(), "file not found") { t.Errorf("Expected 'file not found' error, got: %s", err.Error()) } }) } func TestProcessExists(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Register default uploader for testing _, err := RegisterDefault("data.local") if err != nil { t.Fatalf("Failed to register manager: %v", err) } // Save a file content := "Exists test content" base64Content := base64.StdEncoding.EncodeToString([]byte(content)) dataURI := fmt.Sprintf("data:text/plain;base64,%s", base64Content) saveP := process.New("attachment.Save", "data.local", dataURI, "exists-test.txt") saveResult := processSave(saveP) file, ok := saveResult.(*File) if !ok { t.Fatalf("Failed to save test file: %v", saveResult) } // Test 1: Existing file t.Run("FileExists", func(t *testing.T) { p := process.New("attachment.Exists", "data.local", file.ID) result := processExists(p) exists, ok := result.(bool) if !ok { t.Fatalf("Expected bool, got %T", result) } if !exists { t.Error("File should exist") } }) // Test 2: Non-existent file t.Run("FileNotExists", func(t *testing.T) { p := process.New("attachment.Exists", "data.local", "non-existent-file-id") result := processExists(p) exists, ok := result.(bool) if !ok { t.Fatalf("Expected bool, got %T", result) } if exists { t.Error("File should not exist") } }) } func TestProcessURL(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Register default uploader for testing _, err := RegisterDefault("data.local") if err != nil { t.Fatalf("Failed to register manager: %v", err) } // Save a file content := "URL test content" base64Content := base64.StdEncoding.EncodeToString([]byte(content)) dataURI := fmt.Sprintf("data:text/plain;base64,%s", base64Content) saveP := process.New("attachment.Save", "data.local", dataURI, "url-test.txt") saveResult := processSave(saveP) file, ok := saveResult.(*File) if !ok { t.Fatalf("Failed to save test file: %v", saveResult) } // Test 1: Get URL t.Run("GetURL", func(t *testing.T) { p := process.New("attachment.URL", "data.local", file.ID) result := processURL(p) if err, ok := result.(error); ok { t.Fatalf("Failed to get URL: %v", err) } url, ok := result.(string) if !ok { t.Fatalf("Expected string, got %T", result) } if url == "" { t.Error("URL should not be empty") } t.Logf("File URL: %s", url) }) // Test 2: Get URL for non-existent file t.Run("GetURLNonExistent", func(t *testing.T) { p := process.New("attachment.URL", "data.local", "non-existent-file-id") result := processURL(p) err, ok := result.(error) if !ok { t.Fatal("Expected error for non-existent file") } if !strings.Contains(err.Error(), "file not found") { t.Errorf("Expected 'file not found' error, got: %s", err.Error()) } }) } func TestProcessSaveTextAndGetText(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Register default uploader for testing _, err := RegisterDefault("data.local") if err != nil { t.Fatalf("Failed to register manager: %v", err) } // Save a file content := "Original file content" base64Content := base64.StdEncoding.EncodeToString([]byte(content)) dataURI := fmt.Sprintf("data:text/plain;base64,%s", base64Content) saveP := process.New("attachment.Save", "data.local", dataURI, "text-test.txt") saveResult := processSave(saveP) file, ok := saveResult.(*File) if !ok { t.Fatalf("Failed to save test file: %v", saveResult) } // Test 1: Get text from file without saved text (should be empty) t.Run("GetTextEmpty", func(t *testing.T) { p := process.New("attachment.GetText", "data.local", file.ID) result := processGetText(p) if err, ok := result.(error); ok { t.Fatalf("Failed to get text: %v", err) } text, ok := result.(string) if !ok { t.Fatalf("Expected string, got %T", result) } if text != "" { t.Errorf("Expected empty text, got: %s", text) } }) // Test 2: Save text and retrieve t.Run("SaveTextAndRetrieve", func(t *testing.T) { parsedText := "This is the parsed/extracted text content from the file." // Save text saveTextP := process.New("attachment.SaveText", "data.local", file.ID, parsedText) saveTextResult := processSaveText(saveTextP) if err, ok := saveTextResult.(error); ok { t.Fatalf("Failed to save text: %v", err) } success, ok := saveTextResult.(bool) if !ok || !success { t.Errorf("Expected true, got %v", saveTextResult) } // Retrieve text getTextP := process.New("attachment.GetText", "data.local", file.ID) getTextResult := processGetText(getTextP) if err, ok := getTextResult.(error); ok { t.Fatalf("Failed to get text: %v", err) } retrievedText, ok := getTextResult.(string) if !ok { t.Fatalf("Expected string, got %T", getTextResult) } if retrievedText != parsedText { t.Errorf("Text mismatch. Expected: %s, Got: %s", parsedText, retrievedText) } t.Logf("Saved and retrieved text: %s", retrievedText) }) // Test 3: Get full text vs preview t.Run("GetTextFullVsPreview", func(t *testing.T) { // Save a long text longText := strings.Repeat("This is a long text content. ", 200) // > 2000 chars saveTextP := process.New("attachment.SaveText", "data.local", file.ID, longText) saveTextResult := processSaveText(saveTextP) if err, ok := saveTextResult.(error); ok { t.Fatalf("Failed to save long text: %v", err) } // Get preview (default) previewP := process.New("attachment.GetText", "data.local", file.ID) previewResult := processGetText(previewP) previewText, _ := previewResult.(string) // Preview should be 2000 runes if len([]rune(previewText)) != 2000 { t.Errorf("Preview should be 2000 runes, got %d", len([]rune(previewText))) } // Get full content fullP := process.New("attachment.GetText", "data.local", file.ID, true) fullResult := processGetText(fullP) fullText, _ := fullResult.(string) if fullText != longText { t.Errorf("Full text length mismatch. Expected: %d, Got: %d", len(longText), len(fullText)) } }) // Test 4: Save/Get text for non-existent file t.Run("SaveTextNonExistent", func(t *testing.T) { p := process.New("attachment.SaveText", "data.local", "non-existent-id", "some text") result := processSaveText(p) err, ok := result.(error) if !ok { t.Fatal("Expected error for non-existent file") } if !strings.Contains(err.Error(), "file not found") { t.Errorf("Expected 'file not found' error, got: %s", err.Error()) } }) } func TestProcessWithAuthorizedPermission(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() // Register default uploader for testing _, err := RegisterDefault("data.local") if err != nil { t.Fatalf("Failed to register manager: %v", err) } // Test 1: Save with Authorized info - verify via database query since File struct // does not expose these fields in JSON (they are marked with json:"-") t.Run("SaveWithAuthorizedInfo", func(t *testing.T) { content := "Content with permission" base64Content := base64.StdEncoding.EncodeToString([]byte(content)) dataURI := fmt.Sprintf("data:text/plain;base64,%s", base64Content) p := process.New("attachment.Save", "data.local", dataURI, "perm-test.txt") // Set authorized info p.WithAuthorized(process.AuthorizedInfo{ UserID: "user123", TeamID: "team456", TenantID: "tenant789", }) result := processSave(p) if err, ok := result.(error); ok { t.Fatalf("Failed to save file: %v", err) } file, ok := result.(*File) if !ok { t.Fatalf("Expected *File, got %T", result) } // File should be saved successfully if file.ID == "" { t.Error("File ID should not be empty") } // Note: The YaoCreatedBy, YaoTeamID, YaoTenantID fields in File struct // are marked with json:"-" and may not be populated in the returned struct. // The permission fields are stored in the database during upload via UploadOption. // To verify, we would need to query the database directly. t.Logf("File saved with authorized info - ID: %s", file.ID) }) // Test 2: Save without Authorized (should still work) t.Run("SaveWithoutAuthorized", func(t *testing.T) { content := "Content without permission" base64Content := base64.StdEncoding.EncodeToString([]byte(content)) dataURI := fmt.Sprintf("data:text/plain;base64,%s", base64Content) p := process.New("attachment.Save", "data.local", dataURI, "no-perm-test.txt") // Don't set authorized info result := processSave(p) if err, ok := result.(error); ok { t.Fatalf("Failed to save file: %v", err) } file, ok := result.(*File) if !ok { t.Fatalf("Expected *File, got %T", result) } // File should be saved successfully without permission fields if file.ID == "" { t.Error("File ID should not be empty") } t.Logf("File saved without permissions - ID: %s", file.ID) }) } func TestParseDataURI(t *testing.T) { // Test 1: Valid data URI with content type t.Run("ValidDataURI", func(t *testing.T) { content := "Hello, World!" base64Content := base64.StdEncoding.EncodeToString([]byte(content)) dataURI := fmt.Sprintf("data:text/plain;base64,%s", base64Content) contentType, data, err := parseDataURI(dataURI) if err != nil { t.Fatalf("Failed to parse data URI: %v", err) } if contentType != "text/plain" { t.Errorf("Expected content type 'text/plain', got '%s'", contentType) } if string(data) != content { t.Errorf("Expected content '%s', got '%s'", content, string(data)) } }) // Test 2: Plain text (no data URI header) - treated as plain text, not base64 t.Run("PlainText", func(t *testing.T) { content := "Plain text content" contentType, data, err := parseDataURI(content) if err != nil { t.Fatalf("Failed to parse plain text: %v", err) } if contentType != "text/plain" { t.Errorf("Expected content type 'text/plain', got '%s'", contentType) } if string(data) != content { t.Errorf("Expected content '%s', got '%s'", content, string(data)) } }) // Test 3: Data URI with image t.Run("ImageDataURI", func(t *testing.T) { pngBase64 := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" dataURI := fmt.Sprintf("data:image/png;base64,%s", pngBase64) contentType, _, err := parseDataURI(dataURI) if err != nil { t.Fatalf("Failed to parse image data URI: %v", err) } if contentType != "image/png" { t.Errorf("Expected content type 'image/png', got '%s'", contentType) } }) // Test 4: Invalid base64 t.Run("InvalidBase64", func(t *testing.T) { dataURI := "data:text/plain;base64,not-valid!!!" _, _, err := parseDataURI(dataURI) if err == nil { t.Fatal("Expected error for invalid base64") } }) // Test 5: Invalid data URI format t.Run("InvalidDataURIFormat", func(t *testing.T) { dataURI := "data:text/plain" // Missing base64 part _, _, err := parseDataURI(dataURI) if err == nil { t.Fatal("Expected error for invalid data URI format") } }) } func TestGenerateFilename(t *testing.T) { // Note: mime.ExtensionsByType may return different extensions on different systems // (e.g., Linux may return .jfif for image/jpeg, .asc for text/plain) // So we verify the filename has a proper extension format and is not empty testCases := []struct { contentType string expectedPrefix string }{ {"image/png", "file"}, {"image/jpeg", "file"}, {"image/gif", "file"}, {"image/webp", "file"}, {"text/plain", "file"}, {"application/pdf", "file"}, {"application/json", "file"}, {"application/octet-stream", "file"}, {"unknown/type", "file"}, } for _, tc := range testCases { t.Run(tc.contentType, func(t *testing.T) { filename := generateFilename(tc.contentType) // Check prefix if !strings.HasPrefix(filename, tc.expectedPrefix) { t.Errorf("For content type '%s', expected prefix '%s', got '%s'", tc.contentType, tc.expectedPrefix, filename) } // Check filename has an extension (starts with dot and has at least one character) dotIndex := strings.LastIndex(filename, ".") if dotIndex == -1 || dotIndex == len(filename)-1 { t.Errorf("For content type '%s', expected filename with extension, got '%s'", tc.contentType, filename) } // Extension should not be empty ext := filename[dotIndex:] if len(ext) < 2 { t.Errorf("For content type '%s', expected non-empty extension, got '%s'", tc.contentType, ext) } }) } } ================================================ FILE: attachment/s3/storage.go ================================================ package s3 import ( "bytes" "compress/gzip" "context" "fmt" "image" "image/jpeg" "image/png" "io" "mime" "net/http" "os" "path/filepath" "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" ) // DefaultExpiration default expiration time for presigned URLs (5 minutes) const DefaultExpiration = 5 * time.Minute // MaxImageSize maximum image size (1920x1080) const MaxImageSize = 1920 // Storage the S3 storage driver type Storage struct { Endpoint string `json:"endpoint" yaml:"endpoint"` Region string `json:"region" yaml:"region"` Key string `json:"key" yaml:"key"` Secret string `json:"secret" yaml:"secret"` Bucket string `json:"bucket" yaml:"bucket"` Expiration time.Duration `json:"expiration" yaml:"expiration"` CacheDir string `json:"cache_dir" yaml:"cache_dir"` client *s3.Client prefix string compression bool } // New create a new S3 storage func New(options map[string]interface{}) (*Storage, error) { storage := &Storage{ Region: "auto", Expiration: DefaultExpiration, compression: true, } if endpoint, ok := options["endpoint"].(string); ok { storage.Endpoint = endpoint } if region, ok := options["region"].(string); ok { storage.Region = region } if key, ok := options["key"].(string); ok { storage.Key = key } if secret, ok := options["secret"].(string); ok { storage.Secret = secret } if bucket, ok := options["bucket"].(string); ok { storage.Bucket = bucket } if prefix, ok := options["prefix"].(string); ok { storage.prefix = prefix } if cacheDir, ok := options["cache_dir"].(string); ok { storage.CacheDir = cacheDir } else { // Use system temp directory as default storage.CacheDir = os.TempDir() } if exp, ok := options["expiration"].(time.Duration); ok { storage.Expiration = exp } if compression, ok := options["compression"].(bool); ok { storage.compression = compression } // Validate required fields if storage.Key == "" || storage.Secret == "" { return nil, fmt.Errorf("key and secret are required") } if storage.Bucket == "" { return nil, fmt.Errorf("bucket is required") } // Create S3 client opts := s3.Options{ Region: storage.Region, Credentials: credentials.NewStaticCredentialsProvider(storage.Key, storage.Secret, ""), UsePathStyle: true, } if storage.Endpoint != "" { // Remove bucket name from endpoint if present endpoint := storage.Endpoint if strings.Contains(endpoint, "/"+storage.Bucket) { endpoint = strings.TrimSuffix(endpoint, "/"+storage.Bucket) } opts.BaseEndpoint = aws.String(endpoint) } storage.client = s3.New(opts) // Ensure cache directory exists if err := os.MkdirAll(storage.CacheDir, 0755); err != nil { return nil, fmt.Errorf("failed to create cache directory %s: %w", storage.CacheDir, err) } return storage, nil } // Upload upload file to S3 func (storage *Storage) Upload(ctx context.Context, path string, reader io.Reader, contentType string) (string, error) { if storage.client == nil { return "", fmt.Errorf("s3 client not initialized") } key := filepath.Join(storage.prefix, path) // Upload file _, err := storage.client.PutObject(ctx, &s3.PutObjectInput{ Bucket: aws.String(storage.Bucket), Key: aws.String(key), Body: reader, ContentType: aws.String(contentType), }) if err != nil { return "", fmt.Errorf("failed to upload file %s: %w", path, err) } return path, nil } // UploadChunk uploads a chunk of a file to S3 func (storage *Storage) UploadChunk(ctx context.Context, path string, chunkIndex int, reader io.Reader, contentType string) error { if storage.client == nil { return fmt.Errorf("s3 client not initialized") } // Store chunks with a special prefix chunkKey := filepath.Join(storage.prefix, ".chunks", path, fmt.Sprintf("chunk_%d", chunkIndex)) _, err := storage.client.PutObject(ctx, &s3.PutObjectInput{ Bucket: aws.String(storage.Bucket), Key: aws.String(chunkKey), Body: reader, ContentType: aws.String(contentType), }) if err != nil { return fmt.Errorf("failed to upload chunk %s %d: %w", path, chunkIndex, err) } return nil } // MergeChunks merges all chunks into the final file in S3 func (storage *Storage) MergeChunks(ctx context.Context, path string, totalChunks int) error { if storage.client == nil { return fmt.Errorf("s3 client not initialized") } finalKey := filepath.Join(storage.prefix, path) // Create a buffer to hold the merged content var mergedContent bytes.Buffer var contentType string // Download and merge chunks in order for i := 0; i < totalChunks; i++ { chunkKey := filepath.Join(storage.prefix, ".chunks", path, fmt.Sprintf("chunk_%d", i)) result, err := storage.client.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(storage.Bucket), Key: aws.String(chunkKey), }) if err != nil { return fmt.Errorf("failed to get chunk %d: %w", i, err) } // Get content type from the first chunk if i == 0 && result.ContentType != nil { contentType = *result.ContentType } _, err = io.Copy(&mergedContent, result.Body) result.Body.Close() if err != nil { return fmt.Errorf("failed to copy chunk %s %d: %w", path, i, err) } } // Default content type if not found if contentType == "" { contentType = "application/octet-stream" } // Upload the merged content as the final file with proper content type _, err := storage.client.PutObject(ctx, &s3.PutObjectInput{ Bucket: aws.String(storage.Bucket), Key: aws.String(finalKey), Body: bytes.NewReader(mergedContent.Bytes()), ContentType: aws.String(contentType), }) if err != nil { return fmt.Errorf("failed to upload merged file %s: %w", path, err) } // Clean up chunks for i := 0; i < totalChunks; i++ { chunkKey := filepath.Join(storage.prefix, ".chunks", path, fmt.Sprintf("chunk_%d", i)) storage.client.DeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: aws.String(storage.Bucket), Key: aws.String(chunkKey), }) } return nil } // Reader read file from S3 func (storage *Storage) Reader(ctx context.Context, path string) (io.ReadCloser, error) { if storage.client == nil { return nil, fmt.Errorf("s3 client not initialized") } key := filepath.Join(storage.prefix, path) result, err := storage.client.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(storage.Bucket), Key: aws.String(key), }) if err != nil { return nil, fmt.Errorf("failed to get file %s: %w", path, err) } // If the file is a gzip file, decompress it if strings.HasSuffix(path, ".gz") { reader, err := gzip.NewReader(result.Body) if err != nil { return nil, err } return reader, nil } return result.Body, nil } // Download download file from S3 func (storage *Storage) Download(ctx context.Context, path string) (io.ReadCloser, string, error) { if storage.client == nil { return nil, "", fmt.Errorf("s3 client not initialized") } key := filepath.Join(storage.prefix, path) // Get object result, err := storage.client.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(storage.Bucket), Key: aws.String(key), }) if err != nil { return nil, "", fmt.Errorf("failed to download file %s: %w", path, err) } contentType := "application/octet-stream" if result.ContentType != nil { contentType = *result.ContentType } // Try to detect content type from file extension ext := filepath.Ext(strings.TrimSuffix(path, ".gz")) switch strings.ToLower(ext) { case ".txt": contentType = "text/plain" case ".html": contentType = "text/html" case ".css": contentType = "text/css" case ".js": contentType = "application/javascript" case ".json": contentType = "application/json" case ".jpg", ".jpeg": contentType = "image/jpeg" case ".png": contentType = "image/png" case ".gif": contentType = "image/gif" case ".pdf": contentType = "application/pdf" case ".mp4": contentType = "video/mp4" case ".mp3": contentType = "audio/mpeg" case ".wav": contentType = "audio/wav" case ".ogg": contentType = "audio/ogg" case ".webm": contentType = "video/webm" case ".webp": contentType = "image/webp" case ".zip": } // If the file is a gzip file, decompress it if strings.HasSuffix(path, ".gz") { reader, err := gzip.NewReader(result.Body) if err != nil { return nil, "", err } return reader, contentType, nil } return result.Body, contentType, nil } // GetContent gets file content as bytes func (storage *Storage) GetContent(ctx context.Context, path string) ([]byte, error) { reader, err := storage.Reader(ctx, path) if err != nil { return nil, err } defer reader.Close() return io.ReadAll(reader) } // URL get file url with expiration func (storage *Storage) URL(ctx context.Context, path string) string { if storage.client == nil { return "" } key := filepath.Join(storage.prefix, path) presignClient := s3.NewPresignClient(storage.client) request, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(storage.Bucket), Key: aws.String(key), }, s3.WithPresignExpires(storage.Expiration)) if err != nil { return "" } return request.URL } // Exists checks if a file exists in S3 func (storage *Storage) Exists(ctx context.Context, path string) bool { if storage.client == nil { return false } key := filepath.Join(storage.prefix, path) _, err := storage.client.HeadObject(ctx, &s3.HeadObjectInput{ Bucket: aws.String(storage.Bucket), Key: aws.String(key), }) return err == nil } // Delete deletes a file from S3 func (storage *Storage) Delete(ctx context.Context, path string) error { if storage.client == nil { return fmt.Errorf("s3 client not initialized") } key := filepath.Join(storage.prefix, path) _, err := storage.client.DeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: aws.String(storage.Bucket), Key: aws.String(key), }) if err != nil { return fmt.Errorf("failed to delete file: %w", err) } return nil } func (storage *Storage) makeID(filename string, ext string) string { date := time.Now().Format("20060102") name := strings.TrimSuffix(filepath.Base(filename), ext) return fmt.Sprintf("%s/%s-%d%s", date, name, time.Now().UnixNano(), ext) } // isImage checks if the content type is an image func isImage(contentType string) bool { return strings.HasPrefix(contentType, "image/") } // compressImage compresses the image while maintaining aspect ratio func compressImage(data []byte, contentType string) ([]byte, error) { // Decode image img, _, err := image.Decode(bytes.NewReader(data)) if err != nil { return nil, fmt.Errorf("failed to decode image: %w", err) } // Calculate new dimensions bounds := img.Bounds() width := bounds.Dx() height := bounds.Dy() var newWidth, newHeight int if width > height { if width > MaxImageSize { newWidth = MaxImageSize newHeight = int(float64(height) * (float64(MaxImageSize) / float64(width))) } else { return data, nil // No need to resize } } else { if height > MaxImageSize { newHeight = MaxImageSize newWidth = int(float64(width) * (float64(MaxImageSize) / float64(height))) } else { return data, nil // No need to resize } } // Create new image with new dimensions newImg := image.NewRGBA(image.Rect(0, 0, newWidth, newHeight)) // Scale the image using bilinear interpolation for y := 0; y < newHeight; y++ { for x := 0; x < newWidth; x++ { srcX := float64(x) * float64(width) / float64(newWidth) srcY := float64(y) * float64(height) / float64(newHeight) newImg.Set(x, y, img.At(int(srcX), int(srcY))) } } // Encode image var buf bytes.Buffer switch contentType { case "image/jpeg": err = jpeg.Encode(&buf, newImg, &jpeg.Options{Quality: 85}) case "image/png": err = png.Encode(&buf, newImg) default: return data, nil // Unsupported format, return original } if err != nil { return nil, fmt.Errorf("failed to encode image: %w", err) } return buf.Bytes(), nil } // LocalPath downloads the file to cache directory and returns absolute path with content type func (storage *Storage) LocalPath(ctx context.Context, path string) (string, string, error) { if storage.client == nil { return "", "", fmt.Errorf("s3 client not initialized") } // Create cache file path using the same structure as storage path cacheFilePath := filepath.Join(storage.CacheDir, "s3_cache", path) // Create directory for cache file dir := filepath.Dir(cacheFilePath) if err := os.MkdirAll(dir, 0755); err != nil { return "", "", fmt.Errorf("failed to create cache directory: %w", err) } // Check if file already exists in cache and is not outdated if _, err := os.Stat(cacheFilePath); err == nil { // File exists in cache, detect content type and return contentType, err := detectContentType(cacheFilePath) if err != nil { return "", "", fmt.Errorf("failed to detect content type: %w", err) } return cacheFilePath, contentType, nil } // Download file from S3 to cache key := filepath.Join(storage.prefix, path) result, err := storage.client.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(storage.Bucket), Key: aws.String(key), }) if err != nil { return "", "", fmt.Errorf("failed to download file %s: %w", path, err) } defer result.Body.Close() // Create cache file cacheFile, err := os.Create(cacheFilePath) if err != nil { return "", "", fmt.Errorf("failed to create cache file: %w", err) } defer cacheFile.Close() // Handle gzipped files - decompress during download var reader io.Reader = result.Body if strings.HasSuffix(path, ".gz") { gzipReader, err := gzip.NewReader(result.Body) if err != nil { return "", "", fmt.Errorf("failed to create gzip reader: %w", err) } defer gzipReader.Close() reader = gzipReader // Remove .gz extension from cache file path since we're decompressing newCacheFilePath := strings.TrimSuffix(cacheFilePath, ".gz") cacheFile.Close() os.Remove(cacheFilePath) cacheFile, err = os.Create(newCacheFilePath) if err != nil { return "", "", fmt.Errorf("failed to create decompressed cache file: %w", err) } defer cacheFile.Close() cacheFilePath = newCacheFilePath } // Copy file content to cache _, err = io.Copy(cacheFile, reader) if err != nil { return "", "", fmt.Errorf("failed to copy file to cache: %w", err) } // For files that were decompressed from .gz, we need to detect the original content type var contentType string if strings.HasSuffix(path, ".gz") { // Original path was gzipped, detect content type of decompressed content originalPath := strings.TrimSuffix(path, ".gz") ext := filepath.Ext(originalPath) // First try to detect by original file extension contentType, err = detectContentTypeFromExtension(ext) if err != nil || contentType == "application/octet-stream" { // Fallback: detect from decompressed content contentType, err = detectContentType(cacheFilePath) if err != nil { return "", "", fmt.Errorf("failed to detect content type: %w", err) } } } else { // Regular file content type detection contentType, err = detectContentType(cacheFilePath) if err != nil { return "", "", fmt.Errorf("failed to detect content type: %w", err) } } return cacheFilePath, contentType, nil } // detectContentType detects content type based on file extension and content func detectContentType(filePath string) (string, error) { // First try to detect by file extension ext := strings.ToLower(filepath.Ext(filePath)) // Common file extensions mapping switch ext { case ".txt": return "text/plain", nil case ".html", ".htm": return "text/html", nil case ".css": return "text/css", nil case ".js": return "application/javascript", nil case ".json": return "application/json", nil case ".xml": return "application/xml", nil case ".jpg", ".jpeg": return "image/jpeg", nil case ".png": return "image/png", nil case ".gif": return "image/gif", nil case ".webp": return "image/webp", nil case ".svg": return "image/svg+xml", nil case ".pdf": return "application/pdf", nil case ".doc": return "application/msword", nil case ".docx": return "application/vnd.openxmlformats-officedocument.wordprocessingml.document", nil case ".xls": return "application/vnd.ms-excel", nil case ".xlsx": return "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", nil case ".ppt": return "application/vnd.ms-powerpoint", nil case ".pptx": return "application/vnd.openxmlformats-officedocument.presentationml.presentation", nil case ".zip": return "application/zip", nil case ".tar": return "application/x-tar", nil case ".gz": return "application/gzip", nil case ".mp3": return "audio/mpeg", nil case ".wav": return "audio/wav", nil case ".m4a": return "audio/mp4", nil case ".ogg": return "audio/ogg", nil case ".mp4": return "video/mp4", nil case ".avi": return "video/x-msvideo", nil case ".mov": return "video/quicktime", nil case ".webm": return "video/webm", nil case ".md", ".mdx": return "text/markdown", nil case ".yao": return "application/yao", nil case ".csv": return "text/csv", nil } // Try to detect by MIME package if contentType := mime.TypeByExtension(ext); contentType != "" { return contentType, nil } // Fallback: detect by reading file content file, err := os.Open(filePath) if err != nil { return "application/octet-stream", nil // Default fallback } defer file.Close() // Read first 512 bytes for content detection buffer := make([]byte, 512) n, err := file.Read(buffer) if err != nil && err != io.EOF { return "application/octet-stream", nil } // Use http.DetectContentType to detect based on content contentType := http.DetectContentType(buffer[:n]) return contentType, nil } // detectContentTypeFromExtension detects content type based only on file extension func detectContentTypeFromExtension(ext string) (string, error) { ext = strings.ToLower(ext) // Common file extensions mapping switch ext { case ".txt": return "text/plain", nil case ".html", ".htm": return "text/html", nil case ".css": return "text/css", nil case ".js": return "application/javascript", nil case ".json": return "application/json", nil case ".xml": return "application/xml", nil case ".jpg", ".jpeg": return "image/jpeg", nil case ".png": return "image/png", nil case ".gif": return "image/gif", nil case ".webp": return "image/webp", nil case ".svg": return "image/svg+xml", nil case ".pdf": return "application/pdf", nil case ".doc": return "application/msword", nil case ".docx": return "application/vnd.openxmlformats-officedocument.wordprocessingml.document", nil case ".xls": return "application/vnd.ms-excel", nil case ".xlsx": return "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", nil case ".ppt": return "application/vnd.ms-powerpoint", nil case ".pptx": return "application/vnd.openxmlformats-officedocument.presentationml.presentation", nil case ".zip": return "application/zip", nil case ".tar": return "application/x-tar", nil case ".mp3": return "audio/mpeg", nil case ".wav": return "audio/wav", nil case ".m4a": return "audio/mp4", nil case ".ogg": return "audio/ogg", nil case ".mp4": return "video/mp4", nil case ".avi": return "video/x-msvideo", nil case ".mov": return "video/quicktime", nil case ".webm": return "video/webm", nil case ".md", ".mdx": return "text/markdown", nil case ".yao": return "application/yao", nil case ".csv": return "text/csv", nil } // Try to detect by MIME package if contentType := mime.TypeByExtension(ext); contentType != "" { return contentType, nil } // Return default if not found return "application/octet-stream", nil } ================================================ FILE: attachment/s3/storage_test.go ================================================ package s3 import ( "bytes" "context" "io" "os" "path/filepath" "strings" "testing" "time" "compress/gzip" "github.com/google/uuid" "github.com/stretchr/testify/assert" ) // generateTestFileName generates a unique test filename with the given prefix and extension func generateTestFileName(prefix, ext string) string { return prefix + "-" + uuid.New().String() + ext } func getS3Config() map[string]interface{} { return map[string]interface{}{ "endpoint": os.Getenv("S3_API"), "region": "auto", "key": os.Getenv("S3_ACCESS_KEY"), "secret": os.Getenv("S3_SECRET_KEY"), "bucket": os.Getenv("S3_BUCKET"), "prefix": "attachment-test", "expiration": 5 * time.Minute, "compression": true, } } func skipIfNoS3Config(t *testing.T) { if os.Getenv("S3_ACCESS_KEY") == "" || os.Getenv("S3_SECRET_KEY") == "" || os.Getenv("S3_BUCKET") == "" { t.Skip("S3 configuration not available (set S3_ACCESS_KEY, S3_SECRET_KEY, S3_BUCKET environment variables)") } } func TestS3Storage(t *testing.T) { t.Run("Create Storage", func(t *testing.T) { options := getS3Config() storage, err := New(options) if os.Getenv("S3_ACCESS_KEY") == "" || os.Getenv("S3_SECRET_KEY") == "" || os.Getenv("S3_BUCKET") == "" { // Should fail without credentials assert.Error(t, err) return } assert.NoError(t, err) assert.NotNil(t, storage) if storage != nil { assert.Equal(t, os.Getenv("S3_API"), storage.Endpoint) assert.Equal(t, "auto", storage.Region) assert.Equal(t, os.Getenv("S3_ACCESS_KEY"), storage.Key) assert.Equal(t, os.Getenv("S3_SECRET_KEY"), storage.Secret) assert.Equal(t, os.Getenv("S3_BUCKET"), storage.Bucket) assert.Equal(t, "attachment-test", storage.prefix) assert.Equal(t, 5*time.Minute, storage.Expiration) assert.True(t, storage.compression) } }) t.Run("Upload and Download Text File", func(t *testing.T) { skipIfNoS3Config(t) storage, err := New(getS3Config()) assert.NoError(t, err) content := []byte("test content") reader := bytes.NewReader(content) fileID := generateTestFileName("upload-test", ".txt") _, err = storage.Upload(context.Background(), fileID, reader, "text/plain") assert.NoError(t, err) assert.NotEmpty(t, fileID) // Get presigned URL url := storage.URL(context.Background(), fileID) assert.NotEmpty(t, url) assert.Contains(t, url, "X-Amz-Signature") assert.Contains(t, url, "X-Amz-Expires") // Download reader2, contentType, err := storage.Download(context.Background(), fileID) assert.NoError(t, err) assert.Contains(t, contentType, "text/plain") downloaded, err := io.ReadAll(reader2) assert.NoError(t, err) assert.Equal(t, content, downloaded) reader2.Close() // Clean up storage.Delete(context.Background(), fileID) }) t.Run("Chunked Upload", func(t *testing.T) { skipIfNoS3Config(t) storage, err := New(getS3Config()) assert.NoError(t, err) fileID := generateTestFileName("test-chunked", ".txt") content1 := []byte("chunk1") content2 := []byte("chunk2") // Upload chunks err = storage.UploadChunk(context.Background(), fileID, 0, bytes.NewReader(content1), "text/plain") assert.NoError(t, err) err = storage.UploadChunk(context.Background(), fileID, 1, bytes.NewReader(content2), "text/plain") assert.NoError(t, err) // Merge chunks err = storage.MergeChunks(context.Background(), fileID, 2) assert.NoError(t, err) // Download and verify reader, contentType, err := storage.Download(context.Background(), fileID) assert.NoError(t, err) assert.Equal(t, "text/plain", contentType) downloaded, err := io.ReadAll(reader) assert.NoError(t, err) assert.Equal(t, append(content1, content2...), downloaded) reader.Close() // Clean up storage.Delete(context.Background(), fileID) }) t.Run("File Operations", func(t *testing.T) { skipIfNoS3Config(t) storage, err := New(getS3Config()) assert.NoError(t, err) fileID := generateTestFileName("test-ops", ".txt") content := []byte("test content") // Upload file _, err = storage.Upload(context.Background(), fileID, bytes.NewReader(content), "text/plain") assert.NoError(t, err) // Check if file exists exists := storage.Exists(context.Background(), fileID) assert.True(t, exists) // Read file reader, err := storage.Reader(context.Background(), fileID) assert.NoError(t, err) defer reader.Close() data, err := io.ReadAll(reader) assert.NoError(t, err) assert.Equal(t, content, data) // Get file content directly directContent, err := storage.GetContent(context.Background(), fileID) assert.NoError(t, err) assert.Equal(t, content, directContent) // Delete file err = storage.Delete(context.Background(), fileID) assert.NoError(t, err) // Check if file no longer exists exists = storage.Exists(context.Background(), fileID) assert.False(t, exists) }) t.Run("Download Non-existent File", func(t *testing.T) { skipIfNoS3Config(t) storage, err := New(getS3Config()) assert.NoError(t, err) // Use UUID for non-existent file to avoid any potential conflicts nonExistentFileID := generateTestFileName("non-existent", ".txt") _, _, err = storage.Download(context.Background(), nonExistentFileID) assert.Error(t, err) }) t.Run("Invalid Configuration", func(t *testing.T) { // Test with missing required fields _, err := New(map[string]interface{}{ "endpoint": "https://s3.amazonaws.com", "region": "us-east-1", // Missing key and secret }) assert.Error(t, err) assert.Contains(t, err.Error(), "key and secret are required") // Test with missing bucket _, err = New(map[string]interface{}{ "endpoint": "https://s3.amazonaws.com", "region": "us-east-1", "key": "test-key", "secret": "test-secret", // Missing bucket }) assert.Error(t, err) assert.Contains(t, err.Error(), "bucket is required") }) t.Run("LocalPath", func(t *testing.T) { skipIfNoS3Config(t) // Create storage with custom cache directory tempCacheDir, err := os.MkdirTemp("", "s3_cache_test") assert.NoError(t, err) defer os.RemoveAll(tempCacheDir) config := getS3Config() config["cache_dir"] = tempCacheDir storage, err := New(config) assert.NoError(t, err) // Test different file types testFiles := []struct { name string content []byte contentType string expectedCT string }{ {"localpath-test.txt", []byte("Hello S3 World"), "text/plain", "text/plain"}, {"localpath-test.json", []byte(`{"s3": "test"}`), "application/json", "application/json"}, {"localpath-test.html", []byte("S3 Test"), "text/html", "text/html"}, {"localpath-test.csv", []byte("s3,test\nval1,val2"), "text/csv", "text/csv"}, {"localpath-test.md", []byte("# S3 Markdown"), "text/markdown", "text/markdown"}, {"localpath-test.yao", []byte("s3 yao content"), "application/yao", "application/yao"}, } for _, tf := range testFiles { // Upload file to S3 fileID := generateTestFileName("s3-localpath", "-"+tf.name) _, err = storage.Upload(context.Background(), fileID, bytes.NewReader(tf.content), tf.contentType) assert.NoError(t, err, "Failed to upload %s", tf.name) // Get local path - first call should download to cache localPath1, detectedCT1, err := storage.LocalPath(context.Background(), fileID) assert.NoError(t, err, "Failed to get local path for %s", tf.name) assert.NotEmpty(t, localPath1, "Local path should not be empty for %s", tf.name) assert.Equal(t, tf.expectedCT, detectedCT1, "Content type mismatch for %s", tf.name) // Verify the path is absolute assert.True(t, filepath.IsAbs(localPath1), "Path should be absolute for %s", tf.name) // Verify the file exists at the returned path _, err = os.Stat(localPath1) assert.NoError(t, err, "File should exist at local path for %s", tf.name) // Verify file content fileContent, err := os.ReadFile(localPath1) assert.NoError(t, err, "Failed to read file at local path for %s", tf.name) assert.Equal(t, tf.content, fileContent, "File content mismatch for %s", tf.name) // Get local path again - should use cached version localPath2, detectedCT2, err := storage.LocalPath(context.Background(), fileID) assert.NoError(t, err, "Failed to get cached local path for %s", tf.name) assert.Equal(t, localPath1, localPath2, "Cached path should be same as first call for %s", tf.name) assert.Equal(t, detectedCT1, detectedCT2, "Cached content type should be same as first call for %s", tf.name) // Clean up from S3 storage.Delete(context.Background(), fileID) } }) t.Run("LocalPath_GzippedFile", func(t *testing.T) { skipIfNoS3Config(t) // Create storage with custom cache directory tempCacheDir, err := os.MkdirTemp("", "s3_cache_gzip_test") assert.NoError(t, err) defer os.RemoveAll(tempCacheDir) config := getS3Config() config["cache_dir"] = tempCacheDir storage, err := New(config) assert.NoError(t, err) // Create gzipped content originalContent := []byte("This content will be gzipped and stored in S3") var gzipBuf bytes.Buffer gzipWriter := gzip.NewWriter(&gzipBuf) _, err = gzipWriter.Write(originalContent) assert.NoError(t, err) gzipWriter.Close() // Upload gzipped file fileID := generateTestFileName("gzipped", ".txt.gz") _, err = storage.Upload(context.Background(), fileID, bytes.NewReader(gzipBuf.Bytes()), "text/plain") assert.NoError(t, err) // Get local path - should decompress during download localPath, contentType, err := storage.LocalPath(context.Background(), fileID) assert.NoError(t, err) assert.NotEmpty(t, localPath) // Verify the file is decompressed in cache (path should not end with .gz) assert.False(t, strings.HasSuffix(localPath, ".gz"), "Cached file should be decompressed") // Verify content is decompressed cachedContent, err := os.ReadFile(localPath) assert.NoError(t, err) assert.Equal(t, originalContent, cachedContent, "Cached file should contain decompressed content") // Verify content type assert.Equal(t, "text/plain", contentType) // Clean up storage.Delete(context.Background(), fileID) }) t.Run("LocalPath_NonExistentFile", func(t *testing.T) { skipIfNoS3Config(t) storage, err := New(getS3Config()) assert.NoError(t, err) // Test with non-existent file nonExistentFileID := generateTestFileName("non-existent-localpath", ".txt") _, _, err = storage.LocalPath(context.Background(), nonExistentFileID) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to download file") }) t.Run("LocalPath_CustomCacheDir", func(t *testing.T) { skipIfNoS3Config(t) // Create custom cache directory customCacheDir, err := os.MkdirTemp("", "custom_s3_cache") assert.NoError(t, err) defer os.RemoveAll(customCacheDir) config := getS3Config() config["cache_dir"] = customCacheDir storage, err := New(config) assert.NoError(t, err) // Verify cache directory is set correctly assert.Equal(t, customCacheDir, storage.CacheDir) // Upload a test file content := []byte("Custom cache directory test") fileID := generateTestFileName("custom-cache", ".txt") _, err = storage.Upload(context.Background(), fileID, bytes.NewReader(content), "text/plain") assert.NoError(t, err) // Get local path localPath, contentType, err := storage.LocalPath(context.Background(), fileID) assert.NoError(t, err) assert.NotEmpty(t, localPath) assert.Equal(t, "text/plain", contentType) // Verify the file is cached in the custom directory assert.True(t, strings.HasPrefix(localPath, customCacheDir), "File should be cached in custom directory") // Clean up storage.Delete(context.Background(), fileID) }) } ================================================ FILE: attachment/types.go ================================================ package attachment import ( "context" "io" "mime/multipart" "github.com/yaoapp/gou/model" "github.com/yaoapp/gou/types" ) // FileManager defines the interface for file management operations. // This interface provides abstraction for file operations, making it easier to: // - Write unit tests with mock implementations // - Switch between different storage backends // - Maintain consistent API across different implementations // // Example usage: // // var fileManager FileManager = manager // Manager implements FileManager // file, err := fileManager.Upload(ctx, header, reader, options) // data, err := fileManager.Read(ctx, file.ID) type FileManager interface { // Upload uploads a file with optional chunked upload support Upload(ctx context.Context, fileheader *FileHeader, reader io.Reader, option UploadOption) (*File, error) // Download downloads a file by its ID Download(ctx context.Context, fileID string) (*FileResponse, error) // Read reads a file content as bytes Read(ctx context.Context, fileID string) ([]byte, error) // ReadBase64 reads a file content as base64 encoded string ReadBase64(ctx context.Context, fileID string) (string, error) // Info retrieves complete file information from database by file ID Info(ctx context.Context, fileID string) (*File, error) // List retrieves files from database with pagination and filtering List(ctx context.Context, option ListOption) (*ListResult, error) // Exists checks if a file exists Exists(ctx context.Context, fileID string) bool // Delete deletes a file Delete(ctx context.Context, fileID string) error // LocalPath gets the local path of the file LocalPath(ctx context.Context, fileID string) (string, string, error) // GetText retrieves the parsed text content for a file // By default returns preview (first 2000 chars), set fullContent=true for complete text GetText(ctx context.Context, fileID string, fullContent ...bool) (string, error) // SaveText saves the parsed text content for a file // Automatically saves both full content and preview SaveText(ctx context.Context, fileID string, text string) error } // File the file type File struct { ID string `json:"file_id"` UserPath string `json:"user_path"` // User-specified complete file path Path string `json:"path"` // Actual storage path Bytes int `json:"bytes"` CreatedAt int `json:"created_at"` Filename string `json:"filename"` ContentType string `json:"content_type"` Status string `json:"status"` // uploading, uploaded, indexing, indexed, upload_failed, index_failed // Permission fields Public bool `json:"public,omitempty"` // Whether this attachment is shared across all teams Share string `json:"share,omitempty"` // Attachment sharing scope: "private" or "team" YaoCreatedBy string `json:"-"` // User who created the attachment (not exposed in JSON) YaoTeamID string `json:"-"` // Team ID for team-based access control (not exposed in JSON) YaoTenantID string `json:"-"` // Tenant ID for multi-tenancy support (not exposed in JSON) } // FileResponse represents a file download response type FileResponse struct { Reader io.ReadCloser ContentType string Extension string } // Attachment represents a file attachment type Attachment struct { Name string `json:"name,omitempty"` URL string `json:"url,omitempty"` Description string `json:"description,omitempty"` Type string `json:"type,omitempty"` ContentType string `json:"content_type,omitempty"` Bytes int64 `json:"bytes,omitempty"` CreatedAt int64 `json:"created_at,omitempty"` FileID string `json:"file_id,omitempty"` UserPath string `json:"user_path,omitempty"` // User-specified complete file path Path string `json:"path,omitempty"` // Actual storage path Groups []string `json:"groups,omitempty"` Gzip bool `json:"gzip,omitempty"` // Gzip the file, Optional, default is false Public bool `json:"public,omitempty"` // Whether this attachment is shared across all teams in the platform Share string `json:"share,omitempty"` // Attachment sharing scope: "private" or "team" // Yao custom fields for permission control YaoCreatedBy string `json:"__yao_created_by,omitempty"` // User who created the attachment YaoUpdatedBy string `json:"__yao_updated_by,omitempty"` // User who last updated the attachment YaoTeamID string `json:"__yao_team_id,omitempty"` // Team ID for team-based access control YaoTenantID string `json:"__yao_tenant_id,omitempty"` // Tenant ID for multi-tenancy support } // Manager the manager struct type Manager struct { ManagerOption Name string // Manager name for identification storage Storage maxsize int64 chunsize int64 allowedTypes allowedType } // Storage the storage interface type Storage interface { Upload(ctx context.Context, path string, reader io.Reader, contentType string) (string, error) UploadChunk(ctx context.Context, path string, chunkIndex int, reader io.Reader, contentType string) error MergeChunks(ctx context.Context, path string, totalChunks int) error Download(ctx context.Context, path string) (io.ReadCloser, string, error) Reader(ctx context.Context, path string) (io.ReadCloser, error) GetContent(ctx context.Context, path string) ([]byte, error) URL(ctx context.Context, path string) string Exists(ctx context.Context, path string) bool Delete(ctx context.Context, path string) error LocalPath(ctx context.Context, path string) (string, string, error) // Returns absolute path and content type } // ManagerOption the manager option type ManagerOption struct { types.MetaInfo MaxSize string `json:"max_size,omitempty" yaml:"max_size,omitempty"` // Max size of the file, Optional, default is 20M ChunkSize string `json:"chunk_size,omitempty" yaml:"chunk_size,omitempty"` // Chunk size of the file, Optional, default is 2M AllowedTypes []string `json:"allowed_types,omitempty" yaml:"allowed_types,omitempty"` // Allowed types of the file, Optional, default is all Gzip bool `json:"gzip,omitempty" yaml:"gzip,omitempty"` // Gzip the file, Optional, default is false Driver string `json:"driver,omitempty" yaml:"driver,omitempty"` // Driver, Optional, default is local Options map[string]interface{} `json:"options,omitempty" yaml:"options,omitempty"` // Options, Optional } type allowedType struct { mapping map[string]bool wildcards []string // Wildcard patterns for file types (e.g., "image/*", "text/*") } // UploadOption the upload option type UploadOption struct { CompressImage bool `json:"compress_image,omitempty" form:"compress_image"` // Compress the file, Optional, default is true CompressSize int `json:"compress_size,omitempty" form:"compress_size"` // Compress the file size, Optional, default is 1920, if compress_image is true, the file size will be compressed to the compress_size Gzip bool `json:"gzip,omitempty" form:"gzip"` // Gzip the file, Optional, default is false OriginalFilename string `json:"original_filename,omitempty" form:"original_filename"` // Original filename sent separately to avoid encoding issues Groups []string `json:"groups,omitempty" form:"groups"` // Groups, Optional, default is empty, Multi-level groups like ["user", "user123", "chat", "chat456"] Public bool `json:"public,omitempty" form:"public"` // Whether this attachment is shared across all teams in the platform Share string `json:"share,omitempty" form:"share"` // Attachment sharing scope: "private" or "team" // Yao custom fields for permission control YaoCreatedBy string `json:"__yao_created_by,omitempty" form:"__yao_created_by"` // User who created the attachment YaoUpdatedBy string `json:"__yao_updated_by,omitempty" form:"__yao_updated_by"` // User who last updated the attachment YaoTeamID string `json:"__yao_team_id,omitempty" form:"__yao_team_id"` // Team ID for team-based access control YaoTenantID string `json:"__yao_tenant_id,omitempty" form:"__yao_tenant_id"` // Tenant ID for multi-tenancy support } // ListOption defines options for listing files type ListOption struct { Page int `json:"page,omitempty"` // Page number (1-based), default is 1 PageSize int `json:"page_size,omitempty"` // Page size, default is 20 Filters map[string]interface{} `json:"filters,omitempty"` // Filter conditions, e.g., {"status": "uploaded", "content_type": "image/*"} Wheres []model.QueryWhere `json:"wheres,omitempty"` // Advanced where clauses for permission filtering OrderBy string `json:"order_by,omitempty"` // Order by field, e.g., "created_at desc", "name asc" Select []string `json:"select,omitempty"` // Fields to select, empty means select all } // ListResult contains the paginated list result type ListResult struct { Files []*File `json:"files"` // List of files Total int64 `json:"total"` // Total count Page int `json:"page"` // Current page PageSize int `json:"page_size"` // Page size TotalPages int `json:"total_pages"` // Total pages } // FileHeader the file header type FileHeader struct { *multipart.FileHeader } ================================================ FILE: audit/README.md ================================================ # Audit Log ================================================ FILE: bin/yao-dev ================================================ #!/bin/bash # yao-dev - Run yao from source for real-time debugging # Uses go -C to compile from source while keeping current directory as app root # Get the real path of this script (resolve symlinks) SCRIPT_PATH="${BASH_SOURCE[0]}" while [ -L "$SCRIPT_PATH" ]; do SCRIPT_DIR="$(cd "$(dirname "$SCRIPT_PATH")" && pwd)" SCRIPT_PATH="$(readlink "$SCRIPT_PATH")" # If the link is relative, resolve it relative to the directory [[ "$SCRIPT_PATH" != /* ]] && SCRIPT_PATH="$SCRIPT_DIR/$SCRIPT_PATH" done SCRIPT_DIR="$(cd "$(dirname "$SCRIPT_PATH")" && pwd)" # YAO source directory is the parent of bin/ YAO_SOURCE_DIR="$(dirname "$SCRIPT_DIR")" if [ ! -d "$YAO_SOURCE_DIR" ]; then echo "Error: yao source directory not found at $YAO_SOURCE_DIR" exit 1 fi if [ ! -f "$YAO_SOURCE_DIR/go.mod" ]; then echo "Error: go.mod not found in $YAO_SOURCE_DIR, not a valid yao source directory" exit 1 fi # Find app root by looking for app.yao in current directory or parent directories find_app_root() { local dir="$(pwd)" while [ "$dir" != "/" ]; do if [ -f "$dir/app.yao" ] || [ -f "$dir/app.json" ] || [ -f "$dir/app.jsonc" ]; then echo "$dir" return 0 fi dir="$(dirname "$dir")" done # If not found, use current directory as fallback pwd } # Set YAO_ROOT to app root directory export YAO_ROOT="$(find_app_root)" # Convert relative file paths in arguments to absolute paths # This is needed because go -C changes the working directory ARGS=() for arg in "$@"; do # Check if it looks like a relative path if [[ "$arg" == ./* ]] || [[ "$arg" == ../* ]]; then # Always convert to absolute path based on current directory ARGS+=("$(pwd)/${arg#./}") else ARGS+=("$arg") fi done # Use go -C to run from source directory while staying in current directory exec go -C "$YAO_SOURCE_DIR" run . "${ARGS[@]}" ================================================ FILE: cert/cert.go ================================================ package cert import ( "path/filepath" "github.com/yaoapp/gou/application" "github.com/yaoapp/gou/ssl" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/share" ) // Load 加载API func Load(cfg config.Config) error { // Ignore if the certs directory does not exist exists, err := application.App.Exists("certs") if err != nil { return err } if !exists { return nil } exts := []string{"*.pem", "*.key", "*.pub"} return application.App.Walk("certs", func(root, file string, isdir bool) error { if isdir { return nil } _, err := ssl.Load(file, share.ID(root, file)+filepath.Ext(file)) return err }, exts...) } ================================================ FILE: cert/cert_test.go ================================================ package cert import ( "testing" "github.com/stretchr/testify/assert" "github.com/yaoapp/gou/process" "github.com/yaoapp/gou/ssl" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/test" ) func TestLoad(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() Load(config.Conf) check(t) } func check(t *testing.T) { ids := map[string]bool{} for id := range ssl.Certificates { ids[id] = true } assert.True(t, ids["cert.pem"]) assert.True(t, ids["cert.key"]) assert.True(t, ids["cert.pub"]) } func TestProcessSign(t *testing.T) { Load(config.Conf) args := []interface{}{"hello world", "cert.key", "SHA256"} signature, err := process.New("ssl.Sign", args...).Exec() if err != nil { t.Fatal(err) } assert.Equal(t, "EDHf3C9TXEk7y8LzIk5czLefXZyGxcMDVMcbNuBBegDkTqnPsRQnhFtNOgCdox8lI3MzLatwjoljoMY4Qk+sHGd5mAHMpiREa1gRFSVYpA2xvXZ3+KsfOHAdICQrfUdy59QaJGo6iGPNGG8PQOXHPTVNn6LMfryat9+f4l21DPAZiT0RyCUgFZE3/Qv8Z/6J4AsIXMSKZD6BGPPHUxGe7UBrXZvcR5dX25EiNjuH2OO38YJnDiTRVw14UI5fk/mQrwRdezj5tSKFCyHt912BZExXtkHISiYFNTZ/2RhOup5Xx6o3GvrEOdshrnN80Lwu1Aaju+lnZp13hDz4P6hU7w==", signature) } func TestProcessVerify(t *testing.T) { Load(config.Conf) signature := "EDHf3C9TXEk7y8LzIk5czLefXZyGxcMDVMcbNuBBegDkTqnPsRQnhFtNOgCdox8lI3MzLatwjoljoMY4Qk+sHGd5mAHMpiREa1gRFSVYpA2xvXZ3+KsfOHAdICQrfUdy59QaJGo6iGPNGG8PQOXHPTVNn6LMfryat9+f4l21DPAZiT0RyCUgFZE3/Qv8Z/6J4AsIXMSKZD6BGPPHUxGe7UBrXZvcR5dX25EiNjuH2OO38YJnDiTRVw14UI5fk/mQrwRdezj5tSKFCyHt912BZExXtkHISiYFNTZ/2RhOup5Xx6o3GvrEOdshrnN80Lwu1Aaju+lnZp13hDz4P6hU7w==" args := []interface{}{"hello world", signature, "cert.pem", "SHA256"} res, err := process.New("ssl.Verify", args...).Exec() if err != nil { t.Fatal(err) } assert.True(t, res.(bool)) } ================================================ FILE: cmd/README.md ================================================ # Yao CLI Commands The Yao CLI provides a set of commands for managing, running, and testing Yao applications. ## Installation ```bash # Build from source go build -o yao . # Or install via go install go install github.com/yaoapp/yao@latest ``` ## Global Flags | Flag | Short | Description | | -------- | ----- | ------------------------------- | | `--app` | `-a` | Application directory path | | `--file` | `-f` | Application package file (.yaz) | | `--key` | `-k` | Application license key | ## Environment Variables | Variable | Description | | ---------- | -------------------------------------------- | | `YAO_ROOT` | Application root directory | | `YAO_LANG` | Language setting (e.g., `zh-CN` for Chinese) | ## Commands ### `yao start` Start the Yao application engine. ```bash # Start in current directory yao start # Start with specific app directory yao start -a /path/to/app # Start in debug mode yao start --debug ``` **Flags:** | Flag | Description | | -------------------- | ----------------------------- | | `--debug` | Enable development/debug mode | | `--disable-watching` | Disable file watching | --- ### `yao run` Execute a Yao process. ```bash # Run a process yao run models.user.Find 1 # Run with JSON arguments yao run models.user.Create '::[{"name":"John","age":30}]' # Run in silent mode (JSON output only) yao run -s models.user.Find 1 ``` **Flags:** | Flag | Short | Description | | ---------- | ----- | ---------------------------------------- | | `--silent` | `-s` | Silent mode - output result as JSON only | **Argument Syntax:** - Regular arguments: `arg1 arg2` - JSON arguments: `'::[{"key":"value"}]'` (prefix with `::`) - Escaped `::`: `'\::literal'` --- ### `yao migrate` Update database schema based on model definitions. ```bash # Migrate all models yao migrate # Migrate specific model yao migrate -n user # Force migrate in production mode yao migrate --force # Reset (drop and recreate) tables yao migrate --reset ``` **Flags:** | Flag | Short | Description | | --------- | ----- | -------------------------------- | | `--name` | `-n` | Specific model name to migrate | | `--force` | | Force migrate in production mode | | `--reset` | | Drop tables before migration | --- ### `yao inspect` Display application configuration. ```bash yao inspect ``` --- ### `yao version` Show Yao version information. ```bash # Show version yao version # Show all version details yao version --all ``` **Flags:** | Flag | Description | | ------- | -------------------------------------------------------------------- | | `--all` | Print all version information (Go version, commit, build time, etc.) | --- ## Agent Commands Commands for testing and managing AI agents. ### `yao agent test` Test an agent with input cases from a JSONL file, direct message, or script tests. ```bash # Test with direct message (development mode) yao agent test -i "Extract keywords from: AI and machine learning" -n workers.system.keyword # Test with JSONL file yao agent test -i tests/inputs.jsonl # Test with custom output file yao agent test -i tests/inputs.jsonl -o report.html # Test with specific connector yao agent test -i tests/inputs.jsonl -c openai.gpt4 # Stability testing (multiple runs) yao agent test -i tests/inputs.jsonl --runs 5 # Parallel execution yao agent test -i tests/inputs.jsonl --parallel 4 # Verbose output yao agent test -i tests/inputs.jsonl -v # Script tests (test agent handler scripts) yao agent test -i scripts.expense.setup -v # Script tests with test filtering yao agent test -i scripts.expense.setup --run "TestSystemReady" # Script tests with custom context yao agent test -i scripts.expense.setup --ctx tests/context.json -v ``` **Flags:** | Flag | Short | Description | | ------------- | ----- | ---------------------------------------------------------------- | | `--input` | `-i` | Input: JSONL file path, message, or script ID (required) | | `--output` | `-o` | Output file path (default: `output-{timestamp}.jsonl`) | | `--name` | `-n` | Agent ID (default: auto-detect from path) | | `--connector` | `-c` | Override default connector | | `--user` | `-u` | Test user ID (default: `test-user`) | | `--team` | `-t` | Test team ID (default: `test-team`) | | `--ctx` | | Path to context JSON file for custom authorization | | `--reporter` | `-r` | Reporter agent ID for custom report generation | | `--runs` | | Number of runs per test case for stability analysis (default: 1) | | `--run` | | Regex pattern to filter which tests to run | | `--timeout` | | Timeout per test case (default: `5m`) | | `--parallel` | | Number of parallel test cases (default: 1) | | `--verbose` | `-v` | Enable verbose output | | `--fail-fast` | | Stop on first failure | | `--app` | `-a` | Application directory | | `--env` | `-e` | Environment file | **Input Modes:** 1. **Direct Message Mode**: For quick development/debugging ```bash yao agent test -i "Hello world" -n my.agent ``` - Outputs result directly to stdout - No report file generated - Ideal for iterative development 2. **File Mode**: For comprehensive testing ```bash yao agent test -i tests/inputs.jsonl ``` - Reads test cases from JSONL file - Generates detailed report - Supports stability analysis 3. **Script Test Mode**: For testing agent handler scripts ```bash yao agent test -i scripts.expense.setup -v ``` - Tests TypeScript/JavaScript handler scripts (hooks, tools, setup functions) - Input format: `scripts..` (e.g., `scripts.expense.setup`) - Automatically discovers and runs all `Test*` functions - Uses Go-like testing interface with assertions **Script Test Function Signature:** ```typescript // assistants/expense/src/setup_test.ts import { SystemReady } from "./setup"; export function TestSystemReady(t: testing.T, ctx: agent.Context) { const result = SystemReady(ctx); t.assert.True(result.success, "SystemReady should succeed"); t.assert.Equal(result.status, "ready", "Status should be ready"); } ``` **Context JSON Format (for `--ctx` flag):** ```json { "authorized": { "sub": "user-12345", "client_id": "my-app", "user_id": "admin", "team_id": "team-001", "tenant_id": "acme-corp", "constraints": { "owner_only": true, "team_only": false, "extra": { "department": "engineering" } } }, "metadata": { "request_id": "req-123" }, "client": { "type": "web", "ip": "192.168.1.100" }, "locale": "zh-cn" } ``` **JSONL Input Format:** ```jsonl {"id": "T001", "input": "Simple text input"} {"id": "T002", "input": {"role": "user", "content": "Message with role"}} {"id": "T003", "input": [{"role": "system", "content": "System prompt"}, {"role": "user", "content": "User message"}]} {"id": "T004", "input": "Test with timeout", "timeout": "30s"} {"id": "T005", "input": "Skip this test", "skip": true} {"id": "T006", "input": "Test with specific user", "user": "alice", "team": "engineering"} ``` **Output Formats:** | Extension | Format | Description | | --------- | -------- | -------------------------- | | `.jsonl` | JSONL | Streaming format (default) | | `.json` | JSON | Complete structured report | | `.md` | Markdown | Human-readable with tables | | `.html` | HTML | Interactive web report | **Agent Resolution:** The agent is resolved in the following priority order: 1. Explicit `-n` flag: `yao agent test -i msg -n my.agent` 2. `YAO_ROOT` environment variable 3. Auto-detect from input file path (traverses up to find `package.yao`) 4. Auto-detect from current working directory --- ## SUI Commands SUI (Serverless UI) template engine commands. ### `yao sui watch` Auto-build templates when files change. ```bash yao sui watch [data] # Example yao sui watch default index '::{}' ``` ### `yao sui build` Build a template. ```bash yao sui build [data] # Example yao sui build default index '::{}' # Debug mode yao sui build default index '::{}' --debug ``` ### `yao sui trans` Translate template content. ```bash yao sui trans # With specific locales yao sui trans default index -l "en-US,zh-CN,ja-JP" ``` **SUI Flags:** | Flag | Short | Description | | ----------- | ----- | ----------------------------------------- | | `--data` | `-d` | Session data as JSON (prefix with `::`) | | `--debug` | `-D` | Enable debug mode | | `--locales` | `-l` | Locales for translation (comma-separated) | --- ## Examples ### Development Workflow ```bash # Start development server yao start --debug # Run a process yao run scripts.test.Hello "World" # Test an agent interactively yao agent test -i "What is the weather today?" -n assistant.weather # Watch and auto-build templates yao sui watch default home ``` ### Testing Workflow ```bash # Run comprehensive agent tests yao agent test -i tests/inputs.jsonl -o report.html -v # Run script tests for agent handlers yao agent test -i scripts.expense.setup -v # Run specific script tests with filtering yao agent test -i scripts.expense.setup --run "TestSystem.*" -v # Run script tests with custom context yao agent test -i scripts.expense.setup --ctx tests/context.json -v # Stability analysis (run each test 10 times) yao agent test -i tests/inputs.jsonl --runs 10 -o stability-report.json # Parallel testing with timeout yao agent test -i tests/inputs.jsonl --parallel 4 --timeout 2m # CI/CD integration yao agent test -i tests/inputs.jsonl -o results.jsonl && echo "Tests passed" ``` ### Database Migration ```bash # Migrate all models yao migrate # Migrate specific model with reset yao migrate -n user --reset --force ``` --- ## Exit Codes | Code | Description | | ---- | --------------------- | | 0 | Success | | 1 | Error or test failure | --- ## Directory Structure ``` myapp/ ├── app.yao # Application configuration ├── .env # Environment variables ├── models/ # Data models ├── apis/ # API definitions ├── flows/ # Business flows ├── scripts/ # JavaScript/TypeScript scripts ├── assistants/ # AI agents │ └── my-agent/ │ ├── package.yao # Agent configuration │ ├── prompts.yml # Agent prompts │ └── tests/ │ └── inputs.jsonl # Test cases └── public/ # Static files ``` --- ## See Also - [Yao Documentation](https://yaoapps.com/docs) - [Agent Test Design](../agent/test/DESIGN.md) - [SUI Documentation](https://yaoapps.com/docs/sui) ================================================ FILE: cmd/agent/add.go ================================================ package agent import ( "fmt" "os" "github.com/spf13/cobra" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/registry" agentmgr "github.com/yaoapp/yao/registry/manager/agent" ) var agentAddForce bool // AddCmd implements "yao agent add @scope/name" var AddCmd = &cobra.Command{ Use: "add [package]", Short: L("Install an assistant package from the registry"), Long: L("Install an assistant package from the registry. Example: yao agent add @yao/keeper"), Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { Boot() pkgID := args[0] version, _ := cmd.Flags().GetString("version") client := registry.New(config.Conf.Registry, registry.WithAuth( os.Getenv("YAO_REGISTRY_USER"), os.Getenv("YAO_REGISTRY_PASS"), ), ) mgr := agentmgr.New(client, config.Conf.Root, nil) if err := mgr.Add(pkgID, agentmgr.AddOptions{ Version: version, Force: agentAddForce, }); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } }, } func init() { AddCmd.Flags().StringP("version", "v", "latest", L("Package version or dist-tag")) AddCmd.Flags().BoolVarP(&agentAddForce, "force", "", false, L("Force reinstall")) AddCmd.PersistentFlags().StringVarP(&appPath, "app", "a", "", L("Application directory")) AddCmd.PersistentFlags().StringVarP(&envFile, "env", "e", "", L("Environment file")) } ================================================ FILE: cmd/agent/agent.go ================================================ package agent import ( "os" "path/filepath" "github.com/yaoapp/kun/exception" "github.com/yaoapp/yao/config" ) var appPath string var envFile string var langs = map[string]string{ "Test an agent with input cases": "使用测试用例测试智能体", "Test an agent with input cases from JSONL file or direct message": "使用 JSONL 文件或直接消息测试智能体", "Application directory": "应用目录", "Environment file": "环境变量文件", "Input: JSONL file path or message (required)": "输入: JSONL 文件路径或消息 (必需)", "Path to output file (default: output-{timestamp}.jsonl)": "输出文件路径 (默认: output-{timestamp}.jsonl)", "Explicit agent ID (default: auto-detect)": "指定智能体 ID (默认: 自动检测)", "Override connector": "覆盖连接器", "Test user ID (default: test-user)": "测试用户 ID (默认: test-user)", "Test team ID (default: test-team)": "测试团队 ID (默认: test-team)", "Path to context JSON file for custom authorization": "自定义认证信息的 JSON 文件路径", "Reporter agent ID for custom report": "自定义报告生成器智能体 ID", "Number of runs for stability analysis": "稳定性分析的运行次数", "Regex pattern to filter which tests to run": "用于过滤测试的正则表达式", "Default timeout per test case": "每个测试用例的默认超时时间", "Number of parallel test cases": "并行测试用例数", "Verbose output": "详细输出", "Stop on first failure": "遇到第一个失败时停止", "Error: input is required (-i flag)": "错误: 需要输入 (-i 参数)", "Error: failed to get current directory": "错误: 获取当前目录失败", "Error: agent (-n) is required when using direct message input and not in an agent directory": "错误: 使用直接消息输入且不在智能体目录时需要指定 -n 参数", "Hint: Make sure you're in a Yao application directory or specify --app flag": "提示: 确保在 Yao 应用目录中或使用 --app 参数指定", "Error: invalid timeout format": "错误: 无效的超时格式", // Registry commands "Install an assistant package from the registry": "从注册中心安装助手包", "Update an installed assistant package": "更新已安装的助手包", "Push an assistant package to the registry": "推送助手包到注册中心", "Fork an assistant to a local scope": "Fork 一个助手到本地范围", "Package version or dist-tag": "包版本或 dist-tag", "Force reinstall": "强制重新安装", "Package version (required)": "包版本 (必填)", "Target version or dist-tag": "目标版本或 dist-tag", // Extract command "Extract test results to individual files for review": "提取测试结果到单独的文件供审查", "Extract test results from output JSONL file to individual Markdown or JSON files": "从输出 JSONL 文件中提取测试结果到单独的 Markdown 或 JSON 文件", "Output directory (default: same as input file)": "输出目录 (默认: 与输入文件相同)", "Output format: markdown, json": "输出格式: markdown, json", } // L Language switch func L(words string) string { var lang = os.Getenv("YAO_LANG") if lang == "" { return words } if trans, has := langs[words]; has { return trans } return words } // Boot sets the configuration func Boot() { // Use root from Init() unless appPath is explicitly specified root := config.Conf.Root if appPath != "" { r, err := filepath.Abs(appPath) if err != nil { exception.New("Root error %s", 500, err.Error()).Throw() } root = r } // Load .env file, preserving the correct root if envFile != "" { config.Conf = config.LoadFromWithRoot(envFile, root) } else { config.Conf = config.LoadFromWithRoot(filepath.Join(root, ".env"), root) } config.ApplyMode() } ================================================ FILE: cmd/agent/extract.go ================================================ package agent import ( "fmt" "os" "path/filepath" "github.com/fatih/color" "github.com/spf13/cobra" "github.com/yaoapp/yao/agent/test" ) // Extract command flags var ( extractOutput string extractFormat string ) // ExtractCmd is the agent extract command var ExtractCmd = &cobra.Command{ Use: "extract ", Short: L("Extract test results to individual files for review"), Long: L("Extract test results from output JSONL file to individual Markdown or JSON files"), Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { inputFile := args[0] // Resolve absolute path absPath, err := filepath.Abs(inputFile) if err != nil { color.Red("Error: %s\n", err.Error()) os.Exit(1) } // Check if file exists if _, err := os.Stat(absPath); os.IsNotExist(err) { color.Red("Error: file not found: %s\n", absPath) os.Exit(1) } // Build extract options opts := &test.ExtractOptions{ InputFile: absPath, OutputDir: extractOutput, Format: extractFormat, } // Create extractor and run extractor := test.NewExtractor(opts) files, err := extractor.Extract() if err != nil { color.Red("Error: %s\n", err.Error()) os.Exit(1) } // Print results fmt.Println() color.New(color.FgGreen, color.Bold).Println("═══════════════════════════════════════════════════════════════") color.New(color.FgGreen, color.Bold).Println(" Extract Complete") color.New(color.FgGreen, color.Bold).Println("═══════════════════════════════════════════════════════════════") fmt.Println() for _, file := range files { color.New(color.FgGreen).Printf("✓ ") fmt.Printf("Written: %s\n", filepath.Base(file)) } fmt.Println() color.New(color.FgWhite).Printf(" Total: ") color.New(color.FgCyan).Printf("%d files\n", len(files)) if extractOutput != "" { color.New(color.FgWhite).Printf(" Output: ") color.New(color.FgCyan).Printf("%s\n", extractOutput) } else { color.New(color.FgWhite).Printf(" Output: ") color.New(color.FgCyan).Printf("%s\n", filepath.Dir(absPath)) } fmt.Println() }, } func init() { // Extract command flags ExtractCmd.Flags().StringVarP(&extractOutput, "output", "o", "", L("Output directory (default: same as input file)")) ExtractCmd.Flags().StringVar(&extractFormat, "format", "markdown", L("Output format: markdown, json")) } ================================================ FILE: cmd/agent/fork.go ================================================ package agent import ( "fmt" "os" "github.com/spf13/cobra" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/registry" agentmgr "github.com/yaoapp/yao/registry/manager/agent" ) // ForkCmd implements "yao agent fork @scope/name [@target-scope]" var ForkCmd = &cobra.Command{ Use: "fork [package] [target-scope]", Short: L("Fork an assistant to a local scope"), Long: L("Fork an assistant for local modification. Example: yao agent fork @yao/keeper"), Args: cobra.RangeArgs(1, 2), Run: func(cmd *cobra.Command, args []string) { Boot() pkgID := args[0] var targetScope string if len(args) > 1 { targetScope = args[1] } client := registry.New(config.Conf.Registry, registry.WithAuth( os.Getenv("YAO_REGISTRY_USER"), os.Getenv("YAO_REGISTRY_PASS"), ), ) mgr := agentmgr.New(client, config.Conf.Root, nil) if err := mgr.Fork(pkgID, agentmgr.ForkOptions{ TargetScope: targetScope, }); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } }, } func init() { ForkCmd.PersistentFlags().StringVarP(&appPath, "app", "a", "", L("Application directory")) ForkCmd.PersistentFlags().StringVarP(&envFile, "env", "e", "", L("Environment file")) } ================================================ FILE: cmd/agent/push.go ================================================ package agent import ( "fmt" "os" "github.com/spf13/cobra" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/registry" agentmgr "github.com/yaoapp/yao/registry/manager/agent" ) // PushCmd implements "yao agent push scope.name --version x.y.z" var PushCmd = &cobra.Command{ Use: "push [yao-id]", Short: L("Push an assistant package to the registry"), Long: L("Package and push an assistant to the registry. Example: yao agent push max.keeper --version 1.0.0"), Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { Boot() yaoID := args[0] version, _ := cmd.Flags().GetString("version") force, _ := cmd.Flags().GetBool("force") client := registry.New(config.Conf.Registry, registry.WithAuth( os.Getenv("YAO_REGISTRY_USER"), os.Getenv("YAO_REGISTRY_PASS"), ), ) mgr := agentmgr.New(client, config.Conf.Root, nil) if err := mgr.Push(yaoID, agentmgr.PushOptions{ Version: version, Force: force, }); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } }, } func init() { PushCmd.Flags().StringP("version", "v", "", L("Package version (required)")) PushCmd.Flags().Bool("force", false, L("Overwrite existing version")) PushCmd.PersistentFlags().StringVarP(&appPath, "app", "a", "", L("Application directory")) PushCmd.PersistentFlags().StringVarP(&envFile, "env", "e", "", L("Environment file")) } ================================================ FILE: cmd/agent/test.go ================================================ package agent import ( "fmt" "os" "path/filepath" "time" "github.com/fatih/color" "github.com/spf13/cobra" "github.com/yaoapp/gou/plugin" "github.com/yaoapp/yao/agent" "github.com/yaoapp/yao/agent/test" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/engine" "github.com/yaoapp/yao/kb" "github.com/yaoapp/yao/share" ) // Test command flags var ( testInput string testOutput string testAgent string testConnector string testUser string testTeam string testContext string // --ctx flag for custom context JSON file testReporter string testRuns int testRun string // --run flag for test filtering (regex pattern) testTimeout string testParallel int testVerbose bool testFailFast bool testBefore string // --before flag for global BeforeAll hook testAfter string // --after flag for global AfterAll hook testDryRun bool // --dry-run flag for generating tests without running testSimulator string // --simulator flag for default simulator agent in dynamic mode ) // TestCmd is the agent test command var TestCmd = &cobra.Command{ Use: "test", Short: L("Test an agent with input cases"), Long: L("Test an agent with input cases from JSONL file or direct message"), Run: func(cmd *cobra.Command, args []string) { defer share.SessionStop() defer plugin.KillAll() // Validate input if testInput == "" { color.Red(L("Error: input is required (-i flag)") + "\n") os.Exit(1) } // Detect input mode inputMode := test.DetectInputMode(testInput) // For message mode, agent must be specified or resolvable from cwd if inputMode == test.InputModeMessage && testAgent == "" { // Try to find app root from current directory cwd, err := os.Getwd() if err != nil { color.Red(L("Error: failed to get current directory")+": %s\n", err.Error()) os.Exit(1) } // Try to find package.yao from cwd resolver := test.NewResolver() _, err = resolver.ResolveFromPath(cwd) if err != nil { color.Red(L("Error: agent (-n) is required when using direct message input and not in an agent directory") + "\n") os.Exit(1) } } // Find app root directory // Priority: -a flag > YAO_ROOT env > auto-detect from path var err error if appPath == "" { // Check YAO_ROOT environment variable if yaoRoot := os.Getenv("YAO_ROOT"); yaoRoot != "" { appPath = yaoRoot } } if appPath == "" { // Auto-detect from path if inputMode == test.InputModeFile { // For file mode, find app root from input file path appPath, err = findAppRoot(testInput) } else { // For message mode, find app root from current directory cwd, _ := os.Getwd() appPath, err = findAppRoot(cwd) } if err != nil { color.Red("Error: %s\n", err.Error()) color.Yellow(L("Hint: Make sure you're in a Yao application directory or specify --app flag") + "\n") os.Exit(1) } } // Boot the application Boot() // Set Runtime Mode config.Conf.Runtime.Mode = "standard" cfg := config.Conf cfg.Session.IsCLI = true // Load engine _, err = engine.Load(cfg, engine.LoadOption{Action: "agent-test"}) if err != nil { color.Red("Engine: %s\n", err.Error()) os.Exit(1) } // Load KB (required for agent KB features) _, err = kb.Load(cfg) if err != nil { color.Red("KB: %s\n", err.Error()) os.Exit(1) } // Load agent err = agent.Load(cfg) if err != nil { color.Red("Agent: %s\n", err.Error()) os.Exit(1) } // Parse timeout timeout := 5 * time.Minute if testTimeout != "" { d, err := time.ParseDuration(testTimeout) if err != nil { color.Red(L("Error: invalid timeout format")+": %s\n", testTimeout) os.Exit(1) } timeout = d } // Build test options opts := &test.Options{ Input: testInput, InputMode: inputMode, OutputFile: testOutput, AgentID: testAgent, Connector: testConnector, UserID: testUser, TeamID: testTeam, ContextFile: testContext, ReporterID: testReporter, Runs: testRuns, Run: testRun, Timeout: timeout, Parallel: testParallel, Verbose: testVerbose, FailFast: testFailFast, BeforeAll: testBefore, AfterAll: testAfter, DryRun: testDryRun, Simulator: testSimulator, } // Merge with defaults opts = test.MergeOptions(opts, test.DefaultOptions()) // Resolve output path (only for file mode, direct message mode outputs to stdout) if inputMode == test.InputModeFile { opts.OutputFile = test.ResolveOutputPath(opts) } // Run tests runner := test.NewRunner(opts) report, err := runner.Run() if err != nil { color.Red("Error: %s\n", err.Error()) os.Exit(1) } // Exit with appropriate code if report.HasFailures() { os.Exit(1) } }, } // findAppRoot finds the Yao application root directory by looking for app.yao // It traverses up from the given path until it finds app.yao or reaches the filesystem root func findAppRoot(startPath string) (string, error) { // Get absolute path absPath, err := filepath.Abs(startPath) if err != nil { return "", fmt.Errorf("failed to get absolute path: %w", err) } // If it's a file, start from its directory info, err := os.Stat(absPath) if err != nil { return "", fmt.Errorf("path not found: %s", absPath) } var dir string if info.IsDir() { dir = absPath } else { dir = filepath.Dir(absPath) } // Traverse up to find app.yao for { // Check for app.yao, app.json, or app.jsonc for _, appFile := range []string{"app.yao", "app.json", "app.jsonc"} { appFilePath := filepath.Join(dir, appFile) if _, err := os.Stat(appFilePath); err == nil { return dir, nil } } // Move to parent directory parent := filepath.Dir(dir) if parent == dir { // Reached root, no app.yao found break } dir = parent } return "", fmt.Errorf("no app.yao found in path hierarchy of %s", startPath) } func init() { // Test command flags TestCmd.Flags().StringVarP(&appPath, "app", "a", "", L("Application directory")) TestCmd.Flags().StringVarP(&envFile, "env", "e", "", L("Environment file")) TestCmd.Flags().StringVarP(&testInput, "input", "i", "", L("Input: JSONL file path or message (required)")) TestCmd.Flags().StringVarP(&testOutput, "output", "o", "", L("Path to output file (default: output-{timestamp}.jsonl)")) TestCmd.Flags().StringVarP(&testAgent, "name", "n", "", L("Explicit agent ID (default: auto-detect)")) TestCmd.Flags().StringVarP(&testConnector, "connector", "c", "", L("Override connector")) TestCmd.Flags().StringVarP(&testUser, "user", "u", "", L("Test user ID (default: test-user)")) TestCmd.Flags().StringVarP(&testTeam, "team", "t", "", L("Test team ID (default: test-team)")) TestCmd.Flags().StringVar(&testContext, "ctx", "", L("Path to context JSON file for custom authorization")) TestCmd.Flags().StringVarP(&testReporter, "reporter", "r", "", L("Reporter agent ID for custom report")) TestCmd.Flags().IntVar(&testRuns, "runs", 1, L("Number of runs for stability analysis")) TestCmd.Flags().StringVar(&testRun, "run", "", L("Regex pattern to filter which tests to run")) TestCmd.Flags().StringVar(&testTimeout, "timeout", "5m", L("Default timeout per test case")) TestCmd.Flags().IntVar(&testParallel, "parallel", 1, L("Number of parallel test cases")) TestCmd.Flags().BoolVarP(&testVerbose, "verbose", "v", false, L("Verbose output")) TestCmd.Flags().BoolVar(&testFailFast, "fail-fast", false, L("Stop on first failure")) TestCmd.Flags().StringVar(&testBefore, "before", "", L("Global BeforeAll hook (e.g., env_test.BeforeAll)")) TestCmd.Flags().StringVar(&testAfter, "after", "", L("Global AfterAll hook (e.g., env_test.AfterAll)")) TestCmd.Flags().BoolVar(&testDryRun, "dry-run", false, L("Generate test cases without running them")) TestCmd.Flags().StringVar(&testSimulator, "simulator", "", L("Default simulator agent for dynamic mode (e.g., tests.simulator-agent)")) // Mark input as required TestCmd.MarkFlagRequired("input") } ================================================ FILE: cmd/agent/update.go ================================================ package agent import ( "fmt" "os" "github.com/spf13/cobra" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/registry" agentmgr "github.com/yaoapp/yao/registry/manager/agent" ) // UpdateCmd implements "yao agent update @scope/name" var UpdateCmd = &cobra.Command{ Use: "update [package]", Short: L("Update an installed assistant package"), Long: L("Update an installed assistant to a newer version. Example: yao agent update @yao/keeper"), Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { Boot() pkgID := args[0] version, _ := cmd.Flags().GetString("version") client := registry.New(config.Conf.Registry, registry.WithAuth( os.Getenv("YAO_REGISTRY_USER"), os.Getenv("YAO_REGISTRY_PASS"), ), ) mgr := agentmgr.New(client, config.Conf.Root, nil) if err := mgr.Update(pkgID, agentmgr.UpdateOptions{ Version: version, }); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } }, } func init() { UpdateCmd.Flags().StringP("version", "v", "latest", L("Target version or dist-tag")) UpdateCmd.PersistentFlags().StringVarP(&appPath, "app", "a", "", L("Application directory")) UpdateCmd.PersistentFlags().StringVarP(&envFile, "env", "e", "", L("Environment file")) } ================================================ FILE: cmd/ci-token/main.go ================================================ //go:build ci package main import ( "flag" "fmt" "os" "path/filepath" "time" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/engine" "github.com/yaoapp/yao/openapi/oauth" ) func main() { appPath := flag.String("app", envOr("YAO_CI_APP_PATH", "."), "Yao application directory") clientID := flag.String("client-id", envOr("YAO_CI_OAUTH_CLIENT_ID", "ci-tai"), "OAuth client ID embedded in token") subject := flag.String("subject", envOr("YAO_CI_OAUTH_SUBJECT", "ci-tai"), "JWT subject claim") scope := flag.String("scope", envOr("YAO_CI_OAUTH_SCOPE", "tai:tunnel"), "Token scope (space-separated)") ttl := flag.String("ttl", envOr("YAO_CI_OAUTH_TTL", "24h"), "Token TTL (e.g. 1h, 24h, 168h)") userID := flag.String("user-id", envOr("YAO_CI_OAUTH_USER_ID", ""), "User ID claim") teamID := flag.String("team-id", envOr("YAO_CI_OAUTH_TEAM_ID", ""), "Team ID claim") flag.Parse() root, err := filepath.Abs(*appPath) if err != nil { fmt.Fprintf(os.Stderr, "ci-token: invalid app path: %v\n", err) os.Exit(1) } if err := os.Chdir(root); err != nil { fmt.Fprintf(os.Stderr, "ci-token: chdir %s: %v\n", root, err) os.Exit(1) } savedStdout := os.Stdout os.Stdout, _ = os.Open(os.DevNull) config.Conf = config.LoadFrom(filepath.Join(root, ".env")) config.Conf.Root = root cfg := config.Conf cfg.Session.IsCLI = true warnings, err := engine.Load(cfg, engine.LoadOption{Action: "run"}) os.Stdout = savedStdout if err != nil { fmt.Fprintf(os.Stderr, "ci-token: engine.Load failed: %v\n", err) os.Exit(1) } for _, w := range warnings { fmt.Fprintf(os.Stderr, "ci-token: warning [%s]: %v\n", w.Widget, w.Error) } if oauth.OAuth == nil { fmt.Fprintln(os.Stderr, "ci-token: oauth service not initialized (openapi.Load may have failed)") os.Exit(1) } dur, err := time.ParseDuration(*ttl) if err != nil { fmt.Fprintf(os.Stderr, "ci-token: invalid --ttl %q: %v\n", *ttl, err) os.Exit(1) } expiresIn := int(dur.Seconds()) extraClaims := map[string]interface{}{} if *userID != "" { extraClaims["user_id"] = *userID } if *teamID != "" { extraClaims["team_id"] = *teamID } token, err := oauth.OAuth.MakeAccessToken(*clientID, *scope, *subject, expiresIn, extraClaims) if err != nil { fmt.Fprintf(os.Stderr, "ci-token: MakeAccessToken failed: %v\n", err) os.Exit(1) } fmt.Print(token) } func envOr(key, fallback string) string { if v := os.Getenv(key); v != "" { return v } return fallback } ================================================ FILE: cmd/credential.go ================================================ package cmd import ( "encoding/base64" "encoding/json" "fmt" "os" "path/filepath" "time" ) // Credential represents the stored OAuth credential for gRPC mode. type Credential struct { Server string `json:"server"` GRPCAddr string `json:"grpc_addr,omitempty"` AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token,omitempty"` Scope string `json:"scope,omitempty"` User string `json:"user,omitempty"` ExpiresAt string `json:"expires_at,omitempty"` } // Expired returns true if the credential has an expires_at in the past. func (c *Credential) Expired() bool { if c.ExpiresAt == "" { return false } t, err := time.Parse(time.RFC3339, c.ExpiresAt) if err != nil { return false } return time.Now().After(t) } func credentialPath() (string, error) { home, err := os.UserHomeDir() if err != nil { return "", fmt.Errorf("cannot determine home directory: %w", err) } return filepath.Join(home, ".yao", "credentials"), nil } // LoadCredential reads and decodes ~/.yao/credentials. Returns nil if the file // does not exist. func LoadCredential() (*Credential, error) { path, err := credentialPath() if err != nil { return nil, err } raw, err := os.ReadFile(path) if os.IsNotExist(err) { return nil, nil } if err != nil { return nil, fmt.Errorf("read credentials: %w", err) } decoded, err := base64.StdEncoding.DecodeString(string(raw)) if err != nil { return nil, fmt.Errorf("decode credentials: %w", err) } var cred Credential if err := json.Unmarshal(decoded, &cred); err != nil { return nil, fmt.Errorf("unmarshal credentials: %w", err) } return &cred, nil } // LoadCredentialFrom reads and decodes a credential file from a custom path. func LoadCredentialFrom(path string) (*Credential, error) { raw, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("read credentials from %s: %w", path, err) } decoded, err := base64.StdEncoding.DecodeString(string(raw)) if err != nil { return nil, fmt.Errorf("decode credentials: %w", err) } var cred Credential if err := json.Unmarshal(decoded, &cred); err != nil { return nil, fmt.Errorf("unmarshal credentials: %w", err) } return &cred, nil } // SaveCredential encodes and writes the credential to ~/.yao/credentials. func SaveCredential(cred *Credential) error { path, err := credentialPath() if err != nil { return err } dir := filepath.Dir(path) if err := os.MkdirAll(dir, 0700); err != nil { return fmt.Errorf("create directory %s: %w", dir, err) } data, err := json.Marshal(cred) if err != nil { return fmt.Errorf("marshal credentials: %w", err) } encoded := base64.StdEncoding.EncodeToString(data) if err := os.WriteFile(path, []byte(encoded), 0600); err != nil { return fmt.Errorf("write credentials: %w", err) } return nil } // RemoveCredential deletes ~/.yao/credentials. func RemoveCredential() error { path, err := credentialPath() if err != nil { return err } if err := os.Remove(path); err != nil && !os.IsNotExist(err) { return fmt.Errorf("remove credentials: %w", err) } return nil } ================================================ FILE: cmd/dump.go ================================================ package cmd import ( "archive/zip" "errors" "fmt" "io/ioutil" "os" "path/filepath" "strings" "time" "github.com/fatih/color" "github.com/spf13/cobra" "github.com/yaoapp/gou/model" "github.com/yaoapp/kun/exception" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/engine" ) var dumpModel string var dumpCmd = &cobra.Command{ Use: "dump", Short: L("Dump the application data"), Long: L("Dump the application data"), Run: func(cmd *cobra.Command, args []string) { defer func() { err := exception.Catch(recover()) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) } }() Boot() path, err := filepath.Abs(".") if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } output := filepath.Join(fmt.Sprintf("%s-%s.zip", filepath.Base(path), time.Now().Format("20060102150405"))) if len(args) > 0 { output = args[0] } output, err = filepath.Abs(output) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } _, err = os.Stat(output) if !errors.Is(err, os.ErrNotExist) { fmt.Println(color.RedString("%s exists", output)) os.Exit(1) } // Load model loadWarnings, err := engine.Load(config.Conf, engine.LoadOption{Action: "dump"}) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } if len(loadWarnings) > 0 { for _, warning := range loadWarnings { fmt.Println(color.YellowString("[%s] %s", warning.Widget, warning.Error)) } } if dumpModel != "" { fmt.Println(color.YellowString(L("Not supported yet"))) os.Exit(1) return } // Export models files := []string{} for _, mod := range model.Models { fmt.Printf("\r%s", color.GreenString(L("Export the models: %s (%s)"), mod.Name, mod.MetaData.Table.Name)) jsonfiles, err := mod.Export(5000, func(curr, total int) { fmt.Printf("\r%s", strings.Repeat(" ", 80)) fmt.Printf("\r%s", color.GreenString(L("Export the models: %s (%s) %d/%d"), mod.Name, mod.MetaData.Table.Name, curr, total)) }) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } files = append(files, jsonfiles...) } fmt.Printf("\r%s", strings.Repeat(" ", 80)) fmt.Printf("\r%s\n", color.GreenString(L("Export the models: ✨DONE✨"))) // Compress files err = zipfiles(files, output, func(file string) { fmt.Printf("\r%s", strings.Repeat(" ", 80)) fmt.Printf("\r%s", color.GreenString(L("Compress the files: %s"), file)) }) fmt.Printf("\r%s", strings.Repeat(" ", 80)) fmt.Printf("\r%s\n", color.GreenString(L("Compress the files: ✨DONE✨"))) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } fmt.Println(color.GreenString("File: %s", output)) }, } // func init() { // // dumpCmd.PersistentFlags().StringVarP(&dumpModel, "name", "n", "", L("Model name")) // } // gzipfiles func zipfiles(files []string, output string, process func(file string)) error { outpath := filepath.Dir(output) os.MkdirAll(outpath, 0755) outfile, err := os.Create(output) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } defer outfile.Close() w := zip.NewWriter(outfile) defer func() { w.Close() if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } }() for _, file := range files { addFile(w, file, "model", process) } // Add data path dataPath := filepath.Join(config.Conf.Root, "data") _, err = os.Stat(dataPath) if err == nil { addFolder(w, dataPath, "data", process) } return nil } func addFile(w *zip.Writer, file, baseInZip string, process func(file string)) { process(filepath.Join(baseInZip, filepath.Base(file))) dat, err := ioutil.ReadFile(file) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } // Add some files to the archive. f, err := w.Create(filepath.Join(baseInZip, filepath.Base(file))) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } _, err = f.Write(dat) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } os.Remove(file) } func addFolder(w *zip.Writer, basePath, baseInZip string, process func(file string)) { // Open the Directory files, err := ioutil.ReadDir(basePath) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } for _, file := range files { process(filepath.Join(baseInZip, file.Name())) if !file.IsDir() { dat, err := ioutil.ReadFile(filepath.Join(basePath, file.Name())) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } // Add some files to the archive. f, err := w.Create(filepath.Join(baseInZip, file.Name())) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } _, err = f.Write(dat) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } } else if file.IsDir() { // Recurse newBase := filepath.Join(basePath, file.Name()) addFolder(w, newBase, filepath.Join(baseInZip, file.Name()), process) } } } ================================================ FILE: cmd/get/get.go ================================================ package get import ( "archive/zip" "fmt" "io" "io/ioutil" "net/http" "os" "path/filepath" "strings" "time" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/fs/system" "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/widgets/app" ) const ( // Application application Application uint = iota // Widgets ? model & table & flow Widgets // Table table widget Table // Form form widget Form // Model model model Model // Flow data flow Flow ) // Package package type Package struct { Name string Team string Type uint Remote string Origin string Temp string Tag string From string } // New create a package via name func New(repo string) (*Package, error) { team, name, tag, err := parse(repo) if err != nil { return nil, err } pkg := &Package{ Origin: repo, Team: team, Name: name, Tag: tag, Type: Application, } url := pkg.InfraURL() if urlExists(url) { pkg.Remote = url pkg.From = "LetsInfra.com" return pkg, nil } // @Todo: Download from Github return nil, fmt.Errorf("%s not found", repo) } // InfraURL infra package url func parse(repo string) (string, string, string, error) { tag := "latest" repo = strings.TrimSpace(repo) if !strings.Contains(repo, "/") { repo = fmt.Sprintf("yaoapp/%s", repo) } if strings.Contains(repo, "@") { arr := strings.Split(repo, "@") repo = arr[0] tag = arr[1] } arr := strings.Split(repo, "/") if len(arr) != 2 { return "", "", "", fmt.Errorf("REPO: %s format error", repo) } team := arr[0] name := arr[1] return team, name, tag, nil } // InfraURL infra package url func (pkg *Package) InfraURL() string { return fmt.Sprintf("https://mirrors.yao.run/apps/%s/%s/%s.zip", pkg.Team, pkg.Name, pkg.Tag) } // GithubURL github package url func (pkg *Package) GithubURL() string { return fmt.Sprintf("mirrors.letsinfra.com/apps/%s/%s/%s", pkg.Team, pkg.Name, pkg.Tag) } // urlExists check the http url is exists func urlExists(url string) bool { resp, err := http.Get(url) if err != nil { return false } if resp.Body != nil { defer resp.Body.Close() } return resp.StatusCode == 200 } // Download a package from remote func (pkg *Package) Download() error { if pkg.Remote == "" { return fmt.Errorf("remote url is required") } root, err := os.MkdirTemp("", "*-yao-zip") if err != nil { return fmt.Errorf("Can't Create temp dir %s", err.Error()) } name := fmt.Sprintf("%s-%s-%d.zip", pkg.Team, pkg.Name, time.Now().UnixMicro()) file := filepath.Join(root, name) out, err := os.Create(file) defer out.Close() if err != nil { return fmt.Errorf("Can't Create file: %s", err.Error()) } resp, err := http.Get(pkg.Remote) if err != nil { return fmt.Errorf("Download Error: %s", err.Error()) } defer resp.Body.Close() _, err = io.Copy(out, resp.Body) if err != nil { return fmt.Errorf("Copy Error: %s", err.Error()) } pkg.Temp = file return nil } // Validate a package files func (pkg *Package) Validate() error { if pkg.Temp == "" { return fmt.Errorf("temp file not found") } return nil } // Unpack a package to current dir func (pkg *Package) Unpack(dest string) (*app.DSL, error) { dest, err := filepath.Abs(dest) if err != nil { return nil, err } files, err := ioutil.ReadDir(dest) if err != nil { return nil, err } for _, f := range files { if !strings.HasPrefix(f.Name(), "logs") { return nil, fmt.Errorf("current folder shoud be empty") } } temp, err := os.MkdirTemp("", "*-yao-unzip") if err != nil { return nil, err } defer os.RemoveAll(temp) // Read zip file r, err := zip.OpenReader(pkg.Temp) if err != nil { return nil, err } defer r.Close() defer os.Remove(pkg.Temp) path := "" for i, f := range r.File { if i == 0 { path = filepath.Join(temp, strings.TrimRight(f.Name, "/")) } err := extractFile(f, temp) if err != nil { return nil, err } } data, err := os.ReadFile(filepath.Join(path, "app.yao")) if err != nil { return nil, err } var setting app.DSL err = jsoniter.Unmarshal(data, &setting) if err != nil { return nil, err } fs := system.New("/") err = fs.Copy(path, dest) if err != nil { return nil, err } // Remove env fs.Remove(filepath.Join(dest, ".env")) return &setting, nil } // extractFile extract and save file to the dest path func extractFile(f *zip.File, dest string) error { rc, err := f.Open() if err != nil { return err } defer rc.Close() path := filepath.Join(dest, f.Name) // Check for ZipSlip (Directory traversal) if !strings.HasPrefix(path, filepath.Clean(dest)+string(os.PathSeparator)) { return fmt.Errorf("illegal file path: %s", path) } if f.FileInfo().IsDir() { os.MkdirAll(path, f.Mode()) } else { os.MkdirAll(filepath.Dir(path), f.Mode()) f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) if err != nil { return err } defer func() { if err := f.Close(); err != nil { log.Error("repo unzip extractFile: %s", err.Error()) } }() _, err = io.Copy(f, rc) if err != nil { return err } } return nil } ================================================ FILE: cmd/get/get_test.go ================================================ package get import ( "testing" ) func TestUnpack(t *testing.T) { // pkg, err := New("yaoapp/demo-app") // if err != nil { // t.Fatal(err) // } // if err := pkg.Download(); err != nil { // t.Fatal(err) // } // dest, err := os.MkdirTemp("", "*-unit-test") // if err != nil { // t.Fatal(err) // } // defer os.RemoveAll(dest) // app, err := pkg.Unpack(dest) // if err != nil { // t.Fatal(err) // } // assert.NotNil(t, app.Name) } ================================================ FILE: cmd/get.go ================================================ package cmd import ( "fmt" "os" "github.com/fatih/color" "github.com/spf13/cobra" "github.com/yaoapp/yao/cmd/get" "github.com/yaoapp/yao/share" ) var getCmd = &cobra.Command{ Use: "get", Short: L("Get an application"), Long: L("Get an application"), Run: func(cmd *cobra.Command, args []string) { if len(args) < 1 { fmt.Println(color.RedString(L("Not enough arguments"))) fmt.Println(color.WhiteString(share.BUILDNAME + " help")) return } repo := args[0] pkg, err := get.New(repo) if err != nil { fmt.Println(color.RedString(err.Error())) os.Exit(1) } fmt.Println(color.WhiteString("From Yao: %s", pkg.Remote)) fmt.Println(color.WhiteString("Visit: https://yaoapps.com")) err = pkg.Download() if err != nil { fmt.Println(color.RedString(err.Error())) os.Exit(1) } dest, err := os.Getwd() if err != nil { fmt.Println(color.RedString(err.Error())) os.Exit(1) } // dest, err = os.MkdirTemp(dest, "*-unit-test") // if err != nil { // fmt.Println(color.RedString(err.Error())) // os.Exit(1) // } // os.MkdirAll(dest, os.ModePerm) app, err := pkg.Unpack(dest) if err != nil { fmt.Println(color.RedString(err.Error())) os.Exit(1) } fmt.Println(color.GreenString(app.Name), color.WhiteString(app.Version)) fmt.Println(color.GreenString(L("✨DONE✨"))) }, } ================================================ FILE: cmd/help.go ================================================ package cmd import ( "github.com/spf13/cobra" ) var helpCmd = &cobra.Command{ Use: "help", Short: L("Help for yao"), Long: L("Help for yao"), Run: func(cmd *cobra.Command, args []string) { cmd.Help() }, } ================================================ FILE: cmd/init.go ================================================ package cmd import ( "fmt" "os" "path/filepath" "time" "github.com/fatih/color" "github.com/spf13/cobra" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/engine" "github.com/yaoapp/yao/setup" "github.com/yaoapp/yao/share" ) var initCmd = &cobra.Command{ Use: "init", Short: L("Initialize project"), Long: L("Initialize a new Yao application in the current directory"), Run: func(cmd *cobra.Command, args []string) { Boot() // First check if we're inside an existing Yao app (including parent directories) if setup.InYaoApp(config.Conf.Root) { fmt.Println(color.YellowString(L("Directory is inside an existing Yao application"))) fmt.Println(color.WhiteString("Please run 'yao init' outside of any Yao project")) os.Exit(1) } // Check if this is an empty directory if !setup.IsEmptyDir(config.Conf.Root) { // Directory is not empty fmt.Println(color.RedString(L("Directory is not empty"))) fmt.Println(color.WhiteString("Please run 'yao init' in an empty directory")) os.Exit(1) } startTime := time.Now() fmt.Println(color.CyanString("Initializing Yao application...")) // Install the init app (copy embedded files) if err := setup.Install(config.Conf.Root); err != nil { fmt.Println(color.RedString(L("Install: %s"), err.Error())) os.Exit(1) } fmt.Printf(" %s %s\n", color.GreenString("✓"), "Copied application files") // Reload configuration after install Boot() // Load the application engine loadWarnings, err := engine.Load(config.Conf, engine.LoadOption{Action: "init"}) if err != nil { fmt.Println(color.RedString(L("Load: %s"), err.Error())) os.Exit(1) } fmt.Printf(" %s %s\n", color.GreenString("✓"), "Loaded application engine") // Initialize (migrate + setup hook) if err := setup.Initialize(config.Conf.Root, config.Conf); err != nil { fmt.Println(color.RedString(L("Initialize: %s"), err.Error())) os.Exit(1) } fmt.Printf(" %s %s\n", color.GreenString("✓"), "Initialized database and data") initDuration := time.Since(startTime) // Print warnings if any if len(loadWarnings) > 0 { fmt.Println(color.YellowString("\n---------------------------------")) fmt.Println(color.YellowString(L("Warnings"))) fmt.Println(color.YellowString("---------------------------------")) for _, warning := range loadWarnings { fmt.Println(color.YellowString("[%s] %s", warning.Widget, warning.Error)) } } // Print success message fmt.Printf("\n%s Application initialized successfully in %s\n\n", color.GreenString("✓"), color.CyanString("%v", initDuration)) // Print application info root, _ := filepath.Abs(config.Conf.Root) fmt.Println(color.WhiteString("---------------------------------")) fmt.Println(color.WhiteString(L("Application Info"))) fmt.Println(color.WhiteString("---------------------------------")) fmt.Println(color.WhiteString(L("Name")), color.GreenString(" %s", share.App.Name)) fmt.Println(color.WhiteString(L("Version")), color.GreenString(" %s", share.App.Version)) fmt.Println(color.WhiteString(L("Root")), color.GreenString(" %s", root)) // Print welcome message printInitWelcome() }, } func printInitWelcome() { fmt.Println(color.CyanString("\n---------------------------------")) fmt.Println(color.CyanString(L("🎉 Application Ready 🎉"))) fmt.Println(color.CyanString("---------------------------------")) fmt.Println(color.WhiteString("📚 Documentation: "), color.CyanString("https://yaoapps.com/docs")) fmt.Println(color.WhiteString("🏡 Join Yao Community: "), color.CyanString("https://yaoapps.com/community")) fmt.Println(color.WhiteString("🤖 Build Your Digital Workforce:"), color.CyanString("https://yaoagents.com")) fmt.Println("") fmt.Println(color.WhiteString(L("NEXT:"))) fmt.Println(color.GreenString(" 1. Edit .env to configure your application")) fmt.Println(color.GreenString(" 2. Run 'yao start' to start the server")) fmt.Println("") } func init() { // Register init command rootCmd.AddCommand(initCmd) } ================================================ FILE: cmd/inspect.go ================================================ package cmd import ( "github.com/spf13/cobra" "github.com/yaoapp/kun/maps" "github.com/yaoapp/kun/utils" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/engine" "github.com/yaoapp/yao/share" ) var inspectCmd = &cobra.Command{ Use: "inspect", Short: L("Show app configure"), Long: L("Show app configure"), Run: func(cmd *cobra.Command, args []string) { Boot() engine.InspectExtTools() res := maps.Map{ "version": share.VERSION, "config": config.Conf, } if share.Tools != nil { res["tools"] = share.Tools } utils.Dump(res) }, } ================================================ FILE: cmd/login.go ================================================ package cmd import ( "encoding/base64" "encoding/json" "fmt" "io" "net/http" "net/url" "os" "strings" "time" "github.com/fatih/color" "github.com/spf13/cobra" "github.com/yaoapp/yao/engine" ) var loginServer string var loginCmd = &cobra.Command{ Use: "login", Short: L("Login to remote Yao server"), Long: L("Login to remote Yao server using device authorization flow"), Run: func(cmd *cobra.Command, args []string) { if loginServer == "" { color.Red(L("Missing --server flag\n")) fmt.Println(" yao login --server https://yaoagents.com") os.Exit(1) } serverURL := strings.TrimRight(loginServer, "/") // 1. Discover OAuth endpoints via well-known metadata endpoints, err := discoverEndpoints(serverURL) if err != nil { color.Red(" %s %s\n", L("Server discovery failed:"), err) os.Exit(1) } // 2. Compute deterministic client_id from machine fingerprint machine, err := engine.GetMachineID() if err != nil { color.Red("Failed to compute machine ID: %s\n", err) os.Exit(1) } clientID := machine.ID // 3. Register the client (idempotent for same client_id) if endpoints.RegistrationEndpoint != "" { if err := registerClient(endpoints.RegistrationEndpoint, clientID); err != nil { color.Red("Client registration failed: %s\n", err) os.Exit(1) } } // 4. Start device authorization deviceResp, err := requestDeviceAuthorization(endpoints.DeviceAuthorizationEndpoint, clientID) if err != nil { color.Red("Device authorization failed: %s\n", err) os.Exit(1) } // 5. Display the code to the user dashboard := endpoints.Dashboard if dashboard == "" { dashboard = "/admin" } verifyURI := strings.TrimRight(serverURL, "/") + dashboard + "/auth/device" verifyURIComplete := verifyURI + "?user_code=" + deviceResp.UserCode fmt.Println() color.White(" %s %s\n", L("Open:"), color.CyanString(verifyURIComplete)) fmt.Println() color.White(" %s %s\n", L("Or visit:"), color.CyanString(verifyURI)) color.White(" %s %s\n", L("Enter code:"), color.YellowString(deviceResp.UserCode)) fmt.Println() // 6. Poll for token interval := deviceResp.Interval if interval < 5 { interval = 5 } color.White(" %s", L("Waiting for authorization...")) tokenResp, err := pollForToken(endpoints.TokenEndpoint, clientID, deviceResp.DeviceCode, interval, deviceResp.ExpiresIn) if err != nil { fmt.Println() color.Red("\n %s %s\n", L("Login failed:"), err) os.Exit(1) } // 6. Save credential expiresAt := "" if tokenResp.ExpiresIn > 0 { expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).UTC().Format(time.RFC3339) } cred := &Credential{ Server: serverURL, GRPCAddr: endpoints.GRPCAddr, AccessToken: tokenResp.AccessToken, RefreshToken: tokenResp.RefreshToken, Scope: tokenResp.Scope, User: parseJWTSubject(tokenResp.AccessToken), ExpiresAt: expiresAt, } if err := SaveCredential(cred); err != nil { color.Red("\n Failed to save credentials: %s\n", err) os.Exit(1) } fmt.Print("\033[2J\033[H") color.Green(" ✓ %s\n", L("Login successful")) color.White(" %s %s\n", L("Server:"), serverURL) if cred.GRPCAddr != "" { color.White(" %s %s\n", L("gRPC:"), cred.GRPCAddr) } if cred.User != "" { color.White(" %s %s\n", L("User:"), cred.User) } if cred.ExpiresAt != "" { color.White(" %s %s\n", L("Expires:"), cred.ExpiresAt) } fmt.Println() }, } func init() { loginCmd.PersistentFlags().StringVar(&loginServer, "server", "", L("Remote Yao server URL")) } // --- types --- type oauthEndpoints struct { RegistrationEndpoint string `json:"registration_endpoint"` DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"` TokenEndpoint string `json:"token_endpoint"` RevocationEndpoint string `json:"revocation_endpoint"` Dashboard string `json:"-"` GRPCAddr string `json:"-"` } type deviceAuthResponse struct { DeviceCode string `json:"device_code"` UserCode string `json:"user_code"` VerificationURI string `json:"verification_uri"` VerificationURIComplete string `json:"verification_uri_complete"` ExpiresIn int `json:"expires_in"` Interval int `json:"interval"` } type tokenResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` RefreshToken string `json:"refresh_token"` Scope string `json:"scope"` } type oauthError struct { Error string `json:"error"` ErrorDescription string `json:"error_description"` } // --- HTTP helpers --- // discoverEndpoints fetches OAuth endpoint URLs from /.well-known/yao, // using the openapi base prefix to construct correct API paths. func discoverEndpoints(serverURL string) (*oauthEndpoints, error) { return discoverFromYaoMetadata(serverURL) } type yaoMetadataResponse struct { OpenAPI string `json:"openapi"` Dashboard string `json:"dashboard"` GRPC string `json:"grpc"` } func discoverFromYaoMetadata(serverURL string) (*oauthEndpoints, error) { resp, err := http.Get(serverURL + "/.well-known/yao") if err != nil { return nil, fmt.Errorf("network error: %w", err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("/.well-known/yao returned %d", resp.StatusCode) } var meta yaoMetadataResponse if err := json.Unmarshal(body, &meta); err != nil { return nil, fmt.Errorf("invalid /.well-known/yao response: %w", err) } base := strings.TrimRight(serverURL, "/") + meta.OpenAPI return &oauthEndpoints{ RegistrationEndpoint: base + "/oauth/register", DeviceAuthorizationEndpoint: base + "/oauth/device_authorization", TokenEndpoint: base + "/oauth/token", RevocationEndpoint: base + "/oauth/revoke", Dashboard: meta.Dashboard, GRPCAddr: meta.GRPC, }, nil } func registerClient(endpoint, clientID string) error { body := fmt.Sprintf( `{"client_id":"%s","client_name":"yao-cli","grant_types":["urn:ietf:params:oauth:grant-type:device_code"],"token_endpoint_auth_method":"none","redirect_uris":["http://localhost"]}`, clientID, ) resp, err := http.Post(endpoint, "application/json", strings.NewReader(body)) if err != nil { return fmt.Errorf("network error: %w", err) } defer resp.Body.Close() if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated { return nil } respBody, _ := io.ReadAll(resp.Body) var oerr oauthError if json.Unmarshal(respBody, &oerr) == nil && oerr.Error == "invalid_client_metadata" { return nil // client already registered, idempotent } return fmt.Errorf("registration returned %d: %s", resp.StatusCode, string(respBody)) } func requestDeviceAuthorization(endpoint, clientID string) (*deviceAuthResponse, error) { data := url.Values{ "client_id": {clientID}, "scope": {"grpc:run grpc:stream grpc:shell grpc:mcp grpc:llm grpc:agent"}, } resp, err := http.PostForm(endpoint, data) if err != nil { return nil, fmt.Errorf("network error: %w", err) } defer resp.Body.Close() respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { var oerr oauthError json.Unmarshal(respBody, &oerr) if oerr.ErrorDescription != "" { return nil, fmt.Errorf("%s", oerr.ErrorDescription) } return nil, fmt.Errorf("server returned %d: %s", resp.StatusCode, string(respBody)) } var result deviceAuthResponse if err := json.Unmarshal(respBody, &result); err != nil { return nil, fmt.Errorf("invalid response: %w", err) } return &result, nil } func pollForToken(endpoint, clientID, deviceCode string, interval, expiresIn int) (*tokenResponse, error) { deadline := time.Now().Add(time.Duration(expiresIn) * time.Second) ticker := time.NewTicker(time.Duration(interval) * time.Second) defer ticker.Stop() for range ticker.C { if time.Now().After(deadline) { return nil, fmt.Errorf("device code expired") } data := url.Values{ "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, "client_id": {clientID}, "device_code": {deviceCode}, } resp, err := http.PostForm(endpoint, data) if err != nil { continue } respBody, _ := io.ReadAll(resp.Body) resp.Body.Close() if resp.StatusCode == http.StatusOK { var tok tokenResponse if err := json.Unmarshal(respBody, &tok); err != nil { return nil, fmt.Errorf("invalid token response: %w", err) } return &tok, nil } var oerr oauthError json.Unmarshal(respBody, &oerr) switch oerr.Error { case "authorization_pending": fmt.Print(".") continue case "slow_down": interval += 5 ticker.Reset(time.Duration(interval) * time.Second) continue case "expired_token": return nil, fmt.Errorf("device code expired") case "access_denied": return nil, fmt.Errorf("authorization denied by user") default: desc := oerr.ErrorDescription if desc == "" { desc = oerr.Error } return nil, fmt.Errorf("%s", desc) } } return nil, fmt.Errorf("device code expired") } // parseJWTSubject extracts the "sub" claim from a JWT access token // without verifying the signature (display-only). func parseJWTSubject(token string) string { parts := strings.Split(token, ".") if len(parts) != 3 { return "" } payload, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return "" } var claims struct { Sub string `json:"sub"` } if json.Unmarshal(payload, &claims) != nil { return "" } return claims.Sub } ================================================ FILE: cmd/logout.go ================================================ package cmd import ( "net/http" "net/url" "os" "strings" "github.com/fatih/color" "github.com/spf13/cobra" ) var logoutCmd = &cobra.Command{ Use: "logout", Short: L("Logout from remote Yao server"), Long: L("Revoke token and remove stored credentials"), Run: func(cmd *cobra.Command, args []string) { cred, err := LoadCredential() if err != nil { color.Red(" %s %s\n", L("Failed to read credentials:"), err) os.Exit(1) } if cred == nil { color.Yellow(" %s\n", L("Not logged in")) return } // Best-effort token revocation via discovery if cred.AccessToken != "" && cred.Server != "" { if ep, err := discoverEndpoints(cred.Server); err == nil && ep.RevocationEndpoint != "" { revokeToken(ep.RevocationEndpoint, cred.AccessToken) } } if err := RemoveCredential(); err != nil { color.Red(" %s %s\n", L("Failed to remove credentials:"), err) os.Exit(1) } color.Green(" ✓ %s\n", L("Logged out")) if cred.Server != "" { color.White(" %s %s\n", L("Server:"), cred.Server) } }, } func revokeToken(endpoint, token string) { data := url.Values{"token": {token}} req, err := http.NewRequest("POST", endpoint, strings.NewReader(data.Encode())) if err != nil { return } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") http.DefaultClient.Do(req) } func init() { // Add i18n entries langs["Login to remote Yao server"] = "登录远程 Yao 服务器" langs["Login to remote Yao server using device authorization flow"] = "使用设备授权流程登录远程 Yao 服务器" langs["Remote Yao server URL"] = "远程 Yao 服务器地址" langs["Logout from remote Yao server"] = "登出远程 Yao 服务器" langs["Revoke token and remove stored credentials"] = "撤销令牌并移除存储的凭证" langs["Missing --server flag"] = "缺少 --server 参数" langs["Open:"] = "打开:" langs["Or visit:"] = "或访问:" langs["Enter code:"] = "输入设备码:" langs["Waiting for authorization..."] = "等待授权..." langs["Login failed:"] = "登录失败:" langs["Login successful"] = "登录成功" langs["Server:"] = "服务器:" langs["Scope:"] = "授权范围:" langs["Failed to read credentials:"] = "读取凭证失败:" langs["Not logged in"] = "未登录" langs["Failed to remove credentials:"] = "移除凭证失败:" langs["Logged out"] = "已登出" langs["Path to credentials file"] = "凭证文件路径" langs["Failed to load credentials:"] = "加载凭证失败:" langs["Server discovery failed:"] = "服务发现失败:" } ================================================ FILE: cmd/mcp/add.go ================================================ package mcp import ( "fmt" "os" "github.com/spf13/cobra" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/registry" mcpmgr "github.com/yaoapp/yao/registry/manager/mcp" ) var mcpAddForce bool // AddCmd implements "yao mcp add @scope/name" var AddCmd = &cobra.Command{ Use: "add [package]", Short: L("Install an MCP package from the registry"), Long: L("Install an MCP package from the registry. Example: yao mcp add @yao/rag-tools"), Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { Boot() pkgID := args[0] version, _ := cmd.Flags().GetString("version") client := registry.New(config.Conf.Registry, registry.WithAuth( os.Getenv("YAO_REGISTRY_USER"), os.Getenv("YAO_REGISTRY_PASS"), ), ) mgr := mcpmgr.New(client, config.Conf.Root, nil) if err := mgr.Add(pkgID, mcpmgr.AddOptions{ Version: version, Force: mcpAddForce, }); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } }, } func init() { AddCmd.Flags().StringP("version", "v", "latest", L("Package version or dist-tag")) AddCmd.Flags().BoolVarP(&mcpAddForce, "force", "", false, L("Force reinstall")) AddCmd.PersistentFlags().StringVarP(&appPath, "app", "a", "", L("Application directory")) AddCmd.PersistentFlags().StringVarP(&envFile, "env", "e", "", L("Environment file")) } ================================================ FILE: cmd/mcp/fork.go ================================================ package mcp import ( "fmt" "os" "github.com/spf13/cobra" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/registry" mcpmgr "github.com/yaoapp/yao/registry/manager/mcp" ) // ForkCmd implements "yao mcp fork @scope/name [@target-scope]" var ForkCmd = &cobra.Command{ Use: "fork [package] [target-scope]", Short: L("Fork an MCP to a local scope"), Long: L("Fork an MCP for local modification. Example: yao mcp fork @yao/rag-tools"), Args: cobra.RangeArgs(1, 2), Run: func(cmd *cobra.Command, args []string) { Boot() pkgID := args[0] var targetScope string if len(args) > 1 { targetScope = args[1] } client := registry.New(config.Conf.Registry, registry.WithAuth( os.Getenv("YAO_REGISTRY_USER"), os.Getenv("YAO_REGISTRY_PASS"), ), ) mgr := mcpmgr.New(client, config.Conf.Root, nil) if err := mgr.Fork(pkgID, mcpmgr.ForkOptions{ TargetScope: targetScope, }); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } }, } func init() { ForkCmd.PersistentFlags().StringVarP(&appPath, "app", "a", "", L("Application directory")) ForkCmd.PersistentFlags().StringVarP(&envFile, "env", "e", "", L("Environment file")) } ================================================ FILE: cmd/mcp/mcp.go ================================================ package mcp import ( "os" "path/filepath" "github.com/yaoapp/kun/exception" "github.com/yaoapp/yao/config" ) var appPath string var envFile string var langs = map[string]string{ "Install an MCP package from the registry": "从注册中心安装 MCP 包", "Update an installed MCP package": "更新已安装的 MCP 包", "Push an MCP package to the registry": "推送 MCP 包到注册中心", "Fork an MCP to a local scope": "Fork 一个 MCP 到本地范围", "Package version or dist-tag": "包版本或 dist-tag", "Force reinstall": "强制重新安装", "Package version (required)": "包版本 (必填)", "Target version or dist-tag": "目标版本或 dist-tag", "Application directory": "应用目录", "Environment file": "环境变量文件", } // L Language switch func L(words string) string { var lang = os.Getenv("YAO_LANG") if lang == "" { return words } if trans, has := langs[words]; has { return trans } return words } // Boot sets the configuration func Boot() { root := config.Conf.Root if appPath != "" { r, err := filepath.Abs(appPath) if err != nil { exception.New("Root error %s", 500, err.Error()).Throw() } root = r } if envFile != "" { config.Conf = config.LoadFromWithRoot(envFile, root) } else { config.Conf = config.LoadFromWithRoot(filepath.Join(root, ".env"), root) } config.ApplyMode() } ================================================ FILE: cmd/mcp/push.go ================================================ package mcp import ( "fmt" "os" "github.com/spf13/cobra" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/registry" mcpmgr "github.com/yaoapp/yao/registry/manager/mcp" ) // PushCmd implements "yao mcp push scope.name --version x.y.z" var PushCmd = &cobra.Command{ Use: "push [yao-id]", Short: L("Push an MCP package to the registry"), Long: L("Package and push an MCP to the registry. Example: yao mcp push max.rag-tools --version 1.0.0"), Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { Boot() yaoID := args[0] version, _ := cmd.Flags().GetString("version") force, _ := cmd.Flags().GetBool("force") client := registry.New(config.Conf.Registry, registry.WithAuth( os.Getenv("YAO_REGISTRY_USER"), os.Getenv("YAO_REGISTRY_PASS"), ), ) mgr := mcpmgr.New(client, config.Conf.Root, nil) if err := mgr.Push(yaoID, mcpmgr.PushOptions{ Version: version, Force: force, }); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } }, } func init() { PushCmd.Flags().StringP("version", "v", "", L("Package version (required)")) PushCmd.Flags().Bool("force", false, L("Overwrite existing version")) PushCmd.PersistentFlags().StringVarP(&appPath, "app", "a", "", L("Application directory")) PushCmd.PersistentFlags().StringVarP(&envFile, "env", "e", "", L("Environment file")) } ================================================ FILE: cmd/mcp/update.go ================================================ package mcp import ( "fmt" "os" "github.com/spf13/cobra" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/registry" mcpmgr "github.com/yaoapp/yao/registry/manager/mcp" ) // UpdateCmd implements "yao mcp update @scope/name" var UpdateCmd = &cobra.Command{ Use: "update [package]", Short: L("Update an installed MCP package"), Long: L("Update an installed MCP package. Example: yao mcp update @yao/rag-tools"), Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { Boot() pkgID := args[0] version, _ := cmd.Flags().GetString("version") client := registry.New(config.Conf.Registry, registry.WithAuth( os.Getenv("YAO_REGISTRY_USER"), os.Getenv("YAO_REGISTRY_PASS"), ), ) mgr := mcpmgr.New(client, config.Conf.Root, nil) if err := mgr.Update(pkgID, mcpmgr.UpdateOptions{ Version: version, }); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } }, } func init() { UpdateCmd.Flags().StringP("version", "v", "latest", L("Target version or dist-tag")) UpdateCmd.PersistentFlags().StringVarP(&appPath, "app", "a", "", L("Application directory")) UpdateCmd.PersistentFlags().StringVarP(&envFile, "env", "e", "", L("Environment file")) } ================================================ FILE: cmd/migrate.go ================================================ package cmd import ( "fmt" "os" "github.com/fatih/color" "github.com/spf13/cobra" "github.com/yaoapp/gou/model" "github.com/yaoapp/gou/process" "github.com/yaoapp/kun/exception" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/engine" "github.com/yaoapp/yao/share" ) var name string var force bool = false var resetModel bool = false var migrateCmd = &cobra.Command{ Use: "migrate", Short: L("Update database schema"), Long: L("Update database schema"), Run: func(cmd *cobra.Command, args []string) { defer func() { err := exception.Catch(recover()) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) } }() Boot() if !force && config.Conf.Mode == "production" { fmt.Println(color.WhiteString(L("TRY:")), color.GreenString("%s migrate --force", share.BUILDNAME)) exception.New(L("Migrate is not allowed on production mode."), 403).Throw() } // 加载数据模型 loadWarnings, err := engine.Load(config.Conf, engine.LoadOption{Action: "migrate"}) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } if len(loadWarnings) > 0 { for _, warning := range loadWarnings { fmt.Println(color.YellowString("[%s] %s", warning.Widget, warning.Error)) } } if name != "" { mod, has := model.Models[name] if !has { fmt.Println(color.RedString(L("Model: %s does not exits"), name)) return } fmt.Print(color.WhiteString(fmt.Sprintf(L("Update schema model: %s (%s) "), mod.Name, mod.MetaData.Table.Name)) + "\t") if resetModel { err := mod.DropTable() if err != nil { fmt.Print(color.RedString(fmt.Sprintf(L("FAILURE\n%s"), err.Error())) + "\n") return } } err := mod.Migrate(false) if err != nil { fmt.Print(color.RedString(fmt.Sprintf(L("FAILURE\n%s"), err.Error())) + "\n") return } fmt.Print(color.GreenString(L("SUCCESS")) + "\n") return } // Do Stuff Here for _, mod := range model.Models { fmt.Print(color.WhiteString(fmt.Sprintf(L("Update schema model: %s (%s) "), mod.Name, mod.MetaData.Table.Name)) + "\t") if resetModel { err := mod.DropTable() if err != nil { fmt.Print(color.RedString(fmt.Sprintf(L("FAILURE\n%s"), err.Error())) + "\n") continue } } err := mod.Migrate(false) if err != nil { fmt.Print(color.RedString(fmt.Sprintf(L("FAILURE\n%s"), err.Error())) + "\n") continue } fmt.Print(color.GreenString(L("SUCCESS")) + "\n") } // After Migrate Hook if share.App.AfterMigrate != "" { option := map[string]any{"force": force, "reset": resetModel, "mode": config.Conf.Mode} p, err := process.Of(share.App.AfterMigrate, option) if err != nil { fmt.Println(color.RedString(L("AfterMigrate: %s %v"), share.App.AfterMigrate, err)) return } _, err = p.Exec() if err != nil { fmt.Println(color.RedString(L("AfterMigrate: %s %v"), share.App.AfterMigrate, err)) } } // fmt.Println(color.GreenString(L("✨DONE✨"))) }, } func init() { migrateCmd.PersistentFlags().StringVarP(&name, "name", "n", "", L("Model name")) migrateCmd.PersistentFlags().BoolVarP(&force, "force", "", false, L("Force migrate")) migrateCmd.PersistentFlags().BoolVarP(&resetModel, "reset", "", false, L("Drop the table if exist")) } ================================================ FILE: cmd/pack.go ================================================ package cmd import ( "bufio" "fmt" "os" "path/filepath" "strings" "github.com/fatih/color" "github.com/spf13/cobra" "github.com/yaoapp/gou/application/yaz" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/pack" ) var packOutput = "" var packLicense = "" var packCmd = &cobra.Command{ Use: "pack", Short: L("Package the application"), Long: L("Package the application into a single file"), Run: func(cmd *cobra.Command, args []string) { cfg := config.Conf output, err := filepath.Abs(filepath.Join(cfg.Root, "dist")) if err != nil { color.Red(err.Error()) } if packOutput != "" { output, err = filepath.Abs(packOutput) if err != nil { color.Red(err.Error()) } } stat, err := os.Stat(output) if err != nil && os.IsNotExist(err) { color.Green("Creating directory %s", output) err = os.MkdirAll(output, 0755) if err != nil { color.Red(err.Error()) os.Exit(1) } } else if err != nil { color.Red(err.Error()) os.Exit(1) } else if !stat.IsDir() { color.Red("Output directory %s is not a directory.\n", output) os.Exit(1) } outputFile := filepath.Join(output, "app.yaz") _, err = os.Stat(outputFile) if !os.IsNotExist(err) { color.Yellow("%s already exists", outputFile) fmt.Printf("%s", color.RedString("Do you want to overwrite it? (y/n): ")) scanner := bufio.NewScanner(os.Stdin) if scanner.Scan() { if strings.ToLower(scanner.Text()) != "y" { os.Exit(0) return } } } os.Remove(outputFile) if packLicense != "" { pack.SetCipher(packLicense) err = yaz.PackTo(cfg.Root, outputFile, pack.Cipher) } else { err = yaz.CompressTo(cfg.Root, outputFile) } if err != nil { color.Red(err.Error()) os.Exit(1) } color.Green("Packaged to %s", outputFile) }, } func init() { packCmd.PersistentFlags().StringVarP(&packOutput, "output", "o", "", L("Output Directory")) packCmd.PersistentFlags().StringVarP(&packLicense, "license", "l", "", L("Pack with the license")) } ================================================ FILE: cmd/restore.go ================================================ package cmd import ( "archive/zip" "errors" "fmt" "io" "io/ioutil" "os" "path/filepath" "strings" "time" "github.com/fatih/color" "github.com/spf13/cobra" "github.com/yaoapp/gou/model" "github.com/yaoapp/kun/exception" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/engine" "github.com/yaoapp/yao/share" ) var restoreForce bool = false var migrateNoInsert bool = false var restoreCmd = &cobra.Command{ Use: "restore", Short: L("Restore the application data"), Long: L("Restore the application data"), Run: func(cmd *cobra.Command, args []string) { defer func() { err := exception.Catch(recover()) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) } }() if len(args) < 1 { fmt.Println(color.RedString(L("Not enough arguments"))) fmt.Println(color.WhiteString(share.BUILDNAME + " help")) os.Exit(1) } zipfile, err := filepath.Abs(args[0]) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } Boot() if !restoreForce && config.Conf.Mode == "production" { fmt.Println(color.WhiteString(L("TRY:")), color.GreenString("%s restore --force", share.BUILDNAME)) exception.New(L("Retore is not allowed on production mode."), 403).Throw() } // Unzip files dst := unzipFile(zipfile, func(file string) { fmt.Printf("\r%s", strings.Repeat(" ", 80)) fmt.Printf("\r%s", color.GreenString(L("Unzip the file: %s"), file)) }) // 加载数据模型 loadWarnings, err := engine.Load(config.Conf, engine.LoadOption{Action: "restore"}) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } if len(loadWarnings) > 0 { for _, warning := range loadWarnings { fmt.Println(color.YellowString("[%s] %s", warning.Widget, warning.Error)) } } // Restore models restoreModels(filepath.Join(dst, "model"), []model.MigrateOption{ model.WithDonotInsertValues(migrateNoInsert), }) // Restore Data restoreData(filepath.Join(dst, "data")) // Clean os.RemoveAll(dst) fmt.Println(color.GreenString(L("✨DONE✨"))) }, } func init() { restoreCmd.PersistentFlags().BoolVarP(&restoreForce, "force", "", false, L("Force restore")) restoreCmd.PersistentFlags().BoolVarP(&migrateNoInsert, "migrate-no-insert", "", false, L("Do not insert values when migrating")) } func restoreData(basePath string) { _, err := os.Stat(basePath) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } // Clean Data dataPath := filepath.Join(config.Conf.Root, "data") _, err = os.Stat(dataPath) if err == nil { os.RemoveAll(dataPath) } err = os.Rename(basePath, dataPath) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } } func restoreModels(basePath string, migOpts []model.MigrateOption) { files, err := ioutil.ReadDir(basePath) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } // Migrate models for _, mod := range model.Models { fmt.Printf("\r%s", strings.Repeat(" ", 80)) fmt.Print(color.GreenString(fmt.Sprintf(L("\rUpdate schema model: %s (%s) "), mod.Name, mod.MetaData.Table.Name))) err := mod.Migrate(true, migOpts...) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } } fmt.Println("") for _, file := range files { namer := strings.Split(file.Name(), ".") name := strings.Join(namer[:len(namer)-2], ".") if mod, has := model.Models[name]; has { fmt.Printf("\r%s", strings.Repeat(" ", 80)) fmt.Print(color.GreenString(fmt.Sprintf(L("\rRestore model: %s (%s) "), mod.Name, mod.MetaData.Table.Name))) err := mod.Import(filepath.Join(basePath, file.Name())) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } } } } func unzipFile(file string, process func(file string)) string { _, err := os.Stat(file) if errors.Is(err, os.ErrNotExist) { fmt.Println(color.RedString("%s not exists", file)) os.Exit(1) } if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } dst := filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s", filepath.Base(file), time.Now().Format("20060102150405"))) os.MkdirAll(dst, 0755) archive, err := zip.OpenReader(file) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } defer archive.Close() for _, f := range archive.File { filePath := filepath.Join(dst, f.Name) process(f.Name) if !strings.HasPrefix(filePath, filepath.Clean(dst)+string(os.PathSeparator)) { fmt.Println(color.RedString(L("Fatal: invalid file path"))) os.Exit(1) } if f.FileInfo().IsDir() { os.MkdirAll(filePath, os.ModePerm) continue } if err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } dstFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } fileInArchive, err := f.Open() if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } if _, err := io.Copy(dstFile, fileInArchive); err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) os.Exit(1) } dstFile.Close() fileInArchive.Close() } return dst } ================================================ FILE: cmd/robot/add.go ================================================ package robot import ( "fmt" "os" "github.com/spf13/cobra" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/registry" robotmgr "github.com/yaoapp/yao/registry/manager/robot" ) // AddCmd implements "yao robot add @scope/name --team TEAM_ID" var AddCmd = &cobra.Command{ Use: "add [package]", Short: L("Install a robot package from the registry"), Long: L("Install a robot and its dependencies. Example: yao robot add @yao/keeper --team team-123"), Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { Boot() pkgID := args[0] version, _ := cmd.Flags().GetString("version") teamID, _ := cmd.Flags().GetString("team") client := registry.New(config.Conf.Registry, registry.WithAuth( os.Getenv("YAO_REGISTRY_USER"), os.Getenv("YAO_REGISTRY_PASS"), ), ) mgr := robotmgr.New(client, config.Conf.Root, nil) robot, err := mgr.Add(pkgID, robotmgr.AddOptions{ Version: version, TeamID: teamID, }) if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } // The actual member record creation requires database access. // In P0 we print the robot config for the CLI layer to handle. fmt.Printf("Robot ready: %s (display_name: %s)\n", pkgID, robot.DisplayName) fmt.Println("Note: Member record must be created via Mission Control or database.") }, } func init() { AddCmd.Flags().StringP("version", "v", "latest", L("Package version or dist-tag")) AddCmd.Flags().StringP("team", "t", "", L("Team ID (required)")) AddCmd.PersistentFlags().StringVarP(&appPath, "app", "a", "", L("Application directory")) AddCmd.PersistentFlags().StringVarP(&envFile, "env", "e", "", L("Environment file")) } ================================================ FILE: cmd/robot/robot.go ================================================ package robot import ( "os" "path/filepath" "github.com/yaoapp/kun/exception" "github.com/yaoapp/yao/config" ) var appPath string var envFile string var langs = map[string]string{ "Install a robot package from the registry": "从注册中心安装 Robot 包", "Team ID (required)": "团队 ID (必填)", "Package version or dist-tag": "包版本或 dist-tag", "Application directory": "应用目录", "Environment file": "环境变量文件", } // L Language switch func L(words string) string { var lang = os.Getenv("YAO_LANG") if lang == "" { return words } if trans, has := langs[words]; has { return trans } return words } // Boot sets the configuration func Boot() { root := config.Conf.Root if appPath != "" { r, err := filepath.Abs(appPath) if err != nil { exception.New("Root error %s", 500, err.Error()).Throw() } root = r } if envFile != "" { config.Conf = config.LoadFromWithRoot(envFile, root) } else { config.Conf = config.LoadFromWithRoot(filepath.Join(root, ".env"), root) } config.ApplyMode() } ================================================ FILE: cmd/root.go ================================================ package cmd import ( "fmt" "os" "path/filepath" "github.com/spf13/cobra" "github.com/yaoapp/kun/exception" "github.com/yaoapp/yao/cmd/agent" "github.com/yaoapp/yao/cmd/mcp" "github.com/yaoapp/yao/cmd/robot" "github.com/yaoapp/yao/cmd/sui" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/pack" "github.com/yaoapp/yao/share" ) var appPath string var yazFile string var licenseKey string var lang = os.Getenv("YAO_LANG") var langs = map[string]string{ "Start Engine": "启动 YAO 应用引擎", "Get an application": "下载应用源码", "One or more arguments are not correct": "参数错误", "Application directory": "指定应用路径", "Environment file": "指定环境变量文件", "Help for yao": "显示命令帮助文档", "Show app configure": "显示应用配置信息", "Update database schema": "更新数据表结构", "Execute process": "运行处理器", "Show version": "显示当前版本号", "Development mode": "使用开发模式启动", "Enabled unstable features": "启用内测功能", "Fatal: %s": "失败: %s", "Service stopped": "服务已关闭", "API": " API接口", "API List": "API列表", "Root": "应用目录", "Data": "数据目录", "Frontend": "前台地址", "Dashboard": "管理后台", "Not enough arguments": "参数错误: 缺少参数", "Run: %s": "运行: %s", "Arguments: %s": "参数错误: %s", "%s Response": "%s 返回结果", "Update schema model: %s (%s) ": "更新表结构 model: %s (%s)", "Model name": "模型名称", "Initialize project": "项目初始化", "✨DONE✨": "✨完成✨", "NEXT:": "下一步:", "Listening": " 监听", "✨LISTENING✨": "✨服务正在运行✨", "✨STOPPED✨": "✨服务已停止✨", "SessionPort": "会话服务端口", "Force migrate": "强制更新数据表结构", "Migrate is not allowed on production mode.": "Migrate 不能再生产环境下使用", "Upgrade yao to latest version": "升级 yao 到最新版本", "🎉Current version is the latest🎉": "🎉当前版本是最新的🎉", "Do you want to update to %s ? (y/n): ": "是否更新到 %s ? (y/n): ", "Invalid input": "输入错误", "Canceled upgrade": "已取消更新", "Error occurred while updating binary: %s": "更新二进制文件时出错: %s", "🎉Successfully updated to version: %s🎉": "🎉成功更新到版本: %s🎉", "Print all version information": "显示详细版本信息", "SUI Template Engine": "SUI 模板引擎命令", "MCP commands": "MCP 包管理命令", "MCP package management commands": "MCP 包管理命令", "Robot commands": "Robot 包管理命令", "Robot package management commands": "Robot 包管理命令", } // L Language switch func L(words string) string { if lang == "" { return words } if trans, has := langs[words]; has { return trans } return words } // RootCmd export the rootCmd to support customized commands when use yao as lib var RootCmd = rootCmd var rootCmd = &cobra.Command{ Use: share.BUILDNAME, Short: "Yao App Engine", Long: `Yao App Engine`, CompletionOptions: cobra.CompletionOptions{ DisableDefaultCmd: true, }, Run: func(cmd *cobra.Command, args []string) { if len(args) > 0 { switch args[0] { case "fuxi": fuxi() return } fmt.Fprintln(os.Stderr, L("One or more arguments are not correct"), args) os.Exit(1) return } // No arguments - show help cmd.Help() }, } var suiCmd = &cobra.Command{ Use: "sui", Short: L("SUI Template Engine"), Long: L("SUI Template Engine"), CompletionOptions: cobra.CompletionOptions{ DisableDefaultCmd: true, }, Run: func(cmd *cobra.Command, args []string) { cmd.Help() }, } var agentCmd = &cobra.Command{ Use: "agent", Short: L("Agent commands"), Long: L("Agent commands for testing and management"), CompletionOptions: cobra.CompletionOptions{ DisableDefaultCmd: true, }, Run: func(cmd *cobra.Command, args []string) { cmd.Help() }, } var mcpCmd = &cobra.Command{ Use: "mcp", Short: L("MCP commands"), Long: L("MCP package management commands"), CompletionOptions: cobra.CompletionOptions{ DisableDefaultCmd: true, }, Run: func(cmd *cobra.Command, args []string) { cmd.Help() }, } var robotCmd = &cobra.Command{ Use: "robot", Short: L("Robot commands"), Long: L("Robot package management commands"), CompletionOptions: cobra.CompletionOptions{ DisableDefaultCmd: true, }, Run: func(cmd *cobra.Command, args []string) { cmd.Help() }, } // Command initialize func init() { // Sui suiCmd.AddCommand(sui.WatchCmd) suiCmd.AddCommand(sui.BuildCmd) suiCmd.AddCommand(sui.TransCmd) // Agent agentCmd.AddCommand(agent.TestCmd) agentCmd.AddCommand(agent.ExtractCmd) agentCmd.AddCommand(agent.AddCmd) agentCmd.AddCommand(agent.UpdateCmd) agentCmd.AddCommand(agent.PushCmd) agentCmd.AddCommand(agent.ForkCmd) // MCP mcpCmd.AddCommand(mcp.AddCmd) mcpCmd.AddCommand(mcp.UpdateCmd) mcpCmd.AddCommand(mcp.PushCmd) mcpCmd.AddCommand(mcp.ForkCmd) // Robot robotCmd.AddCommand(robot.AddCmd) rootCmd.AddCommand( versionCmd, migrateCmd, inspectCmd, startCmd, runCmd, loginCmd, logoutCmd, // getCmd, // dumpCmd, // restoreCmd, // socketCmd, // websocketCmd, // packCmd, suiCmd, agentCmd, mcpCmd, robotCmd, // upgradeCmd, ) // rootCmd.SetHelpCommand(helpCmd) rootCmd.PersistentFlags().StringVarP(&appPath, "app", "a", "", L("Application directory")) rootCmd.PersistentFlags().StringVarP(&yazFile, "file", "f", "", L("Application package file")) rootCmd.PersistentFlags().StringVarP(&licenseKey, "key", "k", "", L("Application license key")) } // Execute Command func Execute() { if err := rootCmd.Execute(); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } } // Boot Setting func Boot() { root := config.Conf.Root if appPath != "" { r, err := filepath.Abs(appPath) if err != nil { exception.New("Root error %s", 500, err.Error()).Throw() } root = r } config.Conf = config.LoadFrom(filepath.Join(root, ".env")) if share.BUILDIN { os.Setenv("YAO_APP_SOURCE", "::binary") config.Conf.AppSource = "::binary" } if yazFile != "" { os.Setenv("YAO_APP_SOURCE", yazFile) config.Conf.AppSource = yazFile } if config.Conf.Mode == "production" { config.Production() } else if config.Conf.Mode == "development" { config.Development() } // set license if licenseKey != "" { pack.SetCipher(licenseKey) } } ================================================ FILE: cmd/run.go ================================================ package cmd import ( "context" "fmt" "os" "path/filepath" "strings" "github.com/fatih/color" jsoniter "github.com/json-iterator/go" "github.com/spf13/cobra" "github.com/yaoapp/gou/helper" "github.com/yaoapp/gou/plugin" "github.com/yaoapp/gou/process" "github.com/yaoapp/kun/exception" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/engine" grpcclient "github.com/yaoapp/yao/grpc/client" ischedule "github.com/yaoapp/yao/schedule" "github.com/yaoapp/yao/share" itask "github.com/yaoapp/yao/task" ) var runSilent = false var runAuthPath string var runCmd = &cobra.Command{ Use: "run", Short: L("Execute process"), Long: L("Execute process"), Run: func(cmd *cobra.Command, args []string) { // Resolve credential: --auth flag > ~/.yao/credentials > nil (local mode) cred := resolveCredential() if cred != nil { runGRPC(cred, args) return } runLocal(args) }, } func init() { runCmd.PersistentFlags().BoolVarP(&runSilent, "silent", "s", false, L("Silent mode")) runCmd.PersistentFlags().StringVar(&runAuthPath, "auth", "", L("Path to credentials file")) } // resolveCredential loads credential from --auth flag or default path. func resolveCredential() *Credential { if runAuthPath != "" { cred, err := LoadCredentialFrom(runAuthPath) if err != nil { color.Red(" %s %s\n", L("Failed to load credentials:"), err) os.Exit(1) } return cred } cred, _ := LoadCredential() return cred } // runGRPC executes a process via the remote gRPC server. func runGRPC(cred *Credential, args []string) { if len(args) < 1 { if !runSilent { color.Red(L("Not enough arguments\n")) color.White(share.BUILDNAME + " help\n") } else { fmt.Print(L("Not enough arguments\n")) } os.Exit(1) } if cred.GRPCAddr == "" { color.Red(" %s\n", L("No gRPC address in credentials. Please re-login.")) os.Exit(1) } name := args[0] if !runSilent { color.Green(L("Run: %s gRPC: %s\n"), name, cred.GRPCAddr) } pargs := parseRunArgs(args) argsJSON, err := jsoniter.Marshal(pargs) if err != nil { color.Red(" %s %s\n", L("Arguments:"), err.Error()) os.Exit(1) } tm := grpcclient.NewTokenManager(cred.AccessToken, cred.RefreshToken, "") client, err := grpcclient.Dial(cred.GRPCAddr, tm) if err != nil { color.Red(" %s %s\n", L("gRPC connect failed:"), err.Error()) os.Exit(1) } defer client.Close() data, err := client.Run(context.Background(), name, argsJSON, 0) if err != nil { if !runSilent { color.Red(" %s %s\n", L("Process:"), err.Error()) } else { fmt.Printf("%s\n", err.Error()) } os.Exit(1) } if !runSilent { color.White("--------------------------------------\n") color.White(L("%s Response\n"), name) color.White("--------------------------------------\n") var res interface{} if jsoniter.Unmarshal(data, &res) == nil { helper.Dump(res) } else { fmt.Printf("%s\n", data) } color.White("--------------------------------------\n") fmt.Printf("\033[32m✨DONE✨\033[0m \033[90mgRPC: %s\033[0m\n", cred.GRPCAddr) } else { fmt.Printf("%s\n", data) } } // runLocal executes a process locally (existing behavior). func runLocal(args []string) { defer share.SessionStop() defer plugin.KillAll() defer func() { err := exception.Catch(recover()) if err != nil { if !runSilent { color.Red(L("Fatal: %s\n"), err.Error()) return } fmt.Printf("%s\n", err.Error()) } }() // Auto-detect app root if not specified if appPath == "" { cwd, err := os.Getwd() if err == nil { if root, err := findAppRootFromPath(cwd); err == nil { appPath = root } } } Boot() // Set Runtime Mode config.Conf.Runtime.Mode = "standard" cfg := config.Conf cfg.Session.IsCLI = true if len(args) < 1 { if !runSilent { color.Red(L("Not enough arguments\n")) color.White(share.BUILDNAME + " help\n") return } fmt.Print(L("Not enough arguments\n")) return } loadWarnings, err := engine.Load(cfg, engine.LoadOption{Action: "run"}) if err != nil { if !runSilent { color.Red(L("Engine: %s\n"), err.Error()) return } fmt.Printf("%s\n", err.Error()) return } name := args[0] if !runSilent { color.Green(L("Run: %s\n"), name) } pargs := parseRunArgs(args) // Start Tasks itask.Start() defer itask.Stop() // Start Schedules ischedule.Start() defer ischedule.Stop() p := process.NewWithContext(context.Background(), name, pargs...) res, err := p.Exec() if err != nil { if !runSilent { color.Red(L("Process: %s\n"), fmt.Sprintf("%s", strings.TrimPrefix(err.Error(), "Exception|404:"))) return } fmt.Printf("%s\n", err.Error()) return } if !runSilent { if len(loadWarnings) > 0 { fmt.Println(color.YellowString("---------------------------------")) fmt.Println(color.YellowString(L("Warnings"))) fmt.Println(color.YellowString("---------------------------------")) for _, warning := range loadWarnings { fmt.Println(color.YellowString("[%s] %s", warning.Widget, warning.Error)) } fmt.Printf("\n") } color.White("--------------------------------------\n") color.White(L("%s Response\n"), name) color.White("--------------------------------------\n") helper.Dump(res) color.White("--------------------------------------\n") color.Green(L("✨DONE✨\n")) return } // Silent mode output switch res.(type) { case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: fmt.Printf("%v\n", res) return case string, []byte: fmt.Printf("%s\n", res) return default: txt, err := jsoniter.Marshal(res) if err != nil { fmt.Printf("%s\n", err.Error()) } fmt.Printf("%s\n", txt) } } // parseRunArgs parses the CLI arguments into process arguments, handling :: prefixed JSON. func parseRunArgs(args []string) []interface{} { pargs := []interface{}{} for i, arg := range args { if i == 0 { continue } if strings.HasPrefix(arg, "::") { raw := strings.TrimPrefix(arg, "::") var v interface{} err := jsoniter.Unmarshal([]byte(raw), &v) if err != nil { color.Red(L("Arguments: %s\n"), err.Error()) return pargs } pargs = append(pargs, v) if !runSilent { color.White("args[%d]: %s\n", i-1, raw) } } else if strings.HasPrefix(arg, "\\::") { cleaned := "::" + strings.TrimPrefix(arg, "\\::") pargs = append(pargs, cleaned) if !runSilent { color.White("args[%d]: %s\n", i-1, cleaned) } } else { pargs = append(pargs, arg) if !runSilent { color.White("args[%d]: %s\n", i-1, arg) } } } return pargs } // findAppRootFromPath finds the Yao application root directory by looking for app.yao // It traverses up from the given path until it finds app.yao or reaches the filesystem root func findAppRootFromPath(startPath string) (string, error) { // Get absolute path absPath, err := filepath.Abs(startPath) if err != nil { return "", fmt.Errorf("failed to get absolute path: %w", err) } // If it's a file, start from its directory info, err := os.Stat(absPath) if err != nil { return "", fmt.Errorf("path not found: %s", absPath) } var dir string if info.IsDir() { dir = absPath } else { dir = filepath.Dir(absPath) } // Traverse up to find app.yao for { // Check for app.yao, app.json, or app.jsonc for _, appFile := range []string{"app.yao", "app.json", "app.jsonc"} { appFilePath := filepath.Join(dir, appFile) if _, err := os.Stat(appFilePath); err == nil { return dir, nil } } // Move to parent directory parent := filepath.Dir(dir) if parent == dir { // Reached root, no app.yao found break } dir = parent } return "", fmt.Errorf("no app.yao found in path hierarchy of %s", startPath) } ================================================ FILE: cmd/socket.go ================================================ package cmd import ( "fmt" "strings" "github.com/fatih/color" jsoniter "github.com/json-iterator/go" "github.com/spf13/cobra" "github.com/yaoapp/gou/plugin" "github.com/yaoapp/gou/socket" "github.com/yaoapp/kun/exception" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/engine" "github.com/yaoapp/yao/share" ) var socketCmd = &cobra.Command{ Use: "socket", Short: L("Open a socket connection"), Long: L("Open a socket connection"), Run: func(cmd *cobra.Command, args []string) { defer share.SessionStop() defer plugin.KillAll() defer func() { err := exception.Catch(recover()) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) } }() Boot() cfg := config.Conf cfg.Session.IsCLI = true engine.Load(cfg, engine.LoadOption{Action: "socket"}) if len(args) < 1 { fmt.Println(color.RedString(L("Not enough arguments"))) fmt.Println(color.WhiteString(share.BUILDNAME + " help")) return } name := args[0] pargs := []interface{}{} for i, arg := range args { if i == 0 { continue } // 解析参数 if strings.HasPrefix(arg, "::") { arg := strings.TrimPrefix(arg, "::") var v interface{} err := jsoniter.Unmarshal([]byte(arg), &v) if err != nil { fmt.Println(color.RedString(L("Arguments: %s"), err.Error())) return } pargs = append(pargs, v) fmt.Println(color.WhiteString("args[%d]: %s", i-1, arg)) } else if strings.HasPrefix(arg, "\\::") { arg := "::" + strings.TrimPrefix(arg, "\\::") pargs = append(pargs, arg) fmt.Println(color.WhiteString("args[%d]: %s", i-1, arg)) } else { pargs = append(pargs, arg) fmt.Println(color.WhiteString("args[%d]: %s", i-1, arg)) } } socket, has := socket.Sockets[name] if !has { fmt.Println(color.RedString(L("%s not exists!"), name)) return } if socket.Mode != "client" { fmt.Println(color.RedString(L("%s not supported yet!"), socket.Mode)) return } host := socket.Host port := socket.Port argsLen := len(pargs) if argsLen > 0 { if inputHost, ok := pargs[0].(string); ok { host = inputHost } } if argsLen > 1 { if inputPort, ok := pargs[1].(string); ok { port = inputPort } } fmt.Println(color.WhiteString("\n---------------------------------")) fmt.Println(color.WhiteString(socket.Name)) fmt.Println(color.WhiteString("---------------------------------")) fmt.Println(color.GreenString("Mode: %s", socket.Mode)) fmt.Println(color.GreenString("Host: %s://%s", socket.Protocol, host)) fmt.Println(color.GreenString("Port: %s", port)) fmt.Println(color.WhiteString("--------------------------------------")) err := socket.Open(pargs...) if err != nil { fmt.Println(color.RedString(L("%s"), err.Error())) return } }, } ================================================ FILE: cmd/start.go ================================================ package cmd import ( "fmt" "net" "os" "os/signal" "path/filepath" "strconv" "strings" "syscall" "github.com/fatih/color" "github.com/spf13/cobra" "github.com/yaoapp/gou/api" "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/fs" "github.com/yaoapp/gou/mcp" "github.com/yaoapp/gou/plugin" "github.com/yaoapp/gou/schedule" "github.com/yaoapp/gou/store" "github.com/yaoapp/gou/task" "github.com/yaoapp/gou/websocket" "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/engine" yaogrpc "github.com/yaoapp/yao/grpc" _ "github.com/yaoapp/yao/grpc/auth" sandboxhandler "github.com/yaoapp/yao/grpc/sandbox" "github.com/yaoapp/yao/openapi" sandbox "github.com/yaoapp/yao/sandbox/v2" ischedule "github.com/yaoapp/yao/schedule" "github.com/yaoapp/yao/service" "github.com/yaoapp/yao/setup" "github.com/yaoapp/yao/share" itask "github.com/yaoapp/yao/task" ) var startDebug = false var startDisableWatching = false var startCmd = &cobra.Command{ Use: "start", Short: L("Start Engine"), Long: L("Start Engine"), Run: func(cmd *cobra.Command, args []string) { defer share.SessionStop() defer plugin.KillAll() // recive interrupt signal interrupt := make(chan os.Signal, 1) signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM, syscall.SIGQUIT) Boot() // Setup isnew := false // Check if current directory is a Yao app root if !setup.IsYaoApp(config.Conf.Root) { // Check if we're inside a Yao app (subdirectory) if setup.InYaoApp(config.Conf.Root) { fmt.Println(color.RedString(L("Please run the command in the root directory of project"))) os.Exit(1) } // Not in a Yao app, check if empty to install if setup.IsEmptyDir(config.Conf.Root) { // Install the init app if err := install(); err != nil { fmt.Println(color.RedString(L("Install: %s"), err.Error())) os.Exit(1) } isnew = true } else { // Directory not empty and no app.yao fmt.Println(color.RedString("The app.yao file is missing")) os.Exit(1) } } // force debug if startDebug { config.Development() } // load the application engine loadWarnings, err := engine.Load(config.Conf, engine.LoadOption{ Action: "start", }) if err != nil { fmt.Println(color.RedString(L("Load: %s"), err.Error())) os.Exit(1) } port := fmt.Sprintf(":%d", config.Conf.Port) if port == ":80" { port = "" } // variables for the service fs, err := fs.Get("system") if err != nil { fmt.Println(color.RedString(L("FileSystem: %s"), err.Error())) os.Exit(1) } mode := config.Conf.Mode host := config.Conf.Host dataRoot := fs.Root() runtimeMode := config.Conf.Runtime.Mode fmt.Println(color.WhiteString("\n--------------------------------------------")) fmt.Println( color.WhiteString(strings.TrimPrefix(share.App.Name, "::")), color.WhiteString(share.App.Version), mode, ) fmt.Println(color.WhiteString("--------------------------------------------")) if !share.BUILDIN { root, _ := filepath.Abs(config.Conf.Root) fmt.Println(color.WhiteString(L("Root")), color.GreenString(" %s", root)) } fmt.Println(color.WhiteString(L("Runtime")), color.GreenString(" %s", runtimeMode)) fmt.Println(color.WhiteString(L("Data")), color.GreenString(" %s", dataRoot)) fmt.Println(color.WhiteString(L("Listening")), color.GreenString(" %s:%d", config.Conf.Host, config.Conf.Port)) // print the messages under the development mode if mode == "development" { printApis(false) printTasks(false) printSchedules(false) printConnectors(false) printStores(false) printMCPs(false) } root, _ := adminRoot() endpoints := []setup.Endpoint{{URL: fmt.Sprintf("http://%s%s", "127.0.0.1", port), Interface: "localhost"}} switch host { case "0.0.0.0": if values, err := setup.Endpoints(config.Conf); err == nil { endpoints = append(endpoints, values...) } case "127.0.0.1": // Localhost only default: matched := false endpoints = []setup.Endpoint{} if values, err := setup.Endpoints(config.Conf); err == nil { for _, value := range values { if strings.HasPrefix(value.URL, fmt.Sprintf("http://%s:", host)) { endpoints = append(endpoints, value) matched = true } } } if !matched { fmt.Println(color.RedString(L("Host %s not found"), host)) os.Exit(1) } } // Print welcome message for the new application if isnew { printWelcome() } // Start Tasks itask.Start() defer itask.Stop() // Start Schedules ischedule.Start() defer ischedule.Stop() // Pre-flight: detect port conflicts before attempting to start servers. if occupied, proc := portOccupied(config.Conf.Host, config.Conf.Port); occupied { fmt.Println(color.RedString(L("Fatal: HTTP port %d is already in use%s"), config.Conf.Port, proc)) return } if strings.ToLower(config.Conf.GRPC.Enabled) != "off" { for _, h := range yaogrpc.ExpandHosts(config.Conf.GRPC.Host) { if occupied, proc := portOccupied(h, config.Conf.GRPC.Port); occupied { fmt.Println(color.RedString(L("Fatal: gRPC port %d is already in use%s"), config.Conf.GRPC.Port, proc)) return } } } // Wire gRPC heartbeat → sandbox Manager so container liveness is tracked. yaogrpc.SetSandboxOnBeat(func(data *sandboxhandler.HeartbeatData) string { sandbox.M().Heartbeat(data.SandboxID, true, int(data.RunningProcs)) return "ok" }) // Start all servers (gRPC + HTTP) as a single unit. // Start() blocks until HTTP port is bound (READY) or returns error. svc, err := service.Start(config.Conf, service.ServerHooks{ Start: yaogrpc.StartServer, Stop: yaogrpc.Stop, Addrs: yaogrpc.Addr, }) if err != nil { fmt.Println(color.RedString(L("Fatal: %s"), err.Error())) return } // Access Points (printed after servers are up so addresses are known) fmt.Println(color.WhiteString("\n---------------------------------")) fmt.Println(color.WhiteString(L("Access Points"))) fmt.Println(color.WhiteString("---------------------------------")) if grpcAddrs := svc.HookAddrs(); len(grpcAddrs) > 0 { fmt.Println(color.CyanString("\ngRPC")) fmt.Println(color.WhiteString("--------------------------")) for _, addr := range grpcAddrs { fmt.Println(color.WhiteString(L("Server")), color.GreenString(" %s", addr)) } } apiRoot := "/api" if openapi.Server != nil { apiRoot = openapi.Server.Config.BaseURL } for _, endpoint := range endpoints { fmt.Println(color.CyanString("\n%s", endpoint.Interface)) fmt.Println(color.WhiteString("--------------------------")) fmt.Println(color.WhiteString(L("Website")), color.GreenString(" %s", endpoint.URL)) fmt.Println(color.WhiteString(L("Dashboard")), color.GreenString(" %s/%s/auth/entry", endpoint.URL, strings.Trim(root, "/"))) if openapi.Server != nil { fmt.Println(color.WhiteString(L("OpenAPI")), color.GreenString(" %s%s", endpoint.URL, apiRoot)) } else { fmt.Println(color.WhiteString(L("API")), color.GreenString(" %s%s", endpoint.URL, apiRoot)) } } fmt.Println("") // Start watching watchDone := make(chan uint8, 1) if mode == "development" && !startDisableWatching { go svc.Watch(watchDone) } // Print the messages under the production mode if mode == "production" { printApis(true) printTasks(true) printSchedules(true) printConnectors(true) printStores(true) printMCPs(true) } // Print the warnings if len(loadWarnings) > 0 { fmt.Println(color.YellowString("---------------------------------")) fmt.Println(color.YellowString(L("Warnings"))) fmt.Println(color.YellowString("---------------------------------")) for _, warning := range loadWarnings { fmt.Println(color.YellowString("[%s] %s", warning.Widget, warning.Error)) } fmt.Printf("\n") } fmt.Println(color.GreenString(L("Server is up and running..."))) fmt.Println(color.GreenString("Ctrl+C to stop")) for { select { case <-interrupt: fmt.Println(color.WhiteString("\nShutting down...")) svc.Stop() fmt.Println(color.GreenString(L("✨Exited successfully!"))) watchDone <- 1 return } } }, } func install() error { // Copy the app source files from the binary err := setup.Install(config.Conf.Root) if err != nil { return err } // Reload the application engine Boot() // load the application engine loadWarnings, err := engine.Load(config.Conf, engine.LoadOption{Action: "start"}) if err != nil { return err } // Print the warnings if len(loadWarnings) > 0 { for _, warning := range loadWarnings { fmt.Println(color.YellowString("[%s] %s", warning.Widget, warning.Error)) } fmt.Printf("\n\n") } err = setup.Initialize(config.Conf.Root, config.Conf) if err != nil { return err } return nil } func adminRoot() (string, int) { adminRoot := "/yao/" if share.App.AdminRoot != "" { root := strings.TrimPrefix(share.App.AdminRoot, "/") root = strings.TrimSuffix(root, "/") adminRoot = fmt.Sprintf("/%s/", root) } adminRootLen := len(adminRoot) return adminRoot, adminRootLen } func printWelcome() { fmt.Println(color.CyanString("\n---------------------------------")) fmt.Println(color.CyanString(L("🎉 Welcome to Yao 🎉 "))) fmt.Println(color.CyanString("---------------------------------")) fmt.Println(color.WhiteString("📚 Documentation: "), color.CyanString("https://yaoapps.com/docs")) fmt.Println(color.WhiteString("🏡 Join Yao Community: "), color.CyanString("https://yaoapps.com/community")) fmt.Println(color.WhiteString("🤖 Build Your Digital Workforce:"), color.CyanString("https://yaoagents.com")) fmt.Println("") } func printConnectors(silent bool) { if len(connector.Connectors) == 0 { return } if silent { for name := range connector.Connectors { log.Info("[Connector] %s loaded", name) } return } fmt.Println(color.WhiteString("\n---------------------------------")) fmt.Println(color.WhiteString(L("Connectors List (%d)"), len(connector.Connectors))) fmt.Println(color.WhiteString("---------------------------------")) for name := range connector.Connectors { fmt.Print(color.CyanString("[Connector]")) fmt.Print(color.WhiteString(" %s\t loaded\n", name)) } } func printStores(silent bool) { if len(store.Pools) == 0 { return } if silent { for name := range store.Pools { log.Info("[Store] %s loaded", name) } return } fmt.Println(color.WhiteString("\n---------------------------------")) fmt.Println(color.WhiteString(L("Stores List (%d)"), len(store.Pools))) fmt.Println(color.WhiteString("---------------------------------")) for name := range store.Pools { fmt.Print(color.CyanString("[Store]")) fmt.Print(color.WhiteString(" %s\t loaded\n", name)) } } func printSchedules(silent bool) { if len(schedule.Schedules) == 0 { return } if silent { for name, sch := range schedule.Schedules { process := fmt.Sprintf("Process: %s", sch.Process) if sch.TaskName != "" { process = fmt.Sprintf("Task: %s", sch.TaskName) } log.Info("[Schedule] %s %s %s %s", sch.Schedule, name, sch.Name, process) } return } fmt.Println(color.WhiteString("\n---------------------------------")) fmt.Println(color.WhiteString(L("Schedules List (%d)"), len(schedule.Schedules))) fmt.Println(color.WhiteString("---------------------------------")) for name, sch := range schedule.Schedules { process := fmt.Sprintf("Process: %s", sch.Process) if sch.TaskName != "" { process = fmt.Sprintf("Task: %s", sch.TaskName) } fmt.Print(color.CyanString("[Schedule] %s %s", sch.Schedule, name)) fmt.Print(color.WhiteString("\t%s\t%s\n", sch.Name, process)) } } func printTasks(silent bool) { if len(task.Tasks) == 0 { return } if silent { for _, t := range task.Tasks { log.Info("[Task] %s workers:%d", t.Option.Name, t.Option.WorkerNums) } return } fmt.Println(color.WhiteString("\n---------------------------------")) fmt.Println(color.WhiteString(L("Tasks List (%d)"), len(task.Tasks))) fmt.Println(color.WhiteString("---------------------------------")) for _, t := range task.Tasks { fmt.Print(color.CyanString("[Task] %s", t.Option.Name)) fmt.Print(color.WhiteString("\t workers: %d\n", t.Option.WorkerNums)) } } func printApis(silent bool) { // Determine API root based on OpenAPI mode apiRoot := "/api" if openapi.Server != nil { apiRoot = openapi.Server.Config.BaseURL } if silent { for _, api := range api.APIs { if len(api.HTTP.Paths) <= 0 { continue } log.Info("[API] %s(%d)", api.ID, len(api.HTTP.Paths)) for _, p := range api.HTTP.Paths { log.Info("%s %s %s", p.Method, filepath.Join(apiRoot, api.HTTP.Group, p.Path), p.Process) } } for name, upgrader := range websocket.Upgraders { // WebSocket log.Info("[WebSocket] GET /websocket/%s process:%s", name, upgrader.Process) } return } // Skip detailed API list when OpenAPI is enabled if openapi.Server != nil { return } fmt.Println(color.WhiteString("\n---------------------------------")) fmt.Println(color.WhiteString(L("APIs List"))) fmt.Println(color.WhiteString("---------------------------------")) for _, api := range api.APIs { // API info if len(api.HTTP.Paths) <= 0 { continue } deprecated := "" if strings.HasPrefix(api.ID, "xiang.") { deprecated = " WILL BE DEPRECATED" } fmt.Printf("%s%s\n", color.CyanString("\n%s(%d)", api.ID, len(api.HTTP.Paths)), color.RedString(deprecated)) for _, p := range api.HTTP.Paths { fmt.Println( colorMehtod(p.Method), color.WhiteString(filepath.Join(apiRoot, api.HTTP.Group, p.Path)), "\tprocess:", p.Process) } } if len(websocket.Upgraders) > 0 { fmt.Print(color.CyanString(fmt.Sprintf("\n%s(%d)\n", "WebSocket", len(websocket.Upgraders)))) for name, upgrader := range websocket.Upgraders { // WebSocket fmt.Println( colorMehtod("GET"), color.WhiteString(filepath.Join("/websocket", name)), "\tprocess:", upgrader.Process) } } } func printMCPs(silent bool) { clients := mcp.ListClients() if len(clients) == 0 { return } if silent { for _, clientID := range clients { log.Info("[MCP] %s loaded", clientID) } return } // Separate agent MCPs from standard MCPs by Type field agentClients := []string{} standardClients := []string{} for _, clientID := range clients { client, err := mcp.Select(clientID) if err != nil { standardClients = append(standardClients, clientID) continue } info := client.Info() if info != nil && info.Type == "agent" { agentClients = append(agentClients, clientID) } else { standardClients = append(standardClients, clientID) } } fmt.Println(color.WhiteString("\n---------------------------------")) fmt.Println(color.WhiteString(L("MCP Clients List (%d)"), len(clients))) fmt.Println(color.WhiteString("---------------------------------")) if len(standardClients) > 0 { fmt.Println(color.WhiteString("\n%s (%d)", "Standard MCPs", len(standardClients))) fmt.Println(color.WhiteString("--------------------------")) for _, clientID := range standardClients { client, err := mcp.Select(clientID) if err != nil { fmt.Print(color.CyanString("[MCP] %s", clientID)) fmt.Print(color.WhiteString("\tloaded\n")) continue } info := client.Info() transport := "unknown" label := clientID if info != nil { if info.Transport != "" { transport = string(info.Transport) } if info.Label != "" { label = info.Label } } fmt.Print(color.CyanString("[MCP] %s", label)) fmt.Print(color.WhiteString("\t%s\tid: %s", transport, clientID)) // Only show tools count for process transport if transport == "process" { toolsCount := 0 mapping, err := mcp.GetClientMapping(clientID) if err == nil && mapping.Tools != nil { toolsCount = len(mapping.Tools) } fmt.Print(color.WhiteString("\ttools: %d", toolsCount)) } fmt.Print("\n") } } if len(agentClients) > 0 { fmt.Println(color.WhiteString("\n%s (%d)", "Agent MCPs", len(agentClients))) fmt.Println(color.WhiteString("--------------------------")) for _, clientID := range agentClients { client, err := mcp.Select(clientID) if err != nil { fmt.Print(color.CyanString("[MCP] %s", clientID)) fmt.Print(color.WhiteString("\tloaded\n")) continue } info := client.Info() transport := "unknown" label := clientID if info != nil { if info.Transport != "" { transport = string(info.Transport) } if info.Label != "" { label = info.Label } } fmt.Print(color.CyanString("[MCP] %s", label)) fmt.Print(color.WhiteString("\t%s\tid: %s", transport, clientID)) // Only show tools count for process transport if transport == "process" { toolsCount := 0 mapping, err := mcp.GetClientMapping(clientID) if err == nil && mapping.Tools != nil { toolsCount = len(mapping.Tools) } fmt.Print(color.WhiteString("\ttools: %d", toolsCount)) } fmt.Print("\n") } } } func colorMehtod(method string) string { method = strings.ToUpper(method) switch method { case "GET": return color.GreenString("GET") case "POST": return color.YellowString("POST") default: return color.WhiteString(method) } } // portOccupied probes whether host:port is already bound. // Returns (true, " (pid XXXX)") when occupied, (false, "") otherwise. func portOccupied(host string, port int) (bool, string) { addr := net.JoinHostPort(host, strconv.Itoa(port)) ln, err := net.Listen("tcp", addr) if err != nil { return true, fmt.Sprintf(" (%s)", err.Error()) } ln.Close() return false, "" } func init() { startCmd.PersistentFlags().BoolVarP(&startDebug, "debug", "", false, L("Development mode")) startCmd.PersistentFlags().BoolVarP(&startDisableWatching, "disable-watching", "", false, L("Disable watching")) } ================================================ FILE: cmd/sui/build.go ================================================ package sui import ( "fmt" "os" "path/filepath" "strings" "time" "github.com/fatih/color" "github.com/google/uuid" jsoniter "github.com/json-iterator/go" "github.com/spf13/cobra" "github.com/yaoapp/gou/session" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/engine" "github.com/yaoapp/yao/sui/core" ) // BuildCmd command var BuildCmd = &cobra.Command{ Use: "build", Short: L("Build the template"), Long: L("Build the template"), Run: func(cmd *cobra.Command, args []string) { if len(args) < 1 { fmt.Fprintln(os.Stderr, color.RedString(L("yao sui build [template] [data]"))) return } Boot() cfg := config.Conf loadWarnings, err := engine.Load(cfg, engine.LoadOption{Action: "sui.build"}) if err != nil { fmt.Fprintln(os.Stderr, color.RedString(err.Error())) return } id := args[0] template := "default" if len(args) >= 2 { template = args[1] } // For agent SUI, use "agent" as default template if id == "agent" && template == "default" { template = "agent" } var sessionData map[string]interface{} err = jsoniter.UnmarshalFromString(strings.TrimPrefix(data, "::"), &sessionData) if err != nil { fmt.Fprintln(os.Stderr, color.RedString(err.Error())) return } sid := uuid.New().String() if sessionData != nil && len(sessionData) > 0 { session.Global().ID(sid).SetMany(sessionData) } sui, has := core.SUIs[id] if !has { fmt.Fprint(os.Stderr, color.RedString("the sui "+id+" does not exist")) return } sui.WithSid(sid) tmpl, err := sui.GetTemplate(template) if err != nil { fmt.Fprintln(os.Stderr, color.RedString(err.Error())) return } // - publicRoot, err := sui.PublicRootWithSid(sid) assetRoot := filepath.Join(publicRoot, "assets") if err != nil { fmt.Fprintln(os.Stderr, color.RedString(err.Error())) return } fmt.Println(color.WhiteString("-----------------------")) fmt.Println(color.WhiteString("Public Root: /public%s", publicRoot)) fmt.Println(color.WhiteString(" Template: %s", tmpl.GetRoot())) fmt.Println(color.WhiteString(" Session: %s", strings.TrimLeft(data, "::"))) fmt.Println(color.WhiteString("-----------------------")) // Timecost start := time.Now() minify := true mode := "production" if debug { minify = false mode = "development" } warnings, err := tmpl.Build(&core.BuildOption{SSR: true, AssetRoot: assetRoot, ExecScripts: true, ScriptMinify: minify, StyleMinify: minify}) if err != nil { fmt.Fprintln(os.Stderr, color.RedString(err.Error())) return } end := time.Now() timecost := end.Sub(start).Truncate(time.Millisecond) if debug { fmt.Println(color.YellowString("Build succeeded for %s in %s", mode, timecost)) return } if len(loadWarnings) > 0 { for _, warning := range loadWarnings { fmt.Println(color.YellowString("[%s] %s", warning.Widget, warning.Error)) } } if len(warnings) > 0 { for _, warning := range warnings { fmt.Println(color.YellowString("Warning: %s", warning)) } } fmt.Println(color.GreenString("Build succeeded for %s in %s", mode, timecost)) }, } ================================================ FILE: cmd/sui/sui.go ================================================ package sui var data string var locales string var debug bool func init() { WatchCmd.PersistentFlags().StringVarP(&data, "data", "d", "::{}", L("Session Data")) BuildCmd.PersistentFlags().StringVarP(&data, "data", "d", "::{}", L("Session Data")) BuildCmd.PersistentFlags().BoolVarP(&debug, "debug", "D", false, L("Debug mode")) TransCmd.PersistentFlags().StringVarP(&data, "data", "d", "::{}", L("Session Data")) TransCmd.PersistentFlags().BoolVarP(&debug, "debug", "D", false, L("Debug mode")) TransCmd.PersistentFlags().StringVarP(&locales, "locales", "l", "", L("Locales, separated by commas")) } ================================================ FILE: cmd/sui/trans.go ================================================ package sui import ( "fmt" "os" "path/filepath" "strings" "time" "github.com/fatih/color" "github.com/google/uuid" jsoniter "github.com/json-iterator/go" "github.com/spf13/cobra" "github.com/yaoapp/gou/session" "github.com/yaoapp/yao/config" "github.com/yaoapp/yao/engine" "github.com/yaoapp/yao/sui/core" "golang.org/x/text/language" ) // TransCmd command var TransCmd = &cobra.Command{ Use: "trans", Short: L("Translate the template"), Long: L("Translate the template"), Run: func(cmd *cobra.Command, args []string) { if len(args) < 2 { fmt.Fprintln(os.Stderr, color.RedString(L("yao sui trans