Repository: foreveryh/mentis Branch: main Commit: 7859b536b98b Files: 240 Total size: 1.4 MB Directory structure: gitextract_zde6lsy3/ ├── .gitignore ├── README.md ├── __init__.py ├── api/ │ ├── __init__.py │ ├── agent/ │ │ ├── __init__.py │ │ └── loader.py │ ├── server.py │ └── utils.py ├── core/ │ ├── __init__.py │ ├── a2a/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── agent_task_manager.py │ │ ├── client/ │ │ │ ├── __init__.py │ │ │ ├── card_resolver.py │ │ │ └── client.py │ │ ├── config.json │ │ ├── server/ │ │ │ ├── __init__.py │ │ │ ├── server.py │ │ │ ├── task_manager.py │ │ │ └── utils.py │ │ ├── types.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── in_memory_cache.py │ │ └── push_notification_auth.py │ ├── agents/ │ │ ├── __init__.py │ │ ├── base/ │ │ │ ├── base_agent.py │ │ │ ├── create_react_agent_wrapper.py │ │ │ └── react_agent.py │ │ ├── react_based_supervisor/ │ │ │ ├── __init__.py │ │ │ ├── agent_name.py │ │ │ ├── handoff.py │ │ │ ├── planning_handler.py │ │ │ ├── simple_planning_tool.py │ │ │ ├── state_schema.py │ │ │ └── supervisor.py │ │ ├── react_supervisor_agent.py │ │ ├── sb_supervisor_agent.py │ │ ├── state_based_supervisor/ │ │ │ ├── __init__.py │ │ │ ├── agent_name.py │ │ │ ├── evaluate_result_node.py │ │ │ ├── handoff.py │ │ │ ├── planner_node.py │ │ │ ├── planning_handler.py │ │ │ ├── prompt.py │ │ │ ├── state_schema.py │ │ │ ├── supervisor_graph.py │ │ │ └── supervisor_node.py │ │ └── sub_agents/ │ │ ├── __init__.py │ │ ├── coder_agent.py │ │ ├── data_analyst_agent.py │ │ ├── designer_agent.py │ │ ├── reporter_agent.py │ │ └── research_agent.py │ ├── llm/ │ │ ├── llm_manager.py │ │ └── model_config.py │ ├── mcp/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── client.py │ │ ├── config_loader.py │ │ ├── mcp_server_config.json │ │ ├── run_server.py │ │ ├── server.py │ │ └── test/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── minimal_fastmcp_test.py │ │ └── test_minimal_client.py │ ├── tools/ │ │ ├── __init__.py │ │ ├── e2b_tool.py │ │ ├── firecrawl_tool.py │ │ ├── registry.py │ │ └── replicate_flux_tool.py │ └── utils/ │ ├── agent_utils.py │ └── timezone.py ├── examples/ │ ├── 01_supervisor_test.py │ ├── 02_supervisor_agent_test.py │ ├── 03_tavily_tools_test.py │ ├── 04_react_agent_test.py │ ├── 05_react_agent_user_input.py │ ├── 06_web_extraction_tools_test.py │ ├── 07_web_extraction_with_filesystem.py │ ├── 08_react_agent_tool_registry_test.py │ ├── 09_e2b_code_interpreter_test.py │ ├── 10_financial_data_analysis.py │ ├── 11_e2b_sandbox_test.py │ ├── 12_planning_supervisor_test.py │ ├── 13_multi_agent_roles_test.py │ ├── 14_mcp_client_fetch_test.py │ ├── 15_mcp_agent_test.py │ ├── 16_google_a2a/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── agent_task_manager_test.py │ │ ├── client_example.py │ │ ├── currency_agent_test.py │ │ ├── currency_agent_test_README.md │ │ └── langgraph_integration.py │ ├── TODO_computer_tool_demo.py │ ├── __init__.py │ ├── state_based_supervisor_examples/ │ │ ├── 01_simple.py │ │ ├── 02_tavily.py │ │ └── 03_multi_agents.py │ └── web_agents/ │ ├── README.md │ ├── README_SPEC.md │ ├── __init__.py │ ├── research_assistant/ │ │ ├── README.md │ │ ├── __init__.py │ │ └── graph.py │ └── weather_agent/ │ ├── README.md │ └── __init__.py ├── instructions/ │ ├── 00.Langgraph 和 React Agent.md │ ├── 01.supervisor_pattern.md │ ├── 02.supervisor_pattern_agent.md │ ├── 03.tavily_search_integration.md │ ├── 04.react_agent.md │ ├── 05.react_agent_user_input.md │ ├── 06.web_extraction_tools.md │ ├── 07.web_extraction_with_filesystem.md │ ├── 08.react_agent_tool_registry.md │ └── 09.e2b_sandbox_integration.md ├── log_analyzer.py ├── pyproject.toml ├── requirements.txt ├── setup.py ├── super_agents/ │ ├── __init__.py │ ├── browser_use/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── agent/ │ │ │ ├── __init__.py │ │ │ ├── graph.py │ │ │ ├── nodes.py │ │ │ ├── prompts.py │ │ │ ├── schemas.py │ │ │ ├── state.py │ │ │ └── tools.py │ │ ├── agent.py │ │ ├── browser/ │ │ │ ├── browser.py │ │ │ ├── detector.py │ │ │ ├── findVisibleInteractiveElements.js │ │ │ ├── models.py │ │ │ └── utils.py │ │ ├── llm.py │ │ └── main.py │ ├── customized_deep_research/ │ │ ├── PRD_README.md │ │ ├── README.md │ │ ├── __init__.py │ │ ├── main.py │ │ └── reason_graph/ │ │ ├── __init__.py │ │ ├── graph.py │ │ ├── nodes.py │ │ ├── prompt.py │ │ ├── schemas.py │ │ ├── state.py │ │ └── tools.py │ └── deep_research/ │ ├── README.md │ ├── __init__.py │ ├── a2a_adapter/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── client_example.py │ │ ├── deep_research_task_manager.py │ │ ├── dr_terminal_output.md │ │ ├── run_server.py │ │ └── setup.py │ ├── main.py │ ├── output/ │ │ ├── research_report_analyze_smartvalue_co_ltds_9417t_core_business_key_productsservices_eg_government_cloud_solutions_mo_20250418_125137.md │ │ ├── research_report_id_like_a_thorough_analysis_of_li_auto_stock_including_summary_company_overview_key_metrics_performa_20250327_121800.md │ │ └── research_report_id_like_a_thorough_analysis_of_xpev_stock_including_summary_company_overview_key_metrics_performance_20250327_105350.md │ ├── reason_graph/ │ │ ├── __init__.py │ │ ├── graph.py │ │ ├── nodes.py │ │ ├── prompt.py │ │ ├── schemas.py │ │ ├── state.py │ │ └── tools.py │ └── tests/ │ ├── __init__.py │ └── test_graph.py ├── web/ │ ├── .gitignore │ ├── README.md │ ├── app/ │ │ ├── api/ │ │ │ └── agent/ │ │ │ └── route.ts │ │ ├── chat/ │ │ │ ├── [id]/ │ │ │ │ ├── agent-types.ts │ │ │ │ ├── components/ │ │ │ │ │ ├── chatbot-node.tsx │ │ │ │ │ ├── checkpoint-card.tsx │ │ │ │ │ ├── node-card.tsx │ │ │ │ │ ├── reminder.tsx │ │ │ │ │ ├── research/ │ │ │ │ │ │ ├── report-preview.tsx │ │ │ │ │ │ ├── research-node.tsx │ │ │ │ │ │ ├── research-status.tsx │ │ │ │ │ │ └── search-results.tsx │ │ │ │ │ └── weather/ │ │ │ │ │ ├── cloudy.tsx │ │ │ │ │ ├── rainy.tsx │ │ │ │ │ ├── snowy.tsx │ │ │ │ │ ├── sunny.tsx │ │ │ │ │ └── weather-node.tsx │ │ │ │ └── page.tsx │ │ │ └── page.tsx │ │ ├── deep-research/ │ │ │ ├── [id]/ │ │ │ │ └── page.tsx │ │ │ └── page.tsx │ │ ├── globals.css │ │ ├── layout.tsx │ │ └── page.tsx │ ├── components/ │ │ ├── app-sidebar.tsx │ │ ├── theme-provider.tsx │ │ ├── theme-switcher.tsx │ │ └── ui/ │ │ ├── badge.tsx │ │ ├── button.tsx │ │ ├── card.tsx │ │ ├── checkbox.tsx │ │ ├── dialog.tsx │ │ ├── input.tsx │ │ ├── popover.tsx │ │ ├── progress.tsx │ │ ├── separator.tsx │ │ ├── sheet.tsx │ │ ├── sidebar.tsx │ │ ├── skeleton.tsx │ │ ├── textarea.tsx │ │ └── tooltip.tsx │ ├── components.json │ ├── eslint.config.mjs │ ├── hooks/ │ │ ├── use-mobile.tsx │ │ └── useLangGraphAgent/ │ │ ├── actions.ts │ │ ├── api.ts │ │ ├── ascii-tree.ts │ │ ├── types.ts │ │ └── useLangGraphAgent.tsx │ ├── next.config.ts │ ├── package.json │ ├── postcss.config.mjs │ ├── stores/ │ │ └── chat-store.tsx │ ├── tailwind.config.ts │ └── tsconfig.json └── web_for_a2a/ ├── .gitignore ├── Instruction.md ├── README.md ├── app/ │ ├── api/ │ │ └── a2a/ │ │ └── route.ts │ ├── deepresearch/ │ │ └── page.tsx │ ├── globals.css │ ├── layout.tsx │ └── page.tsx ├── package.json ├── postcss.config.js ├── tailwind.config.js └── tsconfig.json ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Python __pycache__/ *.py[cod] *$py.class *.so .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg # Virtual Environment venv/ env/ ENV/ # IDE .idea/ .vscode/ *.swp *.swo # OS specific .DS_Store Thumbs.db # LangSmith .langchain.db .langsmith/ # Logs *.log # Env .env # output exampels/logs/ exampels/output/ examples/output/sandbox_test ================================================ FILE: README.md ================================================ # Mentis - Agent Development Kit [![Python Version](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) ## 概述 (Overview) Mentis 是一个基于 LangGraph 构建的、可扩展的多 Agent ADK(Agent Development Kit)。它的核心是一个**状态驱动的规划型 Supervisor Agent**,负责理解用户复杂请求、制定执行计划,并智能地协调一组具有不同专业能力的子 Agent (Specialist Agents) 来共同完成任务。 此框架旨在实现复杂任务的自动化处理,通过 Agent 间的协作提供比单一 Agent 更强大、更灵活的问题解决能力。 ## 核心特性 (Core Features) * **Multi-Agent 架构**: 采用中心化的 Supervisor 协调多个专门的子 Agent (如 Research, Coder, Reporter, Designer, Data Analyst)。 * **State-Based Planning**: 引入独立的 `Planner` 节点负责初始规划,`Supervisor` 专注于基于计划状态的执行和调度,`Evaluator` 节点负责评估子 Agent 结果并更新状态。计划状态通过 LangGraph 持久化(需配置 Checkpointer)。 * **模块化 Agent 设计**: 基于 `BaseAgent` 和 `ReactAgent` 构建,易于添加或修改具有不同能力的子 Agent。 * **工具注册与管理**: 通过 `core/tools/registry.py` 实现工具的集中注册、分类和动态加载。 * **可配置 LLM**: 支持通过 `LLMManager` (或环境变量) 配置和切换不同的 LLM Provider (OpenAI, DeepSeek, XAI Grok via compatible endpoint) 和模型。 * **持久化支持**: 基于 LangGraph 的 Checkpointer 机制,可以实现对话状态和计划的持久化。 * **清晰的执行流程**: Planner -> Supervisor -> (Handoff -> Agent -> Evaluator -> Supervisor 循环) -> 最终输出/Reporter。 * **A2A 协议支持**: 实现了 Google 的 Agent-to-Agent (A2A) 协议,使 Mentis Agents 能够与其他支持 A2A 协议的系统进行互操作。 ## 架构概览 (Architecture Overview) 1. **用户请求 (Input)**: 用户通过入口点 (`main.py` 或 API) 提交任务请求。 2. **规划节点 (Planner Node)**: 分析请求,生成一个包含任务步骤、建议 Agent 的初始计划 (`Plan`),并更新到图状态 (`PlanningAgentState`)。 3. **主管节点 (Supervisor Node)**: 接收带有计划的状态,根据计划状态和消息历史决定下一步行动: * 启动新任务 (标记 'in_progress')。 * 委派 'in_progress' 的任务给合适的子 Agent (通过 Handoff 工具)。 * 等待子 Agent 完成。 * 判断计划是否最终完成。 * 决定最终输出方式(自己总结或调用 Reporter)。 4. **切换执行器 (Handoff Executor)**: 处理 Supervisor 发出的 `transfer_to_` 工具调用,并将控制权和状态传递给目标子 Agent。 5. **子 Agent 节点 (Specialist Agent Nodes)**: 继承自 `ReactAgent` 或 `BaseAgent`,执行具体的任务(研究、编码、生成报告/图像、数据分析),可能调用其自身的工具。 6. **评估节点 (Evaluate Result Node)**: 接收子 Agent 的执行结果,进行确定性评估(成功/失败),更新对应任务的状态和 Plan 的整体状态。 7. **循环与结束**: 流程在 Evaluator -> Supervisor 之间循环,直到 Supervisor 判断 Plan 完成,然后路由到 `END` 或 `ReporterAgent`。 ## 快速开始 (Getting Started) ### 1. 环境设置 (Prerequisites) * Python 3.11+ * 使用 `pip` 或 `uv` 等工具管理依赖。 ### 2. 安装依赖 (Installation) 在项目根目录运行: 建议使用 uv 管理 ```bash uv venv source .venv/bin/activate uv sync ``` ```bash # pip install -r requirements.txt # 或者 uv pip install -r requirements.txt ``` (requirements.txt 我没维护,请确保 `requirements.txt` 文件包含了所有必要的库,如 `langchain`, `langgraph`, `langchain-openai`, `e2b` (如果使用 E2B), `replicate` (如果使用 Replicate), `tavily-python`, `exa-py`, `python-dotenv`, `anyio`, `tiktoken` 等)。 ### 3. 配置环境 (Configuration) * 复制 `.env.example` 文件为 `.env`。 * 在 `.env` 文件中填入您所需的 API Keys/Tokens: * `OPENAI_API_KEY` (如果使用 OpenAI 模型) * `DEEPSEEK_API_KEY` (如果使用 DeepSeek 模型) * `XAI_API_KEY` (如果使用 XAI Grok,并确认 Base URL) * `REPLICATE_API_TOKEN` (如果使用 Replicate 工具) * `E2B_API_KEY` (如果使用 E2B Code Interpreter,推荐!) * `TAVILY_API_KEY` (如果使用 Tavily 搜索,推荐!) * `EXA_API_KEY` (如果使用 Exa 搜索) * `LANGCHAIN_TRACING_V2="true"` (强烈推荐,用于 LangSmith 调试) * `LANGCHAIN_API_KEY="ls_..."` (您的 LangSmith Key) * `LANGCHAIN_PROJECT="Your_Project_Name"` (您在 LangSmith 上的项目名) * **LLM 配置**: * 如果您使用了 `LLMManager`(如示例所示),请检查并配置其读取的模型配置文件(例如 `config/models.yaml`,路径可能不同)。 * 如果您在 `tools.py` 中直接根据环境变量初始化 LLM,请确保设置了对应的环境变量,如 `LLM_PROVIDER`, `LLM_MODEL_NAME`, `LLM_BASE_URL` (用于兼容 API)。 * **工具配置**: 确保 `core/tools/__init__.py` 或 `registry.py` 中的工具预注册逻辑能够正确找到并初始化您需要的工具。 ### 4. 运行示例 (Running Examples) 项目包含示例脚本以演示框架的使用: ```bash # 从项目根目录 (mentis/) 运行 python examples/state_based_supervisor_examples/03_multi_agents.py ``` 脚本会提示您输入初始请求。您可以进行简单尝试: * `"What is the capital of France?"` (简单测试) * `"Write a short, four-line poem about spring."` (测试 Reporter) * `"Generate an image of a cat wearing a top hat, oil painting style."` (测试 Designer) * `"Write a Python function to calculate factorial and run it for 5."` (测试 Coder) ## 项目结构 (Project Structure) ``` mentis/ ├── api/ # (可选) API 服务相关代码 ├── core/ # 核心框架代码 │ ├── a2a/ # A2A 协议的客户端和服务器实现 │ ├── agents/ # Agent 定义 (base, react, supervisor, sub-agents) │ │ ├── base/ │ │ ├── state_based_supervisor/ # Supervisor 相关 (graph, node, planner, evaluator) │ │ ├── sub_agents/ # 具体子 Agent 实现 (research, coder, etc.) │ │ └── sb_supervisor_agent.py # SupervisorAgent 类定义 │ ├── llm/ # (可选) LLM 管理或配置 │ ├── tools/ # 工具定义和注册表 (registry, e2b, replicate, etc.) │ └── utils/ # 通用辅助函数 ├── examples/ # 示例和测试脚本 │ └── state_based_supervisor_examples/ │ └── 03_multi_agents.py # 我们使用的测试脚本 ├── super_agents/ # 独立功能型 Agent 实现 │ └── deep_research/ # DeepResearch Agent 实现 │ └── a2a_adapter/ # DeepResearch 的 A2A 协议适配器 ├── web/ # (可选) Web 客户端代码 ├── web_for_a2a/ # 基于 A2A 协议的 Web 界面 ├── .env.example # 环境变量示例 ├── requirements.txt # Python 依赖 └── README.md # 本文件 ``` ## Super Agents (独立功能型 Agent) 除了由 Supervisor 协调的、专注于单一技能的 Specialist Agents (如 Coder, Researcher) 之外,本框架也支持构建和集成更复杂的 **"Super Agents"**。 Super Agent 可以理解为一个**独立的、具有端到端能力、能够完成一个相对完整且复杂任务的 Agent 图**。它可以包含自己的规划、执行、甚至内部协调逻辑。 这些 Super Agents 既可以**独立运行**以完成特定的大型任务,也可以被更高层的协调者(例如我们的 Supervisor Agent)**视为一种强大的“能力”或“工具”**来调用,以处理其复杂计划中的某个步骤。 ### DeepResearch Agent (第一个实例) https://github.com/user-attachments/assets/2a685709-5be0-43a3-9e2d-934ef5fa3315 `DeepResearch Agent` 是我们在此框架理念下实现的第一个 Super Agent 实例(其早期版本是我们开发此 Multi-Agent 框架的基础)。 * **核心功能**: 旨在针对用户给定的**任意主题**,自动化地执行一个**深度研究**流程。 * **内部工作流**: 它包含自己的一套完整的内部步骤,大致如下: 1. **研究规划 (Plan Research)**: 分析主题,生成初步的搜索查询和分析点。 2. **多源搜索 (Multi-Source Search)**: 调用网页搜索 (Tavily)、学术搜索 (Exa) 等工具获取信息。 3. **(可选) 分析执行 (Perform Analysis)**: 对搜索结果进行初步分析(如情感、SWOT 等)。 4. **差距分析 (Gap Analysis)**: 评估已有信息,识别知识空白和局限性。 5. **(可选) 补充搜索 (Gap Filling)**: 针对知识空白进行额外的、更具针对性的搜索。 6. **最终综合 (Final Synthesis)**: 整合所有信息,提炼关键发现和不确定性。 7. **报告生成 (Report Generation)**: 将综合结果和上下文信息,撰写成一份详细的、带引用的 Markdown 研究报告。 * **当前状态**: 该 Agent 的核心逻辑和节点已基本实现,并且现在支持 A2A 协议和专用 Web 界面。 #### A2A 协议支持 我们为 DeepResearch Agent 实现了完整的 A2A 协议适配器,使其能够: * 作为标准的 A2A 服务被发现和调用 * 通过 `tasks/send` 和 `tasks/sendSubscribe` 端点接收研究任务 * 提供实时的流式研究进度更新 * 返回结构化的研究结果 * 支持推送通知机制 这使得 DeepResearch Agent 可以轻松地与其他支持 A2A 协议的系统(如 Google Assistant)集成,或者被自定义的前端应用调用。 #### 专用 Web 界面 https://github.com/user-attachments/assets/640365c7-839b-4765-b9ac-ee0ac961ceb8 我们还开发了一个基于 Next.js 的现代 Web 界面,专门用于与 DeepResearch A2A 服务交互: * 提供直观的用户界面,用于输入研究主题和启动研究任务 * 实时显示研究进度和中间更新(通过 Server-Sent Events) * 美观地展示最终生成的研究报告 * 演示了如何在前端应用中使用浏览器原生 API 处理 A2A 流式响应 **如何体验 DeepResearch Agent:** 1. **独立运行模式**: * 确保环境配置: 确认您的 `.env` 文件中包含了所需的所有 API Keys(例如 `OPENAI_API_KEY`/`DEEPSEEK_API_KEY`, `TAVILY_API_KEY`, `EXA_API_KEY`)。 * 运行脚本: 在项目根目录执行: ```bash python super_agents/deep_research/main.py ``` * 输入主题并查看结果: 生成的报告通常会保存在 `output/` 文件夹中。 2. **A2A 服务模式**: * 启动 A2A 服务器: ```bash cd super_agents/deep_research/a2a_adapter python run_server.py ``` * 服务器将在默认端口(通常是 8000)启动,并提供符合 A2A 规范的 API 端点。 3. **Web 界面模式**: * 确保 A2A 服务器正在运行 * 启动 Web 界面: ```bash cd web_for_a2a npm install npm run dev ``` * 在浏览器中访问 http://localhost:3000/deepresearch 使用图形界面与 DeepResearch Agent 交互。 ## 未来工作 (Future Work / Contributing) * 完善子 Agent 的工具集和 Prompt。 * 增强 Evaluator Node 的评估逻辑。 * 添加更复杂的任务依赖处理。 * 优化长对话历史的管理。 * 集成持久化 Checkpointer (如 SQLite, Redis)。 * 欢迎提出 Issue 或 Pull Request! * 有问题也可以添加我的微信 brown🩷cony999 ## 许可证 (License) This project is licensed under the MIT License - see the LICENSE file for details. ================================================ FILE: __init__.py ================================================ # Project package initialization ================================================ FILE: api/__init__.py ================================================ ================================================ FILE: api/agent/__init__.py ================================================ ================================================ FILE: api/agent/loader.py ================================================ # Agent Loader Module # This module is responsible for loading agents from the web_agents directory import importlib import os import sys from typing import Dict, Optional, Any, List from langgraph.graph import StateGraph from langgraph.graph.graph import CompiledGraph # Add this import # Try to import deep_research_app try: # Adjust this import path based on your project structure from super_agents.deep_research.reason_graph.graph import web_app as deep_research_app except ImportError: print("Warning: Failed to import deep_research_app. DeepResearchAgent will be unavailable.") deep_research_app = None # Add examples directory to Python path to allow importing web_agents examples_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'examples') if examples_path not in sys.path: sys.path.append(examples_path) def list_available_agents() -> Dict[str, str]: """List all available agents in the web_agents directory Returns: Dict[str, str]: A dictionary mapping agent names to their descriptions """ agents = {} web_agents_dir = os.path.join(examples_path, 'web_agents') # Check if web_agents directory exists if not os.path.exists(web_agents_dir) or not os.path.isdir(web_agents_dir): pass # Continue with empty agents dict else: # Iterate through subdirectories in web_agents for item in os.listdir(web_agents_dir): agent_dir = os.path.join(web_agents_dir, item) # Skip non-directories and special directories if not os.path.isdir(agent_dir) or item.startswith('__') or item.startswith('.'): continue # Check if the directory contains an __init__.py file with get_graph function init_file = os.path.join(agent_dir, '__init__.py') if os.path.exists(init_file): # Try to get description from README.md readme_file = os.path.join(agent_dir, 'README.md') description = item # Default description is the directory name if os.path.exists(readme_file): try: with open(readme_file, 'r', encoding='utf-8') as f: first_line = f.readline().strip() if first_line.startswith('# '): description = first_line[2:] except Exception: pass agents[item] = description # Add deep_research to available agents if it's imported successfully if deep_research_app is not None: agents["deep_research"] = "Deep Research Agent for in-depth topic exploration" return agents def load_agent(agent_name: str) -> Optional[CompiledGraph]: """Load an agent from the web_agents directory or special agents Args: agent_name (str): The name of the agent to load Returns: Optional[CompiledGraph]: The compiled graph for the agent, or None if the agent could not be loaded """ # Special case for deep_research agent if agent_name == "deep_research": if deep_research_app: return deep_research_app else: print(f"ERROR: DeepResearchAgent requested but not available.") return None # Standard agents from web_agents directory try: # Import the agent module module = importlib.import_module(f'web_agents.{agent_name}') # Check if the module has a get_graph function if hasattr(module, 'get_graph'): # Call the get_graph function to get the compiled graph return module.get_graph() else: print(f"Error: Agent '{agent_name}' does not have a get_graph function") return None except ImportError as e: print(f"Error importing agent '{agent_name}': {e}") return None except Exception as e: print(f"Error loading agent '{agent_name}': {e}") return None # Default agent to use if none is specified DEFAULT_AGENT = 'research_assistant' # DEFAULT_AGENT = 'weather_agent' def get_default_agent() -> Optional[CompiledGraph]: """Get the default agent Returns: Optional[CompiledGraph]: The compiled graph for the default agent, or None if it could not be loaded """ return load_agent(DEFAULT_AGENT) ================================================ FILE: api/server.py ================================================ import uvicorn from langgraph.types import Command, Interrupt from fastapi import FastAPI, Request, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from sse_starlette.sse import EventSourceResponse from typing import AsyncGenerator, Dict, Optional, Union, Any from api.utils import message_chunk_event, interrupt_event, custom_event, checkpoint_event, format_state_snapshot, stream_update_event import asyncio import traceback import json from langchain_core.messages import HumanMessage from langchain_core.runnables import RunnableConfig # Import the agent loader from api.agent.loader import load_agent, list_available_agents, get_default_agent # Load the default agent graph = get_default_agent() # Track active connections active_connections: Dict[str, asyncio.Event] = {} app = FastAPI( title="LangGraph API", description="API for LangGraph interactions", version="0.1.0" ) # Configure CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, replace with specific origins allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/agents") async def list_agents(): """Endpoint returning a list of available agents.""" return list_available_agents() @app.get("/state") async def state(thread_id: str | None = None, agent: Optional[str] = Query(None)): """Endpoint returning current graph state.""" if not thread_id: raise HTTPException(status_code=400, detail="thread_id is required") # Load the specified agent if provided current_graph = load_agent(agent) if agent else graph if not current_graph: raise HTTPException(status_code=404, detail=f"Agent '{agent}' not found") config: RunnableConfig = {"configurable": {"thread_id": thread_id}} state = await current_graph.aget_state(config) return format_state_snapshot(state) @app.get("/history") async def history(thread_id: str | None = None, agent: Optional[str] = Query(None)): """Endpoint returning complete state history. Used for restoring graph.""" if not thread_id: raise HTTPException(status_code=400, detail="thread_id is required") # Load the specified agent if provided current_graph = load_agent(agent) if agent else graph if not current_graph: raise HTTPException(status_code=404, detail=f"Agent '{agent}' not found") config: RunnableConfig = {"configurable": {"thread_id": thread_id}} records = [] async for state in current_graph.aget_state_history(config): records.append(format_state_snapshot(state)) return records @app.post("/agent/stop") async def stop_agent(request: Request): """Endpoint for stopping the running agent.""" body = await request.json() thread_id = body.get("thread_id") if not thread_id: raise HTTPException(status_code=400, detail="thread_id is required") if thread_id in active_connections: active_connections[thread_id].set() return {"status": "stopped", "thread_id": thread_id} raise HTTPException(status_code=404, detail="Thread is not running") @app.post("/agent") async def agent(request: Request): """Endpoint for running the agent.""" body = await request.json() request_type = body.get("type") if not request_type: raise HTTPException(status_code=400, detail="type is required") thread_id = body.get("thread_id") if not thread_id: raise HTTPException(status_code=400, detail="thread_id is required") # Get the agent name if provided agent_name = body.get("agent") # Load the specified agent if provided current_graph = load_agent(agent_name) if agent_name else graph if not current_graph: raise HTTPException(status_code=404, detail=f"Agent '{agent_name or 'default'}' not found") stop_event = asyncio.Event() active_connections[thread_id] = stop_event config: RunnableConfig = {"configurable": {"thread_id": thread_id}} initial_graph_state: Dict[str, Any] = {} input_for_astream: Optional[Union[Dict, Command]] = None # input for astream # Get initial state or messages from frontend initial_state_input = body.get("state", {"messages": []}) if not isinstance(initial_state_input, dict): raise HTTPException(status_code=400, detail="state must be a dictionary") if agent_name == "deep_research": # --- Prepare state for DeepResearch Agent --- print("Preparing state for DeepResearchAgent...") # Extract topic from the first message in state['messages'] first_message_content = "" try: # Ensure initial_state_input['messages'] is a list and not empty if isinstance(initial_state_input.get('messages'), list) and initial_state_input['messages']: # Assume the first message's content is the topic first_message_content = initial_state_input['messages'][0]['content'] else: # Try to get topic from other fields in state (alternative) first_message_content = initial_state_input.get('topic', '') except Exception as e: print(f"Warning: Could not extract topic from initial state input: {e}") if not first_message_content or not isinstance(first_message_content, str): raise HTTPException(status_code=400, detail="A valid 'topic' string is required for deep_research agent, expected in state.messages[0].content or state.topic") # Build the ResearchState needed by DeepResearch Agent (at least topic and depth) initial_graph_state = { "topic": first_message_content, "depth": initial_state_input.get("depth", "advanced"), # Optional: allow frontend to specify depth "messages": [], # DeepResearch manages its own message history "stream_updates": [], # Initialize stream_updates # Initialize other ResearchState fields to None or default values "plan": None, "research_plan": None, "search_results": [], "gap_analysis": None, "final_synthesis": None, "final_report_markdown": None, } print(f"Initial ResearchState: {{'topic': '{initial_graph_state['topic']}', 'depth': '{initial_graph_state['depth']}', ...}}") # DeepResearch Agent's astream input is the complete initial state if request_type == "run": input_for_astream = initial_graph_state elif request_type == "resume": # DeepResearch Agent might not support or need different resume approach print("Warning: 'resume' might not be fully supported for DeepResearchAgent yet.") # Assume resume Command can be understood by the graph input_for_astream = Command(resume=body.get("resume")) config["configurable"]["checkpoint_id"] = body.get("resume") # Resume usually needs checkpoint ID else: # Fork, Replay typically only need config config_from_request = body.get("config") if not config_from_request: raise HTTPException(status_code=400, detail="config is required for fork/replay") config = config_from_request # Use complete config provided in the request input_for_astream = None else: # For Supervisor or other Agents (assume using PlanningAgentState) print("Preparing state for Supervisor/Other Agent...") # --- Prepare PlanningAgentState --- # Ensure messages list contains correct BaseMessage objects (or let BaseAgent preprocess) initial_messages = initial_state_input.get("messages", []) initial_graph_state = { "messages": initial_messages, "plan": None, # Planner node will create it "error": None # Add other fields needed by PlanningAgentState and set to None or default values } # --- Set astream input (logic similar to before) --- if request_type == "run": # For PlanningAgentState, initial input typically only contains messages input_for_astream = {"messages": initial_messages} elif request_type == "resume": resume_val = body.get("resume") if not resume_val: raise HTTPException(status_code=400, detail="resume value is required") input_for_astream = Command(resume=resume_val) # Ensure config includes checkpoint_id for resuming if "configurable" not in config: config["configurable"] = {} config["configurable"]["checkpoint_id"] = resume_val elif request_type == "fork": config_from_request = body.get("config") if not config_from_request: raise HTTPException(status_code=400, detail="config is required for fork") config = config_from_request # Fork uses complete config provided # Fork typically starts from specified checkpoint, no extra state dict input needed input_for_astream = None elif request_type == "replay": config_from_request = body.get("config") if not config_from_request: raise HTTPException(status_code=400, detail="config is required for replay") config = config_from_request input_for_astream = None else: raise HTTPException(status_code=400, detail="invalid request type") # Ensure config always has thread_id (important for all agents) if "configurable" not in config: config["configurable"] = {} config["configurable"]["thread_id"] = thread_id # --- State and Input preparation complete --- async def generate_events() -> AsyncGenerator[dict, None]: try: # 设置recursion_limit为100,解决深度研究时的递归限制问题 if agent_name == "deep_research" and "recursion_limit" not in config: config["recursion_limit"] = 100 async for chunk in current_graph.astream( input_for_astream, # Use prepared input config, # Use prepared config stream_mode=["debug", "messages", "updates", "custom"], ): if stop_event.is_set(): break chunk_type, chunk_data = chunk if chunk_type == "debug": # type can be checkpoint, task, task_result if isinstance(chunk_data, dict) and "type" in chunk_data: debug_type = chunk_data["type"] if debug_type == "checkpoint": yield checkpoint_event(chunk_data) elif debug_type == "task_result": interrupts = chunk_data["payload"].get( "interrupts", []) if interrupts and len(interrupts) > 0: yield interrupt_event(interrupts) elif chunk_type == "messages": # 确保chunk_data是一个包含至少两个元素的列表/元组,并且第二个元素是一个包含langgraph_node的字典 if isinstance(chunk_data, (list, tuple)) and len(chunk_data) > 1 and isinstance(chunk_data[1], dict) and "langgraph_node" in chunk_data[1]: yield message_chunk_event(chunk_data[1]["langgraph_node"], chunk_data[0]) else: print(f"Warning: Unexpected messages chunk_data format: {chunk_data}") # 尝试使用安全的默认值 node_name = chunk_data[1].get("langgraph_node", "unknown") if isinstance(chunk_data, (list, tuple)) and len(chunk_data) > 1 and isinstance(chunk_data[1], dict) else "unknown" message = chunk_data[0] if isinstance(chunk_data, (list, tuple)) and len(chunk_data) > 0 else None if message is not None: yield message_chunk_event(node_name, message) elif chunk_type == "custom": # Check if this is a StreamUpdate if isinstance(chunk_data, dict) and all(k in chunk_data for k in ['id', 'type', 'status', 'title']): yield stream_update_event(chunk_data) else: yield custom_event(chunk_data) elif chunk_type == "updates": # Handle state update events (e.g., real-time Plan updates) pass # Currently ignore updates events, rely on checkpoint or custom # --- Loop ended --- yield {"event": "end", "data": "{}"} # Send an end event to frontend except Exception as e: print(f"Error during agent execution stream: {e}") traceback.print_exc() # Send error event to frontend yield {"event": "error", "data": json.dumps({"message": f"Agent execution error: {e}"})} finally: if thread_id in active_connections: del active_connections[thread_id] return EventSourceResponse(generate_events()) def main(): uvicorn.run("api.server:app", host="0.0.0.0", port=8000, reload=True) if __name__ == "__main__": import sys import os # 将项目根目录添加到 Python 路径中 sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) main() ================================================ FILE: api/utils.py ================================================ import json from typing import Dict, Any, List, Optional from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, ToolMessage from langgraph.types import StateSnapshot def checkpoint_event(value): """Create a checkpoint event for the client.""" def format_values(values: dict): formatted_values = values.copy() if "messages" in formatted_values: formatted_values["messages"] = [ { "type": msg.get("type") if isinstance(msg, dict) else msg.type, "content": msg.get("content") if isinstance(msg, dict) else msg.content, "id": msg.get("id") if isinstance(msg, dict) else msg.id, "tool_calls": msg.get("tool_calls") if isinstance(msg, dict) else (msg.tool_calls if hasattr(msg, 'tool_calls') else None) } for msg in formatted_values["messages"] ] return formatted_values def format_writes(writes: dict): if writes is None: return None formatted_writes = {} for key, value in writes.items(): if isinstance(value, dict): formatted_writes[key] = format_values(value) elif isinstance(value, list): formatted_writes[key] = [format_values(item) if isinstance( item, dict) else item for item in value] else: formatted_writes[key] = value return formatted_writes configurable = value["payload"]["config"]["configurable"] data = { "next": value["payload"]["next"], "values": format_values(value["payload"]["values"]), "config": { "configurable": { "checkpoint_id": configurable["checkpoint_id"], "checkpoint_ns": configurable["checkpoint_ns"], "thread_id": configurable["thread_id"] } }, "metadata": { "source": value["payload"]["metadata"]["source"], "step": value["payload"]["metadata"]["step"], "writes": format_writes(value["payload"]["metadata"]["writes"]), "parents": value["payload"]["metadata"]["parents"] } } return { "event": "checkpoint", "data": json.dumps(data) } def message_chunk_event(node_name, message_chunk): """Create a message chunk event for the client.""" def format_messages(value): """Format message chunk into a serializable dictionary. This is needed because the message class is not serializable. """ return { "content": value.content, "id": value.id, "tool_calls": value.tool_calls if hasattr(value, 'tool_calls') else None, "tool_call_chunks": value.tool_call_chunks if hasattr(value, 'tool_call_chunks') else None } return { "event": "message_chunk", "data": json.dumps({ "node_name": node_name, "message_chunk": format_messages(message_chunk) }) } def interrupt_event(interrupts): """Create an interrupt event for the client.""" formatted_interrupts = [{"value": interrupt["value"]} for interrupt in interrupts] return { "event": "interrupt", "data": json.dumps(formatted_interrupts) } def custom_event(value): """Create a custom event for the client.""" return { "event": "custom", "data": json.dumps(value) } def format_state_snapshot(snapshot: StateSnapshot): interrupts = [] for task in snapshot.tasks: for interrupt in task.interrupts: interrupts.append({"value": interrupt.value}) return { "values": snapshot.values, "next": snapshot.next, "config": snapshot.config, "interrupts": interrupts, "parent_config": snapshot.parent_config, "metadata": snapshot.metadata } def stream_update_event(data: dict): """为 DeepResearch Agent 的 StreamUpdateData 创建一个 stream_update 事件。 Args: data: 从 add_stream_update 产生的、符合 StreamUpdateData 结构的字典。 Returns: 符合 SSE EventSourceResponse 格式的字典。 """ if not isinstance(data, dict): # 如果传入的不是字典,返回一个错误事件 return { "event": "error", "data": json.dumps({"message": "Internal server error: Invalid stream update data type."}) } return { "event": "stream_update", "data": json.dumps(data, default=str) } ================================================ FILE: core/__init__.py ================================================ # Core module initialization ================================================ FILE: core/a2a/README.md ================================================ # Mentis A2A (Agent2Agent) 协议集成 本目录 (`core/a2a/`) 包含用于实现 Agent2Agent (A2A) 协议的客户端和服务器实现,使 Mentis Agents 能够与其他支持 A2A 协议的代理系统进行通信和协作。 ## 背景 A2A 是由 Google 发起的开放标准,旨在使不同框架(如 LangGraph、CrewAI、Google ADK、Genkit)或不同供应商构建的 AI 代理能够发现彼此的能力,协商交互模式(文本、文件、数据等),并在任务上进行协作。 ## 核心组件 ### 1. A2A 客户端 (`A2AClient`) `A2AClient` 类(位于 `client/client.py`)提供了与支持 A2A 协议的服务器进行交互的功能: * **代理发现:** 支持通过 `.well-known/agent.json` 端点自动发现代理能力(Agent Card)。 * **任务管理:** 提供发送、获取和取消任务的方法。 * **推送通知:** 支持设置和获取任务的推送通知配置。 * **流式响应:** 支持通过流式API接收任务执行的实时更新。 * **异步架构:** 基于 `asyncio` 和 `httpx` 构建,适合异步应用。 ### 2. A2A 服务器 (`A2AServer`) `A2AServer` 类(位于 `server/server.py`)允许将现有的 Mentis Agent 暴露为支持 A2A 协议的服务: * **基于 Starlette:** 使用 Starlette 框架提供 HTTP 和 SSE 端点。 * **任务处理:** 支持任务的创建、执行和状态跟踪。 * **流式更新:** 通过 Server-Sent Events (SSE) 提供任务执行的实时更新。 * **Agent Card:** 通过 `.well-known/agent.json` 端点公开代理能力。 ### 3. 辅助工具 #### 推送通知认证 (`PushNotificationAuth`) `PushNotificationAuth` 类(位于 `utils/push_notification_auth.py`)提供了安全的推送通知机制: * **发送方认证 (`PushNotificationSenderAuth`):** - 生成和管理 JWT 密钥对 - 验证推送通知 URL - 签名并发送推送通知 - 提供 JWKS 端点供接收方获取公钥 * **接收方认证 (`PushNotificationReceiverAuth`):** - 从 JWKS URL 加载公钥 - 验证接收到的推送通知的完整性和时效性 - 防止重放攻击 #### 内存缓存 (`InMemoryCache`) `InMemoryCache` 类(位于 `utils/in_memory_cache.py`)提供了线程安全的内存缓存实现: * **单例模式:** 确保应用中只有一个缓存实例 * **TTL 支持:** 支持设置缓存项的过期时间 * **线程安全:** 使用锁机制确保并发安全 ## 数据类型 A2A 协议定义了几个关键数据类型(位于 `types.py`): * **AgentCard:** 描述代理的元数据,包括名称、描述、URL、能力和技能。 * **Task:** 表示代理执行的任务,包含状态、内容和产物。 * **Part:** 内容的一部分,可以是文本、文件或数据。 * **Artifact:** 代理产生的产物,如结果、生成的文件等。 * **TaskState:** 任务状态枚举(已提交、进行中、需要输入、已完成、已取消、失败)。 * **PushNotificationConfig:** 推送通知配置,包含回调URL和认证信息。 ## 如何使用 ### 1. 创建和使用 A2A 客户端 ```python import asyncio from common.types import AgentCard from core.a2a.client.client import A2AClient async def main(): # 方式1:直接指定URL创建客户端 async with A2AClient(url="http://localhost:8000/a2a") as client: # 发送任务 response = await client.send_task({"text": "请帮我研究人工智能"}) task_id = response["result"]["taskId"] # 获取任务结果 task_response = await client.get_task({"id": task_id}) # 设置推送通知 await client.set_task_callback({ "taskId": task_id, "callbackUrl": "https://your-callback-url.com/webhook" }) # 方式2:通过Agent Card创建客户端 agent_card = AgentCard(name="Example Agent", url="http://localhost:8000/a2a") async with A2AClient(agent_card=agent_card) as client: # 使用流式API接收实时更新 async for update in client.send_task_streaming({"text": "分析最新的AI趋势"}): print(update) # 运行 asyncio.run(main()) ``` ### 2. 创建 A2A 服务器 ```python from core.a2a.server.server import A2AServer from core.a2a.server.task_manager import InMemoryTaskManager from common.types import AgentCard # 创建Agent卡片 agent_card = AgentCard( name="My Agent", description="一个示例代理", url="http://localhost:5000" ) # 创建任务管理器 task_manager = InMemoryTaskManager() # 创建服务器 server = A2AServer( host="0.0.0.0", port=5000, endpoint="/", agent_card=agent_card, task_manager=task_manager ) # 启动服务器 server.start() ``` ### 3. 配置推送通知 #### 发送方配置 ```python from core.a2a.utils.push_notification_auth import PushNotificationSenderAuth # 创建发送方认证 sender_auth = PushNotificationSenderAuth() # 生成密钥对 sender_auth.generate_jwk() # 添加JWKS端点到你的服务器 app.add_route("/.well-known/jwks.json", sender_auth.handle_jwks_endpoint) # 验证接收方URL is_valid = await sender_auth.verify_push_notification_url("https://receiver-url.com/webhook") # 发送推送通知 if is_valid: await sender_auth.send_push_notification( "https://receiver-url.com/webhook", {"event": "task_completed", "taskId": "123"} ) ``` #### 接收方配置 ```python from core.a2a.utils.push_notification_auth import PushNotificationReceiverAuth from starlette.requests import Request # 创建接收方认证 receiver_auth = PushNotificationReceiverAuth() # 加载发送方的公钥 await receiver_auth.load_jwks("https://sender-url.com/.well-known/jwks.json") # 在webhook处理函数中验证推送通知 async def webhook_handler(request: Request): is_valid = await receiver_auth.verify_push_notification(request) if is_valid: # 处理推送通知... data = await request.json() print(f"收到有效的推送通知: {data}") ``` ### 4. 使用内存缓存 ```python from core.a2a.utils.in_memory_cache import InMemoryCache # 获取缓存实例 cache = InMemoryCache() # 设置缓存项(带TTL) cache.set("api_result", {"data": "some_value"}, ttl=300) # 5分钟过期 # 获取缓存项 result = cache.get("api_result") if result: print(f"从缓存获取结果: {result}") else: print("缓存已过期或不存在") # 删除缓存项 cache.delete("api_result") # 清空所有缓存 cache.clear() ``` ## 完整示例 查看 `examples/16_a2a_integration_test.py` 获取完整的集成示例,包括: 1. 创建 A2A 服务器,将现有 Agent 暴露为 A2A 服务 2. 使用 A2A 客户端连接到 A2A 服务器 3. 创建一个 Agent,使用 A2A 客户端作为工具 运行示例: ```bash # 启动 A2A 服务器 python -m examples.16_a2a_integration_test server # 运行 A2A 客户端 python -m examples.16_a2a_integration_test client # 运行带有 A2A 工具的 Agent python -m examples.16_a2a_integration_test agent ``` ## 与 MCP 的关系 Mentis 同时支持 MCP(Model Context Protocol)和 A2A(Agent2Agent)协议: * **MCP:** 专注于 AI 模型与外部工具/服务的交互,主要用于扩展单个 Agent 的能力。 * **A2A:** 专注于不同 Agent 之间的通信和协作,使多个 Agent 能够协同工作。 这两个协议是互补的,可以同时使用以构建功能强大的 Agent 系统。 ================================================ FILE: core/a2a/__init__.py ================================================ ================================================ FILE: core/a2a/agent_task_manager.py ================================================ import asyncio import logging import traceback from typing import Dict, Any, Union, AsyncIterable, Optional from core.a2a.types import ( TaskState, TaskStatus, Task, Artifact, Message, TextPart, SendTaskRequest, SendTaskResponse, GetTaskRequest, GetTaskResponse, CancelTaskRequest, CancelTaskResponse, SendTaskStreamingRequest, SendTaskStreamingResponse, SetTaskPushNotificationRequest, SetTaskPushNotificationResponse, GetTaskPushNotificationRequest, GetTaskPushNotificationResponse, TaskResubscriptionRequest, TaskSendParams, JSONRPCResponse, InvalidParamsError, TaskNotFoundError, TaskNotCancelableError, PushNotificationNotSupportedError, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, InternalError, TaskIdParams, PushNotificationConfig ) from core.a2a.server.task_manager import TaskManager, InMemoryTaskManager from core.a2a.server import utils logger = logging.getLogger(__name__) class AgentTaskManager(InMemoryTaskManager): """ AgentTaskManager是连接LangGraph Agent与A2A协议的关键组件。 它负责管理任务生命周期、处理流式响应、更新任务状态以及发送推送通知。 """ def __init__(self, agent, notification_sender_auth=None): """ 初始化AgentTaskManager Args: agent: LangGraph Agent实例 notification_sender_auth: 推送通知认证(可选) """ super().__init__() self.agent = agent self.notification_sender_auth = notification_sender_auth async def _run_streaming_agent(self, request: SendTaskStreamingRequest): """ 运行流式Agent并处理响应 Args: request: 流式任务请求 """ task_send_params: TaskSendParams = request.params query = self._get_user_query(task_send_params) try: async for item in self.agent.stream(query, task_send_params.sessionId): is_task_complete = item["is_task_complete"] require_user_input = item["require_user_input"] artifact = None message = None parts = [{"type": "text", "text": item["content"]}] end_stream = False if not is_task_complete and not require_user_input: task_state = TaskState.WORKING message = Message(role="agent", parts=parts) elif require_user_input: task_state = TaskState.INPUT_REQUIRED message = Message(role="agent", parts=parts) end_stream = True else: task_state = TaskState.COMPLETED artifact = Artifact(parts=parts, index=0, append=False) end_stream = True task_status = TaskStatus(state=task_state, message=message) latest_task = await self.update_store( task_send_params.id, task_status, None if artifact is None else [artifact], ) await self.send_task_notification(latest_task) if artifact: task_artifact_update_event = TaskArtifactUpdateEvent( id=task_send_params.id, artifact=artifact ) await self.enqueue_events_for_sse( task_send_params.id, task_artifact_update_event ) task_update_event = TaskStatusUpdateEvent( id=task_send_params.id, status=task_status, final=end_stream ) await self.enqueue_events_for_sse( task_send_params.id, task_update_event ) except Exception as e: logger.error(f"An error occurred while streaming the response: {e}") await self.enqueue_events_for_sse( task_send_params.id, InternalError(message=f"An error occurred while streaming the response: {e}") ) def _get_user_query(self, task_send_params: TaskSendParams) -> str: """ 从任务参数中提取用户查询 (采用 Google Demo 的严格方法) Args: task_send_params: 任务发送参数 Returns: str: 用户查询文本 """ if not task_send_params.message or not task_send_params.message.parts: logger.warning(f"[_get_user_query] Message or parts are empty for task {task_send_params.id}") return "" # 或者可以抛出错误,取决于你的设计 # 直接获取第一个 part part = task_send_params.message.parts[0] logger.debug(f"[_get_user_query] First part: type={type(part)}, value={part!r}") # 保留调试日志 # 严格检查第一个 part 是否为 TextPart 实例 if not isinstance(part, TextPart): logger.error(f"[_get_user_query] First part is not a TextPart instance! Type: {type(part)}") # 直接抛出错误,这会中断流程并提供明确信息 raise ValueError(f"Expected first message part to be TextPart, but got {type(part)}") # 如果检查通过,直接返回文本 logger.debug(f"[_get_user_query] Extracted query from TextPart: '{part.text}'") return part.text def _validate_request( self, request: Union[SendTaskRequest, SendTaskStreamingRequest] ) -> JSONRPCResponse | None: """ 验证请求参数 Args: request: 任务请求 Returns: JSONRPCResponse | None: 错误响应或None """ task_send_params: TaskSendParams = request.params if not utils.are_modalities_compatible( task_send_params.acceptedOutputModes, self.agent.SUPPORTED_CONTENT_TYPES ): logger.warning( "Unsupported output mode. Received %s, Support %s", task_send_params.acceptedOutputModes, self.agent.SUPPORTED_CONTENT_TYPES, ) return utils.new_incompatible_types_error(request.id) if task_send_params.pushNotification and not task_send_params.pushNotification.url: logger.warning("Push notification URL is missing") return JSONRPCResponse(id=request.id, error=InvalidParamsError(message="Push notification URL is missing")) return None async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: """ 处理发送任务请求 Args: request: 任务请求 Returns: SendTaskResponse: 任务响应 """ validation_error = self._validate_request(request) if validation_error: return SendTaskResponse(id=request.id, error=validation_error.error) if request.params.pushNotification: if not await self.set_push_notification_info(request.params.id, request.params.pushNotification): return SendTaskResponse(id=request.id, error=InvalidParamsError(message="Push notification URL is invalid")) await self.upsert_task(request.params) task = await self.update_store( request.params.id, TaskStatus(state=TaskState.WORKING), None ) await self.send_task_notification(task) task_send_params: TaskSendParams = request.params query = self._get_user_query(task_send_params) try: agent_response = self.agent.invoke(query, task_send_params.sessionId) # 处理Agent响应并更新任务状态 parts = [{"type": "text", "text": agent_response}] artifact = Artifact(parts=parts, index=0, append=False) task = await self.update_store( task_send_params.id, TaskStatus(state=TaskState.COMPLETED), [artifact] ) await self.send_task_notification(task) return SendTaskResponse(id=request.id, result=task) except Exception as e: # 建议也稍微改进一下异常处理日志和返回信息 logger.error(f"Error during agent invocation or task processing: {e}", exc_info=True) # 记录失败状态 try: # 确保即使在异常处理中也能更新状态 task_failed : Task = await self.update_store( task_send_params.id, TaskStatus(state=TaskState.FAILED, error={"message": str(e)}), None ) await self.send_task_notification(task_failed) except Exception as update_err: # 如果更新状态也失败,记录下来 logger.error(f"Failed to update task status to FAILED after initial error: {update_err}", exc_info=True) # 返回更合适的错误类型和消息 # return SendTaskResponse(id=request.id, error=InvalidParamsError(message=f"Error processing task: {e}")) # InternalError 可能更合适,因为错误发生在服务器内部处理中 return SendTaskResponse(id=request.id, error=InternalError(message=f"Error processing task: {str(e) or type(e).__name__}")) async def on_send_task_subscribe( self, request: SendTaskStreamingRequest ) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse: """ 处理流式任务请求 Args: request: 流式任务请求 Returns: AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse: 流式响应或错误 """ try: error = self._validate_request(request) if error: return error await self.upsert_task(request.params) if request.params.pushNotification: if not await self.set_push_notification_info(request.params.id, request.params.pushNotification): return JSONRPCResponse(id=request.id, error=InvalidParamsError(message="Push notification URL is invalid")) task_send_params: TaskSendParams = request.params sse_event_queue = await self.setup_sse_consumer(task_send_params.id, False) asyncio.create_task(self._run_streaming_agent(request)) return self.dequeue_events_for_sse( request.id, task_send_params.id, sse_event_queue ) except Exception as e: logger.error(f"Error in SSE stream: {e}") print(traceback.format_exc()) return JSONRPCResponse( id=request.id, error=InternalError( message="An error occurred while streaming the response" ), ) async def _process_agent_response( self, request: SendTaskRequest, agent_response: dict ) -> SendTaskResponse: """Processes the agent's response and updates the task store.""" task_send_params: TaskSendParams = request.params task_id = task_send_params.id history_length = task_send_params.historyLength task_status = None parts = [{"type": "text", "text": agent_response["content"]}] artifact = None if agent_response["require_user_input"]: task_status = TaskStatus( state=TaskState.INPUT_REQUIRED, message=Message(role="agent", parts=parts), ) else: task_status = TaskStatus(state=TaskState.COMPLETED) artifact = Artifact(parts=parts) task = await self.update_store( task_id, task_status, None if artifact is None else [artifact] ) task_result = self.append_task_history(task, history_length) await self.send_task_notification(task) return SendTaskResponse(id=request.id, result=task_result) async def on_resubscribe_to_task( self, request: TaskResubscriptionRequest ) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse: task_id_params: TaskIdParams = request.params try: sse_event_queue = await self.setup_sse_consumer(task_id_params.id, True) return self.dequeue_events_for_sse(request.id, task_id_params.id, sse_event_queue) except Exception as e: logger.error(f"Error while reconnecting to SSE stream: {e}") return JSONRPCResponse( id=request.id, error=InternalError( message=f"An error occurred while reconnecting to stream: {e}" ), ) async def send_task_notification(self, task: Task): if not await self.has_push_notification_info(task.id): logger.info(f"No push notification info found for task {task.id}") return push_info = await self.get_push_notification_info(task.id) logger.info(f"Notifying for task {task.id} => {task.status.state}") await self.notification_sender_auth.send_push_notification( push_info.url, data=task.model_dump(exclude_none=True) ) async def set_push_notification_info(self, task_id: str, push_notification_config: PushNotificationConfig): # Verify the ownership of notification URL by issuing a challenge request. if self.notification_sender_auth: is_verified = await self.notification_sender_auth.verify_push_notification_url(push_notification_config.url) if not is_verified: return False await super().set_push_notification_info(task_id, push_notification_config) return True ================================================ FILE: core/a2a/client/__init__.py ================================================ ================================================ FILE: core/a2a/client/card_resolver.py ================================================ import httpx from core.a2a.types import ( AgentCard, A2AClientJSONError, ) import json class A2ACardResolver: def __init__(self, base_url, agent_card_path="/.well-known/agent.json"): self.base_url = base_url.rstrip("/") self.agent_card_path = agent_card_path.lstrip("/") def get_agent_card(self) -> AgentCard: with httpx.Client() as client: response = client.get(self.base_url + "/" + self.agent_card_path) response.raise_for_status() try: return AgentCard(**response.json()) except json.JSONDecodeError as e: raise A2AClientJSONError(str(e)) from e ================================================ FILE: core/a2a/client/client.py ================================================ import httpx from httpx_sse import connect_sse from typing import Any, AsyncIterable from core.a2a.types import ( AgentCard, GetTaskRequest, SendTaskRequest, SendTaskResponse, JSONRPCRequest, GetTaskResponse, CancelTaskResponse, CancelTaskRequest, SetTaskPushNotificationRequest, SetTaskPushNotificationResponse, GetTaskPushNotificationRequest, GetTaskPushNotificationResponse, A2AClientHTTPError, A2AClientJSONError, SendTaskStreamingRequest, SendTaskStreamingResponse, ) import json class A2AClient: def __init__(self, agent_card: AgentCard = None, url: str = None): if agent_card: self.url = agent_card.url elif url: self.url = url else: raise ValueError("Must provide either agent_card or url") async def send_task(self, payload: dict[str, Any]) -> SendTaskResponse: request = SendTaskRequest(params=payload) return SendTaskResponse(**await self._send_request(request)) async def send_task_streaming( self, payload: dict[str, Any] ) -> AsyncIterable[SendTaskStreamingResponse]: request = SendTaskStreamingRequest(params=payload) with httpx.Client(timeout=None) as client: with connect_sse( client, "POST", self.url, json=request.model_dump() ) as event_source: try: for sse in event_source.iter_sse(): yield SendTaskStreamingResponse(**json.loads(sse.data)) except json.JSONDecodeError as e: raise A2AClientJSONError(str(e)) from e except httpx.RequestError as e: raise A2AClientHTTPError(400, str(e)) from e async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]: async with httpx.AsyncClient() as client: try: # Image generation could take time, adding timeout response = await client.post( self.url, json=request.model_dump(), timeout=30 ) response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: raise A2AClientHTTPError(e.response.status_code, str(e)) from e except json.JSONDecodeError as e: raise A2AClientJSONError(str(e)) from e async def get_task(self, payload: dict[str, Any]) -> GetTaskResponse: request = GetTaskRequest(params=payload) return GetTaskResponse(**await self._send_request(request)) async def cancel_task(self, payload: dict[str, Any]) -> CancelTaskResponse: request = CancelTaskRequest(params=payload) return CancelTaskResponse(**await self._send_request(request)) async def set_task_callback( self, payload: dict[str, Any] ) -> SetTaskPushNotificationResponse: request = SetTaskPushNotificationRequest(params=payload) return SetTaskPushNotificationResponse(**await self._send_request(request)) async def get_task_callback( self, payload: dict[str, Any] ) -> GetTaskPushNotificationResponse: request = GetTaskPushNotificationRequest(params=payload) return GetTaskPushNotificationResponse(**await self._send_request(request)) ================================================ FILE: core/a2a/config.json ================================================ { "local_agent": { "url": "http://127.0.0.1:8000/", "auth": { "type": "none" } } } ================================================ FILE: core/a2a/server/__init__.py ================================================ ================================================ FILE: core/a2a/server/server.py ================================================ # core/a2a/server/server.py from starlette.applications import Starlette from starlette.responses import JSONResponse from sse_starlette.sse import EventSourceResponse from starlette.requests import Request from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware # --- 添加 Pydantic 的 ValidationError 导入 --- from pydantic import ValidationError # --- 导入结束 --- from core.a2a.types import ( A2ARequest, JSONRPCResponse, InvalidRequestError, JSONParseError, GetTaskRequest, CancelTaskRequest, SendTaskRequest, SetTaskPushNotificationRequest, GetTaskPushNotificationRequest, InternalError, AgentCard, TaskResubscriptionRequest, SendTaskStreamingRequest, MethodNotFoundError, # 确保 ValidationError 没有在这里导入 ) import json from typing import AsyncIterable, Any, Optional, Union from core.a2a.server.task_manager import TaskManager import logging logger = logging.getLogger(__name__) class A2AServer: def __init__( self, host="0.0.0.0", port=5000, endpoint="/", agent_card: AgentCard = None, task_manager: TaskManager = None, allowed_origins: Optional[list[str]] = None, ): self.host = host self.port = port self.endpoint = endpoint self.task_manager = task_manager self.agent_card = agent_card if allowed_origins is None: # 本地开发时默认只允许 localhost:3000 allowed_origins = ["http://localhost:3000"] logger.warning("CORS allow_origins set to 'http://localhost:3000' for local development.") else: logger.info(f"CORS allow_origins configured: {allowed_origins}") middleware = [ Middleware( CORSMiddleware, allow_origins=allowed_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) ] self.app = Starlette(middleware=middleware, debug=True) self.app.add_route(self.endpoint, self._process_request, methods=["POST"]) self.app.add_route( "/.well-known/agent.json", self._get_agent_card, methods=["GET"] ) logger.info(f"A2AServer initialized. Endpoint: {self.endpoint}, Agent Card Endpoint: /.well-known/agent.json") def start(self): if self.agent_card is None: raise ValueError("agent_card must be provided to A2AServer") if self.task_manager is None: raise ValueError("task_manager must be provided to A2AServer") import uvicorn logger.info(f"Starting Uvicorn server on {self.host}:{self.port}...") uvicorn.run(self.app, host=self.host, port=self.port) def _get_agent_card(self, request: Request) -> JSONResponse: logger.debug("Received request for /.well-known/agent.json") if not self.agent_card: logger.error("Agent card requested but not configured in A2AServer.") return JSONResponse({"error": "Agent card not configured"}, status_code=500) return JSONResponse(self.agent_card.model_dump(exclude_none=True)) async def _process_request(self, request: Request) -> Union[JSONResponse, EventSourceResponse]: result = None; json_rpc_request = None; request_id_for_error = None try: try: body = await request.json(); logger.debug(f"Received request body: {body}") except json.JSONDecodeError as e: logger.error(f"JSON decoding failed: {e}"); raise JSONParseError() try: json_rpc_request = A2ARequest.validate_python(body); request_id_for_error = getattr(json_rpc_request, 'id', None) logger.info(f"Processing valid A2A request: Method='{json_rpc_request.method}', ID='{request_id_for_error}', TaskID='{getattr(json_rpc_request.params, 'id', 'N/A')}'") except ValidationError as e: logger.error(f"A2A request validation failed: {e}"); req_id_fallback = body.get('id') if isinstance(body, dict) else None # 注意: 这里抛出的 InvalidRequestError 会在下面的 except Exception 中被捕获 raise InvalidRequestError(data=json.loads(e.json())) from e # 分发给 TaskManager if isinstance(json_rpc_request, GetTaskRequest): result = await self.task_manager.on_get_task(json_rpc_request) elif isinstance(json_rpc_request, SendTaskRequest): result = await self.task_manager.on_send_task(json_rpc_request) elif isinstance(json_rpc_request, SendTaskStreamingRequest): result = await self.task_manager.on_send_task_subscribe(json_rpc_request) elif isinstance(json_rpc_request, CancelTaskRequest): result = await self.task_manager.on_cancel_task(json_rpc_request) elif isinstance(json_rpc_request, SetTaskPushNotificationRequest): result = await self.task_manager.on_set_task_push_notification(json_rpc_request) elif isinstance(json_rpc_request, GetTaskPushNotificationRequest): result = await self.task_manager.on_get_task_push_notification(json_rpc_request) elif isinstance(json_rpc_request, TaskResubscriptionRequest): result = await self.task_manager.on_resubscribe_to_task(json_rpc_request) else: logger.warning(f"Unhandled validated request type: {type(json_rpc_request)}"); raise MethodNotFoundError(data={"method": getattr(json_rpc_request, 'method', 'unknown')}) logger.debug(f"[A2AServer] Result from TaskManager method '{json_rpc_request.method}': type={type(result)}") return self._create_response(result) # 调用 _create_response except Exception as e: # 统一处理所有在请求处理(包括验证和 task manager 调用)中发生的异常 logger.error(f"Exception during request processing: {e}", exc_info=True) return self._handle_exception(e, request_id=request_id_for_error) # 使用 _handle_exception def _handle_exception(self, e: Exception, request_id: Optional[Union[str, int]] = None) -> JSONResponse: status_code = 500; json_rpc_error: Optional[JSONRPCError] = None if isinstance(e, JSONParseError): json_rpc_error = e; status_code = 400 elif isinstance(e, InvalidRequestError): json_rpc_error = e; status_code = 400 elif isinstance(e, MethodNotFoundError): json_rpc_error = e; status_code = 404 # 或 501 # --- 现在可以正确捕获 Pydantic 的 ValidationError --- elif isinstance(e, ValidationError): logger.warning(f"Pydantic Validation error caught in handler: {e}") error_data = str(e); try: error_data = json.loads(e.json()) except: pass # 通常 Pydantic 验证错误发生在请求处理阶段是 InvalidRequestError 的一种 # 如果发生在响应创建阶段则更像是 InternalError json_rpc_error = InvalidRequestError(message="Request/Response data validation failed", data=error_data) status_code = 400 # 认为是客户端请求或服务器返回的数据结构问题 # --- 捕获结束 --- elif isinstance(e, ValueError) and "Unexpected result type" in str(e): logger.error(f"Internal error due to unexpected result type: {e}", exc_info=False) json_rpc_error = InternalError(message="Server error: Unexpected result type from handler.") status_code = 500 elif isinstance(e, NotImplementedError): logger.error(f"Method not implemented: {e}", exc_info=True) json_rpc_error = MethodNotFoundError(message=f"Method not implemented: {e}") status_code = 501 else: logger.error(f"Unhandled internal exception: {e}", exc_info=True) json_rpc_error = InternalError(message=f"An internal server error occurred: {type(e).__name__}") status_code = 500 response = JSONRPCResponse(id=request_id, error=json_rpc_error) logger.debug(f"Returning error response: {response.model_dump(exclude_none=True)}") return JSONResponse(response.model_dump(exclude_none=True), status_code=status_code) def _create_response(self, result: Any) -> Union[JSONResponse, EventSourceResponse]: if isinstance(result, AsyncIterable): logger.debug("[A2AServer] Creating EventSourceResponse (text/event-stream)") async def event_generator(stream_result: AsyncIterable) -> AsyncIterable[dict[str, str]]: try: async for item in stream_result: if hasattr(item, 'model_dump_json'): json_data = item.model_dump_json(exclude_none=True) logger.debug(f"A2AServer yielding SSE data: {json_data}") yield {"data": json_data} else: logger.warning(f"Yielding non-Pydantic object in event stream: {type(item)}") yield {"data": json.dumps(str(item))} except Exception as gen_err: logger.error(f"Error during SSE event generation: {gen_err}", exc_info=True) try: # 尝试 yield 一个标准的 JSON-RPC 错误事件 error_payload = JSONRPCResponse(id=None, error=InternalError(message=f"Streaming generation error: {gen_err}")) yield {"event": "error", "data": error_payload.model_dump_json(exclude_none=True)} except Exception as yield_err: logger.error(f"Failed to yield error event to SSE stream: {yield_err}", exc_info=True) return EventSourceResponse(event_generator(result)) elif isinstance(result, JSONRPCResponse): logger.debug("[A2AServer] Creating JSONResponse (application/json)") return JSONResponse(result.model_dump(exclude_none=True)) else: logger.error(f"Unexpected result type received by _create_response: {type(result)}") raise ValueError(f"Unexpected result type: {type(result)}") ================================================ FILE: core/a2a/server/task_manager.py ================================================ from abc import ABC, abstractmethod from typing import Union, AsyncIterable, List from core.a2a.types import Task from core.a2a.types import ( JSONRPCResponse, TaskIdParams, TaskQueryParams, GetTaskRequest, TaskNotFoundError, SendTaskRequest, CancelTaskRequest, TaskNotCancelableError, SetTaskPushNotificationRequest, GetTaskPushNotificationRequest, GetTaskResponse, CancelTaskResponse, SendTaskResponse, SetTaskPushNotificationResponse, GetTaskPushNotificationResponse, PushNotificationNotSupportedError, TaskSendParams, TaskStatus, TaskState, TaskResubscriptionRequest, SendTaskStreamingRequest, SendTaskStreamingResponse, Artifact, PushNotificationConfig, TaskStatusUpdateEvent, JSONRPCError, TaskPushNotificationConfig, InternalError, ) from core.a2a.server.utils import new_not_implemented_error import asyncio import logging logger = logging.getLogger(__name__) class TaskManager(ABC): @abstractmethod async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse: pass @abstractmethod async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: pass @abstractmethod async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: pass @abstractmethod async def on_send_task_subscribe( self, request: SendTaskStreamingRequest ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: pass @abstractmethod async def on_set_task_push_notification( self, request: SetTaskPushNotificationRequest ) -> SetTaskPushNotificationResponse: pass @abstractmethod async def on_get_task_push_notification( self, request: GetTaskPushNotificationRequest ) -> GetTaskPushNotificationResponse: pass @abstractmethod async def on_resubscribe_to_task( self, request: TaskResubscriptionRequest ) -> Union[AsyncIterable[SendTaskResponse], JSONRPCResponse]: pass class InMemoryTaskManager(TaskManager): def __init__(self): self.tasks: dict[str, Task] = {} self.push_notification_infos: dict[str, PushNotificationConfig] = {} self.lock = asyncio.Lock() self.task_sse_subscribers: dict[str, List[asyncio.Queue]] = {} self.subscriber_lock = asyncio.Lock() async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse: logger.info(f"Getting task {request.params.id}") task_query_params: TaskQueryParams = request.params async with self.lock: task = self.tasks.get(task_query_params.id) if task is None: return GetTaskResponse(id=request.id, error=TaskNotFoundError()) task_result = self.append_task_history( task, task_query_params.historyLength ) return GetTaskResponse(id=request.id, result=task_result) async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: logger.info(f"Cancelling task {request.params.id}") task_id_params: TaskIdParams = request.params async with self.lock: task = self.tasks.get(task_id_params.id) if task is None: return CancelTaskResponse(id=request.id, error=TaskNotFoundError()) return CancelTaskResponse(id=request.id, error=TaskNotCancelableError()) @abstractmethod async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: pass @abstractmethod async def on_send_task_subscribe( self, request: SendTaskStreamingRequest ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: pass async def set_push_notification_info(self, task_id: str, notification_config: PushNotificationConfig): async with self.lock: task = self.tasks.get(task_id) if task is None: raise ValueError(f"Task not found for {task_id}") self.push_notification_infos[task_id] = notification_config return async def get_push_notification_info(self, task_id: str) -> PushNotificationConfig: async with self.lock: task = self.tasks.get(task_id) if task is None: raise ValueError(f"Task not found for {task_id}") return self.push_notification_infos[task_id] return async def has_push_notification_info(self, task_id: str) -> bool: async with self.lock: return task_id in self.push_notification_infos async def on_set_task_push_notification( self, request: SetTaskPushNotificationRequest ) -> SetTaskPushNotificationResponse: logger.info(f"Setting task push notification {request.params.id}") task_notification_params: TaskPushNotificationConfig = request.params try: await self.set_push_notification_info(task_notification_params.id, task_notification_params.pushNotificationConfig) except Exception as e: logger.error(f"Error while setting push notification info: {e}") return JSONRPCResponse( id=request.id, error=InternalError( message="An error occurred while setting push notification info" ), ) return SetTaskPushNotificationResponse(id=request.id, result=task_notification_params) async def on_get_task_push_notification( self, request: GetTaskPushNotificationRequest ) -> GetTaskPushNotificationResponse: logger.info(f"Getting task push notification {request.params.id}") task_params: TaskIdParams = request.params try: notification_info = await self.get_push_notification_info(task_params.id) except Exception as e: logger.error(f"Error while getting push notification info: {e}") return GetTaskPushNotificationResponse( id=request.id, error=InternalError( message="An error occurred while getting push notification info" ), ) return GetTaskPushNotificationResponse(id=request.id, result=TaskPushNotificationConfig(id=task_params.id, pushNotificationConfig=notification_info)) async def upsert_task(self, task_send_params: TaskSendParams) -> Task: logger.info(f"Upserting task {task_send_params.id}") async with self.lock: task = self.tasks.get(task_send_params.id) if task is None: task = Task( id=task_send_params.id, sessionId = task_send_params.sessionId, messages=[task_send_params.message], status=TaskStatus(state=TaskState.SUBMITTED), history=[task_send_params.message], ) self.tasks[task_send_params.id] = task else: task.history.append(task_send_params.message) return task async def on_resubscribe_to_task( self, request: TaskResubscriptionRequest ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: return new_not_implemented_error(request.id) async def update_store( self, task_id: str, status: TaskStatus, artifacts: list[Artifact] ) -> Task: async with self.lock: try: task = self.tasks[task_id] except KeyError: logger.error(f"Task {task_id} not found for updating the task") raise ValueError(f"Task {task_id} not found") task.status = status if status.message is not None: task.history.append(status.message) if artifacts is not None: if task.artifacts is None: task.artifacts = [] task.artifacts.extend(artifacts) return task def append_task_history(self, task: Task, historyLength: int | None): new_task = task.model_copy() if historyLength is not None and historyLength > 0: new_task.history = new_task.history[-historyLength:] else: new_task.history = [] return new_task async def setup_sse_consumer(self, task_id: str, is_resubscribe: bool = False): async with self.subscriber_lock: if task_id not in self.task_sse_subscribers: if is_resubscribe: raise ValueError("Task not found for resubscription") else: self.task_sse_subscribers[task_id] = [] sse_event_queue = asyncio.Queue(maxsize=0) # <=0 is unlimited self.task_sse_subscribers[task_id].append(sse_event_queue) return sse_event_queue async def enqueue_events_for_sse(self, task_id, task_update_event): async with self.subscriber_lock: if task_id not in self.task_sse_subscribers: return current_subscribers = self.task_sse_subscribers[task_id] for subscriber in current_subscribers: await subscriber.put(task_update_event) async def dequeue_events_for_sse( self, request_id, task_id, sse_event_queue: asyncio.Queue ) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse: try: while True: event = await sse_event_queue.get() if isinstance(event, JSONRPCError): yield SendTaskStreamingResponse(id=request_id, error=event) break yield SendTaskStreamingResponse(id=request_id, result=event) if isinstance(event, TaskStatusUpdateEvent) and event.final: break finally: async with self.subscriber_lock: if task_id in self.task_sse_subscribers: self.task_sse_subscribers[task_id].remove(sse_event_queue) ================================================ FILE: core/a2a/server/utils.py ================================================ from core.a2a.types import ( JSONRPCResponse, ContentTypeNotSupportedError, UnsupportedOperationError, ) from typing import List def are_modalities_compatible( server_output_modes: List[str], client_output_modes: List[str] ): """Modalities are compatible if they are both non-empty and there is at least one common element.""" if client_output_modes is None or len(client_output_modes) == 0: return True if server_output_modes is None or len(server_output_modes) == 0: return True return any(x in server_output_modes for x in client_output_modes) def new_incompatible_types_error(request_id): return JSONRPCResponse(id=request_id, error=ContentTypeNotSupportedError()) def new_not_implemented_error(request_id): return JSONRPCResponse(id=request_id, error=UnsupportedOperationError()) ================================================ FILE: core/a2a/types.py ================================================ from typing import Union, Any from pydantic import BaseModel, Field, TypeAdapter from typing import Literal, List, Annotated, Optional from datetime import datetime from pydantic import model_validator, ConfigDict, field_serializer from uuid import uuid4 from enum import Enum from typing_extensions import Self class TaskState(str, Enum): SUBMITTED = "submitted" WORKING = "working" INPUT_REQUIRED = "input-required" COMPLETED = "completed" CANCELED = "canceled" FAILED = "failed" UNKNOWN = "unknown" class TextPart(BaseModel): type: Literal["text"] = "text" text: str metadata: dict[str, Any] | None = None class FileContent(BaseModel): name: str | None = None mimeType: str | None = None bytes: str | None = None uri: str | None = None @model_validator(mode="after") def check_content(self) -> Self: if not (self.bytes or self.uri): raise ValueError("Either 'bytes' or 'uri' must be present in the file data") if self.bytes and self.uri: raise ValueError( "Only one of 'bytes' or 'uri' can be present in the file data" ) return self class FilePart(BaseModel): type: Literal["file"] = "file" file: FileContent metadata: dict[str, Any] | None = None class DataPart(BaseModel): type: Literal["data"] = "data" data: dict[str, Any] metadata: dict[str, Any] | None = None Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")] class Message(BaseModel): role: Literal["user", "agent"] parts: List[Part] metadata: dict[str, Any] | None = None class TaskStatus(BaseModel): state: TaskState message: Message | None = None timestamp: datetime = Field(default_factory=datetime.now) @field_serializer("timestamp") def serialize_dt(self, dt: datetime, _info): return dt.isoformat() class Artifact(BaseModel): name: str | None = None description: str | None = None parts: List[Part] metadata: dict[str, Any] | None = None index: int = 0 append: bool | None = None lastChunk: bool | None = None class Task(BaseModel): id: str sessionId: str | None = None status: TaskStatus artifacts: List[Artifact] | None = None history: List[Message] | None = None metadata: dict[str, Any] | None = None class TaskStatusUpdateEvent(BaseModel): id: str status: TaskStatus final: bool = False metadata: dict[str, Any] | None = None class TaskArtifactUpdateEvent(BaseModel): id: str artifact: Artifact metadata: dict[str, Any] | None = None class AuthenticationInfo(BaseModel): model_config = ConfigDict(extra="allow") schemes: List[str] credentials: str | None = None class PushNotificationConfig(BaseModel): url: str token: str | None = None authentication: AuthenticationInfo | None = None class TaskIdParams(BaseModel): id: str metadata: dict[str, Any] | None = None class TaskQueryParams(TaskIdParams): historyLength: int | None = None class TaskSendParams(BaseModel): id: str sessionId: str = Field(default_factory=lambda: uuid4().hex) message: Message acceptedOutputModes: Optional[List[str]] = None pushNotification: PushNotificationConfig | None = None historyLength: int | None = None metadata: dict[str, Any] | None = None class TaskPushNotificationConfig(BaseModel): id: str pushNotificationConfig: PushNotificationConfig ## RPC Messages class JSONRPCMessage(BaseModel): jsonrpc: Literal["2.0"] = "2.0" id: int | str | None = Field(default_factory=lambda: uuid4().hex) class JSONRPCRequest(JSONRPCMessage): method: str params: dict[str, Any] | None = None class JSONRPCError(BaseModel): code: int message: str data: Any | None = None class JSONRPCResponse(JSONRPCMessage): result: Any | None = None error: JSONRPCError | None = None class SendTaskRequest(JSONRPCRequest): method: Literal["tasks/send"] = "tasks/send" params: TaskSendParams class SendTaskResponse(JSONRPCResponse): result: Task | None = None class SendTaskStreamingRequest(JSONRPCRequest): method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe" params: TaskSendParams class SendTaskStreamingResponse(JSONRPCResponse): result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None class GetTaskRequest(JSONRPCRequest): method: Literal["tasks/get"] = "tasks/get" params: TaskQueryParams class GetTaskResponse(JSONRPCResponse): result: Task | None = None class CancelTaskRequest(JSONRPCRequest): method: Literal["tasks/cancel",] = "tasks/cancel" params: TaskIdParams class CancelTaskResponse(JSONRPCResponse): result: Task | None = None class SetTaskPushNotificationRequest(JSONRPCRequest): method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set" params: TaskPushNotificationConfig class SetTaskPushNotificationResponse(JSONRPCResponse): result: TaskPushNotificationConfig | None = None class GetTaskPushNotificationRequest(JSONRPCRequest): method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get" params: TaskIdParams class GetTaskPushNotificationResponse(JSONRPCResponse): result: TaskPushNotificationConfig | None = None class TaskResubscriptionRequest(JSONRPCRequest): method: Literal["tasks/resubscribe",] = "tasks/resubscribe" params: TaskIdParams A2ARequest = TypeAdapter( Annotated[ Union[ SendTaskRequest, GetTaskRequest, CancelTaskRequest, SetTaskPushNotificationRequest, GetTaskPushNotificationRequest, TaskResubscriptionRequest, SendTaskStreamingRequest, ], Field(discriminator="method"), ] ) ## Error types class JSONParseError(JSONRPCError): code: int = -32700 message: str = "Invalid JSON payload" data: Any | None = None class InvalidRequestError(JSONRPCError): code: int = -32600 message: str = "Request payload validation error" data: Any | None = None class MethodNotFoundError(JSONRPCError): code: int = -32601 message: str = "Method not found" data: None = None class InvalidParamsError(JSONRPCError): code: int = -32602 message: str = "Invalid parameters" data: Any | None = None class InternalError(JSONRPCError): code: int = -32603 message: str = "Internal error" data: Any | None = None class TaskNotFoundError(JSONRPCError): code: int = -32001 message: str = "Task not found" data: None = None class TaskNotCancelableError(JSONRPCError): code: int = -32002 message: str = "Task cannot be canceled" data: None = None class PushNotificationNotSupportedError(JSONRPCError): code: int = -32003 message: str = "Push Notification is not supported" data: None = None class UnsupportedOperationError(JSONRPCError): code: int = -32004 message: str = "This operation is not supported" data: None = None class ContentTypeNotSupportedError(JSONRPCError): code: int = -32005 message: str = "Incompatible content types" data: None = None class AgentProvider(BaseModel): organization: str url: str | None = None class AgentCapabilities(BaseModel): streaming: bool = False pushNotifications: bool = False stateTransitionHistory: bool = False class AgentAuthentication(BaseModel): schemes: List[str] credentials: str | None = None class AgentSkill(BaseModel): id: str name: str description: str | None = None tags: List[str] | None = None examples: List[str] | None = None inputModes: List[str] | None = None outputModes: List[str] | None = None class AgentCard(BaseModel): name: str description: str | None = None url: str provider: AgentProvider | None = None version: str documentationUrl: str | None = None capabilities: AgentCapabilities authentication: AgentAuthentication | None = None defaultInputModes: List[str] = ["text"] defaultOutputModes: List[str] = ["text"] skills: List[AgentSkill] class A2AClientError(Exception): pass class A2AClientHTTPError(A2AClientError): def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message super().__init__(f"HTTP Error {status_code}: {message}") class A2AClientJSONError(A2AClientError): def __init__(self, message: str): self.message = message super().__init__(f"JSON Error: {message}") class MissingAPIKeyError(Exception): """Exception for missing API key.""" pass ================================================ FILE: core/a2a/utils/__init__.py ================================================ ================================================ FILE: core/a2a/utils/in_memory_cache.py ================================================ """In Memory Cache utility.""" import threading import time from typing import Any, Dict, Optional class InMemoryCache: """A thread-safe Singleton class to manage cache data. Ensures only one instance of the cache exists across the application. """ _instance: Optional["InMemoryCache"] = None _lock: threading.Lock = threading.Lock() _initialized: bool = False def __new__(cls): """Override __new__ to control instance creation (Singleton pattern). Uses a lock to ensure thread safety during the first instantiation. Returns: The singleton instance of InMemoryCache. """ if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): """Initialize the cache storage. Uses a flag (_initialized) to ensure this logic runs only on the very first creation of the singleton instance. """ if not self._initialized: with self._lock: if not self._initialized: # print("Initializing SessionCache storage") self._cache_data: Dict[str, Dict[str, Any]] = {} self._ttl: Dict[str, float] = {} self._data_lock: threading.Lock = threading.Lock() self._initialized = True def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: """Set a key-value pair. Args: key: The key for the data. value: The data to store. ttl: Time to live in seconds. If None, data will not expire. """ with self._data_lock: self._cache_data[key] = value if ttl is not None: self._ttl[key] = time.time() + ttl else: if key in self._ttl: del self._ttl[key] def get(self, key: str, default: Any = None) -> Any: """Get the value associated with a key. Args: key: The key for the data within the session. default: The value to return if the session or key is not found. Returns: The cached value, or the default value if not found. """ with self._data_lock: if key in self._ttl and time.time() > self._ttl[key]: del self._cache_data[key] del self._ttl[key] return default return self._cache_data.get(key, default) def delete(self, key: str) -> None: """Delete a specific key-value pair from a cache. Args: key: The key to delete. Returns: True if the key was found and deleted, False otherwise. """ with self._data_lock: if key in self._cache_data: del self._cache_data[key] if key in self._ttl: del self._ttl[key] return True return False def clear(self) -> bool: """Remove all data. Returns: True if the data was cleared, False otherwise. """ with self._data_lock: self._cache_data.clear() self._ttl.clear() return True return False ================================================ FILE: core/a2a/utils/push_notification_auth.py ================================================ from jwcrypto import jwk import uuid from starlette.responses import JSONResponse from starlette.requests import Request from typing import Any import jwt import time import json import hashlib import httpx import logging from jwt import PyJWK, PyJWKClient logger = logging.getLogger(__name__) AUTH_HEADER_PREFIX = 'Bearer ' class PushNotificationAuth: def _calculate_request_body_sha256(self, data: dict[str, Any]): """Calculates the SHA256 hash of a request body. This logic needs to be same for both the agent who signs the payload and the client verifier. """ body_str = json.dumps( data, ensure_ascii=False, allow_nan=False, indent=None, separators=(",", ":"), ) return hashlib.sha256(body_str.encode()).hexdigest() class PushNotificationSenderAuth(PushNotificationAuth): def __init__(self): self.public_keys = [] self.private_key_jwk: PyJWK = None @staticmethod async def verify_push_notification_url(url: str) -> bool: async with httpx.AsyncClient(timeout=10) as client: try: validation_token = str(uuid.uuid4()) response = await client.get( url, params={"validationToken": validation_token} ) response.raise_for_status() is_verified = response.text == validation_token logger.info(f"Verified push-notification URL: {url} => {is_verified}") return is_verified except Exception as e: logger.warning(f"Error during sending push-notification for URL {url}: {e}") return False def generate_jwk(self): key = jwk.JWK.generate(kty='RSA', size=2048, kid=str(uuid.uuid4()), use="sig") self.public_keys.append(key.export_public(as_dict=True)) self.private_key_jwk = PyJWK.from_json(key.export_private()) def handle_jwks_endpoint(self, _request: Request): """Allow clients to fetch public keys. """ return JSONResponse({ "keys": self.public_keys }) def _generate_jwt(self, data: dict[str, Any]): """JWT is generated by signing both the request payload SHA digest and time of token generation. Payload is signed with private key and it ensures the integrity of payload for client. Including iat prevents from replay attack. """ iat = int(time.time()) return jwt.encode( {"iat": iat, "request_body_sha256": self._calculate_request_body_sha256(data)}, key=self.private_key_jwk, headers={"kid": self.private_key_jwk.key_id}, algorithm="RS256" ) async def send_push_notification(self, url: str, data: dict[str, Any]): jwt_token = self._generate_jwt(data) headers = {'Authorization': f"Bearer {jwt_token}"} async with httpx.AsyncClient(timeout=10) as client: try: response = await client.post( url, json=data, headers=headers ) response.raise_for_status() logger.info(f"Push-notification sent for URL: {url}") except Exception as e: logger.warning(f"Error during sending push-notification for URL {url}: {e}") class PushNotificationReceiverAuth(PushNotificationAuth): def __init__(self): self.public_keys_jwks = [] self.jwks_client = None async def load_jwks(self, jwks_url: str): self.jwks_client = PyJWKClient(jwks_url) async def verify_push_notification(self, request: Request) -> bool: auth_header = request.headers.get("Authorization") if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX): print("Invalid authorization header") return False token = auth_header[len(AUTH_HEADER_PREFIX):] signing_key = self.jwks_client.get_signing_key_from_jwt(token) decode_token = jwt.decode( token, signing_key, options={"require": ["iat", "request_body_sha256"]}, algorithms=["RS256"], ) actual_body_sha256 = self._calculate_request_body_sha256(await request.json()) if actual_body_sha256 != decode_token["request_body_sha256"]: # Payload signature does not match the digest in signed token. raise ValueError("Invalid request body") if time.time() - decode_token["iat"] > 60 * 5: # Do not allow push-notifications older than 5 minutes. # This is to prevent replay attack. raise ValueError("Token is expired") return True ================================================ FILE: core/agents/__init__.py ================================================ # Agents module initialization ================================================ FILE: core/agents/base/base_agent.py ================================================ import json from typing import List, Dict, Any, Optional, Union, Callable, Sequence, TypeVar, cast from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models import LanguageModelLike from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, ToolMessage from langchain_core.tools import BaseTool from langchain_core.runnables import RunnableConfig from langgraph.graph import StateGraph from langgraph.types import Checkpointer from langgraph.graph.graph import CompiledGraph from langgraph.graph.state import CompiledStateGraph import logging try: import tiktoken TIKTOKEN_AVAILABLE = True except ImportError: TIKTOKEN_AVAILABLE = False print("Warning: Tiktoken not installed. Using naive token estimation.") logger = logging.getLogger(__name__) DEFAULT_MODEL_NAME = "gpt-4o-mini" StateSchema = TypeVar("StateSchema", bound=Union[dict, Any]) class BaseAgent: def __init__( self, name: str, model: Union[BaseChatModel, LanguageModelLike], tools: Optional[List[Union[BaseTool, Callable]]] = None, prompt: Optional[Union[str, SystemMessage, Callable]] = None, checkpointer: Optional[Checkpointer] = None, max_context_messages: Optional[int] = None, # Limit number of recent messages max_context_tokens: Optional[int] = None, # Limit total estimated tokens model_name: Optional[str] = "gpt-4o-mini", # Optional, used for future token estimation improvements description: str = "No description provided." ): if max_context_messages and max_context_tokens: raise ValueError("Only one of max_context_messages or max_context_tokens should be set.") if name is None or name == "LangGraph": raise ValueError("Agent name must be specified.") self.name = name self.model = model self.tools = tools or [] self.base_prompt = prompt self.checkpointer = checkpointer self.max_context_messages = max_context_messages self.max_context_tokens = max_context_tokens self.model_name = model_name or getattr(model, "model_name", DEFAULT_MODEL_NAME) self.description = description self._workflow: Optional[StateGraph] = None self._compiled_agent: Optional[CompiledGraph] = None # Stores the final compiled graph self._tokenizer = None if TIKTOKEN_AVAILABLE: try: self._tokenizer = tiktoken.encoding_for_model(self.model_name) except KeyError: try: self._tokenizer = tiktoken.get_encoding("cl100k_base") # print(f"Warning: Tiktoken encoding for model '{self.model_name}' not found. Using 'cl100k_base'.") except Exception as e: print(f"Error getting tiktoken encoding 'cl100k_base': {e}.") except Exception as e: print(f"Error initializing tiktoken for model '{self.model_name}': {e}.") def _estimate_tokens(self, message: BaseMessage) -> int: content_to_encode = "" if isinstance(message, (HumanMessage, SystemMessage, AIMessage)): if isinstance(message.content, str): content_to_encode = message.content elif isinstance(message.content, list): for block in message.content: if isinstance(block, dict) and block.get("type") == "text": content_to_encode += block.get("text", "") + "\n" elif isinstance(message, ToolMessage): content_to_encode = message.content if isinstance(message.content, str) else json.dumps(message.content) else: content_to_encode = str(message) if self._tokenizer: try: return len(self._tokenizer.encode(content_to_encode, disallowed_special=())) except Exception: pass return len(content_to_encode) // 2 def _truncate_by_tokens(self, messages: Sequence[BaseMessage]) -> List[BaseMessage]: if not self.max_context_tokens: return list(messages) truncated_messages: List[BaseMessage] = [] total_tokens = 0 preserved_system_message: Optional[SystemMessage] = None # Check if the first message is a SystemMessage, preserve it if so # Note: This assumes only ONE leading SystemMessage should be preserved. if messages and isinstance(messages[0], SystemMessage): preserved_system_message = messages[0] messages_to_truncate = messages[1:] try: system_tokens = self._estimate_tokens(preserved_system_message) # Only count if it doesn't exceed limit by itself if system_tokens <= self.max_context_tokens: total_tokens += system_tokens else: print(f"Warning: System message alone ({system_tokens} tokens) exceeds token limit ({self.max_context_tokens}). It might be truncated if context grows.") # Don't add to total_tokens yet, let truncation logic handle it. preserved_system_message = None # Don't preserve if it's too big initially except Exception: pass # Ignore errors estimating system message else: messages_to_truncate = messages # Iterate backwards from the most recent message for msg in reversed(messages_to_truncate): try: msg_tokens = self._estimate_tokens(msg) # Check if adding this message exceeds the limit if total_tokens + msg_tokens <= self.max_context_tokens: truncated_messages.append(msg) total_tokens += msg_tokens else: print(f"Context Token Limit ({self.max_context_tokens}) reached. Truncating older messages.") break # Limit reached except Exception as e: print(f"Warning: Failed to estimate tokens for message, skipping: {e}") continue # Re-add the system message at the beginning if it was preserved final_list = list(reversed(truncated_messages)) if preserved_system_message: try: system_tokens = self._estimate_tokens(preserved_system_message) except Exception: system_tokens = 0 # Ensure adding system message doesn't push over limit *again* (edge case) if total_tokens - (msg_tokens if 'msg_tokens' in locals() and total_tokens + msg_tokens > self.max_context_tokens else 0) + system_tokens <= self.max_context_tokens: final_list.insert(0, preserved_system_message) elif not final_list: # If only system message fits return [preserved_system_message] # Else: System message doesn't fit with the truncated history, omit it. return final_list def _truncate_messages(self, messages: Sequence[BaseMessage]) -> List[BaseMessage]: """根据配置(优先 token 数,其次消息数)截断消息历史。""" if self.max_context_tokens is not None: return self._truncate_by_tokens(messages) elif self.max_context_messages is not None: if messages and isinstance(messages[0], SystemMessage): # Keep system message + last N-1 messages keep_count = self.max_context_messages - 1 return [messages[0]] + list(messages[-keep_count:]) if keep_count > 0 and len(messages) > 1 else [messages[0]] else: return list(messages[-self.max_context_messages:]) return list(messages) def _get_state_value(self, state: StateSchema, key: str, default: Any = None) -> Any: return state.get(key, default) if isinstance(state, dict) else getattr(state, key, default) def _format_tools_for_prompt(self, tools: List[Union[BaseTool, Callable]]) -> str: """Formats the tool list for inclusion in the prompt.""" if not tools: return "No tools available for use." # 使用 getattr 安全地访问 name 和 description return "\n".join([ f"- **{getattr(t, 'name', 'Unnamed Tool')}**: {getattr(t, 'description', 'No description available.')}" for t in tools ]) # --- build/compile/get_agent --- def build(self) -> Optional[StateGraph]: """构建 Agent 的 LangGraph 工作流图定义。子类应实现。""" raise NotImplementedError("Subclasses must implement build() or override compile() directly.") def compile(self) -> CompiledGraph: """编译 Agent 工作流。""" if self._compiled_agent is not None: return self._compiled_agent # 尝试调用 build() 来获取 StateGraph workflow = self.build() if workflow is None or not isinstance(workflow, StateGraph): # 如果 build() 不返回 StateGraph (例如 ReactAgent), # 子类的 compile() 需要被覆盖以处理编译 raise ValueError( f"Agent '{self.name}': build() did not return a valid StateGraph, " "and compile() was not overridden to handle direct compilation." ) print(f"Compiling graph for agent: {self.name}") try: # 编译 StateGraph 并存储结果 self._compiled_agent = workflow.compile( checkpointer=self.checkpointer, debug=getattr(self, 'debug', False) # 传递 debug 标志 ) print(f"Graph compiled successfully for agent: {self.name}") return self._compiled_agent except Exception as e: print(f"!!! Error compiling graph for agent {self.name}: {e}") import traceback traceback.print_exc() raise e def get_agent(self) -> CompiledGraph: """获取编译后的核心图实例,如果未编译则先编译。""" if self._compiled_agent is None: print(f"Agent '{self.name}' not compiled yet. Compiling now.") self.compile() if self._compiled_agent is None: raise RuntimeError(f"Failed to get compiled agent for '{self.name}'.") return self._compiled_agent # --- invoke/ainvoke: 标准入口点,调用编译后的图 --- def invoke(self, state: Dict[str, Any], config: Optional[RunnableConfig] = None) -> Dict[str, Any]: """同步调用编译后的 Agent 图。""" try: compiled_agent = self.get_agent() # 获取 (或编译) 图 print(f"--- Invoking Agent: {self.name} ---") # 直接调用编译后的图,预处理由图内部的 prompt callable 处理 (如果使用 ReactAgent) # 或由 Supervisor 节点逻辑处理 (如果使用自定义 Supervisor) result = compiled_agent.invoke(state, config=config) print(f"--- Agent Invocation Complete: {self.name} ---") return cast(Dict[str, Any], result) # 假设返回字典 except Exception as e: print(f"!!! Error during {self.name} agent invocation: {e}") import traceback traceback.print_exc() # 返回带错误标记的状态 (可能是输入状态) state["error"] = f"Agent invocation failed: {e}" return state async def ainvoke(self, state: Dict[str, Any], config: Optional[RunnableConfig] = None) -> Dict[str, Any]: """异步调用编译后的 Agent 图。""" try: compiled_agent = self.get_agent() # 获取 (或编译) 图 print(f"--- Invoking Agent Async: {self.name} ---") # 直接调用编译后的图 result = await compiled_agent.ainvoke(state, config=config) print(f"--- Agent Invocation Complete Async: {self.name} ---") return cast(Dict[str, Any], result) # 假设返回字典 except Exception as e: print(f"!!! Error during {self.name} agent async invocation: {e}") import traceback traceback.print_exc() state["error"] = f"Agent async invocation failed: {e}" return state def run(self, state: Dict[str, Any]) -> Dict[str, Any]: """Run the supervisor workflow synchronously. Args: state: The input state for the workflow Returns: The output state from the workflow """ return self.invoke(state) async def arun(self, state: Dict[str, Any]) -> Dict[str, Any]: """Run the supervisor workflow asynchronously. Args: state: The input state for the workflow Returns: The output state from the workflow """ return await self.ainvoke(state) def reset(self): """重置编译状态,强制下次重新编译。""" print(f"Resetting compiled graph for agent '{self.name}'. Will recompile on next use.") self._compiled_agent = None self._workflow = None def add_tools(self, tools: List[Union[BaseTool, Callable]]) -> None: """添加工具到 Agent 的工具列表。""" print(f"Warning: Adding tools to {self.name} post-initialization. Agent needs recompilation.") self.tools.extend(tools) self.reset() ================================================ FILE: core/agents/base/create_react_agent_wrapper.py ================================================ import logging from typing import Optional, Callable, Dict from langgraph.utils.runnable import RunnableCallable from langchain_core.runnables.config import RunnableConfig logger = logging.getLogger(__name__) class CreateReactAgentWrapper(RunnableCallable): def __init__( self, agent, name: str = "agent", before_invoke: Optional[Callable[[dict], dict]] = None, before_ainvoke: Optional[Callable[[dict], dict]] = None, after_invoke: Optional[Callable[[dict, dict], None]] = None, after_ainvoke: Optional[Callable[[dict, dict], None]] = None ): """ :param agent: The underlying compiled graph or runnable :param name: Unique name for this wrapper (avoid duplicates) :param before_invoke: A sync callback that modifies the state before the wrapped agent call :param before_ainvoke: An async callback that modifies the state before the wrapped agent call :param after_invoke: A sync callback that inspects (state, output) after the wrapped call :param after_ainvoke: An async callback that inspects (state, output) after the wrapped call """ self._agent = agent self.name = name or getattr(agent, "name", "agent") self.before_invoke = before_invoke self.after_invoke = after_invoke self.before_ainvoke = before_ainvoke self.after_ainvoke = after_ainvoke # We define the sync/async "call" functions for RunnableCallable def call(state: Dict, config: Optional[RunnableConfig] = None, **kwargs) -> Dict: logger.info(f"[{self.name}] (sync) call() - started. State keys: {list(state.keys())}") # Or use print if you prefer # print(f"🟢 [Sync] Invoking wrapper: {self.name}, state keys: {list(state.keys())}") # before_invoke callback if self.before_invoke: state = self.before_invoke(state) # Call the underlying agent output = self._agent.invoke(state, config, **kwargs) # after_invoke callback if self.after_invoke: self.after_invoke(state, output) logger.info(f"[{self.name}] (sync) call() - finished. Output keys: {list(output.keys())}") return output async def acall(state: Dict, config: Optional[RunnableConfig] = None, **kwargs) -> Dict: logger.info(f"[{self.name}] (async) acall() - started. State keys: {list(state.keys())}") # print(f"🟢 [Async] Invoking wrapper: {self.name}, state keys: {list(state.keys())}") if self.before_ainvoke: state = await self.before_ainvoke(state) output = await self._agent.ainvoke(state, config, **kwargs) if self.after_ainvoke: await self.after_ainvoke(state, output) logger.info(f"[{self.name}] (async) acall() - finished. Output keys: {list(output.keys())}") return output # Pass these to RunnableCallable super().__init__(call, acall, name=self.name) ================================================ FILE: core/agents/base/react_agent.py ================================================ from typing import Any, Callable, Dict, List, Optional, Type, Union, Literal, Sequence from langchain_core.language_models import LanguageModelLike, LanguageModelInput from langchain_core.tools import BaseTool from langgraph.graph import StateGraph from langgraph.graph.graph import CompiledGraph from langgraph.types import Checkpointer from langgraph.store.base import BaseStore from langchain_core.messages import BaseMessage, SystemMessage # 导入 SystemMessage from langgraph.prebuilt import create_react_agent from langgraph.prebuilt.chat_agent_executor import ( AgentState, StateSchemaType, StructuredResponseSchema, ) from core.agents.base.base_agent import BaseAgent import logging logger = logging.getLogger(__name__) class ReactAgent(BaseAgent): """ReAct Agent class for reasoning and acting with tools. This class provides a high-level interface for creating a ReAct agent workflow that can perform multi-step reasoning and tool calling. """ def __init__( self, model: LanguageModelLike, tools: Optional[List[Union[BaseTool, Callable]]] = None, prompt: Optional[str] = None, response_format: Optional[ Union[StructuredResponseSchema, tuple[str, StructuredResponseSchema]] ] = None, state_schema: StateSchemaType = AgentState, config_schema: Type[Any] = None, checkpointer: Optional[Checkpointer] = None, store: Optional[BaseStore] = None, interrupt_before: Optional[List[str]] = None, interrupt_after: Optional[List[str]] = None, debug: bool = False, version: Literal["v1", "v2"] = "v1", name: str = "react_agent", description: str = "ReAct agent for reasoning and acting with tools.", max_context_messages: Optional[int] = None, max_context_tokens: Optional[int] = None, model_name: Optional[str] = "gpt-4o-mini", ): """Initialize a ReAct agent. Args: model: Language model to use for the agent tools: Optional list of tools available to the agent prompt: Optional prompt to use for the agent response_format: Optional schema for structured output state_schema: State schema to use for the agent graph config_schema: Optional schema for configuration interrupt_before: Optional list of nodes to interrupt before execution interrupt_after: Optional list of nodes to interrupt after execution debug: Whether to enable debug mode version: Version of the ReAct agent ("v1" or "v2") name: Name of the agent max_context_messages: Optional limit on number of recent messages max_context_tokens: Optional limit on total estimated tokens model_name: Optional model name for token estimation """ # Call BaseAgent's __init__ to initialize parent class attributes super().__init__( name=name, model=model, tools=tools or [], prompt=prompt, description=description, checkpointer=checkpointer, max_context_messages=max_context_messages, max_context_tokens=max_context_tokens, model_name=model_name ) # Initialize ReactAgent specific attributes self.response_format = response_format self.react_state_schema = state_schema self.react_config_schema = config_schema self.react_store = store self.react_interrupt_before = interrupt_before self.react_interrupt_after = interrupt_after self.react_debug = debug self.react_version = version def _prepare_llm_input(self, state: Dict[str, Any]) -> LanguageModelInput: """ 准备 LLM 输入:截断消息历史并添加基础 System Prompt (如果存在)。 作为 Callable 传递给 create_react_agent 的 prompt 参数。 """ # 1. 从状态获取消息 (BaseAgent 的方法) messages = self._get_state_value(state, "messages", []) # 2. 截断消息 (BaseAgent 的方法) # 注意:这里截断的是进入 LLM 前的列表,checkpointer 中的完整历史不受影响 # --- 添加 Debug 打印 (截断前) --- # print(f"\nDEBUG _prepare_llm_input ({self.name}): BEFORE truncation (length {len(messages)}):") # for i, msg in enumerate(messages[-5:]): # 只看最后几条 # print(f" Msg {i-5}: Type={type(msg).__name__}, ToolCallID={getattr(msg, 'tool_call_id', 'N/A')}") # --- truncated_messages = self._truncate_messages(messages) # --- 添加 Debug 打印 (截断后) --- # print(f"DEBUG _prepare_llm_input ({self.name}): AFTER truncation (length {len(truncated_messages)}):") # for i, msg in enumerate(truncated_messages[-5:]): # 只看最后几条 # print(f" Msg {i-5}: Type={type(msg).__name__}, ToolCallID={getattr(msg, 'tool_call_id', 'N/A')}") # --- # 3. 添加基础 System Prompt (如果存在) final_messages: List[BaseMessage] = [] if self.base_prompt: if isinstance(self.base_prompt, str): final_messages.append(SystemMessage(content=self.base_prompt)) elif isinstance(self.base_prompt, SystemMessage): final_messages.append(self.base_prompt) # 如果 self.base_prompt 是其他 Runnable 或 Callable,需要相应处理 # 但 create_react_agent 的 prompt 通常是 str 或 SystemMessage final_messages.extend(truncated_messages) # print(f"DEBUG [{self.name}]: Preparing LLM input with {len(final_messages)} messages.") # Optional debug log # 返回最终的消息列表给 LLM return final_messages def build(self) -> Optional[StateGraph]: """对于 ReactAgent,核心图由 create_react_agent 直接创建,无需 build。""" print(f"Note: ReactAgent '{self.name}' uses create_react_agent in compile(). Build returns None.") self._workflow = None return None def compile(self) -> CompiledGraph: """使用 create_react_agent 构建并编译核心 ReAct 工作流,存储在 _compiled_agent。""" if self._compiled_agent is not None: return self._compiled_agent print(f"[[DEBUG]] Compiling core ReAct agent for: {self.name} using create_react_agent") try: # 使用 create_react_agent 创建编译后的图 # 将 self._prepare_llm_input 作为 prompt callable 传入 compiled_agent = create_react_agent( model=self.model, tools=self.tools, prompt=self._prepare_llm_input, # <--- 关键改动:传入准备函数 state_schema=self.react_state_schema, config_schema=self.react_config_schema, checkpointer=self.checkpointer, store=self.react_store, interrupt_before=self.react_interrupt_before, interrupt_after=self.react_interrupt_after, debug=self.react_debug, version=self.react_version, name=self.name, ) # 存储编译好的图 self._compiled_agent = compiled_agent print(f"Core ReAct graph compiled successfully for agent: {self.name}") return self._compiled_agent except Exception as e: print(f"!!! Error compiling graph for agent {self.name} using create_react_agent: {e}") import traceback traceback.print_exc() self._compiled_agent = None raise e ================================================ FILE: core/agents/react_based_supervisor/__init__.py ================================================ # 从当前目录导入create_supervisor函数 from .supervisor import create_supervisor __all__ = ["create_supervisor"] ================================================ FILE: core/agents/react_based_supervisor/agent_name.py ================================================ import re from typing import Literal from langchain_core.language_models import LanguageModelLike from langchain_core.messages import AIMessage, BaseMessage from langchain_core.runnables import RunnableLambda NAME_PATTERN = re.compile(r"(.*?)", re.DOTALL) CONTENT_PATTERN = re.compile(r"(.*?)", re.DOTALL) AgentNameMode = Literal["inline"] def _is_content_blocks_content(content: list[dict] | str) -> bool: return ( isinstance(content, list) and len(content) > 0 and isinstance(content[0], dict) and "type" in content[0] ) def add_inline_agent_name(message: BaseMessage) -> BaseMessage: """Add name and content XML tags to the message content. Examples: >>> add_inline_agent_name(AIMessage(content="Hello", name="assistant")) AIMessage(content="assistantHello", name="assistant") >>> add_inline_agent_name(AIMessage(content=[{"type": "text", "text": "Hello"}], name="assistant")) AIMessage(content=[{"type": "text", "text": "assistantHello"}], name="assistant") """ if not isinstance(message, AIMessage) or not message.name: return message formatted_message = message.model_copy() if _is_content_blocks_content(formatted_message.content): text_blocks = [block for block in message.content if block["type"] == "text"] non_text_blocks = [block for block in message.content if block["type"] != "text"] content = text_blocks[0]["text"] if text_blocks else "" formatted_content = f"{message.name}{content}" formatted_message.content = non_text_blocks + [{"type": "text", "text": formatted_content}] else: formatted_message.content = ( f"{message.name}{formatted_message.content}" ) return formatted_message def remove_inline_agent_name(message: BaseMessage) -> BaseMessage: """Remove explicit name and content XML tags from the AI message content. Examples: >>> remove_inline_agent_name(AIMessage(content="assistantHello", name="assistant")) AIMessage(content="Hello", name="assistant") >>> remove_inline_agent_name(AIMessage(content=[{"type": "text", "text": "assistantHello"}], name="assistant")) AIMessage(content=[{"type": "text", "text": "Hello"}], name="assistant") """ if not isinstance(message, AIMessage) or not message.name: return message is_content_blocks_content = _is_content_blocks_content(message.content) if is_content_blocks_content: text_blocks = [block for block in message.content if block["type"] == "text"] if not text_blocks: return message non_text_blocks = [block for block in message.content if block["type"] != "text"] content = text_blocks[0]["text"] else: content = message.content name_match: re.Match | None = NAME_PATTERN.search(content) content_match: re.Match | None = CONTENT_PATTERN.search(content) if not name_match or not content_match: return message if name_match.group(1) != message.name: return message parsed_content = content_match.group(1) parsed_message = message.model_copy() if is_content_blocks_content: content_blocks = non_text_blocks if parsed_content: content_blocks.append({"type": "text", "text": parsed_content}) parsed_message.content = content_blocks else: parsed_message.content = parsed_content return parsed_message def with_agent_name( model: LanguageModelLike, agent_name_mode: AgentNameMode, ) -> LanguageModelLike: """Attach formatted agent names to the messages passed to and from a language model. This is useful for making a message history with multiple agents more coherent. NOTE: agent name is consumed from the message.name field. If you're using an agent built with create_react_agent, name is automatically set. If you're building a custom agent, make sure to set the name on the AI message returned by the LLM. Args: model: Language model to add agent name formatting to. agent_name_mode: Use to specify how to expose the agent name to the LLM. - "inline": Add the agent name directly into the content field of the AI message using XML-style tags. Example: "How can I help you" -> "agent_nameHow can I help you?". """ if agent_name_mode == "inline": process_input_message = add_inline_agent_name process_output_message = remove_inline_agent_name else: raise ValueError( f"Invalid agent name mode: {agent_name_mode}. Needs to be one of: {AgentNameMode.__args__}" ) def process_input_messages(messages: list[BaseMessage]) -> list[BaseMessage]: return [process_input_message(message) for message in messages] model = ( process_input_messages | model | RunnableLambda(process_output_message, name="process_output_message") ) return model ================================================ FILE: core/agents/react_based_supervisor/handoff.py ================================================ import re import uuid from langchain_core.messages import AIMessage, ToolCall, ToolMessage from langchain_core.tools import BaseTool, InjectedToolCallId, tool from langgraph.prebuilt import InjectedState from langgraph.types import Command from typing_extensions import Annotated WHITESPACE_RE = re.compile(r"\s+") def _normalize_agent_name(agent_name: str) -> str: """Normalize an agent name to be used inside the tool name.""" return WHITESPACE_RE.sub("_", agent_name.strip()).lower() def create_handoff_tool(*, agent_name: str) -> BaseTool: """Create a tool that can handoff control to the requested agent. Args: agent_name: The name of the agent to handoff control to, i.e. the name of the agent node in the multi-agent graph. Agent names should be simple, clear and unique, preferably in snake_case, although you are only limited to the names accepted by LangGraph nodes as well as the tool names accepted by LLM providers (the tool name will look like this: `transfer_to_`). """ tool_name = f"transfer_to_{_normalize_agent_name(agent_name)}" @tool(tool_name) def handoff_to_agent( state: Annotated[dict, InjectedState], tool_call_id: Annotated[str, InjectedToolCallId], ): """Ask another agent for help.""" tool_message = ToolMessage( content=f"Successfully transferred to {agent_name}", name=tool_name, tool_call_id=tool_call_id, ) return Command( goto=agent_name, graph=Command.PARENT, update={"messages": state["messages"] + [tool_message]}, ) return handoff_to_agent def create_handoff_back_messages( agent_name: str, supervisor_name: str ) -> tuple[AIMessage, ToolMessage]: """Create a pair of (AIMessage, ToolMessage) to add to the message history when returning control to the supervisor.""" tool_call_id = str(uuid.uuid4()) tool_name = f"transfer_back_to_{_normalize_agent_name(supervisor_name)}" tool_calls = [ToolCall(name=tool_name, args={}, id=tool_call_id)] return ( AIMessage( content=f"Transferring back to {supervisor_name}", tool_calls=tool_calls, name=agent_name, ), ToolMessage( content=f"Successfully transferred back to {supervisor_name}", name=tool_name, tool_call_id=tool_call_id, ), ) ================================================ FILE: core/agents/react_based_supervisor/planning_handler.py ================================================ import uuid import datetime from typing import List, Dict, Optional class PlanningStateHandler: """ Manages a project plan. A plan is a dict with: - title (str) - description (str) - status (str): "planning", "in_progress", or "completed" - tasks (list): each task is a dict with: id, description, status, agent, notes, evaluation - current_task_id (str or None) - created_at (str) - updated_at (str) """ @staticmethod def _now() -> str: return datetime.datetime.now().isoformat() @staticmethod def _gen_id() -> str: return str(uuid.uuid4()) @staticmethod def create_plan(title: str, description: str) -> Dict: now = PlanningStateHandler._now() return { "title": title, "description": description, "status": "planning", # initial status "tasks": [], "current_task_id": None, "created_at": now, "updated_at": now } @staticmethod def create_task(description: str, status: str = "pending", agent: str = "", notes: str = "", evaluation: str = "") -> Dict: return { "id": PlanningStateHandler._gen_id(), "description": description.strip(), "status": status.strip() if status else "pending", "agent": agent.strip(), "notes": notes.strip(), "evaluation": evaluation.strip() } @staticmethod def add_tasks(plan: Dict, tasks_data: List[Dict]) -> Dict: for tinfo in tasks_data: desc = tinfo.get("description", "Untitled Task") status = tinfo.get("status", "pending") agent = tinfo.get("agent", "") notes = tinfo.get("notes", "") eval_ = tinfo.get("evaluation", "") task = PlanningStateHandler.create_task(desc, status, agent, notes, eval_) plan["tasks"].append(task) plan["updated_at"] = PlanningStateHandler._now() return plan @staticmethod def update_task(plan: Dict, by_id: Optional[str] = None, new_desc: Optional[str] = None, new_status: Optional[str] = None, new_agent: Optional[str] = None, new_notes: Optional[str] = None, new_evaluation: Optional[str] = None) -> Dict: """ Update a task identified by by_id. """ if not by_id: raise ValueError("Must provide 'by_id' to update a task.") task = next((t for t in plan["tasks"] if t["id"] == by_id), None) if not task: raise ValueError("No matching task found with the given ID.") if new_desc is not None: task["description"] = new_desc.strip() if new_status is not None: task["status"] = new_status.strip() if new_agent is not None: task["agent"] = new_agent.strip() if new_notes is not None: task["notes"] = new_notes.strip() if new_evaluation is not None: task["evaluation"] = new_evaluation.strip() plan["updated_at"] = PlanningStateHandler._now() # Determine overall plan status if any(t["status"] == "in_progress" for t in plan["tasks"]): plan["status"] = "in_progress" if all(t["status"] == "completed" for t in plan["tasks"]) and plan["tasks"]: plan["status"] = "completed" return plan @staticmethod def set_current_task(plan: Dict, task_id: str) -> Dict: found = any(t["id"] == task_id for t in plan["tasks"]) if not found: raise ValueError("Task ID not found in plan.") plan["current_task_id"] = task_id plan["updated_at"] = PlanningStateHandler._now() return plan @staticmethod def finish_plan(plan: Dict) -> Dict: """ Forcefully mark the plan as completed. """ plan["status"] = "completed" plan["updated_at"] = PlanningStateHandler._now() return plan ================================================ FILE: core/agents/react_based_supervisor/simple_planning_tool.py ================================================ import json from typing import Dict, List, Optional from langchain_core.tools import BaseTool from core.agents.supervisor.planning_handler import PlanningStateHandler class SimplePlanningTool(BaseTool): """ A tool that manages a single project plan in memory. It supports creating, viewing, adding tasks, updating tasks, setting the current task, and finishing the plan. All operations return a JSON string. """ name: str = "planning" description: str = ("Manage a project plan with actions to create, view, add tasks, update tasks, " "set current task, and finish the plan. All data is stored in JSON.") def __init__(self): super().__init__() self._plan: Optional[Dict] = None def _run(self, action: str, **kwargs) -> str: try: if action == "create_plan": return self._handle_create_plan(**kwargs) elif action == "view_plan": return self._handle_view_plan() elif action == "add_tasks": return self._handle_add_tasks(**kwargs) elif action == "update_task": return self._handle_update_task(**kwargs) elif action == "set_current_task": return self._handle_set_current_task(**kwargs) elif action == "finish_plan": return self._handle_finish_plan() else: return self._json_error(f"Unknown action: {action}") except Exception as e: return self._json_error(str(e)) async def _arun(self, action: str, **kwargs) -> str: return self._run(action, **kwargs) def _handle_create_plan(self, **kwargs) -> str: title = kwargs.get("title", "Untitled Plan") description = kwargs.get("description", "") tasks_data = kwargs.get("tasks", []) new_plan = PlanningStateHandler.create_plan(title, description) PlanningStateHandler.add_tasks(new_plan, tasks_data) self._plan = new_plan return self._json_ok(self._plan) def _handle_view_plan(self) -> str: if not self._plan: self._plan = PlanningStateHandler.create_plan("Untitled", "") return self._json_ok(self._plan) def _handle_add_tasks(self, **kwargs) -> str: if not self._plan: self._plan = PlanningStateHandler.create_plan("Untitled", "") tasks_data = kwargs.get("tasks", []) PlanningStateHandler.add_tasks(self._plan, tasks_data) return self._json_ok(self._plan) def _handle_update_task(self, **kwargs) -> str: if not self._plan: raise ValueError("No plan exists. Please create a plan first.") # Use 'by_id' instead of 'task_id' by_id = kwargs.get("by_id") new_desc = kwargs.get("description") new_status = kwargs.get("status") new_agent = kwargs.get("agent") new_notes = kwargs.get("notes") new_evaluation = kwargs.get("evaluation") updated = PlanningStateHandler.update_task( self._plan, by_id=by_id, new_desc=new_desc, new_status=new_status, new_agent=new_agent, new_notes=new_notes, new_evaluation=new_evaluation ) self._plan = updated return self._json_ok(self._plan) def _handle_set_current_task(self, **kwargs) -> str: if not self._plan: raise ValueError("No plan available to set current task.") tid = kwargs.get("task_id") if not tid: raise ValueError("Must provide 'task_id' for set_current_task.") PlanningStateHandler.set_current_task(self._plan, tid) return self._json_ok(self._plan) def _handle_finish_plan(self) -> str: if not self._plan: raise ValueError("No plan exists to finish.") finished_plan = PlanningStateHandler.finish_plan(self._plan) self._plan = finished_plan return self._json_ok(finished_plan) def _json_ok(self, plan_data: Dict) -> str: return json.dumps({"ok": True, "plan": plan_data}, ensure_ascii=False, indent=2) def _json_error(self, message: str) -> str: return json.dumps({"ok": False, "error": message}, ensure_ascii=False, indent=2) ================================================ FILE: core/agents/react_based_supervisor/state_schema.py ================================================ from typing import Dict, List, Optional, Any, Literal, TypedDict, Union from langchain_core.messages import BaseMessage from langgraph.prebuilt.chat_agent_executor import AgentState # 定义计划状态类型 PlanningStatus = Literal["not_started", "planning", "executing", "completed", "failed"] # 定义任务状态类型 TaskStatus = Literal["pending", "in_progress", "completed", "failed"] # 定义任务项 class Task(TypedDict, total=False): """任务项定义 表示计划中的一个任务项,包含任务描述、状态、分配的代理等信息 """ id: str # 任务唯一标识符 description: str # 任务描述 status: TaskStatus # 任务状态 agent: Optional[str] # 分配的代理名称 created_at: str # 创建时间 updated_at: str # 更新时间 completed_at: Optional[str] # 完成时间 dependencies: Optional[List[str]] # 依赖的任务ID列表 notes: Optional[str] # 任务备注 # 定义计划 class Plan(TypedDict, total=False): """计划定义 表示一个完整的计划,包含计划状态、任务列表等信息 """ status: PlanningStatus # 计划状态 tasks: List[Task] # 任务列表 current_task_id: Optional[str] # 当前执行的任务ID created_at: str # 创建时间 updated_at: str # 更新时间 completed_at: Optional[str] # 完成时间 title: Optional[str] # 计划标题 description: Optional[str] # 计划描述 # 扩展AgentState以支持计划功能 class PlanningAgentState(AgentState): """支持计划功能的代理状态 扩展了AgentState,添加了plan字段用于存储计划信息 """ plan: Optional[Plan] = None ================================================ FILE: core/agents/react_based_supervisor/supervisor.py ================================================ import inspect from typing import Any, Callable, Literal, Optional, Type, Union, Dict, Optional from langchain_core.language_models import BaseChatModel, LanguageModelLike from langchain_core.tools import BaseTool from langgraph.graph import END, START, StateGraph from langgraph.prebuilt.chat_agent_executor import ( create_react_agent, AgentState, Prompt, StateSchemaType, StructuredResponseSchema, ) from langgraph.pregel import Pregel from langgraph.utils.runnable import RunnableCallable from core.agents.base.react_agent import ReactAgent from core.agents.supervisor.agent_name import AgentNameMode, with_agent_name from core.agents.supervisor.handoff import ( create_handoff_back_messages, create_handoff_tool, ) OutputMode = Literal["full_history", "last_message"] """Mode for adding agent outputs to the message history in the multi-agent workflow - `full_history`: add the entire agent message history - `last_message`: add only the last message """ MODELS_NO_PARALLEL_TOOL_CALLS = {"o3-mini"} def _supports_disable_parallel_tool_calls(model: LanguageModelLike) -> bool: if not isinstance(model, BaseChatModel): return False if hasattr(model, "model_name") and model.model_name in MODELS_NO_PARALLEL_TOOL_CALLS: return False if not hasattr(model, "bind_tools"): return False if "parallel_tool_calls" not in inspect.signature(model.bind_tools).parameters: return False return True def _make_call_agent( agent: Pregel, output_mode: OutputMode, add_handoff_back_messages: bool, supervisor_name: str, ) -> Callable[[dict], dict] | RunnableCallable: if output_mode not in OutputMode.__args__: raise ValueError( f"Invalid agent output mode: {output_mode}. Needs to be one of {OutputMode.__args__}" ) def _process_output(output: dict) -> dict: messages = output["messages"] if output_mode == "full_history": pass elif output_mode == "last_message": messages = messages[-1:] else: raise ValueError( f"Invalid agent output mode: {output_mode}. " f"Needs to be one of {OutputMode.__args__}" ) if add_handoff_back_messages: messages.extend(create_handoff_back_messages(agent.name, supervisor_name)) return { **output, "messages": messages, } def call_agent(state: dict) -> dict: #print(f"🟡 [Sync invoke] Handoff to agent '{agent.name}' with state keys: {list(state.keys())}") output = agent.invoke(state) #print(f"✅ [Sync invoke] Agent '{agent.name}' completed.") return _process_output(output) async def acall_agent(state: dict) -> dict: #print(f"🟡 [Async invoke] Handoff to agent '{agent.name}' with state keys: {list(state.keys())}") output = await agent.ainvoke(state) #print(f"✅ [Async invoke] Agent '{agent.name}' completed.") return _process_output(output) return RunnableCallable(call_agent, acall_agent) def create_supervisor( agents: list[Pregel], *, model: LanguageModelLike, tools: list[BaseTool | Callable] | None = None, prompt: Prompt | None = None, response_format: Optional[ Union[StructuredResponseSchema, tuple[str, StructuredResponseSchema]] ] = None, state_schema: StateSchemaType = AgentState, config_schema: Type[Any] | None = None, output_mode: OutputMode = "last_message", add_handoff_back_messages: bool = True, supervisor_name: str = "supervisor", include_agent_name: AgentNameMode | None = None, ) -> StateGraph: """Create a multi-agent supervisor. Args: agents: List of agents to manage model: Language model to use for the supervisor tools: Tools to use for the supervisor prompt: Optional prompt to use for the supervisor. Can be one of: - str: This is converted to a SystemMessage and added to the beginning of the list of messages in state["messages"]. - SystemMessage: this is added to the beginning of the list of messages in state["messages"]. - Callable: This function should take in full graph state and the output is then passed to the language model. - Runnable: This runnable should take in full graph state and the output is then passed to the language model. response_format: An optional schema for the final supervisor output. If provided, output will be formatted to match the given schema and returned in the 'structured_response' state key. If not provided, `structured_response` will not be present in the output state. Can be passed in as: - an OpenAI function/tool schema, - a JSON Schema, - a TypedDict class, - or a Pydantic class. - a tuple (prompt, schema), where schema is one of the above. The prompt will be used together with the model that is being used to generate the structured response. !!! Important `response_format` requires the model to support `.with_structured_output` !!! Note `response_format` requires `structured_response` key in your state schema. You can use the prebuilt `langgraph.prebuilt.chat_agent_executor.AgentStateWithStructuredResponse`. state_schema: State schema to use for the supervisor graph. config_schema: An optional schema for configuration. Use this to expose configurable parameters via supervisor.config_specs. output_mode: Mode for adding managed agents' outputs to the message history in the multi-agent workflow. Can be one of: - `full_history`: add the entire agent message history - `last_message`: add only the last message (default) add_handoff_back_messages: Whether to add a pair of (AIMessage, ToolMessage) to the message history when returning control to the supervisor to indicate that a handoff has occurred. supervisor_name: Name of the supervisor node. include_agent_name: Use to specify how to expose the agent name to the underlying supervisor LLM. - None: Relies on the LLM provider using the name attribute on the AI message. Currently, only OpenAI supports this. - "inline": Add the agent name directly into the content field of the AI message using XML-style tags. Example: "How can I help you" -> "agent_nameHow can I help you?" """ agent_names = set() for agent in agents: if agent.name is None or agent.name == "LangGraph": raise ValueError( "Please specify a name when you create your agent, either via `create_react_agent(..., name=agent_name)` " "or via `graph.compile(name=name)`." ) if agent.name in agent_names: raise ValueError( f"Agent with name '{agent.name}' already exists. Agent names must be unique." ) agent_names.add(agent.name) handoff_tools = [create_handoff_tool(agent_name=agent.name) for agent in agents] all_tools = (tools or []) + handoff_tools if _supports_disable_parallel_tool_calls(model): model = model.bind_tools(all_tools, parallel_tool_calls=False) else: model = model.bind_tools(all_tools) if include_agent_name: model = with_agent_name(model, include_agent_name) supervisor = create_react_agent( name=supervisor_name, model=model, tools=all_tools, prompt=prompt, state_schema=state_schema, response_format=response_format, debug=False, ) # Build the multi-agent supervisor graph using the langgraph StateGraph setup builder = StateGraph(state_schema, config_schema=config_schema) builder.add_node(supervisor, destinations=tuple(agent_names) + (END,)) builder.add_edge(START, supervisor.name) for agent in agents: # If agent is a "ReactAgent" or similar if hasattr(agent, "get_agent") and callable(agent.get_agent): agent = agent.get_agent() # retrieve the compiled subgraph builder.add_node( agent.name, _make_call_agent( agent, output_mode, add_handoff_back_messages, supervisor_name, ), ) builder.add_edge(agent.name, supervisor.name) return builder ================================================ FILE: core/agents/react_supervisor_agent.py ================================================ from typing import Any, Callable, Dict, List, Optional, Union import re from langchain_core.language_models import LanguageModelLike from langchain_core.tools import BaseTool from langgraph.graph import StateGraph from langgraph.graph.state import CompiledStateGraph from langgraph.types import Checkpointer from langgraph.prebuilt.chat_agent_executor import ( AgentState, StateSchemaType, ) from langgraph.utils.runnable import RunnableCallable from core.agents.react_based_supervisor import create_supervisor from core.agents.react_based_supervisor.simple_planning_tool import SimplePlanningTool from core.agents.base.base_agent import BaseAgent from core.agents.react_based_supervisor.state_schema import PlanningAgentState import logging logger = logging.getLogger(__name__) class SupervisorAgent(BaseAgent): """Supervisor class for managing multiple agents with planning capabilities. This class provides a high-level interface for creating a supervisor workflow that can manage and coordinate multiple agents. It also includes planning capabilities to create and manage a plan for complex tasks using a state-driven approach. The planning functionality is implemented using PlanningStateHandler and PlanningTool, which provide a more structured and flexible way to manage tasks compared to the previous TodolistTool approach. """ _PROMPT_TEMPLATE = """You are a Supervisor Agent. Your job is to analyze user requests and coordinate multiple agents to complete tasks. ## Task Approach Methodology ### Understanding Requirements - Analyzing user requests to identify core needs - Asking clarifying questions when requirements are ambiguous - Breaking down complex requests into manageable components - Identifying potential challenges before beginning work ### Coordination - Identifying appropriate agents for each task - Delegating tasks to specialized agents - Tracking progress and ensuring task completion - Synthesizing information from multiple agents Remember: Effective coordination is essential for successful task completion. Take time to understand the request and delegate appropriately. {tools} """ _PLANNING_PROMPT_TEMPLATE = """You are a Supervisor agent. Your role is to analyze user requests, break them down into actionable tasks, and coordinate specialized agents (e.g., research_expert, coder_expert, reporter_expert) to complete them. # Working with Complex Requests 1. FIRST, carefully analyze the user's request and break it down into clear, actionable tasks 2. Identify which agent is best suited for each part of the task 3. Use the handoff tools to delegate tasks to appropriate agents ONE AT A TIME 4. WAIT for each agent to COMPLETELY FINISH their assigned task before proceeding 5. Review the output from each agent before delegating the next task 6. Maintain a sequential workflow - never delegate multiple tasks simultaneously 7. Synthesize the results and provide a coherent response to the user 8. Provide a final summary when all tasks are done """ _PLANNING_TOOL_TEMPLATE = """ # Planning Tool Instructions You have access to a "planning" tool that uses JSON for all operations. Do NOT include any "state" field in your calls. Use the following actions exactly as defined: 1. "create_plan": Create a new plan. - Required fields: - title (string) - description (string) - tasks (list of task objects). Each task object must include: "description": string, "status": "pending" (all tasks must have "status": "pending" initially), "agent": string (empty if not assigned), "notes": string (empty if none), "evaluation": string (empty if none) - Example: { "action": "create_plan", "title": "Python Scraper for Tech News", "description": "Build a Python scraper to fetch the latest tech news and save it as CSV", "tasks": [ {"description": "Research Python scraping libraries", "status": "pending", "agent": "", "notes": "", "evaluation": ""}, {"description": "Implement the scraper", "status": "pending", "agent": "", "notes": "", "evaluation": ""}, {"description": "Test the code", "status": "pending", "agent": "", "notes": "", "evaluation": ""} ] } 2. "view_plan": Retrieve the current plan. - Example: { "action": "view_plan" } 3. "add_tasks": Add additional tasks to the plan. - Required: - tasks: list of task objects (same format as above) - Example: { "action": "add_tasks", "tasks": [ {"description": "Write documentation", "status": "pending", "agent": "", "notes": "", "evaluation": ""} ] } 4. "update_task": Update an existing task. - Identify the task by "by_id" (the task's unique ID from the plan). - You may update any of: "description", "status", "agent", "notes", "evaluation". - Example: { "action": "update_task", "by_id": "TASK-UUID", "status": "completed", "evaluation": "The scraper works perfectly." } 5. "set_current_task": Set the current task by its ID. - Example: { "action": "set_current_task", "task_id": "TASK-UUID" } 6. "finish_plan": Mark the entire plan as completed. - Example: { "action": "finish_plan" } Important: - Always produce valid JSON for your tool calls. - Continuously update and monitor the plan until every task's status is "completed" before delivering your final answer. - If the plan is not fully completed, do not stop; keep updating the plan with appropriate calls. """ def __init__( self, agents: List[BaseAgent], model: LanguageModelLike, tools: Optional[List[Union[BaseTool, Callable]]] = None, prompt: Optional[str] = None, state_schema: StateSchemaType = AgentState, supervisor_name: str = "supervisor", checkpointer: Optional[Checkpointer] = None, output_mode: str = "last_message", # * full_history or last_message * enable_planning: bool = True, # * True or False * ): """Initialize a supervisor. Args: agents: List of agents to manage model: Language model to use for the supervisor tools: Optional list of tools available to the supervisor prompt: Optional prompt to use for the supervisor state_schema: State schema to use for the supervisor graph supervisor_name: Name of the supervisor node checkpointer: Optional checkpointer to use for the supervisor output_mode: Mode for adding agent outputs to the message history ("full_history" or "last_message") enable_planning: Whether to enable planning capabilities auto_planning: Whether to automatically generate plans for complex tasks """ # 设置规划相关属性 self._enable_planning = enable_planning # 如果启用规划功能,设置状态模式为PlanningAgentState if self._enable_planning and state_schema == AgentState: state_schema = PlanningAgentState # Store agent-specific attributes before super().__init__ self.agents = agents self.output_mode = output_mode self.supervisor_name = supervisor_name self.state_schema = state_schema self.checkpointer = checkpointer self.tools = tools or [] self._workflow = None # 生成基础提示词 # _agents_prompt = self._generate_agents_prompt() _final_prompt = self._PLANNING_PROMPT_TEMPLATE + "/n/n" + self._PLANNING_TOOL_TEMPLATE if self._enable_planning else self._PROMPT_TEMPLATE if tools is None: tools = [] # 如果启用规划功能,添加规划提示模板并添加规划工具 if self._enable_planning: tools.append(SimplePlanningTool()) # 初始化BaseAgent父类 super().__init__( name=supervisor_name, model=model, tools=tools, checkpointer=checkpointer, prompt=_final_prompt, ) def build(self) -> StateGraph: """Build the supervisor workflow. Returns: The built StateGraph """ if self._workflow is not None: return self._workflow self._workflow = create_supervisor( agents=self.agents, model=self.model, tools=self.tools, prompt=self.base_prompt, state_schema=self.state_schema, supervisor_name=self.supervisor_name, output_mode=self.output_mode, ) return self._workflow ================================================ FILE: core/agents/sb_supervisor_agent.py ================================================ # reason_graph/supervisor_agent.py from typing import Callable, List, Optional, Union, cast, Literal from langchain_core.language_models import LanguageModelLike from langchain_core.tools import BaseTool from langgraph.graph import StateGraph from langgraph.types import Checkpointer # 内部导入 from core.agents.base.base_agent import BaseAgent from core.agents.state_based_supervisor.state_schema import PlanningAgentState, StateSchemaType # 导入 PlanningAgentState # 导入重构后的 create_supervisor 函数 from core.agents.state_based_supervisor.supervisor_graph import create_supervisor from core.agents.state_based_supervisor.agent_name import AgentNameMode import logging logger = logging.getLogger(__name__) class SupervisorAgent(BaseAgent): """ Supervisor Agent 类 (最终版) 负责协调子 Agent 并管理规划 (使用状态驱动方法)。 invoke/ainvoke 继承自 BaseAgent,负责完整流程。 """ def __init__( self, agents: List[BaseAgent], # 子 Agent 实例列表 model: LanguageModelLike, # Supervisor 使用的 LLM tools: Optional[List[Union[BaseTool, Callable]]] = None, # Supervisor 特有工具 state_schema: StateSchemaType = PlanningAgentState, supervisor_name: str = "supervisor", checkpointer: Optional[Checkpointer] = None, output_mode: str = "last_message", # enable_planning: bool = True, # 不再需要,强制使用 Planning include_agent_name: Optional[str] = "inline", # BaseAgent 参数 max_context_messages: Optional[int] = None, max_context_tokens: Optional[int] = None, model_name: Optional[str] = None, ): """初始化 Supervisor Agent""" if state_schema != PlanningAgentState: print("Warning: SupervisorAgent forces state_schema to PlanningAgentState.") state_schema = PlanningAgentState self.sub_agents = agents self.output_mode = output_mode self.include_agent_name = cast(Optional[AgentNameMode], include_agent_name) # 初始化 BaseAgent 父类 super().__init__( name=supervisor_name, model=model, tools=tools or [], checkpointer=checkpointer, prompt=None, # 核心 Prompt 在 supervisor_node_logic 中处理 max_context_messages=max_context_messages, max_context_tokens=max_context_tokens, model_name=model_name, ) # _workflow_definition 和 _executable_agent 由 BaseAgent 管理 def build(self) -> Optional[StateGraph]: """构建 Supervisor 的 LangGraph 工作流图定义。""" # 调用重构后的 create_supervisor 函数来获取 StateGraph 定义 # 这个 StateGraph 包含了手写的 supervisor_node_logic if self._workflow: return self._workflow print(f"Building supervisor graph definition for '{self.name}'...") try: graph_definition = create_supervisor( model=self.model, sub_agents=self.sub_agents, state_schema=PlanningAgentState, # 强制使用 tools=self.tools, output_mode=cast(Literal["full_history", "last_message"], self.output_mode), supervisor_name=self.name, include_agent_name=self.include_agent_name, ) self._workflow = graph_definition # 存储图定义 print(f"Supervisor graph definition built for '{self.name}'.") return self._workflow except Exception as e: print(f"!!! Error building supervisor graph definition '{self.name}': {e}") import traceback traceback.print_exc() self._workflow = None raise e # compile 方法继承自 BaseAgent # 它会调用上面的 build() 获取 StateGraph 定义,然后编译它, # 并创建包含预处理步骤的最终 _executable_agent # invoke, ainvoke, get_agent, reset 继承自 BaseAgent ================================================ FILE: core/agents/state_based_supervisor/__init__.py ================================================ ================================================ FILE: core/agents/state_based_supervisor/agent_name.py ================================================ import re from typing import Literal from langchain_core.language_models import LanguageModelLike from langchain_core.messages import AIMessage, BaseMessage from langchain_core.runnables import RunnableLambda NAME_PATTERN = re.compile(r"(.*?)", re.DOTALL) CONTENT_PATTERN = re.compile(r"(.*?)", re.DOTALL) AgentNameMode = Literal["inline"] def _is_content_blocks_content(content: list[dict] | str) -> bool: return ( isinstance(content, list) and len(content) > 0 and isinstance(content[0], dict) and "type" in content[0] ) def add_inline_agent_name(message: BaseMessage) -> BaseMessage: """Add name and content XML tags to the message content. Examples: >>> add_inline_agent_name(AIMessage(content="Hello", name="assistant")) AIMessage(content="assistantHello", name="assistant") >>> add_inline_agent_name(AIMessage(content=[{"type": "text", "text": "Hello"}], name="assistant")) AIMessage(content=[{"type": "text", "text": "assistantHello"}], name="assistant") """ if not isinstance(message, AIMessage) or not message.name: return message formatted_message = message.model_copy() if _is_content_blocks_content(formatted_message.content): text_blocks = [block for block in message.content if block["type"] == "text"] non_text_blocks = [block for block in message.content if block["type"] != "text"] content = text_blocks[0]["text"] if text_blocks else "" formatted_content = f"{message.name}{content}" formatted_message.content = non_text_blocks + [{"type": "text", "text": formatted_content}] else: formatted_message.content = ( f"{message.name}{formatted_message.content}" ) return formatted_message def remove_inline_agent_name(message: BaseMessage) -> BaseMessage: """Remove explicit name and content XML tags from the AI message content. Examples: >>> remove_inline_agent_name(AIMessage(content="assistantHello", name="assistant")) AIMessage(content="Hello", name="assistant") >>> remove_inline_agent_name(AIMessage(content=[{"type": "text", "text": "assistantHello"}], name="assistant")) AIMessage(content=[{"type": "text", "text": "Hello"}], name="assistant") """ if not isinstance(message, AIMessage) or not message.name: return message is_content_blocks_content = _is_content_blocks_content(message.content) if is_content_blocks_content: text_blocks = [block for block in message.content if block["type"] == "text"] if not text_blocks: return message non_text_blocks = [block for block in message.content if block["type"] != "text"] content = text_blocks[0]["text"] else: content = message.content name_match: re.Match | None = NAME_PATTERN.search(content) content_match: re.Match | None = CONTENT_PATTERN.search(content) if not name_match or not content_match: return message if name_match.group(1) != message.name: return message parsed_content = content_match.group(1) parsed_message = message.model_copy() if is_content_blocks_content: content_blocks = non_text_blocks if parsed_content: content_blocks.append({"type": "text", "text": parsed_content}) parsed_message.content = content_blocks else: parsed_message.content = parsed_content return parsed_message def with_agent_name( model: LanguageModelLike, agent_name_mode: AgentNameMode, ) -> LanguageModelLike: """Attach formatted agent names to the messages passed to and from a language model. This is useful for making a message history with multiple agents more coherent. NOTE: agent name is consumed from the message.name field. If you're using an agent built with create_react_agent, name is automatically set. If you're building a custom agent, make sure to set the name on the AI message returned by the LLM. Args: model: Language model to add agent name formatting to. agent_name_mode: Use to specify how to expose the agent name to the LLM. - "inline": Add the agent name directly into the content field of the AI message using XML-style tags. Example: "How can I help you" -> "agent_nameHow can I help you?". """ if agent_name_mode == "inline": process_input_message = add_inline_agent_name process_output_message = remove_inline_agent_name else: raise ValueError( f"Invalid agent name mode: {agent_name_mode}. Needs to be one of: {AgentNameMode.__args__}" ) def process_input_messages(messages: list[BaseMessage]) -> list[BaseMessage]: return [process_input_message(message) for message in messages] model = ( process_input_messages | model | RunnableLambda(process_output_message, name="process_output_message") ) return model ================================================ FILE: core/agents/state_based_supervisor/evaluate_result_node.py ================================================ # reason_graph/evaluate_result_node.py import json import time import copy import traceback import anyio from typing import Dict, Any, List, Optional, Union from langchain_core.messages import BaseMessage, AIMessage, ToolMessage from langchain_core.runnables import RunnableConfig # 内部导入 (确保路径正确) try: from .state_schema import PlanningAgentState, TaskStatus, Plan, Task from .planning_handler import PlanningStateHandler except ImportError as e: print(f"Error importing modules in evaluate_result_node.py: {e}") # Fallbacks class PlanningAgentState(Dict): pass; class Plan(Dict): pass; class Task(Dict): pass TaskStatus = str class PlanningStateHandler: # Dummy @staticmethod def update_task(plan, by_id, **kwargs): return plan @staticmethod def set_current_task(plan, task_id): return plan @staticmethod def get_task(plan, task_id): return None @staticmethod def update_plan_status(plan): return plan async def evaluate_result_node_logic(state: PlanningAgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]: """ 评估子 Agent 返回结果并更新计划状态的节点逻辑 (异步, 优化评估逻辑)。 """ print(f"--- Entering Evaluate Result Node ---") messages: List[BaseMessage] = state.get('messages', []) plan: Optional[Plan] = state.get('plan') last_message = messages[-1] if messages else None error_message: Optional[str] = None plan_updated: bool = False updated_plan: Optional[Plan] = copy.deepcopy(plan) if plan else None if not updated_plan: print("Evaluate Result Node: No plan found in state. Skipping.") return {} current_task_id = updated_plan.get("current_task_id") if not current_task_id: # Fallback logic for finding current task (不变) print("Warning: Evaluate Result Node - No current_task_id found in plan...") in_progress_tasks = [t for t in updated_plan.get('tasks', []) if t.get('status') == 'in_progress'] if len(in_progress_tasks) == 1: current_task_id = in_progress_tasks[0].get('id'); print(f" Fallback: Found task {current_task_id}") else: error_message = "Evaluation failed: Cannot determine finished task."; print(f"ERROR: {error_message}"); return {"plan": updated_plan, "error": error_message, "messages": []} agent_result_content: Optional[str] = None agent_name: Optional[str] = None if isinstance(last_message, AIMessage): agent_result_content = str(last_message.content) if last_message.content is not None else "" # Ensure string agent_name = last_message.name or "SubAgent" print(f" Evaluating result from: {agent_name} for task ID: {current_task_id}") else: agent_result_content = f"Error: Expected AIMessage result, got {type(last_message).__name__}." agent_name = "System/Error" print(f"Warning: Last message not AIMessage. Assuming task failed for {current_task_id}.") # --- 优化的评估逻辑 --- new_status: TaskStatus = "completed" # 默认成功 evaluation_notes = f"Result received from {agent_name}." # 1. 检查是否为空内容 (或只有空白符) if agent_result_content is None or not agent_result_content.strip(): new_status = "failed" evaluation_notes = f"Task failed: Agent {agent_name} returned empty content." print(f" Task {current_task_id} evaluated as FAILED (Empty Result).") # 2. 检查是否以明确的错误标识开头 (需要工具配合) # 假设工具出错时会在返回字符串前加上 "Error: " 或 "Execution Failed: " elif agent_result_content.strip().startswith(("Error:", "Execution Failed:", "Tool Error:")): new_status = "failed" evaluation_notes = f"Task failed: Agent {agent_name} reported an error: {agent_result_content[:150]}..." print(f" Task {current_task_id} evaluated as FAILED (Explicit Error Signal).") # 3. (可选) 添加其他特定检查,例如检查是否只是"我不明白"之类的回复 elif len(agent_result_content) < 50 and any(kw in agent_result_content.lower() for kw in ["don't know", "cannot fulfill", "无法回答", "不明白"]): new_status = "failed" # 或 "pending_review" ? 暂时设为 failed evaluation_notes = f"Task likely failed: Agent {agent_name} indicated inability to fulfill request." print(f" Task {current_task_id} evaluated as FAILED (Agent Indicated Inability).") else: # 如果以上都不是,则认为是成功 new_status = "completed" print(f" Task {current_task_id} evaluated as COMPLETED.") # --- 评估逻辑结束 --- # --- 更新 Plan 状态 (逻辑不变) --- try: update_kwargs = { "new_status": new_status, "new_evaluation": evaluation_notes, "new_notes": agent_result_content[:1000] + "..." if agent_result_content and len(agent_result_content) > 1000 else agent_result_content } print(f" Updating task {current_task_id} with: {{'status': '{new_status}', ...}}") if updated_plan and PlanningStateHandler.get_task(updated_plan, current_task_id): updated_plan = PlanningStateHandler.update_task(updated_plan, by_id=current_task_id, **update_kwargs) updated_plan = PlanningStateHandler.set_current_task(updated_plan, None) updated_plan = PlanningStateHandler.update_plan_status(updated_plan) print(f" Plan status after evaluation update: {updated_plan.get('status')}") plan_updated = True else: raise ValueError(f"Task ID '{current_task_id}' not found or plan invalid before update.") except ValueError as ve: error_message = f"Error updating plan: {ve}"; print(f"ERROR: {error_message}"); traceback.print_exc() except Exception as e: error_message = f"Unexpected error updating plan: {e}"; print(f"ERROR: {error_message}"); traceback.print_exc() # --- 准备返回字典 (逻辑不变) --- updates: Dict[str, Any] = {} if updated_plan is not None: updates["plan"] = updated_plan elif plan is not None: updates["plan"] = plan # 记录本节点错误,或清除旧错误 current_state_error = state.get("error") if error_message: updates["error"] = error_message elif current_state_error: updates["error"] = None updates["messages"] = [] # Evaluator 不添加消息 print(f"--- Exiting Evaluate Result Node. Plan updated: {plan_updated} ---") return updates # --- 同步包装器 (保持不变) --- def evaluate_result_node_logic_sync(state: PlanningAgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]: """evaluate_result_node_logic 的同步包装器""" print(f"--- Entering Evaluate Result Node (Sync Wrapper) ---") try: import anyio return anyio.run(evaluate_result_node_logic, state, config) # type: ignore except Exception as e: print(f"Error running evaluate_result_node_logic synchronously: {e}") traceback.print_exc() return {"error": f"Evaluate Result sync execution failed: {e}", "plan": state.get("plan"), "messages": []} ================================================ FILE: core/agents/state_based_supervisor/handoff.py ================================================ # reason_graph/handoff.py # (Paste the code user provided for handoff.py here) import re import uuid from typing import List, Tuple # Import Tuple from langchain_core.messages import AIMessage, ToolCall, ToolMessage, BaseMessage # Import BaseMessage from langchain_core.tools import BaseTool, InjectedToolCallId, tool from langgraph.prebuilt import InjectedState from langgraph.types import Command from typing_extensions import Annotated WHITESPACE_RE = re.compile(r"\s+") def _normalize_agent_name(agent_name: str) -> str: """Normalize an agent name to be used inside the tool name.""" if not agent_name: return "unknown_agent" return WHITESPACE_RE.sub("_", agent_name.strip()).lower() # Note: The original code uses @tool decorator which requires function arguments. # To inject state, the decorated function needs the Annotated state argument. # Let's define the function first and then apply the decorator, or use functools.partial. # Using the function approach first for clarity. def _handoff_to_agent_implementation( state: Annotated[dict, InjectedState], # Inject state here tool_call_id: Annotated[str, InjectedToolCallId], # Inject tool_call_id target_agent_name: str, # Pass the target agent name tool_name: str # Pass the specific tool name for the ToolMessage ) -> Command: """Ask another agent for help. This is the core logic.""" # Create the ToolMessage confirming the handoff BEFORE generating the Command """Handoff 核心逻辑,添加日志""" print(f"\n--- DEBUG: Entering _handoff_to_agent_implementation ---") print(f" - Target Agent: {target_agent_name}") print(f" - Tool Name: {tool_name}") print(f" - Tool Call ID: {tool_call_id}") # print(f" - Current State Keys: {list(state.keys())}") # 可选:打印状态键 tool_message = ToolMessage( content=f"Okay, handing off to {target_agent_name}. The current state and task context have been passed.", name=tool_name, tool_call_id=tool_call_id, ) print(f" - Created ToolMessage: ID={tool_message.tool_call_id}, Name={tool_message.name}") # The Command tells LangGraph to route to the target agent node # It also includes the ToolMessage in the state update for the next step command_obj = Command( goto=target_agent_name, # graph=Command.PARENT, # PARENT is default, usually not needed unless nested graphs update={"messages": [tool_message]}, # Return only the NEW message to be added ) print(f" - Created Command: goto='{command_obj.goto}', update contains {len(command_obj.update.get('messages',[]))} message(s)") print(f"--- DEBUG: Exiting _handoff_to_agent_implementation ---") return command_obj def create_handoff_tool(*, agent_name: str) -> BaseTool: """Create a tool that can handoff control to the requested agent.""" if not agent_name: raise ValueError("agent_name cannot be empty for create_handoff_tool") normalized_name = _normalize_agent_name(agent_name) tool_name = f"transfer_to_{normalized_name}" # Use functools.partial to fix the target_agent_name and tool_name arguments import functools specific_handoff_logic = functools.partial( _handoff_to_agent_implementation, target_agent_name=agent_name, tool_name=tool_name ) # Decorate the partial function # The arguments 'state' and 'tool_call_id' will be automatically injected by LangGraph # when the tool is called due to the Annotations used in _handoff_to_agent_implementation @tool(tool_name) def handoff_tool_wrapper( state: Annotated[dict, InjectedState], tool_call_id: Annotated[str, InjectedToolCallId] ) -> Command: """Dynamically generated tool description: Ask the '{agent_name}' agent for help with the current task or question.""" # --- 添加 Debug 日志 --- print(f"\n--- DEBUG: Handoff Tool '{tool_name}' (wrapper) CALLED ---") # --- return specific_handoff_logic(state=state, tool_call_id=tool_call_id) # type: ignore # Set a more descriptive description handoff_tool_wrapper.description = f"Use this tool to delegate the current task or ask a question to the '{agent_name}' agent. Pass the necessary context or instructions in your reasoning before calling this tool." return handoff_tool_wrapper def create_handoff_back_messages( agent_name: str, supervisor_name: str ) -> Tuple[AIMessage, ToolMessage]: """Create a pair of (AIMessage, ToolMessage) to add to the message history when returning control to the supervisor.""" tool_call_id = str(uuid.uuid4()) # Although no tool exists for transferring back, we simulate the pattern # The AIMessage signals intent, the ToolMessage confirms the transition occurred in the graph logic simulated_tool_name = f"transfer_back_to_{_normalize_agent_name(supervisor_name)}" # The AIMessage contains the *final output* of the sub-agent in its content field # It should also indicate the intent to hand back, though the graph logic forces this anyway. # The content here is just a placeholder - the actual content comes from the agent's final response. ai_message_content = f"Task completed. Transferring back to {supervisor_name}." # We still generate a ToolCall structure for consistency in the AIMessage, even if no real tool is called on supervisor side for hand-back. tool_calls = [ToolCall(name=simulated_tool_name, args={}, id=tool_call_id)] # Create the AIMessage - crucial to include the sub-agent's name ai_message = AIMessage( content=ai_message_content, # Placeholder - see note below tool_calls=tool_calls, name=agent_name, # Identify which agent is responding ) # The ToolMessage confirms the transition happened from the graph's perspective tool_message = ToolMessage( content=f"Successfully transferred back to {supervisor_name} from {agent_name}.", name=simulated_tool_name, tool_call_id=tool_call_id, ) # IMPORTANT NOTE: The `_make_call_agent` helper function should populate the # `ai_message.content` with the *actual* final response message(s) from the sub-agent, # replacing the placeholder content above. It keeps the tool_calls structure. # The code provided for `_make_call_agent` seems to handle extracting `output['messages']`. # We need to ensure it correctly structures the AIMessage part of the tuple returned here. # Let's refine create_handoff_back_messages to just create the ToolMessage, # as the AIMessage content comes from the sub-agent's actual final output. # Refined approach: _make_call_agent gets the final AI response, we only need the ToolMessage here? # No, the pattern expects both. Let's assume _make_call_agent takes the *last* message from the # sub-agent's output and packages it into this AIMessage structure. return ai_message, tool_message # Return both for the standard pattern ================================================ FILE: core/agents/state_based_supervisor/planner_node.py ================================================ import re import json import time import copy import ast import traceback import anyio # <--- 导入 anyio from typing import Dict, Any, List, Optional, Union from datetime import datetime from langchain_core.messages import BaseMessage, AIMessage, SystemMessage, HumanMessage from langchain_core.runnables import RunnableConfig # 内部导入 try: from .state_schema import PlanningAgentState, Plan from .planning_handler import PlanningStateHandler from .prompt import PLANNER_SYSTEM_PROMPT_TEMPLATE except ImportError as e: print(f"Error importing modules in planner_node.py: {e}") class PlanningAgentState(Dict): pass; class Plan(Dict): pass; class PlanningStateHandler: pass PLANNER_SYSTEM_PROMPT_TEMPLATE = "Fallback Planner Prompt: Error loading template. Args: {agent_descriptions}" # --- Planner 节点核心逻辑 (异步) --- async def planner_node_logic( state: PlanningAgentState, config: Optional[RunnableConfig], model: Any, # Planner 使用的 LLM agent_description_map: Dict[str, str] # 需要 Agent 描述来分配任务 ) -> Dict[str, Any]: """Planner 节点逻辑:分析请求,生成初始计划""" print(f"--- Entering Planner Node ---") messages: List[BaseMessage] = state.get('messages', []) # Planner 通常在 plan 为空时运行 plan: Optional[Plan] = state.get('plan') if plan: print("Planner Node: Plan already exists. Skipping plan creation.") # 如果计划已存在,Planner 不应再执行,直接返回当前状态? # 或者返回一个空更新,让图流向 Supervisor? # 返回空更新更安全,让 Supervisor 继续 return {} # 返回空字典,状态不变 if not messages: print("Planner Node: No messages found to create a plan from.") return {"error": "Planner received no messages."} # --- 1. 准备 Planner Prompt --- # Planner 只需要 Agent 描述,不需要 plan_json 或 current_date? # 可以让它知道日期 desc_list = [f"- {name}: {desc}" for name, desc in agent_description_map.items()] agent_descriptions_str = "\n".join(desc_list) current_date_str = datetime.now().strftime("%a, %b %d, %Y") # Planner 也可能需要日期 system_prompt_text = "Error: Planner prompt template could not be loaded/formatted." try: # 加载 Planner 的模板 from .prompt import PLANNER_SYSTEM_PROMPT_TEMPLATE system_prompt_text = PLANNER_SYSTEM_PROMPT_TEMPLATE.format( agent_descriptions=agent_descriptions_str, # 如果 Planner Prompt 需要日期: current_date=current_date_str ) except ImportError: print("ERROR: Could not import PLANNER_SYSTEM_PROMPT_TEMPLATE") except KeyError as e: print(f"ERROR: Missing key in planner prompt formatting: {e}") except Exception as e: print(f"ERROR: Unexpected error loading/formatting planner prompt: {e}") # Planner 的输入只需要 System Prompt 和用户的初始请求(通常是第一条) # 或者传递最后几条消息?为了简单,先只用第一条 HumanMessage initial_user_request = next((m for m in messages if isinstance(m, HumanMessage)), None) if not initial_user_request: print("Planner Node: No HumanMessage found in initial state.") return {"error": "Planner did not find initial user request."} llm_input_messages = [SystemMessage(content=system_prompt_text), initial_user_request] # --- 2. 调用 Planner LLM --- print("--- Calling Planner LLM ---") response: Optional[AIMessage] = None llm_error_msg: Optional[str] = None try: response = await model.ainvoke(llm_input_messages, config=config) if not isinstance(response, AIMessage): raise TypeError("Planner LLM returned non-AIMessage.") # Planner 的回复主要是指令,可以不设置 name print(f"Planner LLM Raw Response Content: {response.content[:300]}...") # Planner 不应该调用工具 if response.tool_calls: print("Warning: Planner LLM unexpectedly generated tool calls!") messages_to_add: List[BaseMessage] = [response] # 可以选择是否将 Planner 的思考过程加入 history except Exception as e: print(f"!!! Error invoking Planner LLM: {e}"); traceback.print_exc() llm_error_msg = f"Planner LLM invocation failed: {e}" messages_to_add = [] response = None # --- 3. 处理 Planner LLM 回复 (解析 CREATE_PLAN) --- new_plan: Optional[Plan] = None plan_updated: bool = False # 标记计划是否在本节点成功创建 directive_error_msg: Optional[str] = None if response and isinstance(response.content, str): try: plan_match = re.search(r"PLAN_UPDATE:\s*CREATE_PLAN\s*(\{.*?\})\s*$", response.content, re.IGNORECASE | re.DOTALL | re.MULTILINE) if plan_match: args_json_str = plan_match.group(1) print(f"Planner directive found: CREATE_PLAN with args: {args_json_str[:100]}...") try: args = json.loads(args_json_str) if not isinstance(args, dict): raise ValueError("Args JSON not a dict.") title=args.get("title", "Plan"); desc=args.get("description",""); tasks=args.get("tasks",[]) if isinstance(tasks, list) and all(isinstance(t, dict) and 'description' in t for t in tasks): for task_data in tasks: task_data['status'] = 'pending' # 强制状态 new_plan = PlanningStateHandler.create_plan(title, desc) new_plan = PlanningStateHandler.add_tasks(new_plan, tasks); plan_updated = True print("DEBUG: Plan successfully created by Planner node.") else: raise ValueError("Invalid 'tasks' format (must be list of dicts with 'description').") except (json.JSONDecodeError, ValueError, KeyError, TypeError) as e: err_msg = f"Error processing CREATE_PLAN directive: {type(e).__name__} - {e}" print(err_msg); traceback.print_exc(); directive_error_msg = err_msg except Exception as e: err_msg = f"Unexpected error processing CREATE_PLAN: {type(e).__name__} - {e}" print(err_msg); traceback.print_exc(); directive_error_msg = err_msg else: directive_error_msg = "Planner LLM did not output a valid PLAN_UPDATE: CREATE_PLAN directive." print(f"Warning: {directive_error_msg}") # 即使没有指令,也可能需要返回 Planner 的回复消息 # 但如果没有 plan,流程可能无法继续,所以记录错误 except Exception as outer_e: directive_error_msg = f"Error searching for PLAN_UPDATE directive: {outer_e}" print(f"ERROR: {directive_error_msg}"); traceback.print_exc() # --- 4. 准备返回的状态更新 --- updates: Dict[str, Any] = {"messages": messages_to_add} # 添加 Planner 的回复消息 if plan_updated and new_plan: updates["plan"] = new_plan # 返回新创建的 Plan final_error = llm_error_msg or directive_error_msg if final_error: # 记录 Planner 步骤中遇到的第一个错误 updates["error"] = final_error print(f"--- Exiting Planner Node. Plan created: {plan_updated} ---") return updates # --- Planner 节点的同步包装器 (使用 anyio) --- def planner_node_logic_sync( state: PlanningAgentState, config: Optional[RunnableConfig], model: Any, agent_description_map: Dict[str, str] ) -> Dict[str, Any]: """planner_node_logic 的同步包装器""" print(f"--- Entering Planner Node (Sync Wrapper) ---") try: # 使用 anyio 在同步函数中运行异步函数 return anyio.run( # type: ignore planner_node_logic, state, config, model, agent_description_map ) except Exception as e: print(f"Error running planner_node_logic synchronously: {e}") traceback.print_exc() return {"error": f"Planner sync execution failed: {e}", "messages": state.get("messages",[])} ================================================ FILE: core/agents/state_based_supervisor/planning_handler.py ================================================ # reason_graph/planning_handler.py import uuid import datetime from typing import List, Dict, Optional, Any from .state_schema import TaskStatus, PlanningStatus, Task, Plan # 从 state_schema 导入类型 class PlanningStateHandler: """ 使用静态方法管理一个表示项目计划的字典。 计划现在存储在 LangGraph 的状态中,此类提供操作该字典的函数。 """ @staticmethod def _now() -> str: return datetime.datetime.now(datetime.timezone.utc).isoformat() @staticmethod def _gen_id() -> str: # 生成更易读的任务 ID (可选) # return f"task_{str(uuid.uuid4())[:8]}" return str(uuid.uuid4()) @staticmethod def create_plan(title: str, description: str) -> Plan: """创建一个新的 Plan 字典""" now = PlanningStateHandler._now() return Plan( title=title, description=description, status="planning", # 初始状态为规划中 tasks=[], current_task_id=None, created_at=now, updated_at=now, completed_at=None, ) @staticmethod def create_task(description: str, agent: Optional[str] = None, dependencies: Optional[List[str]] = None) -> Task: """创建一个新的 Task 字典""" now = PlanningStateHandler._now() return Task( id=PlanningStateHandler._gen_id(), description=description.strip(), status="pending", # 初始状态为待处理 agent=agent.strip() if agent else None, created_at=now, updated_at=now, completed_at=None, dependencies=dependencies or [], notes=None, evaluation=None, result=None, ) @staticmethod def add_tasks(plan: Plan, tasks_data: List[Dict[str, Any]]) -> Plan: """向 Plan 字典中添加任务""" if not isinstance(plan, dict) or "tasks" not in plan: raise ValueError("Invalid plan structure provided.") if not isinstance(tasks_data, list): raise ValueError("tasks_data must be a list of task dictionaries.") for tinfo in tasks_data: desc = tinfo.get("description") if not desc: continue # 跳过没有描述的任务 agent = tinfo.get("agent") deps = tinfo.get("dependencies") task = PlanningStateHandler.create_task(desc, agent, deps) plan["tasks"].append(task) # 如果添加任务时计划仍在 planning 阶段,可以转为 ready if plan.get("status") == "planning": plan["status"] = "ready" plan["updated_at"] = PlanningStateHandler._now() return plan @staticmethod def update_task(plan: Plan, by_id: Optional[str] = None, new_desc: Optional[str] = None, new_status: Optional[TaskStatus] = None, new_agent: Optional[str] = None, new_notes: Optional[str] = None, new_evaluation: Optional[str] = None, new_result: Optional[Any] = None) -> Plan: """更新 Plan 字典中指定 ID 的任务""" if not isinstance(plan, dict) or "tasks" not in plan: raise ValueError("Invalid plan structure provided.") if not by_id: raise ValueError("Must provide 'by_id' to update a task.") task = next((t for t in plan["tasks"] if t.get("id") == by_id), None) if not task: raise ValueError(f"No matching task found with ID: {by_id}") updated = False if new_desc is not None and task.get("description") != new_desc.strip(): task["description"] = new_desc.strip() updated = True if new_status is not None and task.get("status") != new_status.strip(): task["status"] = new_status.strip() if new_status.strip() == "completed": task["completed_at"] = PlanningStateHandler._now() updated = True if new_agent is not None and task.get("agent") != new_agent.strip(): task["agent"] = new_agent.strip() updated = True if new_notes is not None and task.get("notes") != new_notes.strip(): task["notes"] = new_notes.strip() updated = True if new_evaluation is not None and task.get("evaluation") != new_evaluation.strip(): task["evaluation"] = new_evaluation.strip() updated = True if new_result is not None: # 直接更新结果(谨慎使用,可能很大) task["result"] = new_result updated = True if updated: task["updated_at"] = PlanningStateHandler._now() plan["updated_at"] = PlanningStateHandler._now() # 更新整个计划的更新时间 # 检查并更新整个计划的状态 plan = PlanningStateHandler.update_plan_status(plan) return plan @staticmethod def update_plan_status(plan: Plan) -> Plan: """根据任务状态自动更新计划状态""" if not isinstance(plan, dict) or "tasks" not in plan: return plan # Return as is if invalid tasks = plan["tasks"] if not tasks: # 没有任务 if plan.get("status") not in ["completed", "failed", "error"]: plan["status"] = "ready" # 或 "completed" 如果没有任务就算完成? 设为 ready 似乎更合理 return plan all_completed = all(t.get("status") == "completed" for t in tasks) any_failed = any(t.get("status") == "failed" for t in tasks) any_in_progress = any(t.get("status") in ["in_progress", "pending_review"] for t in tasks) any_pending = any(t.get("status") == "pending" for t in tasks) current_status = plan.get("status") new_status = current_status if any_failed: new_status = "failed" # 或 "error" elif all_completed: new_status = "completed" plan["completed_at"] = PlanningStateHandler._now() elif any_in_progress: new_status = "executing" elif any_pending or not any_in_progress: # 如果还有 pending 或所有任务都结束了但不是 completed/failed if current_status not in ["completed", "failed", "error"]: # 避免覆盖最终状态 new_status = "ready" # 准备好执行或等待新任务 if new_status != current_status: plan["status"] = new_status plan["updated_at"] = PlanningStateHandler._now() return plan @staticmethod def set_current_task(plan: Plan, task_id: Optional[str]) -> Plan: """设置 Plan 字典中的当前任务 ID""" if not isinstance(plan, dict): raise ValueError("Invalid plan structure provided.") if task_id is None: plan["current_task_id"] = None plan["updated_at"] = PlanningStateHandler._now() return plan found = any(t.get("id") == task_id for t in plan.get("tasks", [])) if not found: raise ValueError(f"Task ID '{task_id}' not found in plan.") if plan.get("current_task_id") != task_id: plan["current_task_id"] = task_id plan["updated_at"] = PlanningStateHandler._now() return plan @staticmethod def get_task(plan: Plan, task_id: str) -> Optional[Task]: """根据 ID 获取任务字典""" if not isinstance(plan, dict) or "tasks" not in plan: return None return next((t for t in plan["tasks"] if t.get("id") == task_id), None) @staticmethod def get_next_pending_task(plan: Plan) -> Optional[Task]: """获取下一个处于 pending 状态且所有依赖已完成的任务""" if not isinstance(plan, dict) or "tasks" not in plan: return None completed_task_ids = {t["id"] for t in plan["tasks"] if t.get("status") == "completed"} for task in plan["tasks"]: if task.get("status") == "pending": dependencies = task.get("dependencies", []) if not dependencies or all(dep_id in completed_task_ids for dep_id in dependencies): return task return None # 没有找到合适的下一个任务 @staticmethod def finish_plan(plan: Plan) -> Plan: """强制将 Plan 标记为完成""" if not isinstance(plan, dict): raise ValueError("Invalid plan structure provided.") if plan.get("status") != "completed": plan["status"] = "completed" plan["completed_at"] = PlanningStateHandler._now() plan["updated_at"] = PlanningStateHandler._now() return plan ================================================ FILE: core/agents/state_based_supervisor/prompt.py ================================================ # # --- Planner Agent System Prompt (新增) --- # PLANNER_SYSTEM_PROMPT_TEMPLATE = """You are an expert planning agent. Your sole responsibility is to analyze a user request and create a detailed, step-by-step plan to fulfill it by coordinating specialized agents. # The current date is {current_date}. # ## Agent Descriptions: # {agent_descriptions} # *(This list includes the capabilities of available specialist agents.)* # ## Task: # Analyze the user request provided in the message history. Break it down into a sequence of logical tasks. For each task, determine the most suitable agent from the descriptions provided. # ## Output Format: # You MUST output **ONLY** a single `PLAN_UPDATE: CREATE_PLAN ` directive in your response content. The JSON arguments MUST be valid and contain: # - "title": A concise title for the overall plan. # - "description": A brief description summarizing the user's goal. # - "tasks": A list of task objects. Each task object MUST contain: # - "description": A clear and actionable description of the specific sub-task. # - "agent": The name of the MOST SUITABLE agent from the Agent Descriptions to perform this task. Leave empty ("") if unsure or if it's a general task. # - "status": Set **all** initial tasks to **"pending"**. # - (Optional) "dependencies": A list of task IDs (UUIDs that will be generated later) this task depends on, if any (usually empty for initial plan). # **Example JSON Args:** # `{{"title": "Research and Report on AI Ethics", "description": "User wants a report on AI ethics, including research and writing.", "tasks": [{{"description": "Research current trends in AI ethics using web search", "agent": "research_expert", "status": "pending"}}, {{"description": "Write a structured report summarizing the findings", "agent": "reporter_expert", "status": "pending", "dependencies": [""]}}]}}` # *(Note: Actual IDs are UUIDs generated later, dependencies often added via UPDATE_TASK)* # **CRITICAL**: Output **ONLY** the `PLAN_UPDATE: CREATE_PLAN ` directive and nothing else. Do not add conversational text. Make sure the JSON is valid. # """ # SUPERVISOR_PLANNING_PROMPT_TEMPLATE = """You are a meticulous top-level Supervisor agent responsible for executing an existing plan, coordinating specialist agents, and managing task execution based on the provided state. You rely on an external evaluator node to assess task completion after agents run. # The current date is {current_date}. # ## Current Plan State: # ```json # {plan_json} # ``` # *(Review plan status and individual task statuses and IDs (UUIDs). Your main goal is to drive the plan status to 'completed'.)* # ## Agent Descriptions: # {agent_descriptions} # ## Your Goal: # Execute the **existing plan** strictly step-by-step towards 'completed' status. Make **exactly one** logical primary decision per turn. **Do NOT evaluate agent results or mark tasks 'completed'/'failed' yourself.** # ## Workflow & Decision Process (Strict Sequence): # 1. **Analyze State**: Review the latest messages and the 'Current Plan State'. (Note: If the last message is from a sub-agent, an evaluator node has already processed it and updated the plan state before your turn). # 2. **Determine ONE Next Action**: Execute the FIRST matching condition below and **IMMEDIATELY END YOUR TURN**: # * **A. Initiate Next Task**: If the plan is 'ready' or 'executing', AND no task is currently 'in_progress', AND a 'pending' task is ready (dependencies met): # * **Action**: Find the FIRST such task. Output **ONLY** `PLAN_UPDATE: UPDATE_TASK `. **CRITICAL: Use the exact UUID for `by_id`!** JSON Args should be ` {{"by_id": "", "status": "in_progress"}}`. # * **B. Delegate In-Progress Task**: If a task **currently has status 'in_progress'** (check plan state): # * **Action**: Identify the best agent. Output **ONLY** the `transfer_to_` tool call. **CRITICAL**: Tool call args **MUST** include `"task_id": ""` and clear `"instructions"`. # * **C. Finish Plan**: If **ALL** tasks in the plan now have status 'completed' AND the plan status is NOT 'completed' yet (check plan state provided): # * **Action**: Output **ONLY** `PLAN_UPDATE: FINISH_PLAN {{}}`. # * **D. Generate Final Output**: If the **Plan Status IS 'completed'** (check plan state provided): # * **Action**: Decide final output format based on original request. EITHER call `transfer_to_reporter_expert` (passing context in args, like relevant task IDs) OR generate the final `AIMessage` content yourself summarizing the overall result. # * **E. Waiting/Blocked/Failed**: If no other action is appropriate (e.g., plan status 'failed', or waiting for dependencies): # * **Action**: Output a brief waiting or status message explaining the situation. # ## Output Constraints: # - Your response MUST contain exactly ONE primary action (ONE PLAN_UPDATE directive OR ONE transfer_to tool call OR the final answer OR a status message). # - `PLAN_UPDATE:` directives MUST be in the text content with **valid JSON arguments**. # - **CRITICAL**: `UPDATE_TASK` **MUST** use the correct Task UUID string for `"by_id"`. # ## Planning Directives Format (Mandatory - JSON Args in text): # - `PLAN_UPDATE: ADD_TASKS {{"tasks": [...]}}` # You can still add tasks if needed mid-plan # - `PLAN_UPDATE: UPDATE_TASK {{"by_id": "", "status": "in_progress", "notes": ""}}` (**UUID!** Only use non-terminal statuses). # - `PLAN_UPDATE: FINISH_PLAN {{}}` # ## Tool Usage: # - Only `transfer_to_` tools. Args **MUST** include `"task_id"` and `"instructions"`. # Now, analyze the current state (which reflects any recent evaluations) and the LAST message, and determine the single next action based strictly on the workflow for **executing the existing plan**. Remember, you do **not** evaluate results or mark tasks complete/failed. # """ # --- Planner Agent System Prompt --- PLANNER_SYSTEM_PROMPT_TEMPLATE = """You are an expert planning agent. Your sole responsibility is to analyze a user request and create a detailed, step-by-step plan to fulfill it by coordinating specialized agents. The current date is {current_date}. ## Agent Descriptions: {agent_descriptions} *(This list includes the capabilities of available specialist agents.)* ## Task: Analyze the user request provided in the message history. Break it down into a sequence of logical tasks. For each task, determine the most suitable agent from the descriptions provided. ## Task Granularity Guidelines: - **IMPORTANT**: Maintain appropriate task granularity based on complexity: - For simple requests, create just 1-2 tasks that can be completed by a single agent - For complex requests, break down into 3-5 logical steps - Avoid excessive fragmentation of simple tasks - Each task should represent a meaningful unit of work ## Output Format: You MUST output **ONLY** a single `PLAN_UPDATE: CREATE_PLAN ` directive in your response content. The JSON arguments MUST be valid and contain: - "title": A concise title for the overall plan. - "description": A brief description summarizing the user's goal. - "tasks": A list of task objects. Each task object MUST contain: - "description": A clear and actionable description of the specific sub-task. - "agent": The name of the MOST SUITABLE agent from the Agent Descriptions to perform this task. Leave empty ("") if unsure or if it's a general task. - "status": Set **all** initial tasks to **"pending"**. - (Optional) "dependencies": Usually empty for initial plan. **Example JSON Args for SIMPLE request:** `{{"title": "Answer Question About Python", "description": "User wants to know how to use list comprehensions in Python", "tasks": [{{"description": "Provide a comprehensive explanation of Python list comprehensions with examples", "agent": "coder_expert", "status": "pending"}}]}}` **Example JSON Args for COMPLEX request:** `{{"title": "Research and Report on AI Ethics", "description": "User wants a detailed report on AI ethics", "tasks": [{{"description": "Research current trends in AI ethics using web search", "agent": "research_expert", "status": "pending"}}, {{"description": "Write a structured report summarizing the findings", "agent": "reporter_expert", "status": "pending"}}]}}` **CRITICAL**: Output **ONLY** the `PLAN_UPDATE: CREATE_PLAN ` directive and nothing else. Do not add conversational text. Make sure the JSON is valid. """ # --- Supervisor Planning Prompt (允许动作组合 + 强制UUID/JSON) --- SUPERVISOR_PLANNING_PROMPT_TEMPLATE = """You are a meticulous top-level Supervisor agent responsible for executing an existing plan, coordinating specialist agents, and managing task execution based on the provided state. The current date is {current_date}. ## Current Plan State: ```json {plan_json} ``` *(Review plan status and individual task statuses and IDs (UUIDs). Your main goal is to drive the plan status to 'completed'.)* ## Agent Descriptions: {agent_descriptions} *(This list includes specialist agents and yourself.)* ## Your Goal: Execute the **existing plan** step-by-step towards 'completed' status by making logical decisions and issuing appropriate directives and tool calls. ## Workflow & Decision Guidelines: 1. **Analyze State**: Review the latest messages (especially agent results) and the 'Current Plan State'. 2. **Determine Next Action(s)**: Based on the analysis, decide the next logical step(s). * **If a sub-agent just returned results**: a. Evaluate the result against the task. b. Issue the `PLAN_UPDATE: UPDATE_TASK `. **CRITICAL: Use the exact Task UUID for `by_id`!** Include `evaluation` and `notes`. c. **After** the update directive, **if** more tasks are pending and ready, you **CAN** identify the next task, issue `PLAN_UPDATE: UPDATE_TASK ` (using its UUID), **AND** issue the corresponding `transfer_to_` tool call **in the same response**. * **If no agent just returned, AND a 'pending' task is ready**: a. Identify the *next* suitable 'pending' task. b. Issue `PLAN_UPDATE: UPDATE_TASK ` (using its UUID). c. **Immediately following** the directive in the same response, issue the corresponding `transfer_to_` tool call with instructions (including Task UUID). * **If ALL tasks are 'completed' AND plan status is NOT 'completed' yet**: a. Issue `PLAN_UPDATE: FINISH_PLAN {{}}`. b. **In the same response**, decide the final output: EITHER call `transfer_to_reporter_expert` OR generate the final `AIMessage` content yourself. * **If Plan Status IS 'completed'**: a. Your job is done. Generate the final `AIMessage` content if you didn't call the reporter in the previous step. * **If Waiting/Blocked/Failed**: Output a status message explaining the situation. ## Output Constraints: - Your response **CAN** contain **both** a `PLAN_UPDATE:` directive (in content) and a `transfer_to_` tool call if logically appropriate (e.g., completing one task and starting the next). - Your response **CAN** contain **both** `PLAN_UPDATE: FINISH_PLAN` and the final action (call reporter or final answer). - **NEVER** delegate to more than one agent simultaneously (only one `transfer_to_` tool call per response). - `PLAN_UPDATE:` directives MUST be in the text content with **valid JSON arguments**. - **CRITICAL**: `UPDATE_TASK` **MUST** use the correct Task UUID string for `"by_id"`. ## Planning Directives Format (Mandatory - JSON Args in text): Use these exact formats **within your response content**. Arguments **MUST** be a valid JSON string. - `PLAN_UPDATE: ADD_TASKS {{"tasks": [...]}}` - `PLAN_UPDATE: UPDATE_TASK {{"by_id": "", "status": "", "evaluation": "", "notes": ""}}` (**UUID!**) - `PLAN_UPDATE: FINISH_PLAN {{}}` *(Note: CREATE_PLAN is handled by the Planner Agent)* ## Tool Usage: - Only `transfer_to_` tools are callable by you. Args **MUST** include `"task_id"` and `"instructions"`. Now, analyze the current state and messages, and determine the necessary action(s) for this turn. """ # **主要调整说明:** # 1. **允许动作组合**: 修改了 Workflow 和 Output Constraints,明确允许 Supervisor 在一个回合中既更新 Plan 状态(通过 `PLAN_UPDATE:` 指令)又委派任务(通过 `transfer_to_` 工具调用),或者在结束计划的同时进行最终输出操作。这给予 LLM 更大的灵活性,可能更符合它的“思考习惯”。 # 2. **保留核心要求**: 仍然**强制要求** `PLAN_UPDATE` 的参数必须是有效的 JSON,并且 `UPDATE_TASK` **必须**使用正确的 Task UUID。同时,**仍然禁止**一次委派多个 Agent。 # 3. **移除了严格的 `STOP` 指令**: 不再强制要求 LLM 在发出 `PLAN_UPDATE` 后必须结束当前回合。 # **预期效果:** # * Supervisor LLM 在处理完子 Agent 的结果并更新任务状态后,如果发现下一个任务已准备就绪,它可能会在同一个回复中直接发出 `transfer_to_` 指令,从而减少一个交互回合,提高效率。 # * 在所有任务完成后,它可以一步到位地发出 `FINISH_PLAN` 并同时决定最终输出(调用 Reporter 或自己总结)。 # * **潜在风险**: 这种灵活性也可能使得 LLM 在复杂情况下更容易出错(例如,忘记更新状态就去委派,或者错误地组合了动作)。但鉴于之前严格分步也遇到了问题,这种方式值得一试。 ================================================ FILE: core/agents/state_based_supervisor/state_schema.py ================================================ # reason_graph/state_schema.py import operator from typing import Dict, List, Optional, Any, Literal, TypedDict, Sequence, Annotated, Union from langchain_core.messages import BaseMessage from langgraph.graph.message import add_messages from langgraph.managed import IsLastStep, RemainingSteps # 定义计划状态类型 PlanningStatus = Literal["not_started", "planning", "ready", "executing", "completed", "failed", "error"] # 定义任务状态类型 TaskStatus = Literal["pending", "ready", "in_progress", "completed", "failed", "skipped", "pending_review", "revision_needed"] # 定义任务项 class Task(TypedDict, total=False): """任务项定义 表示计划中的一个任务项,包含任务描述、状态、分配的代理等信息 """ id: str # 任务唯一标识符 description: str # 任务描述 status: TaskStatus # 任务状态 agent: Optional[str] # 分配的代理名称 (建议的执行者) created_at: str # 创建时间 (ISO 格式) updated_at: str # 更新时间 (ISO 格式) completed_at: Optional[str] # 完成时间 (ISO 格式) dependencies: Optional[List[str]] # 依赖的任务ID列表 notes: Optional[str] # 关于任务执行情况的备注 (可由 Agent 或 Supervisor 更新) evaluation: Optional[str] # 对任务完成情况的评估 (可由 Supervisor LLM 或 Evaluator Agent 更新) result: Optional[Any] # (可选) 存储任务的直接输出结果摘要 # 定义计划 class Plan(TypedDict, total=False): """计划定义 表示一个完整的计划,包含计划状态、任务列表等信息 """ status: PlanningStatus # 计划状态 tasks: List[Task] # 任务列表 current_task_id: Optional[str] # 当前 Supervisor 关注或正在处理的任务ID created_at: str # 创建时间 (ISO 格式) updated_at: str # 更新时间 (ISO 格式) completed_at: Optional[str] # 完成时间 (ISO 格式) title: Optional[str] # 计划标题 description: Optional[str] # 计划描述 (通常是用户原始请求) # 扩展基础 AgentState 以支持计划功能 class PlanningAgentState(TypedDict): """支持计划功能的、用于 Supervisor 图的状态定义""" messages: Annotated[Sequence[BaseMessage], add_messages] # 消息历史 plan: Optional[Plan] = None # 存储计划对象 # last_agent_result: Optional[Dict[str, Any]] = None # 存储刚结束的子 Agent 的 {name: ..., content: ...} is_last_step: IsLastStep # LangGraph 内部状态 remaining_steps: RemainingSteps # LangGraph 内部状态, 用于防止无限循环 error: Optional[str] = None # 用于记录执行中发生的错误信息 # 可以根据需要添加其他全局共享的状态字段 # 例如: shared_context: Optional[Dict] = None # 可以为子 Agent 定义一个稍微不同的状态(如果它们不需要 plan) class BasicAgentState(TypedDict): """基础 Agent 状态,仅包含消息历史""" messages: Annotated[Sequence[BaseMessage], add_messages] is_last_step: IsLastStep remaining_steps: RemainingSteps error: Optional[str] = None # 方便类型提示 StateSchemaType = Union[Dict[str, Any], PlanningAgentState, BasicAgentState] ================================================ FILE: core/agents/state_based_supervisor/supervisor_graph.py ================================================ # reason_graph/supervisor_graph.py import inspect import re import functools import uuid import asyncio import anyio import traceback from typing import Any, Callable, List, Optional, Type, Union, Dict, Literal, Sequence, cast # <--- 导入 cast from langchain_core.language_models import BaseChatModel, LanguageModelLike from langchain_core.tools import BaseTool from langchain_core.messages import AIMessage, ToolMessage, BaseMessage, ToolCall, SystemMessage # <--- 导入 SystemMessage from langchain_core.runnables import RunnableConfig from langgraph.utils.runnable import RunnableCallable from langgraph.graph import END, START, StateGraph from langgraph.graph.state import CompiledStateGraph from langgraph.prebuilt import ToolNode from langgraph.pregel import Pregel # 内部导入 try: from core.agents.base.base_agent import BaseAgent from .handoff import create_handoff_tool, _normalize_agent_name # 确保导入 _normalize_agent_name from .state_schema import PlanningAgentState, Plan # 导入 PlanningAgentState 和 Plan from .supervisor_node import supervisor_node_logic # 导入异步节点逻辑 from .planner_node import planner_node_logic, planner_node_logic_sync # <--- 导入 Planner 逻辑 from .evaluate_result_node import evaluate_result_node_logic, evaluate_result_node_logic_sync # <--- 导入 Evaluator 逻辑 from .agent_name import AgentNameMode, with_agent_name except ImportError as e: print(f"Error importing modules in supervisor_graph.py: {e}") # Add Dummy classes for type hints if needed class BaseAgent: pass class PlanningAgentState(Dict): pass class Plan(Dict): pass class Pregel: pass AgentNameMode = Literal["inline"] def create_handoff_tool(*args, **kwargs): return None # type: ignore def _normalize_agent_name(s: str) -> str: return s async def supervisor_node_logic(*args, **kwargs): return {} async def planner_node_logic(*args, **kwargs): return {} # <--- 添加 planner_node_logic def planner_node_logic_sync(*args, **kwargs): return {} # <--- 添加 planner_node_logic_sync async def evaluate_result_node_logic(*args, **kwargs): return {} # 添加 evaluate_result_node_logic def evaluate_result_node_logic_sync(*args, **kwargs): return {} # 添加 evaluate_result_node_logic_sync def with_agent_name(model, mode): return model # 定义 OutputMode, MODELS_NO_PARALLEL_TOOL_CALLS, _supports_disable_parallel_tool_calls (保持不变) OutputMode = Literal["full_history", "last_message"] MODELS_NO_PARALLEL_TOOL_CALLS = {"o3-mini"} def _supports_disable_parallel_tool_calls(model: LanguageModelLike) -> bool: if not isinstance(model, BaseChatModel): return False if hasattr(model, "model_name") and model.model_name in MODELS_NO_PARALLEL_TOOL_CALLS: return False if not hasattr(model, "bind_tools"): return False if "parallel_tool_calls" not in inspect.signature(model.bind_tools).parameters: return False return True # _make_call_agent (保持不变 - 已支持同步/异步) def _make_call_agent( agent_graph: Pregel, output_mode: OutputMode, add_handoff_back_messages: bool, supervisor_name: str, ) -> RunnableCallable: if output_mode not in ["full_history", "last_message"]: raise ValueError(...) async def acall_agent(state: Dict, config: Optional[RunnableConfig] = None) -> Dict: agent_name = getattr(agent_graph, 'name', 'sub_agent') print(f"🟡 [Async invoke] Handoff to agent '{agent_name}'") sub_agent_input = {"messages": state.get("messages", [])} output: Dict[str, Any] = {} agent_error: Optional[str] = None try: output = await agent_graph.ainvoke(sub_agent_input, config=config) print(f"✅ [Async invoke] Agent '{agent_name}' completed.") except Exception as e: print(f"!!! Error during sub-agent {agent_name} ainvoke: {e}"); traceback.print_exc() agent_error = f"Error executing agent '{agent_name}': {type(e).__name__}" sub_agent_messages: List[BaseMessage] = output.get("messages", []) returned_messages: List[BaseMessage] = [] if not sub_agent_messages and not agent_error: returned_messages = [AIMessage(content="(No output received from agent)", name=agent_name)] elif output_mode == "last_message": last_ai_message = next((m for m in reversed(sub_agent_messages) if isinstance(m, AIMessage)), None) returned_messages = [last_ai_message] if last_ai_message else sub_agent_messages[-1:] else: returned_messages = sub_agent_messages last_content = agent_error if not last_content and returned_messages: last_content = str(returned_messages[-1].content) if hasattr(returned_messages[-1], 'content') else "(No textual content)" return { "messages": returned_messages, "last_agent_result": { "agent_name": agent_name, "content": last_content or "(Agent execution finished without specific output or error)" } } def call_agent(state: Dict, config: Optional[RunnableConfig] = None) -> Dict: agent_name = getattr(agent_graph, 'name', 'sub_agent') print(f"🟡 [Sync invoke] Handoff to agent '{agent_name}'") sub_agent_input = {"messages": state.get("messages", [])} output: Dict[str, Any] = {} agent_error: Optional[str] = None try: output = agent_graph.invoke(sub_agent_input, config=config); print(f"✅ [Sync invoke] Agent '{agent_name}' completed.") except NotImplementedError: agent_error = f"Error: Sync invoke not supported by agent '{agent_name}'."; print(agent_error) except Exception as e: agent_error = f"Error during sub-agent {agent_name} invoke: {e}"; print(f"!!! {agent_error}") sub_agent_messages: List[BaseMessage] = output.get("messages", []) returned_messages: List[BaseMessage] = [] if not sub_agent_messages and not agent_error: returned_messages = [AIMessage(content="(No output received)", name=agent_name)] elif output_mode == "last_message": last_ai_message = next((m for m in reversed(sub_agent_messages) if isinstance(m, AIMessage)), None) returned_messages = [last_ai_message] if last_ai_message else sub_agent_messages[-1:] else: returned_messages = sub_agent_messages last_content = agent_error if not last_content and returned_messages: last_content = str(returned_messages[-1].content) if hasattr(returned_messages[-1], 'content') else "(No content)" return { "messages": returned_messages, "last_agent_result": { "agent_name": agent_name, "content": last_content or "(Agent sync execution finished)" } } return RunnableCallable(func=call_agent, afunc=acall_agent, name=f"Call_{getattr(agent_graph, 'name', 'sub_agent')}") def supervisor_node_logic_sync( state: PlanningAgentState, config: Optional[RunnableConfig], model: Any, supervisor_name: str, agent_description_map: Dict[str, str] ) -> Dict[str, Any]: print(f"--- Entering Supervisor Node (Sync Wrapper) ---") try: return anyio.run( supervisor_node_logic, state, config, model, supervisor_name, agent_description_map ) except Exception as e: print(f"Error running supervisor_node_logic synchronously using anyio: {e}") import traceback traceback.print_exc() return {"error": f"Sync execution wrapper failed: {e}", "messages": state.get("messages",[])} def create_supervisor( model: LanguageModelLike, sub_agents: List[BaseAgent], state_schema: Type[PlanningAgentState] = PlanningAgentState, config_schema: Type[Any] | None = None, tools: list[BaseTool | Callable] | None = None, output_mode: OutputMode = "last_message", add_handoff_back_messages: bool = False, supervisor_name: str = "supervisor", planner_node_name: str = "planner", evaluator_node_name: str = "evaluate_result", handoff_executor_name: str = "handoff_executor", include_agent_name: AgentNameMode | None = "inline", ) -> StateGraph: agent_graphs: Dict[str, Pregel] = {} agent_names: List[str] = [] agent_description_map: Dict[str, str] = {} # --- 1. 提取 Agent 信息 --- for agent in sub_agents: if not isinstance(agent, BaseAgent): raise TypeError(...) if not agent.name or agent.name == "LangGraph": raise ValueError(...) if agent.name in agent_graphs: raise ValueError(...) agent_names.append(agent.name) agent_description_map[agent.name] = getattr(agent, 'description', '...') try: compiled_graph = agent.get_agent() if not isinstance(compiled_graph, Pregel): core_graph = getattr(compiled_graph, 'last', None) if isinstance(core_graph, Pregel): compiled_graph = core_graph else: raise TypeError(f"Could not retrieve Pregel instance from agent '{agent.name}'.get_agent()") agent_graphs[agent.name] = compiled_graph except Exception as e: raise e # --- 2. 创建 Handoff 工具 --- handoff_tools = [create_handoff_tool(agent_name=name) for name in agent_names] supervisor_callable_tools = (tools or []) + handoff_tools print(f"Supervisor '{supervisor_name}' bound with tools: {[t.name for t in supervisor_callable_tools]}") # --- 3. 绑定工具到 Supervisor 模型 --- bound_supervisor_model: LanguageModelLike if not supervisor_callable_tools: print(f"Warning: Supervisor '{supervisor_name}' has no tools bound.") bound_supervisor_model = model elif _supports_disable_parallel_tool_calls(model): bound_supervisor_model = model.bind_tools(supervisor_callable_tools, parallel_tool_calls=False) else: bound_supervisor_model = model.bind_tools(supervisor_callable_tools) if include_agent_name: bound_supervisor_model = with_agent_name(bound_supervisor_model, include_agent_name) # --- 4. 构建 StateGraph --- builder = StateGraph(state_schema, config_schema=config_schema) # --- 5. 添加 Planner 节点 (使用同步/异步包装) --- planner_logic_partial_async = functools.partial( planner_node_logic, model=model, agent_description_map=agent_description_map, ) planner_logic_partial_sync = functools.partial( planner_node_logic_sync, model=model, agent_description_map=agent_description_map, ) planner_runnable = RunnableCallable( func=planner_logic_partial_sync, afunc=planner_logic_partial_async, name=planner_node_name ) builder.add_node(planner_node_name, planner_runnable) # --- 6. 添加 Supervisor 节点 (使用同步/异步包装) --- supervisor_logic_partial_async = functools.partial( supervisor_node_logic, model=bound_supervisor_model, supervisor_name=supervisor_name, agent_description_map=agent_description_map, ) supervisor_logic_partial_sync = functools.partial( supervisor_node_logic_sync, model=bound_supervisor_model, supervisor_name=supervisor_name, agent_description_map=agent_description_map, ) supervisor_runnable = RunnableCallable( func=supervisor_logic_partial_sync, afunc=supervisor_logic_partial_async, name=supervisor_name ) builder.add_node(supervisor_name, supervisor_runnable) # --- 7. 添加子 Agent 节点 --- for name, compiled_graph in agent_graphs.items(): builder.add_node(name, _make_call_agent(compiled_graph, output_mode, add_handoff_back_messages, supervisor_name)) builder.add_edge(name, evaluator_node_name) # --- 8. 添加 Handoff Tool 执行节点 --- handoff_executor_node = ToolNode(handoff_tools, name=handoff_executor_name) builder.add_node(handoff_executor_name, handoff_executor_node) # --- 9. 添加 Evaluate Result 节点 --- evaluator_runnable = RunnableCallable(func=evaluate_result_node_logic_sync, afunc=evaluate_result_node_logic, name=evaluator_node_name) # Evaluator 不需要 model 或 agent descriptions 作为直接参数 builder.add_node(evaluator_node_name, evaluator_runnable) # type: ignore # --- 10. 设置图的入口和边 --- builder.set_entry_point(planner_node_name) builder.add_edge(planner_node_name, supervisor_name) def route_from_supervisor(state: PlanningAgentState) -> str: messages = state.get('messages', []) plan = state.get('plan') last_message = messages[-1] if messages else None if not isinstance(last_message, AIMessage): print("Routing: Last message not AIMessage, looping supervisor.") return supervisor_name if last_message.tool_calls: tool_call = last_message.tool_calls[0] agent_name_match = re.match(r"transfer_to_(\w+)", tool_call["name"]) if agent_name_match and agent_name_match.group(1) in agent_names: extracted_name = agent_name_match.group(1) print(f"DEBUG route_from_supervisor: Tool Call Name = {repr(tool_call['name'])}") print(f"DEBUG route_from_supervisor: Extracted Target Name = {repr(extracted_name)}") print(f"DEBUG route_from_supervisor: Available Agent Names = {repr(agent_names)}") print(f"Routing: Supervisor -> HandoffExecutor (for {extracted_name})") return handoff_executor_name else: print(f"DEBUG route_from_supervisor: Membership check failed! ('{extracted_name}' in {repr(agent_names)}) is False.") print(f"Warning: Supervisor called unknown/invalid tool: {tool_call['name']}. Looping supervisor.") return supervisor_name if plan and plan.get("status") == "completed": print("Routing: Plan completed -> END") return END print(f"Routing: No tool call and plan not completed (status: {plan.get('status') if plan else 'None'}). Looping supervisor.") return supervisor_name builder.add_conditional_edges( supervisor_name, route_from_supervisor, { handoff_executor_name: handoff_executor_name, supervisor_name: supervisor_name, END: END, } ) # Handoff Executor 完成后, LangGraph 处理 Command(goto=...) 直接路由到子 Agent # 不需要从 Handoff Executor 出发的显式边 # --- 关键修改: 子 Agent 完成后 -> Evaluator --- for name in agent_names: builder.add_edge(name, evaluator_node_name) # <--- 修改: 指向 Evaluator # --- 新增: Evaluator 完成后 -> Supervisor --- builder.add_edge(evaluator_node_name, supervisor_name) # <--- 新增: Evaluator 指回 Supervisor print("Supervisor graph definition created with Planner and Evaluator nodes.") return builder # 返回 StateGraph 定义 ================================================ FILE: core/agents/state_based_supervisor/supervisor_node.py ================================================ # reason_graph/supervisor_node.py import re import json import time import copy import ast import traceback from typing import Dict, Any, List, Optional, Union, cast from datetime import datetime from langchain_core.messages import BaseMessage, AIMessage, SystemMessage, HumanMessage, ToolMessage from langchain_core.messages import ToolCall # 确保导入 from langchain_core.runnables import RunnableConfig from langgraph.graph import END # 内部导入 (确保路径正确) try: from .state_schema import PlanningAgentState, TaskStatus, Plan from .planning_handler import PlanningStateHandler from .prompt import SUPERVISOR_PLANNING_PROMPT_TEMPLATE except ImportError as e: print(f"Error importing modules in supervisor_node.py: {e}") # Fallbacks class PlanningAgentState(Dict): pass class Plan(Dict): pass class PlanningStateHandler: @staticmethod def update_task(*args, **kwargs): return kwargs.get('plan') @staticmethod def create_plan(*args, **kwargs): return {} @staticmethod def add_tasks(*args, **kwargs): return kwargs.get('plan') @staticmethod def finish_plan(*args, **kwargs): return kwargs.get('plan') @staticmethod def get_task(*args, **kwargs): return None @staticmethod def update_plan_status(*args, **kwargs): return kwargs.get('plan') @staticmethod def set_current_task(*args, **kwargs): return kwargs.get('plan') SUPERVISOR_PLANNING_PROMPT_TEMPLATE = "Fallback Prompt: Error loading template." # --- 参数解析函数 (使用 JSON / ast.literal_eval) --- def parse_directive_args(directive_str: str) -> Dict[str, Any]: """从指令字符串中解析 JSON 参数""" args = {} # 查找第一个 '{' 到最后一个 '}' 之间的内容作为 JSON 字符串 json_match = re.search(r"(\{.*?\})\s*$", directive_str.split(maxsplit=1)[1] if len(directive_str.split(maxsplit=1)) > 1 else "", re.DOTALL) if json_match: args_json_str = json_match.group(1) try: args = json.loads(args_json_str) if not isinstance(args, dict): raise ValueError("Args JSON not a dict.") print(f"DEBUG: Parsed args via JSON: {args}") return args except json.JSONDecodeError as json_err: print(f"Warning: JSON parsing failed ({json_err}), trying ast.literal_eval...") try: args = ast.literal_eval(args_json_str) if not isinstance(args, dict): raise ValueError("ast.literal_eval didn't return dict.") print(f"DEBUG: Parsed args via ast.literal_eval: {args}") return args except Exception as ast_err: raise ValueError(f"Failed to parse args: {ast_err}. Raw: '{args_json_str}'") from ast_err elif directive_str.strip().upper().endswith("{}"): # 处理 FINISH_PLAN {} 的情况 return {} # 返回空字典 else: # 如果找不到有效的 JSON 参数,但指令需要参数,则抛出错误或返回空字典 print(f"Warning: Could not find valid JSON arguments in directive: '{directive_str}'. Returning empty args.") return {} # --- Supervisor 节点核心逻辑 (移除结果处理,增加设置 current_task_id) --- async def supervisor_node_logic( state: PlanningAgentState, config: Optional[RunnableConfig], model: Any, supervisor_name: str, agent_description_map: Dict[str, str] ) -> Dict[str, Any]: """Supervisor 节点核心逻辑 (不再处理 Agent 结果状态更新)""" print(f"--- Entering Supervisor Node ({supervisor_name}) ---") messages: List[BaseMessage] = state.get('messages', []) plan: Optional[Plan] = state.get('plan') current_error = state.get('error'); state['error'] = None if current_error: print(f" Supervisor saw previous error: {current_error}") # --- 0. 检查 Plan 是否存在 (不变) --- if not plan: print("ERROR: Supervisor node requires a plan, but none found in state.") return {"error": "Plan is missing.", "messages": []} # --- 1. 准备 Prompt (不变) --- plan_json_str = json.dumps(plan, indent=2, ensure_ascii=False) desc_list = [f"- {name}: {desc}" for name, desc in agent_description_map.items()] desc_list.append(f"- {supervisor_name}: Coordinates tasks...") agent_descriptions_str = "\n".join(desc_list) system_prompt_text = "Error loading/formatting prompt" try: current_date_str = datetime.now().strftime("%a, %b %d, %Y") system_prompt_text = SUPERVISOR_PLANNING_PROMPT_TEMPLATE.format( plan_json=plan_json_str, agent_descriptions=agent_descriptions_str, current_date=current_date_str ) except Exception as e: print(f"ERROR loading/formatting prompt: {e}") llm_input_messages = [SystemMessage(content=system_prompt_text)] + messages # --- 2. 调用 Supervisor LLM (不变) --- print("--- Calling Supervisor LLM ---"); response=None; llm_error_msg=None try: response = await model.ainvoke(llm_input_messages, config=config) if not isinstance(response, AIMessage): raise TypeError(f"LLM returned non-AIMessage: {type(response)}") if not response.name: response.name = supervisor_name print(f"Supervisor LLM Raw Response Content: {response.content[:300]}...") if response.tool_calls: print(f"Supervisor LLM Tool Calls: {response.tool_calls}") messages_to_add = [response] except Exception as e: print(f"!!! Error invoking Supervisor LLM: {e}"); traceback.print_exc() llm_error_msg = f"LLM failed: {e}"; messages_to_add = []; response = None # --- 3. 处理 LLM 回复 --- plan_updated: bool = False updated_plan: Optional[Plan] = copy.deepcopy(plan) # 从当前 plan 开始 directive_error_msg: Optional[str] = None task_id_to_delegate: Optional[str] = None # <-- 存储本轮要委派的任务 ID if response and isinstance(response.content, str): # --- A. 先解析并执行所有 PLAN_UPDATE 指令 (移除 status='completed/failed' 的处理) --- try: plan_directives = re.findall(r"PLAN_UPDATE:\s*(\w+)\s*(\{.*?\})\s*$", response.content, re.IGNORECASE | re.DOTALL | re.MULTILINE) plan_directives.extend(re.findall(r"PLAN_UPDATE:\s*(FINISH_PLAN)\s*(\{\})\s*$", response.content, re.IGNORECASE | re.DOTALL | re.MULTILINE)) if plan_directives: print(f"Found {len(plan_directives)} PLAN_UPDATE directive(s).") for command, args_json_str in plan_directives: command = command.upper(); args_json_str = args_json_str if args_json_str else "{}" print(f"Processing directive: {command} with args JSON: {args_json_str[:100]}...") try: args = json.loads(args_json_str) # 使用 JSON 解析 if not isinstance(args, dict): raise ValueError("Args not dict.") # --- 执行规划指令 --- if command == "ADD_TASKS": if not updated_plan: raise ValueError("No plan."); tasks=args.get("tasks",[]) if isinstance(tasks, list): # 确保新任务状态是 pending for task_data in tasks: task_data['status'] = 'pending' updated_plan = PlanningStateHandler.add_tasks(updated_plan, tasks); plan_updated = True else: raise ValueError("Invalid 'tasks'.") elif command == "UPDATE_TASK": if not updated_plan: raise ValueError("No plan.") by_id=args.get("by_id") if not by_id or not isinstance(by_id, str): raise ValueError("Requires string 'by_id'.") by_id = by_id.strip() task_exists = PlanningStateHandler.get_task(updated_plan, by_id) if not task_exists: raise ValueError(f"Task ID '{by_id}' not found!") # 只处理状态为 'in_progress' 或 其他非终结状态的更新,以及 notes/evaluation new_status=args.get("status"); notes_text=args.get("notes"); eval_text=args.get("evaluation") # 保留 evaluation 用于记录 LLM 的想法 update_kwargs = {} # **不再**设置 "completed", "failed", "pending_review" if new_status and new_status == "in_progress": update_kwargs['new_status'] = "in_progress" task_id_to_delegate = by_id # 记录这个 ID,将在 Handoff 前设置 # 总是可以更新 notes 和 evaluation (如果 LLM 提供了) if notes_text is not None: update_kwargs['new_notes'] = notes_text if eval_text is not None: update_kwargs['new_evaluation'] = eval_text if update_kwargs: # 只有当确实需要更新时才调用 print(f"Updating task {by_id} with: {update_kwargs}") updated_plan = PlanningStateHandler.update_task(updated_plan, by_id=by_id, **update_kwargs); plan_updated = True elif command == "FINISH_PLAN": if not updated_plan: raise ValueError("No plan.") updated_plan = PlanningStateHandler.finish_plan(updated_plan); plan_updated = True else: print(f"Warning: Unknown PLAN_UPDATE command '{command}' ignored by Supervisor.") except (json.JSONDecodeError, ValueError, KeyError, TypeError) as e: err_msg = f"Error processing plan directive '{command} {args_json_str}': {type(e).__name__} - {e}" print(err_msg); traceback.print_exc() if not directive_error_msg: directive_error_msg = err_msg # 只记录第一个错误 except Exception as e: err_msg = f"Unexpected error processing directive '{command} {args_json_str}': {type(e).__name__} - {e}" print(err_msg); traceback.print_exc() if not directive_error_msg: directive_error_msg = err_msg # --- 重新计算 Plan 状态 --- if plan_updated and updated_plan: updated_plan = PlanningStateHandler.update_plan_status(updated_plan) print(f"Plan status after updates by Supervisor: {updated_plan.get('status')}") except Exception as outer_e: err_msg = f"Error occurred while searching for PLAN_UPDATE directives: {outer_e}" print(err_msg); traceback.print_exc() if not directive_error_msg: directive_error_msg = err_msg # --- B. 检查 Tool Calls 并设置 Current Task ID --- handoff_tool_call: Optional[Dict] = None # 显式初始化 if response and response.tool_calls: for tool_call in response.tool_calls: agent_name_match = re.match(r"transfer_to_(\w+)", tool_call["name"]) # **使用 agent_description_map.keys() 来检查** if agent_name_match and agent_name_match.group(1) in agent_description_map.keys(): handoff_tool_call = cast(Dict, tool_call) # 找到第一个有效的就用它 break # 如果决定 Handoff,尝试设置 plan 中的 current_task_id if handoff_tool_call and updated_plan: # **关键**: 尝试从 Tool Call 的 args 中获取 task_id (Prompt 要求 LLM 必须提供) tool_args = handoff_tool_call.get("args", {}) task_id_from_tool = tool_args.get("task_id") if isinstance(tool_args, dict) else None # 如果 Tool args 中没有,再使用之前记录的 task_id_to_delegate (标记为 in_progress 的) effective_task_id = task_id_from_tool or task_id_to_delegate if effective_task_id: print(f"Setting current_task_id in plan to: {effective_task_id}") try: # 验证 ID 存在 if PlanningStateHandler.get_task(updated_plan, effective_task_id): updated_plan = PlanningStateHandler.set_current_task(updated_plan, effective_task_id) # plan_updated 标志可能已经被 Plan Directive 设置,这里不需要重复设置 else: print(f"Warning: Task ID '{effective_task_id}' provided for delegation not found. Cannot set current_task_id.") # 记录错误,阻止 Handoff? 或者让路由回到 Supervisor? directive_error_msg = directive_error_msg or f"Invalid Task ID '{effective_task_id}' for delegation." except Exception as e: err_msg = f"Error setting current_task_id to '{effective_task_id}': {e}" print(f"ERROR: {err_msg}") if not directive_error_msg: directive_error_msg = err_msg # --- 4. 准备最终返回的状态更新字典 --- updates: Dict[str, Any] = {"messages": messages_to_add} if updated_plan is not None: updates["plan"] = updated_plan elif plan is not None: updates["plan"] = plan final_error = llm_error_msg or directive_error_msg if final_error: updates["error"] = final_error elif state.get("error"): updates["error"] = None # 清除旧错误 print(f"--- Exiting Supervisor Node. Plan updated this step: {plan_updated} ---") return updates ================================================ FILE: core/agents/sub_agents/__init__.py ================================================ ================================================ FILE: core/agents/sub_agents/coder_agent.py ================================================ # Refactored coder_agent.py from typing import Any, List, Optional, Union, Callable, Type from langchain_core.language_models import LanguageModelLike from langchain_core.tools import BaseTool from langchain_core.messages import SystemMessage from langgraph.types import Checkpointer from core.agents.base.react_agent import ReactAgent from core.tools.registry import get_tools_by_category, ToolCategory, get_tool_instance # Import get_tool_instance import logging logger = logging.getLogger(__name__) class CoderAgent(ReactAgent): """ Coder Agent (Refactored) - Interacts with a sandboxed Linux environment via code execution tools. """ def __init__( self, name: str = "coder_expert", model: LanguageModelLike = None, tools: Optional[List[Union[BaseTool, Callable]]] = None, checkpointer: Optional[Checkpointer] = None, max_context_messages: Optional[int] = None, max_context_tokens: Optional[int] = 100000, # Coding might need more context **kwargs ): # 1. Define Description description = "Writes, executes, tests, and debugs Python code and Linux shell commands within a secure sandboxed environment. Can install packages, manage files, and interact with the network." # 2. Get Tools from Registry agent_tools = [] default_tool_name = "e2b_code_interpreter" # Expected tool name try: code_tools = get_tools_by_category(ToolCategory.CODE_INTERPRETER) + get_tools_by_category(ToolCategory.FILE_SYSTEM) agent_tools.extend(code_tools) # Optionally add file system tools if not included in interpreter tool # fs_tools = get_tools_by_category(ToolCategory.FILE_SYSTEM) # agent_tools.extend(fs_tools) print(f"[{name}] Loaded tools from registry: {[t.name for t in agent_tools if hasattr(t,'name')]}") # Verify the main execution tool is present if not any(getattr(t,'name', None) == default_tool_name for t in agent_tools): print(f"CRITICAL Warning: CoderAgent '{name}' is missing the primary '{default_tool_name}' tool!") # Attempt to get it specifically if missing? specific_tool = get_tool_instance(default_tool_name) if specific_tool: agent_tools.append(specific_tool) except Exception as e: print(f"Warning: Failed to get tools from registry for {name}: {e}") if tools: # Merge extra tools existing_names = {t.name for t in agent_tools if hasattr(t,'name')} agent_tools.extend([t for t in tools if getattr(t, 'name', None) not in existing_names]) if not agent_tools: print(f"CRITICAL Warning: CoderAgent '{name}' initialized with NO tools!") # 3. Define System Prompt (using the capabilities) tool_name_for_prompt = next((t.name for t in agent_tools if hasattr(t, 'name') and 'code' in t.name.lower()), default_tool_name) # Try to get actual tool name base_prompt = f"""You are an expert Coder Agent interacting with a secure, sandboxed Linux environment provided by the '{tool_name_for_prompt}' tool. Your goal is to fulfill coding, file manipulation, or shell command requests by generating and executing appropriate code or commands within this sandbox. Available Tools: {self._format_tools_for_prompt(agent_tools)} - **{tool_name_for_prompt}**: Executes Python code or shell commands within the sandboxed Linux environment. Returns stdout, stderr, execution errors, and potentially file outputs or structured results (like image data). To run shell commands, generate Python code that uses the 'subprocess' module OR if the tool directly supports it, prefix the command with '!'. Always prefer generating Python code for complex shell operations or when needing output capture. Key Capabilities of the Sandbox Environment (via the tool): - Execute Python 3 code. - Install Python packages using pip (generate code like `import subprocess; subprocess.run(['pip', 'install', 'requests'], check=True)`). - Run standard Linux shell commands (e.g., `ls`, `pwd`, `mkdir`, `curl`, `git`, etc. using Python's subprocess). - Access and manipulate a persistent filesystem within the sandbox (typically starting in `/home/user/` or `/`). Create, read, write, delete files and directories. - Access the internet from within the sandbox for tasks like cloning repos or fetching data. Workflow & Instructions: 1. **Analyze Request**: Understand the goal, constraints, and required inputs/outputs. 2. **Plan Steps**: Outline the necessary code or commands. Consider file paths, dependencies, and error handling. 3. **Generate Code/Command**: Write the Python code or shell command sequence needed. For non-trivial Python, include comments. 4. **Execute using Tool**: Prepare the arguments for the '{tool_name_for_prompt}' tool (usually the code string or command string) and invoke the tool. 5. **Analyze Output**: Carefully review the stdout, stderr, errors, and any results returned by the tool. 6. **Debug/Iterate**: If errors occurred or the output is not as expected, analyze the error, revise the code/command, and execute again using the tool. 7. **Final Output**: Once the task is successfully completed, provide the final working code (if relevant), a summary of the execution results (stdout/stderr highlights), confirmation of file operations, and any requested explanation. If the task cannot be completed, explain why. 8. **File Handling**: If generating files (code, data, images), clearly state the full path within the sandbox where the file was saved (e.g., `/home/user/my_script.py`, `/home/user/output.csv`). Do not attempt to display images directly in your response. Focus strictly on tasks achievable within the sandboxed environment using the provided tool. Be precise and careful with file paths and commands. """ # 4. Call super().__init__ super().__init__( name=name, model=model, tools=agent_tools, prompt=base_prompt, description=description, checkpointer=checkpointer, max_context_messages=max_context_messages, max_context_tokens=max_context_tokens, **kwargs ) print(f"CoderAgent '{self.name}' initialized with tools: {[t.name for t in self.tools if hasattr(t,'name')]}") # Inherits _format_tools_for_prompt and other methods from BaseAgent/ReactAgent ================================================ FILE: core/agents/sub_agents/data_analyst_agent.py ================================================ # data_analyst_agent.py (or in main.py) from typing import Any, List, Optional, Union, Callable, Type from langchain_core.language_models import LanguageModelLike from langchain_core.tools import BaseTool from langchain_core.messages import SystemMessage from langgraph.types import Checkpointer # Internal imports - ensure paths are correct from core.agents.base.react_agent import ReactAgent from core.tools.registry import get_tools_by_category, ToolCategory, get_tool_instance # Import necessary functions import logging logger = logging.getLogger(__name__) # Assume ToolCategory.CODE_INTERPRETER exists # Assume ToolCategory.FILE_SYSTEM exists if needed class DataAnalystAgent(ReactAgent): """ Data Analyst Agent (Refactored) - Focuses on analyzing structured data using code execution sandbox. - Generates insights and saves visualizations to files. """ def __init__( self, name: str = "data_analyst_expert", model: LanguageModelLike = None, tools: Optional[List[Union[BaseTool, Callable]]] = None, checkpointer: Optional[Checkpointer] = None, max_context_messages: Optional[int] = None, max_context_tokens: Optional[int] = 120000, # Analysis might need decent context debug: bool = False, **kwargs ): # 1. Define Description for Supervisor description = "Analyzes structured data (provided in context or potentially read from sandbox files) using Python (Pandas, NumPy, Matplotlib, Seaborn) within a secure code execution environment. Performs statistical analysis, identifies trends, generates insights, and creates data visualizations (saved as files in the sandbox)." # 2. Get Tools from Registry agent_tools = [] default_tool_name = "e2b_code_interpreter" # Tool needed for execution try: # Primarily needs Code Interpreter code_tools = get_tools_by_category(ToolCategory.CODE_INTERPRETER) + get_tools_by_category(ToolCategory.FILE_SYSTEM) # 需要代码和文件工具 agent_tools.extend(code_tools) # Optionally, add File System tools if needed to read data files # fs_tools = get_tools_by_category(ToolCategory.FILE_SYSTEM) # agent_tools.extend(fs_tools) print(f"[{name}] Loaded tools from registry: {[t.name for t in agent_tools if hasattr(t,'name')]}") # Verify the main execution tool is present if not any(getattr(t,'name', None) == default_tool_name for t in agent_tools): print(f"CRITICAL Warning: DataAnalystAgent '{name}' is missing the primary '{default_tool_name}' tool!") specific_tool = get_tool_instance(default_tool_name) if specific_tool: agent_tools.append(specific_tool) except Exception as e: print(f"Warning: Failed to get tools from registry for {name}: {e}") if tools: # Merge extra tools existing_names = {t.name for t in agent_tools if hasattr(t,'name')} agent_tools.extend([t for t in tools if getattr(t, 'name', None) not in existing_names]) if not agent_tools: print(f"CRITICAL Warning: DataAnalystAgent '{name}' initialized with NO execution tools!") # 3. Define System Prompt tool_name_for_prompt = next((t.name for t in agent_tools if hasattr(t, 'name') and 'code' in t.name.lower()), default_tool_name) base_prompt = f"""You are an expert Data Analyst. Your task is to analyze data using Python code within a secure sandbox environment accessed via the '{tool_name_for_prompt}' tool. Libraries like Pandas, NumPy, Matplotlib, and Seaborn are available (install if needed using pip in your code). Available Tools: {self._format_tools_for_prompt(agent_tools)} - **{tool_name_for_prompt}**: Executes Python code in the sandbox. Returns stdout, stderr, errors, and potentially structured results. Key Instructions: 1. **Understand Data & Goal**: Identify the data source (likely provided in previous messages or mentioned as a sandbox file path like '/home/user/data.csv') and the specific analysis question or goal. 2. **Plan Analysis**: Briefly outline the Python code steps (e.g., load data into Pandas DataFrame, clean/transform data, perform calculations, generate plot). 3. **Write Python Code**: Generate the necessary Python code. Use libraries effectively. Import necessary libraries (e.g., `import pandas as pd`, `import matplotlib.pyplot as plt`). 4. **Handle Files (If Needed)**: If reading/writing files within the sandbox, use standard Python file I/O within your code (e.g., `pd.read_csv('/home/user/data.csv')`, `df.to_csv('/home/user/output.csv')`). 5. **Handle Visualizations**: If asked to create plots: * Generate the plot using Matplotlib/Seaborn. * **MUST save the plot to a file** inside the sandbox (e.g., `/home/user/plots/my_plot.png`). Use `plt.savefig('/home/user/plots/my_plot.png')`. Create directories if necessary (`os.makedirs('/home/user/plots', exist_ok=True)`). * Use `plt.show()` or `plt.close()` after saving to clear the plot buffer. * **DO NOT attempt to return image data directly.** Images cannot be displayed in the response. * In your response, **state that the plot was generated and provide the full path** where it was saved in the sandbox (e.g., "I have generated a scatter plot and saved it to /home/user/plots/scatter_plot.png"). 6. **Execute Code**: Use the '{tool_name_for_prompt}' tool to run your complete Python script. 7. **Analyze Results**: Interpret the output (stdout, numerical results, errors) from the tool execution. 8. **Present Findings**: Summarize your analysis and findings clearly. Use Markdown tables for structured data if helpful. Mention any plots saved and their paths. If errors occurred, explain them. 9. **Focus**: Concentrate on data analysis using code execution. Do not perform web searches unless specifically instructed and given tools for it. """ # 4. Call super().__init__ super().__init__( name=name, model=model, tools=agent_tools, prompt=base_prompt, description=description, checkpointer=checkpointer, max_context_messages=max_context_messages, max_context_tokens=max_context_tokens, debug=debug, **kwargs ) print(f"DataAnalystAgent '{self.name}' initialized.") # Inherits _format_tools_for_prompt and other methods ================================================ FILE: core/agents/sub_agents/designer_agent.py ================================================ # 文件路径示例: reason_graph/designer_agent.py from typing import Any, List, Optional, Union, Callable, Type from langchain_core.language_models import LanguageModelLike # 确保导入正确类型 from langchain_core.tools import BaseTool from langchain_core.messages import SystemMessage from langgraph.types import Checkpointer # 内部导入 from core.agents.base.react_agent import ReactAgent from core.tools.registry import get_tools_by_category, ToolCategory # 导入 Registry # 假设您的 Flux 工具已注册或在此导入 # from core.tools.flux_image_tool import FluxImageGeneratorTool import logging logger = logging.getLogger(__name__) # 假设的 ToolCategory.IMAGE_GENERATION if not hasattr(ToolCategory, 'IMAGE_GENERATION'): ToolCategory.IMAGE_GENERATION = ToolCategory.OTHER class DesignerAgent(ReactAgent): """ 设计 Agent (重构版) - 能够理解图像上下文,并使用工具生成新的视觉内容。 - 应用设计原则来完成海报、网页等设计任务。 """ def __init__( self, name: str = "designer_expert", model: LanguageModelLike = None, # <--- 必须传入多模态模型 (e.g., gpt-4o) tools: Optional[List[Union[BaseTool, Callable]]] = None, checkpointer: Optional[Checkpointer] = None, max_context_messages: Optional[int] = None, max_context_tokens: Optional[int] = 8000, # 调整上下文需求 debug: bool = False, **kwargs ): # 1. 定义 Agent 描述 description = "Understands images provided in context and generates new visual content (images, mockups, diagrams) using specialized image generation tools (like Flux). Can apply design thinking for tasks like poster or web page layout design." # 2. 获取工具 (主要是图像生成工具) agent_tools = [] try: # 从 Registry 获取图像生成工具 img_tools = get_tools_by_category(ToolCategory.IMAGE_GENERATION) agent_tools.extend(img_tools) # 也可以直接实例化 # agent_tools.append(FluxImageGeneratorTool()) # 如果不使用 Registry print(f"[{name}] Loaded tools: {[t.name for t in agent_tools if hasattr(t,'name')]}") except Exception as e: print(f"Warning: Failed to get IMAGE_GENERATION tools for {name}: {e}") if tools: # 合并额外工具 existing_names = {t.name for t in agent_tools if hasattr(t,'name')} agent_tools.extend([t for t in tools if getattr(t, 'name', None) not in existing_names]) if not agent_tools: print(f"CRITICAL Warning: DesignerAgent '{name}' initialized with NO generation tools!") # 3. 定义 System Prompt tool_name_for_prompt = next((t.name for t in agent_tools if hasattr(t, 'name') and 'generat' in t.name.lower()), "image_generator_tool") # 获取工具名 base_prompt = f"""You are an expert Visual Designer and Creative Assistant. Your capabilities include understanding images provided in the conversation history and generating new images using available tools based on detailed text prompts. Available Tools: {self._format_tools_for_prompt(agent_tools)} - **{tool_name_for_prompt}**: Use this tool to generate images. Input requires a detailed 'prompt'. Key Instructions & Workflow: 1. **Understand Request**: Analyze the user request, paying attention to both text and any images provided in the message history. Identify the core visual goal (e.g., analyze image, generate image, design layout). 2. **Image Understanding (If Applicable)**: If the request involves analyzing or describing an existing image from the history, provide your analysis directly based on your multimodal understanding. 3. **Design Thinking (For Generation/Design Tasks)**: * **Clarify**: If the request is vague (e.g., "design a logo"), think about necessary elements: target audience, brand feeling, key symbols, color preferences, desired style (minimalist, vintage, futuristic, etc.). You might need to state assumptions if details are missing. * **Conceptualize**: Describe the visual elements, layout, color palette, and overall composition you plan to generate. * **Formulate Prompt for Tool**: Translate your design concept into a **highly detailed and descriptive text prompt** suitable for the `{tool_name_for_prompt}`. Include style, mood, composition, colors, and specific objects. 4. **Use Generation Tool**: Call the `{tool_name_for_prompt}` with the detailed prompt you formulated. 5. **Present Result**: * State that you have generated the image. * Provide the result from the tool (e.g., the image URL or identifier). * Briefly describe the generated image and how it matches the design concept or request. * **Important**: Do NOT attempt to display the image directly in your text response. Only provide the URL or description. 6. **Handle Errors**: If the tool fails, report the error clearly. Focus on visual design and generation tasks. Use your understanding of design principles when conceptualizing visuals for requests like posters or web mockups. """ # 4. 调用父类 __init__ super().__init__( name=name, model=model, # 必须是多模态模型 tools=agent_tools, prompt=base_prompt, description=description, checkpointer=checkpointer, max_context_messages=max_context_messages, max_context_tokens=max_context_tokens, debug=debug, **kwargs ) print(f"DesignerAgent '{self.name}' initialized.") # 继承 _format_tools_for_prompt 和其他 BaseAgent/ReactAgent 方法 ================================================ FILE: core/agents/sub_agents/reporter_agent.py ================================================ # 文件路径: reason_graph/reporter_agent.py import json import time from datetime import datetime from typing import Dict, Any, List, Optional, Union, Type, cast, Sequence # --- LangChain / LangGraph --- from langchain_core.language_models import LanguageModelLike from langchain_core.tools import BaseTool from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage, AIMessage from langchain_core.runnables import RunnableConfig, Runnable from langgraph.graph import StateGraph, END, START # 导入 StateGraph, END, START from langgraph.graph.graph import CompiledGraph from langgraph.types import Checkpointer # --- 内部导入 --- from core.agents.base.base_agent import BaseAgent # 导入最终版 BaseAgent # 导入最终报告的 Prompt 模板 import logging logger = logging.getLogger(__name__) class ReporterAgent(BaseAgent): """ 报告 Agent (最终版) - 继承自 BaseAgent。 - 负责基于完整的消息历史和明确指令生成最终 Markdown 报告。 - 内部包含一个简单的图用于执行报告生成任务。 """ FINAL_REPORT_SYSTEM_PROMPT_TEMPLATE = """You are a professional writer and editor AI assistant. Your primary goal is to generate high-quality, well-structured text content based on the specific instructions provided in the latest message and the relevant information available in the preceding conversation history. The current date is {current_date}. **Your Task Execution Workflow:** 1. **Identify Instructions:** Carefully read the **last message** you received, which contains the specific writing task assigned to you by the supervisor. Understand the desired output (e.g., summary, report section, full report), format, tone, and any other requirements. 2. **Gather Context:** Review the preceding messages in the conversation history to find the necessary information, data points, findings, or creative elements needed to complete the assigned task. 3. **Compose Output:** Write the text according to the instructions. * If asked for creative content (like a poem), focus on fulfilling the creative request. * If asked for a summary or section, synthesize the relevant information concisely and accurately. * If asked to compile a **full report**, structure it logically (e.g., Introduction, Body, Conclusion), use Markdown formatting effectively, and incorporate information/citations from the history as instructed. Adhere to any specified length or style guidelines. 4. **Final Response:** Your output should be **only** the requested written text. Do not add extra conversational phrases unless necessary for context. Do not include planning directives or attempt to call tools (unless a specific writing/editing tool was provided and instructed for use). If you cannot fulfill the request due to missing information in the history, state that clearly. """ def __init__( self, name: str = "reporter_expert", model: LanguageModelLike = None, # 应传入适合长文本生成的模型 checkpointer: Optional[Checkpointer] = None, max_context_messages: Optional[int] = None, max_context_tokens: Optional[int] = 16000, # 报告生成可能需要处理长上下文 debug: bool = False, prompt_template: str = FINAL_REPORT_SYSTEM_PROMPT_TEMPLATE, # 使用最终报告模板 **kwargs # 接收其他 BaseAgent 参数 ): # 1. 定义 Agent 描述 (给 Supervisor 看) description = "Synthesizes information from the complete conversation history and task results into a final, comprehensive, well-structured, and potentially cited Markdown research report, following specific instructions." # 2. 定义工具列表 (Reporter 通常不需要工具) agent_tools = [] # 3. 存储基础 Prompt 模板 (将在节点逻辑中使用) # 注意:我们将模板本身(或其引用)存储起来,而不是格式化后的 prompt self.report_prompt_template = prompt_template # 4. 调用父类 __init__ super().__init__( name=name, model=model, # 传入用于报告生成的 LLM tools=agent_tools, prompt=None, # BaseAgent 的 prompt 字段不直接用于此 Agent 的核心逻辑 description=description, checkpointer=checkpointer, max_context_messages=max_context_messages, max_context_tokens=max_context_tokens, # **kwargs 传递 debug 等 **kwargs ) print(f"ReporterAgent '{self.name}' initialized.") async def _generate_report_node_logic(self, state: Dict[str, Any], config: RunnableConfig) -> Dict[str, Any]: """报告生成节点的核心逻辑""" # 注意:这里的 state 已经是经过 BaseAgent._preprocess_state 处理后的状态 print(f"--- Entering Node: {self.name}._generate_report_node_logic ---") messages: List[BaseMessage] = state.get("messages", []) # 理论上,所有需要的信息都应该在 messages 历史中, # 特别是 Supervisor 委派时的最后一条指令消息。 if not messages: error_msg = "Error: No messages found in state for report generation." print(error_msg) return {"messages": [AIMessage(content=f"# Report Generation Failed\n\n{error_msg}", name=self.name)]} # --- 格式化 System Prompt (包含日期) --- try: current_date_str = datetime.now().strftime("%a, %b %d, %Y") system_prompt = self.report_prompt_template.format(current_date=current_date_str) except Exception as e: print(f"Error formatting report system prompt: {e}") system_prompt = "You are a report writing assistant. Synthesize the provided messages into a final report." # Fallback # --- 准备 LLM 输入 --- # 输入是 System Prompt + 完整的、经过预处理(截断)的消息历史 # BaseAgent 的 _preprocess_state 已经处理了截断 llm_input_messages = [SystemMessage(content=system_prompt)] + messages # --- 调用 LLM 生成报告 --- final_report_markdown = "" llm_error = None try: print(f"--- Calling LLM for Final Report Generation ({self.name}) ---") # 使用 self.model (初始化时传入的 LLM 实例) response = await self.model.ainvoke(llm_input_messages, config=config) final_report_markdown = response.content print(f"--- Report Generation LLM Call Successful ({self.name}). Length: {len(final_report_markdown)} chars ---") except Exception as e: print(f"!!! Error during Report Generation LLM call ({self.name}): {e}") llm_error = f"Report generation failed due to LLM error: {e}" final_report_markdown = f"# Report Generation Failed\n\nError: {str(e)}" # 可以在这里打印更详细的 traceback # import traceback # traceback.print_exc() # --- 返回包含报告或错误的状态更新 --- # Reporter 的最终输出就是报告本身,放入 messages 中,替换掉历史? # 不,应该追加,让调用者(Supervisor 或 main)能看到完整历史和最终报告 # 使用 AIMessage 返回报告 return { "messages": [AIMessage(content=final_report_markdown, name=self.name)], "error": state.get("error") or llm_error # 保留或记录错误 } def build(self) -> Optional[StateGraph]: """构建 Reporter Agent 的简单工作流: Start -> GenerateReport -> End """ if self._workflow: return self._workflow print(f"Building internal graph for ReporterAgent '{self.name}'") # Reporter 通常使用 BasicAgentState,因为它不直接操作 Plan # 但为了兼容 Supervisor 可能传递 PlanningAgentState,这里可以暂时用 Any # 或者定义一个 ReporterState workflow = StateGraph(Dict[str, Any]) # 使用通用字典状态,因为它只关心 messages # 添加报告生成节点,确保它能访问 self.model # functools.partial 不能直接用于异步实例方法,需要包装 async def node_wrapper(state, config): return await self._generate_report_node_logic(state, config) workflow.add_node("generate_report", node_wrapper) # type: ignore workflow.add_edge(START, "generate_report") workflow.add_edge("generate_report", END) self._workflow = workflow return workflow # compile 方法继承自 BaseAgent # 它会调用上面的 build() 获取 StateGraph 定义,然后编译它, # 并创建包含预处理步骤 (_preprocess_state) 的最终 _executable_agent # invoke, ainvoke, get_agent (get_executable_agent), reset 继承自 BaseAgent ================================================ FILE: core/agents/sub_agents/research_agent.py ================================================ # 文件路径示例: reason_graph/research_agent.py from typing import Any, List, Optional, Union, Callable, Type, cast from langchain_core.language_models import LanguageModelLike from langchain_core.tools import BaseTool from langchain_core.messages import SystemMessage from langgraph.types import Checkpointer # 内部导入 - 请确保路径正确 from core.agents.base.react_agent import ReactAgent # 导入工具 Registry 相关 - 只需要 get_tools_by_category 和 ToolCategory from core.tools.registry import get_tools_by_category, ToolCategory # *** 不再需要导入 get_tool 或 get_registered_tools *** import logging logger = logging.getLogger(__name__) # 假设 ToolCategory 包含 SEARCH 和 WEB_Browse if not hasattr(ToolCategory, 'SEARCH'): ToolCategory.SEARCH = ToolCategory.OTHER if not hasattr(ToolCategory, 'WEB_Browse'): ToolCategory.WEB_Browse = ToolCategory.OTHER class ResearchAgent(ReactAgent): """ 研究 Agent (重构版) - 继承自新的 ReactAgent - 专注于定义自身工具和 Prompt - 移除了自定义的状态管理和方法 """ def __init__( self, name: str = "research_expert", model: LanguageModelLike = None, tools: Optional[List[Union[BaseTool, Callable]]] = None, checkpointer: Optional[Checkpointer] = None, max_context_messages: Optional[int] = None, max_context_tokens: Optional[int] = 8000, debug: bool = False, **kwargs ): # 1. 定义 Agent 描述 (不变) description = "Expert at finding, extracting, and synthesizing the latest information, data, and background knowledge on specific topics using search engines (like Tavily, Google Search) and web Browse tools (like Firecrawl, Arxiv). Capable of providing source links and content summaries." # 2. --- 从 Registry 获取和合并工具 --- agent_tools: List[Union[BaseTool, Callable]] = [] search_tools_loaded: List[Union[BaseTool, Callable]] = [] # 用于后续检查 Browse_tools_loaded: List[Union[BaseTool, Callable]] = [] try: search_tools_loaded = get_tools_by_category(ToolCategory.SEARCH) agent_tools.extend(search_tools_loaded) try: Browse_tools_loaded = get_tools_by_category(ToolCategory.WEB_Browse) agent_tools.extend(Browse_tools_loaded) except Exception as e: if debug: print(f"[{name}] Info: Failed to get WEB_Browse tools: {e}") print(f"[{name}] Loaded tools from registry: {[t.name for t in agent_tools if hasattr(t,'name')]}") # --- 简化核心工具检查 --- if not search_tools_loaded: # 直接检查从 Registry 加载的搜索工具列表是否为空 print(f"CRITICAL Warning: ResearchAgent '{name}' initialized without any SEARCH tools from registry!") # ------------------------ except Exception as e: print(f"Warning: Failed to get tools from registry for {name}: {e}") # 合并外部传入的 `tools` 参数 (逻辑不变) if tools: # ... (合并逻辑不变) ... existing_tool_names = {t.name for t in agent_tools if hasattr(t, 'name')} added_external_count = 0 for tool in tools: tool_name = getattr(tool, 'name', None) if tool_name and tool_name not in existing_tool_names: agent_tools.append(tool) existing_tool_names.add(tool_name) added_external_count +=1 elif not tool_name: agent_tools.append(tool) added_external_count += 1 if added_external_count > 0: print(f"[{name}] Merged {added_external_count} external tool(s).") # --- 简化最终工具检查 --- if not agent_tools: print(f"CRITICAL Warning: ResearchAgent '{name}' initialized with NO tools configured!") # 不再需要那个复杂的 any(...) 检查 # ---------------------- # 3. 定义 Agent 的 System Prompt (逻辑不变) base_prompt = f"""You are a professional Research Analyst expert... Available Tools: {self._format_tools_for_prompt(agent_tools)} Instructions: - Analyze the request in the message history. - If the request requires searching for current information, facts, data, or background knowledge, you MUST use one of your search tools (like 'tavily_search_results'). - When using tools, formulate concise and effective search queries based on the request. - Synthesize the information found from the tools into a clear and informative answer. - If you use information from a tool, cite the source implicitly in your response (e.g., "According to [Source Title], ..."). - If the initial search is insufficient, analyze the results and decide if further searches with refined queries or different tools are needed. - If you cannot find the information after thorough searching, or if the tools return errors, clearly state the limitations encountered. Do not invent information. """ # 4. 调用父类 __init__ (逻辑不变) super().__init__( name=name, model=model, tools=agent_tools, prompt=base_prompt, description=description, checkpointer=checkpointer, max_context_messages=max_context_messages, max_context_tokens=max_context_tokens, debug=debug, **kwargs ) print(f"ResearchAgent '{self.name}' initialized with final tools: {[t.name for t in self.tools if hasattr(t,'name')]}") ================================================ FILE: core/llm/llm_manager.py ================================================ # reason_graph/llm_manager.py import os from enum import Enum, auto from typing import Any, Dict, List, Optional, Type, Union, Callable, Tuple from langchain_core.language_models import BaseChatModel, LanguageModelLike from langchain_openai import ChatOpenAI # (移除 ChatGroq 导入) from dotenv import load_dotenv # 加载环境变量 load_dotenv() class ModelType(Enum): """模型提供商类型枚举""" OPENAI = auto() XAI = auto() DEEPSEEK = auto() CUSTOM = auto() # 保持用于其他 OpenAI 兼容 API class ModelCapability(Enum): """模型能力枚举""" GENERAL = auto(); PLANNING = auto(); REASONING = auto() CREATIVE = auto(); RESEARCH = auto(); CODE = auto() LONG_CONTEXT = auto() class LLMManager: """ 模型管理器 (融合版 V2) - 在初始化时根据配置自动注册模型。 - 支持按能力获取模型。 - 支持延迟实例化。 - 从环境变量加载 API Keys/Base URLs。 """ def __init__(self): """初始化模型管理器,加载配置并自动注册模型""" self._models_config: Dict[str, Dict[str, Any]] = {} self._models_instance: Dict[str, BaseChatModel] = {} self._default_model_id: Optional[str] = None self._capability_models: Dict[ModelCapability, str] = {} # 加载 API Keys 和 Base URLs (保持不变) self._loaded_api_keys = { ModelType.OPENAI: os.getenv("OPENAI_API_KEY"), ModelType.XAI: os.getenv("XAI_API_KEY"), ModelType.DEEPSEEK: os.getenv("DEEPSEEK_API_KEY"), ModelType.CUSTOM: os.getenv("LLM_API_KEY"), } self._loaded_base_urls = { ModelType.OPENAI: os.getenv("OPENAI_BASE_URL"), ModelType.XAI: os.getenv("XAI_BASE_URL"), ModelType.DEEPSEEK: os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1"), ModelType.CUSTOM: os.getenv("LLM_BASE_URL"), } print("LLMManager initialized.") print("Loaded API Keys for:", [k.name for k, v in self._loaded_api_keys.items() if v]) print("Loaded Base URLs for:", {k.name: v for k, v in self._loaded_base_urls.items() if v}) # --- 自动注册模型 --- try: from .model_config import SUPPORTED_MODELS_CONFIG # 从配置文件导入 print("Registering models from config...") for model_id, config in SUPPORTED_MODELS_CONFIG.items(): # 检查所需 Key/URL 是否存在,如果不存在则跳过注册并警告 model_type = config.get("model_type") api_key = config.get("config_override", {}).get("api_key") or self._loaded_api_keys.get(model_type) base_url = config.get("config_override", {}).get("base_url") or self._loaded_base_urls.get(model_type) # OpenAI 可以只依赖 OPENAI_API_KEY 环境变量 if model_type == ModelType.OPENAI and not api_key: api_key = os.getenv("OPENAI_API_KEY") # 再次检查 OpenAI 专用 Key # 对于需要 Key 的类型进行检查 key_required = model_type not in [ModelType.CUSTOM] # 假设 CUSTOM 可能匿名 url_required = model_type in [ModelType.XAI, ModelType.CUSTOM] # DeepSeek 有默认值 if key_required and not api_key: print(f" Skipping registration for '{model_id}': Required API key for type '{model_type.name}' not found.") continue if url_required and not base_url: print(f" Skipping registration for '{model_id}': Required Base URL for type '{model_type.name}' not found.") continue # 调用内部注册方法 self._register_model( model_id=model_id, model_type=config["model_type"], model_name=config["model_name"], model_class=config.get("model_class"), # 可能为 None capabilities=config.get("capabilities", [ModelCapability.GENERAL]), set_as_default=config.get("is_default", False), config_override=config.get("config_override"), **config.get("kwargs", {}) ) print("Model registration complete.") # 可以在这里设置一个环境变量的默认模型 ID,如果配置中没有 is_default=True if not self._default_model_id and self._models_config: fallback_default = list(self._models_config.keys())[0] print(f"Warning: No default model marked in config. Falling back to first registered: '{fallback_default}'") self._default_model_id = fallback_default except ImportError: print("Warning: Could not import model_config.py. No models registered automatically.") except Exception as e: print(f"Error during automatic model registration: {e}") print(f"Default model set to: {self._default_model_id}") print(f"Capability mapping: {self.list_capabilities()}") print("-" * 20) # register_model 现在是内部方法 def _register_model( self, model_id: str, model_type: ModelType, model_name: str, model_class: Optional[Type[BaseChatModel]] = None, capabilities: List[ModelCapability] = [ModelCapability.GENERAL], set_as_default: bool = False, config_override: Optional[Dict[str, Any]] = None, **kwargs ) -> None: """(Internal) Registers a model configuration.""" if model_id in self._models_config: # Decide on behavior: overwrite or ignore? Let's overwrite with warning. print(f" Overwriting registration for existing model_id: '{model_id}'") # pass # If ignore is preferred if model_class is None: model_class = ChatOpenAI self._models_config[model_id] = { "type": model_type, "name": model_name, "class": model_class, "capabilities": list(set(capabilities)), "config_override": config_override or {}, "kwargs": kwargs, } print(f" Registered model config: '{model_id}' (Type: {model_type.name}, Class: {model_class.__name__})") if set_as_default: self._default_model_id = model_id print(f" Set '{model_id}' as default.") for capability in capabilities: if capability not in self._capability_models: self._capability_models[capability] = model_id print(f" Mapped capability '{capability.name}' to '{model_id}'.") def set_default_model(self, model_id: str) -> None: """设置默认模型""" if model_id not in self._models_config: raise ValueError(...) self._default_model_id = model_id def set_capability_model(self, capability: ModelCapability, model_id: str) -> None: """设置特定能力的模型""" if model_id not in self._models_config: raise ValueError(...) model_info = self._models_config[model_id] if capability not in model_info.get("capabilities", []): print(f"Warning: Model '{model_id}' not registered with capability '{capability.name}'.") self._capability_models[capability] = model_id # _get_instance (核心实例化逻辑) def _get_instance(self, model_id: str) -> BaseChatModel: """(Internal) Gets or creates a model instance.""" if model_id in self._models_instance: return self._models_instance[model_id] if model_id not in self._models_config: raise ValueError(f"Model ID '{model_id}' not registered or registration skipped due to missing config.") config = self._models_config[model_id] model_type = config["type"] model_name = config["name"] model_class = config["class"] config_override = config["config_override"] kwargs = config["kwargs"] # 确定 Key/URL (优先 override, 其次 env) api_key = config_override.get("api_key", self._loaded_api_keys.get(model_type)) base_url = config_override.get("base_url", self._loaded_base_urls.get(model_type)) # OpenAI 特殊 Key 处理 if model_type == ModelType.OPENAI and not api_key: api_key = os.getenv("OPENAI_API_KEY") # 检查必要配置 key_required = model_type not in [ModelType.CUSTOM] url_required = model_type in [ModelType.XAI, ModelType.DEEPSEEK, ModelType.CUSTOM] if key_required and not api_key: raise ValueError(f"API key required but not found for '{model_id}' (Type: {model_type.name}). Set in .env or config_override.") if url_required and not base_url: raise ValueError(f"Base URL required but not found for '{model_id}' (Type: {model_type.name}). Set in .env or config_override.") print(f"Instantiating model: ID='{model_id}', Type='{model_type.name}', Name='{model_name}'") # 准备构造函数参数 init_kwargs = kwargs.copy() if model_class == ChatOpenAI: init_kwargs['model'] = model_name if api_key: init_kwargs['openai_api_key'] = api_key if base_url: init_kwargs['openai_api_base'] = base_url # elif model_class == ChatGroq: ... # Removed else: # 尝试通用参数 init_kwargs['model'] = model_name # 很多兼容类可能也认 model init_kwargs['model_name'] = model_name if api_key: init_kwargs['api_key'] = api_key if base_url: init_kwargs['base_url'] = base_url # 移除内部配置键 for k in ["config_override", "capabilities", "type", "class", "name", "instance"]: init_kwargs.pop(k, None) # 实例化 try: instance = model_class(**init_kwargs) self._models_instance[model_id] = instance return instance except Exception as e: print(f"!!! Failed to instantiate model '{model_id}'") raise e # get_model 和 get_model_for_capability (保持不变, 调用 _get_instance) def get_model(self, model_id: Optional[str] = None) -> BaseChatModel: """获取模型实例 (通过 ID 或默认)""" target_id = model_id if target_id is None: if self._default_model_id is None: raise ValueError("No default model set.") target_id = self._default_model_id if target_id not in self._models_config: raise ValueError(f"Model ID '{target_id}' not registered.") return self._get_instance(target_id) def get_model_for_capability(self, capability: ModelCapability) -> BaseChatModel: """获取具有特定能力的模型实例""" if capability not in self._capability_models: print(f"No preferred model for '{capability.name}'. Falling back to default.") if self._default_model_id is None: raise ValueError(f"No model for '{capability.name}' and no default set.") model_id = self._default_model_id else: model_id = self._capability_models[capability] print(f"Using model '{model_id}' for capability '{capability.name}'.") return self.get_model(model_id) # list_models 和 list_capabilities (保持不变) def list_models(self) -> Dict[str, Dict[str, Any]]: """列出所有注册的模型及其配置""" result = {}; # ... (populate result) ... for model_id, model_info in self._models_config.items(): result[model_id] = { "type": model_info["type"].name, "name": model_info["name"], "class": model_info["class"].__name__, "capabilities": [c.name for c in model_info.get("capabilities", [])], "is_default": model_id == self._default_model_id, "kwargs": model_info.get("kwargs"), "config_override": model_info.get("config_override"), } return result def list_capabilities(self) -> Dict[str, str]: return {capability.name: model_id for capability, model_id in self._capability_models.items()} ================================================ FILE: core/llm/model_config.py ================================================ # reason_graph/model_config.py from langchain_openai import ChatOpenAI # from langchain_groq import ChatGroq # 不再需要 # (如果未来支持其他非 OpenAI 兼容的,在这里 import) from .llm_manager import ModelType, ModelCapability # 从同级 llm_manager 导入枚举 # 定义支持的模型及其配置 # key 是我们内部使用的 model_id SUPPORTED_MODELS_CONFIG = { "openai_gpt4o": { "model_type": ModelType.OPENAI, "model_name": "gpt-4o", # API 调用名 "model_class": ChatOpenAI, "capabilities": [ ModelCapability.GENERAL, ModelCapability.PLANNING, ModelCapability.REASONING, ModelCapability.CREATIVE, ModelCapability.LONG_CONTEXT, ModelCapability.CODE, ModelCapability.RESEARCH # GPT-4o 也能做一定研究 ], "is_default": False, # 不设为默认 "config_override": {}, # 允许覆盖 env vars, e.g., {'api_key': '...'} "kwargs": {"temperature": 0.1} # 传递给构造函数的额外参数 }, "openai_gpt4o_mini": { "model_type": ModelType.OPENAI, "model_name": "gpt-4o-mini", "model_class": ChatOpenAI, "capabilities": [ModelCapability.GENERAL, ModelCapability.REASONING, ModelCapability.CREATIVE], "is_default": True, # <--- 将其设为默认模型 "config_override": {}, "kwargs": {"temperature": 0.0} }, "xai_grok": { # 假设 ID 命名为 xai_grok "model_type": ModelType.XAI, "model_name": "grok-2-latest", # 或者是 xAI API 实际接受的模型名 "model_class": ChatOpenAI, # 假设使用兼容 OpenAI 的方式连接 "capabilities": [ModelCapability.GENERAL, ModelCapability.REASONING, ModelCapability.LONG_CONTEXT, ModelCapability.CREATIVE], "is_default": False, "config_override": {}, # Key/URL 将从 env (XAI_API_KEY, XAI_BASE_URL) 加载 "kwargs": {"temperature": 0.2} }, "deepseek_v3": { # 假设 ID 命名为 deepseek_chat "model_type": ModelType.DEEPSEEK, "model_name": "deepseek/deepseek-v3-0324", # DeepSeek Chat 模型 API 名 "model_class": ChatOpenAI, # 使用兼容 OpenAI 的方式连接 "capabilities": [ModelCapability.GENERAL, ModelCapability.REASONING, ModelCapability.CODE, ModelCapability.LONG_CONTEXT], "is_default": False, "config_override": {}, # Key/URL 将从 env (DEEPSEEK_API_KEY, DEEPSEEK_BASE_URL) 加载 "kwargs": {"temperature": 0.0} }, # --- 可以继续添加其他模型配置 --- # "groq_llama3_70b": { # "model_type": ModelType.GROQ, # "model_name": "llama3-70b-8192", # "model_class": ChatGroq, # 需要导入 ChatGroq # "capabilities": [...], # "is_default": False, # "config_override": {}, # "kwargs": {"temperature": 0.1} # }, } ================================================ FILE: core/mcp/README.md ================================================ # Mentis MCP 客户端与配置指南 本目录 (`core/mcp/`) 包含用于与模型上下文协议 (MCP - Model Context Protocol) 服务器进行交互的 Python 客户端实现。 ## 背景 MCP 旨在为 AI 模型(如 LLM Agent)提供一个标准的、与外部工具或服务进行交互的协议。本客户端的目标是提供一种灵活、可配置的方式来连接这些 MCP 服务器,并将它们提供的工具集成到 LangChain Agent 中。 ## 客户端 (`MCPClient`) 核心实现是 `MCPClient` 类 (位于 `client.py`),它具备以下特性: * **配置驱动:** 通过读取一个位于 `core/mcp/config.json` 的 JSON 文件来管理一个或多个服务器的连接/启动信息。兼容 "Cursor 风格" 的配置格式。 * **灵活连接:** * **启动本地服务 (stdio):** 如果配置文件中提供了 `command` 和 `args`,客户端会尝试执行该命令启动服务器进程,并通过 **STDIO** 建立通信。这对于使用 `uvx` 或 `python -m` 启动的标准 MCP 服务器很有用。 * **连接远程服务 (sse):** 如果配置文件中提供了 `url`,客户端会直接通过 **SSE** 连接到该 URL 对应的、已在运行的 MCP 服务器。 * **异步架构:** 基于 `asyncio` 构建,适合异步应用。 * **健壮的资源管理:** 使用 `contextlib.AsyncExitStack` 管理连接和会话,旨在提高关闭时的稳定性。 * **LangChain 集成支持:** 提供了加载 MCP 工具为 LangChain `BaseTool` 对象的基础(尽管存在适配器问题,见下文)。 ## 如何使用 ### 1. 配置服务器 (`core/mcp/config.json`) 你需要在此目录下创建一个 `config.json` 文件,定义你想要连接的 MCP 服务器。文件是一个 JSON 对象,键是服务器的逻辑名称,值是该服务器的配置详情。 **示例 `config.json` (只包含外部标准服务器):** ```json { "fetch_via_uvx": { "id": "fetch-uvx-stdio", "type": "mcp-server", "description": "Fetch Server launched by uvx via stdio", "connection": { "transport": "stdio", "command": "uvx", "args": [ "mcp-server-fetch" ], "timeout": 45 } }, "everything": { "id": "everything-stdio", "type": "mcp-server", "description": "Everything Server launched by npx via stdio", "connection": { "transport": "stdio", "command": "npx", "args": [ "-y", "@modelcontextprotocol/server-everything" ], "env": { // 如果 Everything Server 需要 API Keys, 在此添加 // 或确保运行客户端脚本的环境变量会被继承 // "OPENAI_API_KEY": "YOUR_KEY", // "TAVILY_API_KEY": "YOUR_KEY" }, "timeout": 60 } }, "external_sse_example": { "id": "external-sse", "type": "mcp-server", "description": "Connect to a pre-running SSE server (Example)", "connection": { "transport": "sse", "url": "http://localhost:9001/sse" // 假设有服务器在此运行 } } } ``` **重要:** * 使用 `command` 启动服务器时,确保 `command` (如 `uvx`, `npx`, `python`) 在你的环境中可用。 * 如果服务器需要 API Keys,请通过 `env` 字段或系统环境变量提供。 * `transport: "stdio"` 告诉我们的客户端使用 stdio 连接,`transport: "sse"` 告诉它使用 sse 连接。 ### 2. 客户端代码示例 使用 `config_loader.py` 加载配置,并通过 `async with` 语句使用 `MCPClient`。 ```python import asyncio import os from core.mcp.client import MCPClient from core.mcp.config_loader import load_config # 导入 LangChain 相关 (如果需要 Agent) from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent from langchain_core.tools import BaseTool, Tool # 导入工具的 Pydantic Schema (用于手动创建 Tool) from pydantic.v1 import BaseModel, Field # 或 v2 # --- Fetch Schema 示例 --- class FetchInputSchema(BaseModel): url: str = Field(..., description="URL to fetch") # ... 其他字段 ... async def main(): # --- 加载配置 --- config_path = os.path.join(os.path.dirname(__file__), "config.json") # 假设 config 在同目录 try: all_configs = load_config(config_path) # 选择要使用的配置 server_key = "fetch_via_uvx" # 或 "everything", "e2b_stdio" 等 mcp_config = all_configs.get(server_key) if not mcp_config: print(f"Config '{server_key}' not found.") return except Exception as e: print(f"Failed to load config: {e}") return # --- 使用 MCPClient --- async with MCPClient(mcp_config) as client: print(f"Connected to MCP Server '{server_key}'. Session active: {client.session is not None}") if not client.session: return # --- 获取和使用工具 --- # 方式一: 标准方式 (但存在已知问题) # print("\nAttempting standard tool loading via load_mcp_tools...") # loaded_tools = client.get_tools() # 内部调用 load_mcp_tools # print(f"load_mcp_tools returned {len(loaded_tools)} tools.") # # !! 注意:对于某些服务器实现 (如此处之前的 MentisMCPServer), # # !! load_mcp_tools 返回的工具对象的 args_schema 可能是错误的! # # !! 这会导致 Agent 调用失败。但对于 Fetch Server 这样的标准服务器, # # !! 它加载的 Schema 可能是正确的。需要根据打印的 Schema 判断。 # 方式二: 【当前推荐】手动创建 Tool 对象 (绕过 load_mcp_tools 问题) print("\nManually creating Tool object with correct schema...") tool_name = "fetch" # 假设测试 Fetch Server tool_description = "Fetches URL content." # 可以从服务器获取或手写 correct_schema = FetchInputSchema # 使用正确的 Pydantic 模型 # 定义调用逻辑 async def call_mcp_tool_wrapper(**kwargs) -> str: # ... (内部使用 client.session.call_tool 发送正确请求) ... # 参考 examples/14_mcp_fetch_test.py 中的实现 if not client or not client.session: return "ERROR: Session lost." try: req_params = {"name": tool_name, "arguments": kwargs} from mcp.types import CallToolRequest # 需要导入 request = CallToolRequest(method='tools/call', params=req_params) result = await client.session.call_tool(request) if hasattr(result, 'result'): return str(result.result) elif hasattr(result, 'error'): return f"Tool Error: {result.error.message}" else: return "Unknown response" except Exception as e: return f"Error: {e}" # 创建 LangChain Tool manual_tool = Tool.from_function( name=tool_name, description=tool_description, args_schema=correct_schema, coroutine=call_mcp_tool_wrapper ) tools_for_agent = [manual_tool] print(f"Manual tool '{manual_tool.name}' created.") # --- 使用 Agent --- try: # model = llm_manager.get_model("openai_gpt4o_mini") # 获取 LLM # agent = create_react_agent(model, tools_for_agent) # response = await agent.ainvoke(...) # print("Agent Response:", response) print("\nAgent execution part skipped in README example.") print("Refer to examples/14_mcp_fetch_test.py for full Agent integration.") except Exception as e: print(f"Agent execution error: {e}") # if __name__ == "__main__": # asyncio.run(main()) ``` ## 关于自建 MCP Server (MentisMCPServer) 我们在之前的开发中,尝试在 `core/mcp/server.py` 中构建了一个 `MentisMCPServer` 类,目的是将我们内部工具注册表 (`core/tools/registry.py`) 中的 LangChain `BaseTool` 动态包装成 MCP 工具。 **当前遇到的主要挑战:** 我们发现,当使用 `FastMCP` 库的 `@mcp.tool` 装饰器来动态注册这些包装器时,服务器未能正确地向客户端广播这些工具的**输入模式 (Schema)**。这导致客户端的 `load_mcp_tools` 收到了错误的 Schema 信息,进而使 LangChain Agent 在调用工具时因参数错误而失败。 虽然我们通过重构服务器的注册逻辑(改为在 `run_server.py` 中直接使用 `FastMCP` 实例注册顶层包装函数)**成功解决**了 Schema 广播的问题,使得 `load_mcp_tools` 能够获取到正确的 Schema,但后续测试发现 Agent (`create_react_agent`) 在调用这些工具时仍可能出现内部错误 (`TypeError`)。 **结论与建议:** 由于在结合 LangChain 工具、动态包装、`FastMCP` 和 LangChain Agent 时遇到了较深的库交互和调试障碍,我们**目前不建议**将 `MentisMCPServer` 作为稳定可靠的方案对外提供服务。 **推荐使用以下方式来提供或使用 MCP Server:** 1. **使用社区标准服务器:** 直接使用像 `mcp-server-fetch`, `@modelcontextprotocol/server-everything` 这样由社区或官方提供的、预构建好的 MCP 服务器。通过 `config.json` 配置 `command` (如 `uvx`, `npx`, `python -m`) 或 `url` 来使用它们。 2. **采用简单服务器模式:** 如果你需要自己实现 MCP Server 来暴露特定功能,建议参考 `modelcontextprotocol/servers` 仓库中的简单示例(如 `math_server`, `time_server`),采用**直接注册工具函数**(用 `@mcp_instance.tool` 装饰顶层 `async def` 函数)的模式,避免复杂的动态包装层。 ================================================ FILE: core/mcp/__init__.py ================================================ # core/mcp/__init__.py """ MCP (Model Context Protocol) 功能模块 """ ================================================ FILE: core/mcp/client.py ================================================ import os import asyncio from pathlib import Path from typing import List, Dict, Any, Optional, Union, Type, Literal, TypedDict, cast from types import TracebackType import re import sys import json import traceback from contextlib import asynccontextmanager, AsyncExitStack # --- MCP Imports --- from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.client.sse import sse_client # --- Adapter Import --- try: from langchain_mcp_adapters.tools import load_mcp_tools LOAD_MCP_TOOLS_AVAILABLE = True except ImportError: print("警告: 未找到 langchain-mcp-adapters。 load_mcp_tools 将不可用。") async def load_mcp_tools(session: ClientSession) -> list: return [] LOAD_MCP_TOOLS_AVAILABLE = False # --- LangChain / Pydantic Imports --- from langchain_core.tools import BaseTool try: from pydantic.v1 import BaseModel as BaseModelV1 except ImportError: from pydantic import BaseModel as BaseModelV1 # Fallback # --- Config Loader Import --- try: from .config_loader import MCPConfig, StdioConfig, SSEConfig except ImportError: print("WARNING: Could not import config models from .config_loader."); MCPConfig=Any; StdioConfig=Any; SSEConfig=Any # Placeholders print("--- DEBUG: Loading FINAL client.py (Config-Driven + AsyncExitStack) ---") class MCPClient: """Config-driven MCP Client using AsyncExitStack.""" def __init__(self, config: MCPConfig): self.config = config self.session: Optional[ClientSession] = None self.tools: List[BaseTool] = [] self._stack: AsyncExitStack = AsyncExitStack() self._server_process: Optional[asyncio.subprocess.Process] = None async def __aenter__(self) -> "MCPClient": print(f"DEBUG: MCPClient entering context for config ID: {getattr(self.config, 'id', 'N/A')}") try: connection_config = self.config.connection transport_ctx = None reader = None writer = None if isinstance(connection_config, SSEConfig) and connection_config.url: # --- Direct SSE --- print(f"DEBUG: Connecting via SSE to {connection_config.url}") transport_ctx = sse_client( connection_config.url, getattr(connection_config,'headers', None), getattr(connection_config,'timeout', 5.0), getattr(connection_config,'sse_read_timeout', 300.0) ) reader, writer = await self._stack.enter_async_context(transport_ctx) print("DEBUG: SSE transport context entered.") elif isinstance(connection_config, StdioConfig) and connection_config.command: # --- Launch via Command + STDIO --- print(f"DEBUG: Launching command via STDIO: {connection_config.command} {' '.join(connection_config.args)}") merged_env = os.environ.copy(); if connection_config.env: merged_env.update(connection_config.env) server_params = StdioServerParameters( command=connection_config.command, args=connection_config.args, env=merged_env, cwd=connection_config.cwd, encoding=connection_config.encoding, encoding_error_handler=connection_config.encoding_error_handler, startup_timeout=connection_config.timeout ) transport_ctx = stdio_client(server_params) reader, writer = await self._stack.enter_async_context(transport_ctx) print("DEBUG: STDIO transport context entered.") else: # Fallback/Error - Handle case where config might be wrong or transport missing # Added check for command presence before assuming SSE launch if hasattr(connection_config, 'command') and connection_config.command: # This is the complex "launch then connect SSE" case from the guide # Keeping it simple for now - if transport isn't 'stdio', it must be 'sse' with a URL raise NotImplementedError("Launching command for SSE connection (URL capture) not implemented in this client version. Use direct SSE URL or STDIO command.") else: raise ValueError("Invalid configuration: must have 'url' for SSE or 'command' for STDIO.") # --- Establish ClientSession --- session_kwargs = getattr(connection_config, 'session_kwargs', None) or {} session_ctx = ClientSession(reader, writer, **session_kwargs) self.session = await self._stack.enter_async_context(session_ctx) print("DEBUG: ClientSession context entered.") # --- Initialize and Load Tools (with Schema Check) --- print("Initializing MCP session...") await asyncio.wait_for(self.session.initialize(), timeout=30.0) print("MCP session initialized.") if LOAD_MCP_TOOLS_AVAILABLE: print("Loading MCP tools (via langchain-mcp-adapters)...") loaded_tools_from_mcp = await load_mcp_tools(self.session) print(f"Successfully loaded {len(loaded_tools_from_mcp)} tool descriptions.") print("--- Loaded Tools & Args Schema (Diagnostic) ---") self.tools = [] for i, tool in enumerate(loaded_tools_from_mcp): schema = getattr(tool, 'args_schema', 'N/A'); tool_name = getattr(tool, 'name', f'Tool_{i+1}') print(f"{i+1}. Tool Name: {tool_name}") schema_detail = "N/A" is_correct = None # Undetermined if schema != 'N/A': # Schema printing and basic check schema_dict = None if isinstance(schema, type) and issubclass(schema, BaseModelV1): try: schema_dict = schema.schema(); schema_detail = f"(PydanticV1): {json.dumps(schema_dict, indent=2)}" except Exception as e_schema: schema_detail = f"(PydanticV1): Error - {e_schema}" elif hasattr(schema, 'model_json_schema'): try: schema_dict = schema.model_json_schema(); schema_detail = f"(PydanticV2): {json.dumps(schema_dict, indent=2)}" except Exception as e_schema: schema_detail = f"(PydanticV2): Error - {e_schema}" else: schema_detail = f"(Unknown Type): {schema}" # Basic check: does it look like the faulty kwargs schema? if isinstance(schema_dict, dict): props = schema_dict.get('properties', {}) if list(props.keys()) == ['kwargs'] and props['kwargs'].get('type') == 'string': is_correct = False schema_detail += " <-- LOOKS WRONG (kwargs only!)" elif props: is_correct = True # Has properties other than just kwargs schema_detail += " <-- Looks structured correctly" else: is_correct = True # No properties, might be simple input schema_detail += " <-- No properties defined" else: is_correct = False # No schema is usually wrong print(f" Args Schema: {schema_detail}") print("-" * 15); self.tools.append(tool) print(f"Schema Check Result: {'All schemas look structured correctly.' if all(s is not False for s in [getattr(t, 'args_schema', None) != 'N/A' and 'kwargs' not in str(getattr(t, 'args_schema', '')).lower() for t in self.tools]) else 'One or more schemas look incorrect (kwargs only or missing)!'}") print("-------------------------------------------") else: print("Warning: load_mcp_tools unavailable."); self.tools = [] print(f"MCPClient ready. Loaded {len(self.tools)} tools via adapter.") return self except Exception as enter_err: print(f"ERROR: Failed during MCPClient __aenter__: {type(enter_err).__name__}: {enter_err}") await self.close(); raise async def __aexit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]): print("DEBUG: MCPClient exiting context..."); await self.close(); print("DEBUG: MCPClient context exited.") async def close(self): """Closes connections and resets state using AsyncExitStack.""" print("Closing MCP Client..."); if hasattr(self, '_stack') and self._stack: print(" Closing managed async contexts (via AsyncExitStack)...") try: await self._stack.aclose(); print(" AsyncExitStack closed.") except Exception as e: print(f"WARNING: Error closing AsyncExitStack: {type(e).__name__}: {e}") finally: self._stack = None else: print(" No active AsyncExitStack.") self.session = None; self.tools = []; self._transport_ctx = None; self._server_process = None print("MCP Client state reset.") def get_tools(self) -> List[BaseTool]: """Returns the list of tools loaded by load_mcp_tools.""" return self.tools ================================================ FILE: core/mcp/config_loader.py ================================================ # core/mcp/config_loader.py (修改 load_config 返回类型) import json import os from pathlib import Path from typing import Dict, Any, Optional, List, Literal, Union, Type # 导入 Type try: from pydantic.v1 import BaseModel, Field, ValidationError, validator PYDANTIC_V = 1 except ImportError: try: from pydantic import BaseModel, Field, ValidationError, validator # type: ignore PYDANTIC_V = 2 except ImportError: raise ImportError("Pydantic (v1 or v2) required.") from typing_extensions import TypedDict EncodingErrorHandler = Literal["strict", "ignore", "replace"] class StdioConfig(BaseModel): transport: Literal["stdio"] = "stdio"; command: str = Field(...) args: List[str] = Field(default_factory=list); env: Optional[Dict[str, str]] = None cwd: Optional[Union[str, Path]] = None; encoding: str = Field(default="utf-8") encoding_error_handler: EncodingErrorHandler = Field(default="strict") timeout: int = Field(default=30, gt=0); session_kwargs: Optional[Dict[str, Any]] = None if PYDANTIC_V == 1: class Config: extra = 'forbid' else: model_config = {'extra': 'forbid'} class SSEConfig(BaseModel): transport: Literal["sse"] = "sse"; url: str = Field(...) headers: Optional[Dict[str, Any]] = None; timeout: float = Field(default=5.0, gt=0) sse_read_timeout: float = Field(default=300.0, gt=0); session_kwargs: Optional[Dict[str, Any]] = None if PYDANTIC_V == 1: class Config: extra = 'forbid' else: model_config = {'extra': 'forbid'} class MCPConfig(BaseModel): """Represents the structure for a single server configuration.""" id: Optional[str] = Field(default=None) type: Literal["mcp-server"] = Field(default="mcp-server") description: Optional[str] = Field(default=None) connection: Union[StdioConfig, SSEConfig] = Field(..., discriminator='transport') if PYDANTIC_V == 1: class Config: extra = 'forbid' else: model_config = {'extra': 'forbid'} # --- 修改 load_config --- def load_config(config_path: Union[str, Path]) -> Dict[str, MCPConfig]: """ Loads the central MCP configuration JSON file and validates each server entry. Args: config_path: Path to the central config.json file. Returns: A dictionary where keys are server names and values are validated MCPConfig objects. """ config_p = Path(config_path).resolve() if not config_p.is_file(): raise FileNotFoundError(f"Configuration file not found at: {config_p}") print(f"DEBUG: Loading central MCP configuration from: {config_p}") validated_configs: Dict[str, MCPConfig] = {} try: with open(config_p, 'r', encoding='utf-8') as f: raw_config_dict = json.load(f) if not isinstance(raw_config_dict, dict): raise TypeError("Root configuration must be a JSON object (dictionary).") # 遍历字典中的每个服务器配置并验证 for server_name, config_data in raw_config_dict.items(): print(f"DEBUG: Validating config for server: '{server_name}'") if not isinstance(config_data, dict): print(f"WARNING: Entry for '{server_name}' is not a dictionary. Skipping.") continue try: # 确保 connection 和 transport 存在 if 'connection' not in config_data: raise ValueError("Missing 'connection'") if 'transport' not in config_data.get('connection', {}): raise ValueError("Missing 'transport' in connection") if PYDANTIC_V == 2: validated_config = MCPConfig.model_validate(config_data) else: # Pydantic V1 validated_config = MCPConfig.parse_obj(config_data) validated_configs[server_name] = validated_config print(f"DEBUG: Config for '{server_name}' validated successfully.") except (ValidationError, ValueError) as e_val: print(f"ERROR: Validation failed for server '{server_name}' config:\n{e_val}\nSkipping this server.") #可以选择继续加载其他配置,或者在这里 raise 让整个加载失败 if not validated_configs: print("WARNING: No valid server configurations were loaded.") print(f"DEBUG: Central configuration loaded. Found {len(validated_configs)} valid server configs.") return validated_configs except json.JSONDecodeError as e: print(f"ERROR: Failed to decode JSON from {config_p}: {e}"); raise except Exception as e: print(f"ERROR: An unexpected error occurred loading config {config_p}: {e}"); raise ================================================ FILE: core/mcp/mcp_server_config.json ================================================ { "fetch_via_uvx": { "id": "fetch-uvx-stdio", "type": "mcp-server", "description": "Fetch Server launched by uvx via stdio", "connection": { "transport": "stdio", "command": "uvx", "args": [ "mcp-server-fetch" ], "env": null, "cwd": null, "encoding": "utf-8", "encoding_error_handler": "strict", "timeout": 45 } }, "everything": { "id": "everything-stdio", "type": "mcp-server", "description": "Everything Server", "connection": { "transport": "stdio", "command": "npx", "args": [ "-y", "@modelcontextprotocol/server-everything" ], "env": null, "cwd": null, "encoding": "utf-8", "encoding_error_handler": "strict", "timeout": 45 } } } ================================================ FILE: core/mcp/run_server.py ================================================ # core/mcp/run_server.py (FINAL - Direct FastMCP Registration) import os import sys import argparse import traceback import logging from typing import List, Dict, Any, Optional, Type # --- Standard Setup --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger("mcp_server_direct") current_dir = os.path.dirname(os.path.abspath(__file__)); project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_dir))); sys.path.insert(0, project_root) # --- Imports --- from mcp.server.fastmcp import FastMCP # Import FastMCP directly # Assume registry is populated correctly by preregister_core_tools from core.tools.registry import get_registered_tools, get_tool_instance try: from core.tools import preregister_core_tools; PREREGISTER_AVAILABLE = True except ImportError: print("WARNING: preregister_core_tools not found"); def preregister_core_tools(): pass; PREREGISTER_AVAILABLE = False from langchain_core.tools import BaseTool import asyncio import time import json import functools import inspect # Needed for func_metadata potentially print("--- DEBUG: Loading FINAL run_server.py (Direct FastMCP Registration) ---") # --- Tool Wrapper Creation Logic (as a standalone function) --- def create_tool_wrapper(tool_instance: BaseTool): """ Creates the async wrapper function for a given tool instance. This function will be decorated LATER by the mcp_instance. """ tool_name = getattr(tool_instance, 'name', 'unknown_tool') print(f" DEBUG: Defining wrapper function for tool: '{tool_name}'") # Define the actual wrapper coroutine async def dynamic_tool_wrapper(tool_to_run=tool_instance, **kwargs): # Bind instance _tool_name = tool_to_run.name log_file = "/tmp/mcp_wrapper.log"; timestamp = time.strftime("%Y-%m-%d %H:%M:%S"); log_prefix = f"--- {timestamp} WRAPPER for '{_tool_name}' ---" log_lines = [f"{log_prefix} START", f"Received kwargs: {kwargs}"] try: # Main execution block result = None if hasattr(tool_to_run, '_arun'): log_lines.append(f"Calling await tool_to_run._arun(**kwargs)") result = await tool_to_run._arun(**kwargs) log_lines.append(f"Await _arun completed.") elif hasattr(tool_to_run, '_run'): log_lines.append(f"Calling tool_to_run._run(**kwargs) via run_in_executor") loop = asyncio.get_running_loop() sync_func_with_args = functools.partial(tool_to_run._run, **kwargs) result = await loop.run_in_executor(None, sync_func_with_args) log_lines.append(f"Executor _run completed.") else: log_lines.append("ERROR: Tool no _arun/_run!"); raise NotImplementedError(f"Tool {_tool_name} no method.") log_lines.append(f"Raw result type: {type(result)}"); log_lines.append(f"Raw value snippet: {str(result)[:500]}...") final_result = result try: json.dumps(result); log_lines.append("Result JSON serializable.") except TypeError: log_lines.append(f"WARN: Non-JSON type {type(result)}.->str."); final_result = str(result) log_lines.append(f"Returning final (type {type(final_result)})."); log_lines.append(f"{log_prefix} END (Success)") return {"result": final_result} except Exception as e: # Catch execution errors log_lines.append(f"!!! EXCEPTION in tool exec for '{_tool_name}': {e} !!!"); tb_lines = traceback.format_exc().splitlines(); log_lines.append("--- Traceback ---"); log_lines.extend(tb_lines); log_lines.append("-----------------"); log_lines.append(f"{log_prefix} END (Exception)") return f"ERROR_EXECUTING_TOOL_{_tool_name}: {str(e)}" # Return error string finally: # Ensure logging try: for line in log_lines: print(line, flush=True, file=sys.stderr) with open(log_file, "a") as f: f.write("\n".join(log_lines) + "\n\n") except Exception as log_e: print(f"!!! Logging Error for tool {_tool_name}: {log_e} !!!", flush=True, file=sys.stderr) # Return the created wrapper function AND the original tool's metadata return dynamic_tool_wrapper, tool_name, getattr(tool_instance, 'description', f"Tool {tool_name}") # --- Main Execution Logic --- def main(): parser = argparse.ArgumentParser(description='Start Mentis MCP Server (Direct Registration)') parser.add_argument('--transport', type=str, choices=['stdio', 'sse'], default='stdio'); parser.add_argument('--host', type=str, default='0.0.0.0'); parser.add_argument('--port', type=int, default=8000); parser.add_argument('--name', type=str, default='MentisMCP'); parser.add_argument('--tools', nargs='+'); parser.add_argument('--debug', action='store_true') args = parser.parse_args() if args.debug: logger.setLevel(logging.DEBUG); print("DEBUG Logging Enabled") try: # --- 1. Preregister tools into the central registry --- if PREREGISTER_AVAILABLE: print("DEBUG: Calling preregister_core_tools...") preregister_core_tools() # This populates the registry print("DEBUG: preregister_core_tools finished.") else: print("DEBUG: Skipping preregister_core_tools (unavailable).") # --- 2. Create FastMCP instance --- print(f"DEBUG: Creating FastMCP instance: name='{args.name}'") fastmcp_kwargs = {} if args.transport == 'sse': if args.host: fastmcp_kwargs['host'] = args.host if args.port: fastmcp_kwargs['port'] = args.port mcp_instance = FastMCP(args.name, **fastmcp_kwargs) # Create instance directly print(f"DEBUG: FastMCP instance created.") # --- 3. Load tools from registry and register wrappers with FastMCP --- registered_count = 0 target_tools = args.tools # List of names, or None for all # Get all tools first if needed all_tools_dict = get_registered_tools(as_dict=True) tools_to_register = {} if target_tools: # Filter if specific tools requested print(f"DEBUG: Filtering for specific tools: {target_tools}") for name in target_tools: if name in all_tools_dict: tools_to_register[name] = all_tools_dict[name] else: print(f"ERROR: Requested tool '{name}' not found in registry.") else: # Register all tools found in registry print("DEBUG: Registering all tools found in registry...") tools_to_register = all_tools_dict # Iterate and register the selected tools print(f"DEBUG: Attempting to register {len(tools_to_register)} tools with FastMCP...") for tool_name, tool_info in tools_to_register.items(): tool_instance = tool_info.get("tool") if isinstance(tool_instance, BaseTool): try: # Create the wrapper function and get metadata wrapper_func, name, description = create_tool_wrapper(tool_instance) # Register the wrapper directly using the mcp_instance decorator method mcp_instance.tool(name=name, description=description)(wrapper_func) print(f"DEBUG: Successfully registered '{name}' with FastMCP.") registered_count += 1 except Exception as e_register: print(f"ERROR: Failed to register wrapper for tool '{tool_name}': {e_register}") traceback.print_exc() else: print(f"WARNING: Item '{tool_name}' not a BaseTool, skipping.") print(f"DEBUG: Tool registration complete. {registered_count} tools registered with FastMCP.") if registered_count == 0: print("WARNING: No tools were registered!") # --- 4. Run the FastMCP server --- print(f"Starting MCP Server '{args.name}' (Transport: {args.transport})...") mcp_instance.run(transport=args.transport) except KeyboardInterrupt: print("Server shutting down..."); sys.exit(0) except Exception as e: print(f"Error starting server: {e}"); traceback.print_exc(); sys.exit(1) if __name__ == "__main__": main() ================================================ FILE: core/mcp/server.py ================================================ import os import sys import traceback import asyncio import time import json import functools from typing import Dict, Any, Optional, List # mcp & fastmcp from mcp.server.fastmcp import FastMCP from mcp.types import CallToolResult, TextContent, ErrorData # <-- 关键导入 # 修正路径,导入你自己的工具与 BaseTool sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) from core.tools.registry import get_registered_tools, get_tool_instance from langchain_core.tools import BaseTool print("--- DEBUG: Loading REFACTORED server.py (Fix InvalidSignature) ---") class MentisMCPServer: def __init__(self, name: str = "MentisMCP", host: Optional[str] = None, port: Optional[int] = None): print(f"DEBUG: Initializing MentisMCPServer(name='{name}', host={host}, port={port})") fastmcp_kwargs = {} if host is not None: fastmcp_kwargs['host'] = host if port is not None: fastmcp_kwargs['port'] = port try: print(f"DEBUG: Calling FastMCP(name='{name}', **{fastmcp_kwargs})") self.mcp = FastMCP(name, **fastmcp_kwargs) print("DEBUG: FastMCP initialized successfully.") except Exception as e_fastmcp: print("ERROR: Failed to initialize FastMCP!") print(traceback.format_exc()) raise # 记录注册成功的工具包装器 self.registered_tools_wrappers = {} def register_all_tools(self): """批量注册所有在 registry 中找到的 BaseTool""" tools_dict = get_registered_tools(as_dict=True) print(f"DEBUG: Registering all tools ({len(tools_dict)} found)...") registered_count = 0 for tool_name, tool_info in tools_dict.items(): tool_instance = tool_info.get("tool") if isinstance(tool_instance, BaseTool): if self._register_tool_with_simplified_wrapper(tool_instance): registered_count += 1 else: print(f"WARNING: Item '{tool_name}' not BaseTool, skipping.") print(f"DEBUG: Finished registering all tools. Registered: {registered_count}") def register_single_tool(self, tool_name: str): """仅注册特定名称的一个工具""" print(f"DEBUG: Attempting to register single tool: {tool_name}") try: tool_instance = get_tool_instance(tool_name) if not tool_instance: print(f"ERROR: Tool '{tool_name}' not found in registry.") return if isinstance(tool_instance, BaseTool): if self._register_tool_with_simplified_wrapper(tool_instance): print(f"DEBUG: Successfully registered single tool: {tool_instance.name}") else: print(f"ERROR: Failed wrapper registration for: {tool_instance.name}") else: print(f"WARNING: Tool '{tool_name}' not BaseTool, skipping.") except Exception as e: print(f"ERROR during register_single_tool for '{tool_name}': {e}") print(traceback.format_exc()) def _register_tool_with_simplified_wrapper(self, tool: BaseTool) -> bool: """ 为工具创建并注册一个简化的包装器 (Fix InvalidSignature), 并确保返回的数据符合 CallToolResult,以便客户端解析. """ try: tool_name = getattr(tool, 'name', None) tool_description = getattr(tool, 'description', None) if not tool_name or not isinstance(tool_name, str): print(f"ERROR: Invalid tool name: {tool_name}. Skip.") return False if not tool_description or not isinstance(tool_description, str): print(f"WARNING: Empty/invalid description for '{tool_name}'.") tool_description = f"Tool {tool_name}" print(f"DEBUG: Defining wrapper for tool: '{tool_name}'") @self.mcp.tool(name=tool_name, description=tool_description) async def simplified_tool_wrapper(tool_for_wrapper=tool, **kwargs): """ 同步或异步地调用 tool_for_wrapper,并将结果包装到 CallToolResult 中返回给客户端,以匹配 .content 或 .error. """ _tool_name = tool_for_wrapper.name log_file = "/tmp/mcp_wrapper.log" timestamp = time.strftime("%Y-%m-%d %H:%M:%S") log_prefix = f"--- {timestamp} WRAPPER for '{_tool_name}' ---" log_lines = [f"{log_prefix} START", f"Received kwargs: {kwargs}"] try: # 根据工具方法签名决定调用 _arun (异步) 或 _run (同步) result = None if hasattr(tool_for_wrapper, '_arun'): log_lines.append("Calling await tool_for_wrapper._arun(**kwargs)") result = await tool_for_wrapper._arun(**kwargs) log_lines.append("Await _arun completed.") elif hasattr(tool_for_wrapper, '_run'): log_lines.append("Calling tool_for_wrapper._run(**kwargs) via run_in_executor") loop = asyncio.get_running_loop() sync_func_with_args = functools.partial(tool_for_wrapper._run, **kwargs) result = await loop.run_in_executor(None, sync_func_with_args) log_lines.append("Executor _run completed.") else: log_lines.append(f"ERROR: Tool '{_tool_name}' has no _arun/_run!") raise NotImplementedError(f"Tool '{_tool_name}' cannot be invoked directly.") # 记录结果类型和内容片段 log_lines.append(f"Raw result type: {type(result)}") log_lines.append(f"Raw value snippet: {str(result)[:500]}...") # 关键:将结果包装成 CallToolResult,让客户端能识别 .content call_result = CallToolResult( content=[TextContent(text=str(result))] ) log_lines.append("Returning standard CallToolResult with .content.") log_lines.append(f"{log_prefix} END (Success)") return call_result except Exception as e: # 出现异常则使用 .error 返回 log_lines.append(f"!!! EXCEPTION in tool exec for '{_tool_name}': {e} !!!") tb_lines = traceback.format_exc().splitlines() log_lines.append("--- Traceback ---") log_lines.extend(tb_lines) log_lines.append("-----------------") log_lines.append(f"{log_prefix} END (Exception)") err_msg = f"ERROR_EXECUTING_TOOL_{_tool_name}: {str(e)}" return CallToolResult(error=ErrorData(message=err_msg)) finally: # 日志记录 try: for line in log_lines: print(line, flush=True, file=sys.stderr) with open(log_file, "a") as f: f.write("\n".join(log_lines) + "\n\n") except Exception as log_e: print(f"!!! Logging Error for '{_tool_name}': {log_e} !!!", flush=True, file=sys.stderr) # 修正下包装器的名字,避免重复 simplified_tool_wrapper.__name__ = f"{tool_name}_simplified_wrapper" self.registered_tools_wrappers[tool_name] = simplified_tool_wrapper print(f"DEBUG: Registered simplified wrapper for tool: '{tool_name}'") return True except Exception as registration_error: failed_tool_name = getattr(tool, 'name', 'unknown') print(f"ERROR: Failed to create/register wrapper for tool '{failed_tool_name}': {registration_error}") print(traceback.format_exc()) return False def run(self, transport: str = "stdio"): """运行 MCP 服务器 (签名中移除了 host/port)""" print(f"DEBUG: MentisMCPServer.run(transport='{transport}') called.") print(f"正在启动 MCP 服务器,传输方式: {transport}") if transport == "sse": # SSE 方式 host = 'N/A' port = 'N/A' if hasattr(self.mcp, 'settings'): host = getattr(self.mcp.settings, 'host', 'N/A') port = getattr(self.mcp.settings, 'port', 'N/A') print(f"配置 SSE 服务器监听在: http://{host}:{port} (如果 N/A 表示未配置或获取失败)") try: import importlib try: fastmcp_module = importlib.import_module('mcp.server.fastmcp') print(f"FastMCP version: {getattr(fastmcp_module, '__version__', '未知')}") except: pass import uvicorn import fastapi print(f"FastAPI: {fastapi.__version__}, Uvicorn: {uvicorn.__version__}") print(f"DEBUG: Calling self.mcp.run(transport='{transport}') for SSE") self.mcp.run(transport=transport) except Exception as e: print(f"SSE 服务器启动失败: {e}") print(traceback.format_exc()) raise else: # 默认 stdio 模式 print("启动 stdio 模式服务器...") try: print(f"DEBUG: Calling self.mcp.run(transport='{transport}') for STDIO") self.mcp.run(transport=transport) except Exception as e: print(f"stdio 服务器启动失败: {e}") print(traceback.format_exc()) raise ================================================ FILE: core/mcp/test/README.md ================================================ # MCP 测试框架说明 ## 概述 MCP(Machine Conversation Protocol)是一个用于机器对话的协议框架,它允许不同的系统通过标准化的接口进行通信。本测试框架提供了一种方式来测试MCP服务器的功能和性能。 ## 测试文件结构 测试框架包含以下主要文件: ### 1. minimal_fastmcp_test.py 这是一个最小化的FastMCP服务器实现,用于测试基本功能: - 创建FastMCP实例 - 注册简单的工具函数(ping工具) - 通过STDIO传输方式运行服务器 该文件可以独立运行,也可以被其他测试脚本作为子进程启动。 ### 2. test_minimal_client.py 这个脚本使用MCP客户端库来测试minimal_fastmcp_test.py: - 导入必要的MCP客户端库(ClientSession, stdio_client等) - 连接到minimal_fastmcp_test.py并测试ping工具 - 展示如何使用客户端API进行工具调用 ## 测试方法 ### 客户端库测试(test_minimal_client.py) 这种测试方法使用MCP客户端库与MCP服务器通信,展示了如何在实际应用中使用MCP客户端。测试流程如下: 1. 创建ClientSession对象 2. 连接到MCP服务器 3. 调用工具并处理结果 ## 运行测试 ### 运行客户端库测试 ```bash python core/mcp/test/test_minimal_client.py ``` ## 扩展测试 ### 添加新工具 要在minimal_fastmcp_test.py中添加新工具,可以按照以下步骤操作: 1. 定义新的异步工具函数 2. 使用FastMCP实例的装饰器注册工具 示例: ```python async def new_tool(param1: str, param2: int = 0) -> str: """A new tool description.""" # 工具实现 return f"Result: {param1}, {param2}" mcp_server.tool(name="new_tool", description="New tool description.")(new_tool) ``` ### 创建新的测试脚本 可以参考现有的测试脚本创建新的测试脚本,测试不同的功能或场景。 ## 常见问题 ### 服务器无响应 - 确保服务器进程正在运行 - 检查传输方式是否正确(stdio或sse) - 检查客户端连接参数是否正确 ### 工具调用失败 - 确保工具名称正确 - 检查参数是否符合工具的要求 - 查看服务器日志以获取更多信息 ## 总结 MCP测试框架提供了使用MCP客户端库测试MCP服务器功能的方法。通过这些测试,可以验证MCP服务器的基本功能和性能,为开发和调试提供支持。 ================================================ FILE: core/mcp/test/__init__.py ================================================ # MCP测试模块 # 包含用于测试MCP(Message Control Protocol)功能的各种测试脚本 ================================================ FILE: core/mcp/test/minimal_fastmcp_test.py ================================================ import asyncio from mcp.server.fastmcp import FastMCP import logging # 配置基本日志,看FastMCP内部是否有更多信息 logging.basicConfig(level=logging.INFO) logger = logging.getLogger("minimal_test") print("--- Minimal FastMCP Server Test ---") # 1. 创建 FastMCP 实例 # (假设 FastMCP 对于 stdio 不需要 host/port in __init__) mcp_server = FastMCP(name="MinimalServer") print("FastMCP instance created.") # 2. 定义一个简单的 async 工具函数 async def ping_tool(query: str = "default ping") -> str: """A very basic tool that just returns pong.""" print(f"\n--- PING TOOL CALLED! ---") # 在工具内部打印日志 print(f"Received query: {query}") result = f"pong: {query}" print(f"Returning: {result}") print(f"--- PING TOOL END ---") return result # 3. 直接用 FastMCP 实例的装饰器注册 try: mcp_server.tool(name="ping", description="Returns pong plus the query.")(ping_tool) # 上一行等价于: # @mcp_server.tool(name="ping", description="Returns pong plus the query.") # async def ping_tool(...) ... print("Tool 'ping' registered directly with FastMCP.") except Exception as e_reg: print(f"Error registering tool directly: {e_reg}") import traceback traceback.print_exc() exit(1) # 4. 运行服务器 (使用 STDIO) try: print("Starting minimal server with STDIO transport...") # 假设 run() 只需 transport 参数对 stdio 有效 mcp_server.run(transport="stdio") print("Server finished.") # 理应不会执行到,除非服务器停止 except Exception as e_run: print(f"Error running minimal server: {e_run}") import traceback traceback.print_exc() exit(1) ================================================ FILE: core/mcp/test/test_minimal_client.py ================================================ # test_minimal_client_fixed.py - 用于测试minimal_fastmcp_test.py的客户端脚本(修复版) import os import sys import asyncio import json import traceback from typing import Optional, Dict, Any # 添加项目根目录到路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) # 导入必要的MCP客户端库 try: from mcp import ClientSession from mcp.client.stdio import stdio_client, StdioServerParameters from mcp.types import CallToolRequest DEPS_OK = True except ImportError as e: print(f"错误: 缺少必要的依赖: {e}") print("请确保已安装mcp库: pip install mcp") DEPS_OK = False async def main(): """连接到minimal_fastmcp_test.py并测试ping工具""" print("=== MCP最小客户端测试(修复版)===\n") if not DEPS_OK: print("缺少必要的依赖,无法继续。") return # 准备minimal_fastmcp_test.py的路径 script_path = os.path.join(os.path.dirname(__file__), "minimal_fastmcp_test.py") cmd = [sys.executable, script_path] print(f"准备连接到服务器: {script_path}") try: # 创建StdioServerParameters对象 server_params = StdioServerParameters( command=sys.executable, args=[script_path], # 可以根据需要添加其他参数,如env, cwd等 ) print("已创建服务器参数配置。") # 创建STDIO客户端连接 print("\n创建STDIO客户端连接...") async with stdio_client(server_params) as (reader, writer): print("STDIO连接已建立。创建ClientSession...") async with ClientSession(reader, writer) as session: print("ClientSession已创建。初始化会话...") await session.initialize() print("会话已初始化。") # 获取服务器支持的工具列表 print("\n获取服务器支持的工具列表...") tools_result = await session.list_tools() print(f"服务器支持的工具: {tools_result}") # 调用ping工具 print("\n调用ping工具...") ping_request = CallToolRequest( method="tools/call", params={ "name": "ping", "arguments": {"query": "Hello, MCP!"} } ) try: print(f"发送请求: {ping_request}") result = await session.call_tool("ping", {"query": "Hello, MCP!"}) print(f"\n收到响应: {result}") if hasattr(result, 'result'): print(f"结果: {result.result}") elif hasattr(result, 'error'): print(f"错误: {result.error}") else: print(f"未知响应格式: {result}") except Exception as e: print(f"调用工具时出错: {e}") print(traceback.format_exc()) except Exception as e: print(f"运行测试时出错: {e}") print(traceback.format_exc()) if __name__ == "__main__": asyncio.run(main()) ================================================ FILE: core/tools/__init__.py ================================================ # Tools package initialization from langchain_community.agent_toolkits.load_tools import load_tools from core.tools.registry import register_tool, ToolCategory, get_registered_tools from core.tools.firecrawl_tool import FireCrawlTool from core.tools.e2b_tool import E2BCodeInterpreterTool import os import importlib import inspect from typing import Any, Dict, List, Type, Optional from langchain_core.tools import BaseTool # 导入预注册所需的工具 from langchain_community.tools import ( TavilySearchResults, ArxivQueryRun, ) from langchain_community.agent_toolkits import FileManagementToolkit from langchain_community.agent_toolkits.openapi.toolkit import RequestsToolkit,TextRequestsWrapper from langchain_community.tools.riza.command import ExecPython, ExecJavaScript from dotenv import load_dotenv load_dotenv() # 自动加载 .env 文件 # 预注册核心工具列表 - 定义需要预注册的核心工具 def preregister_core_tools(): """预注册核心工具,确保系统启动时这些工具已经可用""" print("开始预注册核心工具...") # 注册搜索类工具 try: # Tavily搜索工具 tavily_search = TavilySearchResults() register_tool(tavily_search, ToolCategory.SEARCH) print(f"已预注册工具: {tavily_search.name} (类别: {ToolCategory.SEARCH.value})") except Exception as e: print(f"预注册Tavily搜索工具失败: {e}") # 注册网页浏览类工具 try: # Arxiv查询工具 arxiv_tool = ArxivQueryRun() register_tool(arxiv_tool, ToolCategory.WEB_BROWSING) print(f"已预注册工具: {arxiv_tool.name} (类别: {ToolCategory.WEB_BROWSING.value})") except Exception as e: print(f"预注册Arxiv查询工具失败: {e}") try: # RequestoolKit请求工具 # 创建TextRequestsWrapper实例作为请求包装器 requests_wrapper = TextRequestsWrapper(headers={}) # 初始化RequestsToolkit,提供必要的参数 requests_toolkit = RequestsToolkit( requests_wrapper=requests_wrapper, allow_dangerous_requests=True # 允许危险请求,使工具可用 ) for req_tool in requests_toolkit.get_tools(): register_tool(req_tool, ToolCategory.WEB_BROWSING) print(f"已预注册工具: {req_tool.name} (类别: {ToolCategory.WEB_BROWSING.value})") except Exception as e: print(f"预注册 RequestoolKit请求工具失败: {e}") # 注册文件系统工具 try: # 获取当前目录作为文件系统工具的根目录 current_dir = os.getcwd() # 创建文件系统工具集 filesystem_toolkit = FileManagementToolkit( root_dir=current_dir, selected_tools=["write_file", "read_file", "list_directory"] ) # 获取文件系统工具并注册 for fs_tool in filesystem_toolkit.get_tools(): register_tool(fs_tool, ToolCategory.FILE_SYSTEM) print(f"已预注册工具: {fs_tool.name} (类别: {ToolCategory.FILE_SYSTEM.value})") except Exception as e: print(f"预注册文件系统工具失败: {e}") # 注册代码解释器工具 # try: # # Python REPL工具 # python_repl = ExecPython() # register_tool(python_repl, ToolCategory.CODE_INTERPRETER) # print(f"已预注册工具: {python_repl.name} (类别: {ToolCategory.CODE_INTERPRETER.value})") # except Exception as e: # print(f"预注册Python REPL工具失败: {e}") # # 注册代码解释器工具 # try: # # Python REPL工具 # javascript_repl = ExecJavaScript() # register_tool(javascript_repl, ToolCategory.CODE_INTERPRETER) # print(f"已预注册工具: {javascript_repl.name} (类别: {ToolCategory.CODE_INTERPRETER.value})") # except Exception as e: # print(f"预注册Python REPL工具失败: {e}") # 注册自定义工具 - FireCrawl工具 try: firecrawl_tool = FireCrawlTool() register_tool(firecrawl_tool, ToolCategory.WEB_BROWSING) print(f"已预注册工具: {firecrawl_tool.name} (类别: {ToolCategory.WEB_BROWSING.value})") except Exception as e: print(f"预注册FireCrawl工具失败: {e}") # 注册E2B代码解释器工具 try: e2b_tool = E2BCodeInterpreterTool() register_tool(e2b_tool, ToolCategory.CODE_INTERPRETER) print(f"已预注册工具: {e2b_tool.name} (类别: {ToolCategory.CODE_INTERPRETER.value})") except Exception as e: print(f"预注册E2B代码解释器工具失败: {e}") from .replicate_flux_tool import ReplicateFluxImageTool, category try: flux_tool = ReplicateFluxImageTool() if flux_tool._is_available: register_tool(flux_tool, category) except Exception as e: print(f"Failed to register ReplicateFluxImageTool: {e}") print("核心工具预注册完成") # 执行预注册 preregister_core_tools() # 注册 LangChain 工具 - 使用load_tools加载的工具列表 try: langchain_tools = load_tools(["serpapi"]) for tool in langchain_tools: register_tool(tool, ToolCategory.SEARCH) print(f"已注册LangChain工具: {tool.name} (类别: {ToolCategory.SEARCH.value})") except Exception as e: print(f"加载LangChain工具失败: {e}") # 工具类别映射 - 用于自动分类直接导入的工具 tool_category_mapping = { # 搜索类工具 "TavilySearchResults": ToolCategory.SEARCH, "GoogleSearchResults": ToolCategory.SEARCH, "GoogleSerperResults": ToolCategory.SEARCH, "WikipediaQueryRun": ToolCategory.SEARCH, "FireCrawl": ToolCategory.SEARCH, # 网页浏览类工具 "WebBrowser": ToolCategory.WEB_BROWSING, "ArxivQueryRun": ToolCategory.WEB_BROWSING, "RequestsGet": ToolCategory.WEB_BROWSING, "RequestsPost": ToolCategory.WEB_BROWSING, # 文件系统类工具 "WriteFile": ToolCategory.FILE_SYSTEM, "ReadFile": ToolCategory.FILE_SYSTEM, "ListDirectory": ToolCategory.FILE_SYSTEM, # 代码解释器类工具 "PythonREPL": ToolCategory.CODE_INTERPRETER, "ShellTool": ToolCategory.CODE_INTERPRETER, "E2BCodeInterpreterTool": ToolCategory.CODE_INTERPRETER, # 数据库类工具 "SQLDatabaseTool": ToolCategory.DATABASE, # 默认为其他类别 "default": ToolCategory.OTHER } def register_direct_tool(tool_instance: BaseTool, category: ToolCategory = None) -> None: """注册直接从langchain_community.tools导入的工具 Args: tool_instance: 工具实例 category: 工具类别,如果为None则自动根据工具名称判断类别 """ if not category: # 获取工具类名 tool_class_name = tool_instance.__class__.__name__ # 根据工具类名自动判断类别 category = tool_category_mapping.get(tool_class_name, tool_category_mapping["default"]) # 注册工具 register_tool(tool_instance, category) print(f"已注册工具: {tool_instance.name} (类别: {category.value})") # 获取 tools 目录路径 tools_dir = os.path.dirname(__file__) # 遍历目录中的所有文件,注册自定义工具 for filename in os.listdir(tools_dir): # 只处理 .py 文件,且排除 __init__.py 和 registry.py if filename.endswith('.py') and filename not in ['__init__.py', 'registry.py']: # 提取模块名(去掉 .py 后缀) module_name = filename[:-3] try: # 动态导入模块 module = importlib.import_module(f'.{module_name}', package='core.tools') # 查找模块中的工具类(继承自BaseTool的类) for name, obj in inspect.getmembers(module): # 检查是否是类且是BaseTool的子类 if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: # 检查该类是否已经被实例化并注册 tool_name = getattr(obj, 'name', None) if tool_name and tool_name not in [info['tool'].name for info in get_registered_tools().values()]: # 确定工具类别 category = getattr(module, 'category', ToolCategory.OTHER) # 实例化并注册工具 try: tool_instance = obj() register_tool(tool_instance, category) print(f"已注册工具类: {name} (工具名: {tool_instance.name}, 类别: {category.value})") except Exception as e: print(f"实例化工具类 {name} 时出错: {e}") except Exception as e: print(f"导入 {module_name} 时出错: {e}") ================================================ FILE: core/tools/e2b_tool.py ================================================ # core/tools/e2b_tool.py import os import json import asyncio import traceback from typing import Dict, Any, Optional, Type, List # 确保导入 List from pydantic import BaseModel, Field, PrivateAttr from langchain_core.tools import BaseTool # --- E2B Imports --- try: from e2b_code_interpreter import Sandbox from e2b_code_interpreter.exceptions import TimeoutException E2B_AVAILABLE = True except ImportError: Sandbox = None # type: ignore SandboxException = Exception # type: ignore # Fallback to base Exception TimeoutException = TimeoutError # type: ignore # Fallback to base TimeoutError E2B_AVAILABLE = False print("Warning: 'e2b' package not installed (pip install e2b). E2BCodeInterpreterTool will not work.") # --- Tool Category --- try: from .registry import ToolCategory, register_tool if not hasattr(ToolCategory, 'CODE_INTERPRETER'): ToolCategory.CODE_INTERPRETER = ToolCategory.OTHER category = ToolCategory.CODE_INTERPRETER except ImportError: category = None print("Tool registry not found.") # --- Input Schema (保持不变) --- class E2BCodeInterpreterToolInput(BaseModel): code: str = Field(description="要执行的Python代码") # --- Tool Class (优化版) --- class E2BCodeInterpreterTool(BaseTool): """ 使用 E2B SDK 在安全沙箱中执行 Python 代码的工具 (修正异常处理版)。 返回执行结果的字符串摘要。 """ name: str = "e2b_code_interpreter" description: str = ( # 可以稍微调整描述,强调是 Python 执行环境 "Executes Python code in a sandboxed environment. " "Input MUST be a JSON object with a 'code' key containing the Python code string. " "Libraries like matplotlib, pandas, numpy, sympy are available. Install others using pip (e.g., `import subprocess; subprocess.run(['pip', 'install', 'requests'])`). " "Use 'print()' to output results. For plots, save them to a file (e.g., '/home/user/plot.png') and state the path; do not return raw image data. " "Returns a string summarizing execution status, stdout, stderr, and any errors." ) args_schema: Type[BaseModel] = E2BCodeInterpreterToolInput _sandbox: Optional[Any] = PrivateAttr(default=None) _is_available: bool = PrivateAttr(default=False) _init_error: Optional[str] = PrivateAttr(default=None) # 不再需要 self.ExceptionClass def __init__(self, **kwargs): super().__init__(**kwargs) self._initialize_sandbox() def _initialize_sandbox(self): """初始化沙箱环境""" if not E2B_AVAILABLE: self._init_error = "Package 'e2b' not installed." print(f"ERROR: {self._init_error}") return if "E2B_API_KEY" not in os.environ: self._init_error = "Environment variable E2B_API_KEY not set." print(f"ERROR: {self._init_error}") return try: print("Initializing E2B Sandbox...") # 实例化 Sandbox self._sandbox = Sandbox() # 使用导入的 Sandbox 类 print("E2B Sandbox initialized successfully!") self._is_available = True self._init_error = None except (SandboxException, TimeoutException) as e: # <--- 捕获特定的 E2B 异常 self._init_error = f"Failed to initialize E2B Sandbox (E2B Error): {e}" print(f"ERROR: {self._init_error}") self._is_available = False except Exception as e: # 捕获其他意外错误 self._init_error = f"An unexpected error occurred during E2B Sandbox initialization: {e}" print(f"ERROR: {self._init_error}") self._is_available = False def _run(self, code: str, **kwargs) -> str: """同步执行 Python 代码并返回结果摘要字符串""" if not self._is_available or self._sandbox is None: # ... (返回包含设置指南的错误信息,不变) ... error_message = "E2B Sandbox is not available" if self._init_error: error_message += f": {self._init_error}" setup_guide = "\n\nSetup: pip install e2b; export E2B_API_KEY='...'" return f"Execution Failed: {error_message}{setup_guide}" output_summary = "" try: print(f"--- E2B: Executing code synchronously ---\n{code}\n--------------------------------------") # 使用 run_python 方法 execution = self._sandbox.run_code(code) # 构建结果字符串 (逻辑保持不变) if execution.error: output_summary += f"Execution Failed!\nError Name: {execution.error.name}\nError Value: {execution.error.value}\n" if execution.error.traceback: traceback_lines = execution.error.traceback.splitlines() output_summary += f"Traceback (last few lines):\n...\n" + "\n".join(traceback_lines[-5:]) else: output_summary += "Execution Successful.\n" if execution.logs.stdout: output_summary += f"\nSTDOUT:\n{execution.logs.stdout}" if execution.logs.stderr: output_summary += f"\nSTDERR:\n{execution.logs.stderr}" if execution.results: output_summary += "\n\nNote: Execution produced structured results (e.g., plots saved as files)." if not output_summary.strip() or output_summary.strip() == "Execution Successful.": output_summary = "Code executed successfully with no textual output." print(f"--- E2B: Execution finished ---\nResult Summary:\n{output_summary[:500]}...\n-----------------------------") return output_summary.strip() except (SandboxException, TimeoutException) as e: # <--- 捕获特定的 E2B 异常 error_str = f"Execution Failed (E2B Error)!\nError Name: {getattr(e, 'name', type(e).__name__)}\nDetails: {e}" # TimeoutException 可能没有 traceback 属性,SandboxException 通常有 tb = getattr(e, 'traceback', traceback.format_exc()) if tb: tb_lines = tb.splitlines() error_str += f"\nTraceback (last few lines):\n...\n" + "\n".join(tb_lines[-5:]) print(f"ERROR during E2B execution: {error_str}") return error_str except Exception as e: # 其他错误 error_str = f"Execution Failed (Unexpected Error)!\nError Type: {type(e).__name__}\nError Details: {str(e)}\nTraceback:\n{traceback.format_exc()}" print(f"ERROR during E2B execution: {error_str}") return error_str async def _arun(self, code: str, **kwargs) -> str: """异步执行 Python 代码并返回结果摘要字符串""" if not self._is_available or self._sandbox is None: # ... (返回错误信息) ... error_message = f"E2B Sandbox is not available: {self._init_error}" return f"Execution Failed: {error_message}" try: loop = asyncio.get_running_loop() import functools # 注意:传递给 run_in_executor 的函数应该是可调用的 # 这里 _run 是实例方法,所以直接传递 self._run 即可 # 但为了确保 code 参数被正确传递,可以用 lambda 或 partial sync_run_with_args = functools.partial(self._run, code=code, **kwargs) print(f"--- E2B: Executing code asynchronously via executor ---\n{code}\n--------------------------------------") result_summary = await loop.run_in_executor( None, sync_run_with_args ) print(f"--- E2B: Async execution finished ---") return result_summary except Exception as e: # run_in_executor 或 _run 的异常会在这里捕获 error_str = f"Execution Failed (Async Wrapper Error)!\nError Type: {type(e).__name__}\nError Details: {str(e)}" # 尝试获取 Traceback tb = traceback.format_exc() error_str += f"\nTraceback:\n{tb}" print(f"ERROR during E2B async execution: {error_str}") return error_str def close(self): """关闭沙箱,释放资源。""" if hasattr(self, "sandbox") and self._is_available and self._sandbox is not None: try: print("Attempting to close E2B Sandbox...") self._sandbox.kill() print("E2B Sandbox closed successfully.") self._is_available = False self._sandbox = None except (SandboxException, TimeoutException) as e: # 捕获特定异常 print(f"Error closing E2B Sandbox (E2B Error): {e}") except Exception as e: print(f"An unexpected error occurred while closing E2B Sandbox: {e}") model_config = { "arbitrary_types_allowed": True } # __del__ 方法用于对象销毁,通常不保证执行,不建议依赖它来关闭资源 # def __del__(self): self.close() ================================================ FILE: core/tools/firecrawl_tool.py ================================================ # 文件路径: core/tools/firecrawl_tool.py (或您存放工具的文件) import os import json # 虽然不直接返回 JSON,但可能用于处理 metadata from typing import Dict, List, Literal, Optional, Tuple, Type, Union, Any # 确保导入 from pydantic import BaseModel, Field, PrivateAttr # 导入 PrivateAttr from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain_core.tools import BaseTool from dotenv import load_dotenv load_dotenv() # 自动加载 .env 文件 # 尝试导入 FireCrawlLoader,如果失败则标记 try: from langchain_community.document_loaders import FireCrawlLoader FIRECRAWL_LOADER_AVAILABLE = True except ImportError: FireCrawlLoader = None # type: ignore FIRECRAWL_LOADER_AVAILABLE = False print("Warning: langchain_community or firecrawl-py not installed? FireCrawlLoader unavailable.") print("Run: pip install -U langchain-community firecrawl-py") # 定义输入 Schema (保持不变) class FireCrawlInput(BaseModel): """Input for the FireCrawl tool.""" url: str = Field(description="URL to crawl or scrape") mode: str = Field( default="scrape", # <-- 将默认模式改为 'scrape' 可能更常用 description="Mode: 'scrape' (single page), 'crawl' (multiple pages). Default: 'scrape'", ) # 可以添加 params 字段如果希望 LLM 控制更多参数 # params: Optional[Dict[str, Any]] = Field(default=None, description="Optional dictionary of additional FireCrawl parameters (e.g., {'pageOptions': {'onlyMainContent': True}})") class FireCrawlTool(BaseTool): """ Tool that uses FireCrawl API to crawl or scrape web content and return a summary. Setup: pip install -U langchain-community firecrawl-py export FIRECRAWL_API_KEY="your-api-key" Instantiate: tool = FireCrawlTool() # Reads API key from env # Or explicitly: tool = FireCrawlTool(api_key="...") Invoke: tool.invoke({"url": "https://example.com", "mode": "scrape"}) """ name: str = "firecrawl_web_content" # 建议用更描述性的名字 description: str = ( "Fetches and extracts the main textual content from a given URL. " "Use 'scrape' mode (default) for a single page, or 'crawl' mode to follow links (use sparingly). " "Input should be a URL. Returns a textual summary of the content." ) args_schema: Type[BaseModel] = FireCrawlInput # --- 配置属性 --- # API Key 可以通过 __init__ 传入,或者留空让 loader 从环境变量读取 _api_key: Optional[str] = PrivateAttr(default=None) # 使用 PrivateAttr 避免 Pydantic 验证 _api_url: Optional[str] = PrivateAttr(default=None) # 可以在 __init__ 中设置默认 mode 和 params,或者在 _run/_arun 中处理 default_mode: str = "scrape" # 工具级别的默认模式 default_params: Dict[str, Any] = Field(default_factory=dict) # 工具级别的默认参数 # 添加 __init__ 以便可以传入 api_key (可选) def __init__(self, api_key: Optional[str] = None, api_url: Optional[str] = None, mode: str = "scrape", params: Optional[Dict[str, Any]] = None, **kwargs): super().__init__(**kwargs) # Pydantic V2 中,非 model 字段需要用 PrivateAttr 或在 model_config 中设置 self._api_key = api_key self._api_url = api_url self.default_mode = mode self.default_params = params or {} # 检查 Loader 是否可用 if not FIRECRAWL_LOADER_AVAILABLE: print("ERROR: FireCrawlLoader is unavailable. Please install required packages.") def _run( self, url: str, mode: Optional[str] = None, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: # <--- 返回值必须是字符串 """使用工具同步获取网页内容。""" if not FIRECRAWL_LOADER_AVAILABLE: return "Error: FireCrawlLoader is not available. Required packages might be missing." # 确定使用的 API Key (优先实例属性,其次环境变量) key = self._api_key or os.getenv('FIRECRAWL_API_KEY') if not key: return "Error: FIRECRAWL_API_KEY not found in environment variables or instantiation." # 打印 Debug 信息 (可选) print(f"DEBUG [FireCrawlTool]: Running for URL='{url}', Mode='{mode or self.default_mode}'") # print(f"DEBUG [FireCrawlTool]: Effective API Key = {'*' * (len(key) - 4) + key[-4:] if key else 'None'}") try: current_mode = mode or self.default_mode loader = FireCrawlLoader( url=url, api_key=key, # 传递最终确定的 key api_url=self._api_url, # 传递实例属性或 None mode=current_mode, params=self.default_params, # 传递实例默认参数 ) print(f"--- Calling FireCrawl API (Sync) for: '{url}' ---") docs = loader.load() print(f"--- FireCrawl API call successful for: '{url}', received {len(docs)} document(s) ---") # --- 格式化结果为字符串 --- if not docs: return f"FireCrawl successful but returned no content from {url} (Mode: {current_mode}). The page might be empty or restricted." summary_parts = [f"Content summary from {url} (Mode: {current_mode}):"] content_limit = 4000 # 限制返回给 LLM 的总字符数 (可调整) current_length = len(summary_parts[0]) doc_count = 0 for doc in docs: # 可以考虑只返回第一个文档的内容,如果文档很多 # if doc_count >= 1 and current_mode == 'scrape': break source_info = f"\n\n--- Source: {doc.metadata.get('sourceURL', url)} ---" page_content = doc.page_content or "" available_length = content_limit - current_length - len(source_info) - 20 # 预留空间 if available_length <= 0 and doc_count > 0: # 如果已经有内容且空间不足 summary_parts.append("\n\n... (further content truncated)") break content = source_info + "\n" + page_content if len(content) > available_length: content = content[:available_length] + "... (truncated)" summary_parts.append(content) current_length += len(content) doc_count += 1 if current_length >= content_limit: break # 达到总长度限制 return "\n".join(summary_parts).strip() # --- 格式化结束 --- except Exception as e: error_msg = f"Error during FireCrawl for {url} (Mode: {mode or self.default_mode}): {repr(e)}" print(f"ERROR: {error_msg}") return error_msg # 返回错误信息字符串 async def _arun( self, url: str, mode: Optional[str] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: # <--- 返回值必须是字符串 """使用工具异步获取网页内容。""" if not FIRECRAWL_LOADER_AVAILABLE: return "Error: FireCrawlLoader is not available." key = self.api_key or os.getenv('FIRECRAWL_API_KEY') if not key: return "Error: FIRECRAWL_API_KEY not found." print(f"DEBUG [FireCrawlTool]: Running async for URL='{url}', Mode='{mode or self.default_mode}'") try: current_mode = mode or self.default_mode loader = FireCrawlLoader( url=url, api_key=key, api_url=self.api_url, mode=current_mode, params=self.default_params, ) print(f"--- Calling FireCrawl API (Async) for: '{url}' ---") # 使用 aload 进行异步加载 docs = await loader.aload() print(f"--- FireCrawl API call successful for: '{url}', received {len(docs)} document(s) ---") # --- 格式化结果为字符串 (与 _run 逻辑相同) --- if not docs: return f"FireCrawl successful but returned no content from {url} (Mode: {current_mode})." summary_parts = [f"Content summary from {url} (Mode: {current_mode}):"] content_limit = 4000; current_length = len(summary_parts[0]); doc_count = 0 for doc in docs: # if doc_count >= 1 and current_mode == 'scrape': break source_info = f"\n\n--- Source: {doc.metadata.get('sourceURL', url)} ---" page_content = doc.page_content or "" available_length = content_limit - current_length - len(source_info) - 20 if available_length <= 0 and doc_count > 0: summary_parts.append("\n\n... (further content truncated)"); break content = source_info + "\n" + page_content if len(content) > available_length: content = content[:available_length] + "... (truncated)" summary_parts.append(content); current_length += len(content); doc_count += 1 if current_length >= content_limit: break return "\n".join(summary_parts).strip() # --- 格式化结束 --- except Exception as e: error_msg = f"Error during Async FireCrawl for {url} (Mode: {mode or self.default_mode}): {repr(e)}" print(f"ERROR: {error_msg}") return error_msg # Pydantic V2: 允许额外的私有属性 model_config = { "arbitrary_types_allowed": True } ================================================ FILE: core/tools/registry.py ================================================ from enum import Enum from typing import List, Dict, Union, Optional from langchain.tools import Tool # 定义工具分类枚举 class ToolCategory(Enum): SEARCH = "Search" CODE_INTERPRETER = "Code Interpreter" WEB_BROWSING = "Web Browsing" DATABASE = "Database" FILE_SYSTEM = "FileSystem" IMAGE_GENERATION = "Image Generation" OTHER = "Other" # 全局工具注册表 _registered_tools = {} def register_tool(tool: Tool, category: ToolCategory) -> None: """注册一个工具到全局字典中,带有分类信息 如果工具名已存在,将覆盖现有的工具注册信息 """ if tool.name in _registered_tools: print(f"警告: 工具名 {tool.name} 已存在,将覆盖现有注册信息") _registered_tools[tool.name] = { "tool": tool, "category": category } def get_registered_tools(as_dict: bool = False) -> Union[List[Tool], Dict[str, Dict]]: """返回所有已注册的工具 Args: as_dict: 如果为True,返回原始字典格式;如果为False,返回工具列表 Returns: 如果as_dict为True,返回原始字典格式;否则返回工具列表 """ if as_dict: return _registered_tools return [info["tool"] for info in _registered_tools.values()] def get_tools_list() -> List[Tool]: """返回所有已注册的工具列表,直接可用于Agent初始化 Returns: 所有已注册工具的列表 """ return [info["tool"] for info in _registered_tools.values()] def get_tools_dict() -> Dict[str, Tool]: """返回工具名称到工具实例的映射字典 Returns: 工具名称到工具实例的映射字典 """ return {name: info["tool"] for name, info in _registered_tools.items()} def get_tool(name: str) -> Optional[Dict]: """根据名称获取工具及其分类 Args: name: 工具名称 Returns: 包含工具和分类的字典,如果工具不存在则返回None """ tool_info = _registered_tools.get(name) if tool_info: return { "tool": tool_info["tool"], "category": tool_info["category"].value } return None def get_tool_instance(name: str) -> Optional[Tool]: """根据名称直接获取工具实例 Args: name: 工具名称 Returns: 工具实例,如果工具不存在则返回None """ tool_info = _registered_tools.get(name) return tool_info["tool"] if tool_info else None def get_tools_by_category(category: ToolCategory, return_instances: bool = True) -> List[Union[str, Tool]]: """返回指定分类的工具列表 Args: category: 工具分类 return_instances: 如果为True,返回工具实例列表;如果为False,返回工具名称列表 Returns: 工具实例列表或工具名称列表 """ if return_instances: return [info["tool"] for name, info in _registered_tools.items() if info["category"] == category] return [name for name, info in _registered_tools.items() if info["category"] == category] ================================================ FILE: core/tools/replicate_flux_tool.py ================================================ # 文件路径: core/tools/replicate_flux_tool.py (或类似) import os import asyncio import json from typing import Dict, Any, Optional, Type, List, Literal from pydantic import BaseModel, Field, PrivateAttr from langchain_core.tools import BaseTool from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) # --- Replicate Client --- try: import replicate REPLICATE_AVAILABLE = True except ImportError: replicate = None # type: ignore REPLICATE_AVAILABLE = False print("Warning: 'replicate' package not installed (pip install replicate). ReplicateFluxImageTool will not work.") # --- Tool Category (可选, 用于 Registry) --- try: from .registry import ToolCategory, register_tool if not hasattr(ToolCategory, 'IMAGE_GENERATION'): ToolCategory.IMAGE_GENERATION = ToolCategory.OTHER category = ToolCategory.IMAGE_GENERATION except ImportError: category = None print("Tool registry not found. Cannot auto-register ReplicateFluxImageTool.") # --- Input Schema based on flux-dev --- class ReplicateFluxToolInput(BaseModel): """Input schema for the Replicate Flux Image Generator Tool.""" prompt: str = Field(description="Required. Detailed text description of the image to be generated.") aspect_ratio: Literal["1:1", "16:9", "21:9", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3", "9:16", "9:21"] = Field( default="1:1", description="Aspect ratio for the generated image." ) num_outputs: int = Field( default=1, description="Number of images to generate (1-4).", ge=1, le=4 ) guidance: float = Field( default=3.0, description="Guidance scale (0-10).", ge=0, le=10 ) num_inference_steps: int = Field( default=28, description="Number of denoising steps (1-50). Lower is faster, lower quality.", ge=1, le=50 ) seed: Optional[int] = Field(default=None, description="Random seed for reproducible generation.") # Add other relevant fields from the schema if needed, e.g., megapixels, output_format # megapixels: Literal["1", "0.25"] = Field(default="1", description="Approximate megapixels for output.") # output_format: Literal["webp", "jpg", "png"] = Field(default="webp", description="Output image format.") # --- Tool Class (修正返回值处理) --- class ReplicateFluxImageTool(BaseTool): """Generates images using 'black-forest-labs/flux-dev' on Replicate.""" name: str = "replicate_flux_image_generator" description: str = ( "Generates high-quality images based on a detailed text prompt using the Flux model on Replicate. " "Specify 'prompt' and optionally other parameters like 'aspect_ratio'. " "Returns a string containing the URL(s) of the generated image(s)." ) args_schema: Type[BaseModel] = ReplicateFluxToolInput _client: Any = PrivateAttr(default=None) _is_available: bool = PrivateAttr(default=False) _init_error: Optional[str] = PrivateAttr(default=None) model_identifier: str = "black-forest-labs/flux-dev" def __init__(self, api_token: Optional[str] = None, model_id: Optional[str] = None, **kwargs): """Initialize the Replicate client.""" super().__init__(**kwargs) if not REPLICATE_AVAILABLE: self._init_error = "..."; print(f"ERROR: {self._init_error}"); return token = api_token or os.getenv("REPLICATE_API_TOKEN") if not token: self._init_error = "..."; print(f"ERROR: {self._init_error}"); return try: print("Initializing Replicate client...") self._client = replicate.Client(api_token=token) print("Replicate client initialized successfully.") self._is_available = True; self._init_error = None if model_id: self.model_identifier = model_id except Exception as e: self._init_error = f"...: {e}"; print(f"ERROR: {self._init_error}"); self._is_available = False def _run( self, run_manager: Optional[CallbackManagerForToolRun] = None, **kwargs: Any ) -> str: """Generates image(s) synchronously.""" if not self._is_available or self._client is None: error_message = f"Replicate client unavailable: {self._init_error}" print(f"ERROR: {error_message}"); return f"Error: {error_message}" input_data = {k: v for k, v in kwargs.items() if v is not None and k in self.args_schema.__fields__} prompt_short = str(input_data.get('prompt', ''))[:100] print(f"--- TOOL CALL: {self.name} ---") print(f" Input: Prompt='{prompt_short}...', Args={ {k:v for k,v in input_data.items() if k != 'prompt'} }") try: # output 现在预期是包含特殊对象 (如 FileOutput 或 URL 字符串) 的列表 output: List[Any] = self._client.run(self.model_identifier, input=input_data) if not output or not isinstance(output, list): result_str = "Image generation failed: Replicate API returned no output or unexpected format." print(f" Warning: {result_str}"); return f"Error: {result_str}" # --- 从返回的对象中提取 URL --- image_urls: List[str] = [] for item in output: if isinstance(item, str): # 如果直接返回了 URL 字符串 image_urls.append(item) elif hasattr(item, 'url') and isinstance(getattr(item, 'url'), str): # 检查是否有 .url 属性且是字符串 image_urls.append(getattr(item, 'url')) elif hasattr(item, 'read'): # 如果是文件类对象,可能需要其他处理或报错 print(f"Warning: Received file-like object from Replicate, cannot directly get URL: {item}") # 或者尝试其他属性?这个需要根据 replicate 库的具体 FileOutput 类型确定 else: print(f"Warning: Unknown item type in Replicate output list: {type(item)}") if not image_urls: result_str = "Image generation succeeded but failed to extract image URLs from the response." print(f" Warning: {result_str}"); return f"Error: {result_str}" # --- 提取结束 --- # 格式化 URL 列表为字符串 url_list_str = "\n".join(image_urls) result_str = f"Successfully generated {len(image_urls)} image(s):\n{url_list_str}" print(f" Result: {result_str}") return result_str except Exception as e: # 捕获 Replicate API 错误等 # 检查是否是 ReplicateError 并提取更具体的细节 error_detail = str(e) if REPLICATE_AVAILABLE and isinstance(e, replicate.exceptions.ReplicateError): error_detail = f"ReplicateError (Status: {e.status}): {e.title} - {e.detail}" error_msg = f"Error calling Replicate API ({self.model_identifier}): {error_detail}" print(f" Error: {error_msg}") # traceback.print_exc() # 可以在调试时取消注释 return f"Error: {error_msg}" # 返回错误信息给 LLM async def _arun( self, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, **kwargs: Any ) -> str: """Generates image(s) asynchronously using run_in_executor.""" if not self._is_available or self._client is None: error_message = f"Replicate client unavailable: {self._init_error}" print(f"ERROR: {error_message}"); return f"Error: {error_message}" input_data = {k: v for k, v in kwargs.items() if v is not None and k in self.args_schema.__fields__} prompt_short = str(input_data.get('prompt', ''))[:100] print(f"--- TOOL CALL (Async): {self.name} ---") print(f" Input: Prompt='{prompt_short}...', Args={ {k:v for k,v in input_data.items() if k != 'prompt'} }") try: loop = asyncio.get_running_loop() import functools sync_call_with_args = functools.partial( self._client.run, self.model_identifier, input=input_data ) output: List[Any] = await loop.run_in_executor( None, sync_call_with_args ) if not output or not isinstance(output, list): result_str = "Async image generation failed: Replicate API returned no output or unexpected format." print(f" Warning: {result_str}"); return f"Error: {result_str}" # --- 从返回的对象中提取 URL (逻辑同 _run) --- image_urls: List[str] = [] for item in output: if isinstance(item, str): image_urls.append(item) elif hasattr(item, 'url') and isinstance(getattr(item, 'url'), str): image_urls.append(getattr(item, 'url')) else: print(f"Warning: Unknown item type in async Replicate output list: {type(item)}") if not image_urls: result_str = "Async image generation succeeded but failed to extract image URLs." print(f" Warning: {result_str}"); return f"Error: {result_str}" # --- 提取结束 --- url_list_str = "\n".join(image_urls) result_str = f"Successfully generated {len(image_urls)} image(s) asynchronously:\n{url_list_str}" print(f" Result: {result_str}") return result_str except Exception as e: # 捕获 Replicate API 错误等 error_detail = str(e) if REPLICATE_AVAILABLE and isinstance(e, replicate.exceptions.ReplicateError): error_detail = f"ReplicateError (Status: {e.status}): {e.title} - {e.detail}" error_msg = f"Error calling Replicate API asynchronously ({self.model_identifier}): {error_detail}" print(f" Error: {error_msg}") # traceback.print_exc() return f"Error: {error_msg}" def close(self): """关闭沙箱(如果需要的话)。Replicate Client 通常不需要关闭。""" print(f"Info: Replicate client for '{self.name}' does not require explicit closing.") pass # Replicate client 通常不需要显式关闭 model_config = {"arbitrary_types_allowed": True} ================================================ FILE: core/utils/agent_utils.py ================================================ import os from typing import Dict, Any, Optional, Literal from langchain_core.messages import AIMessage, ToolMessage import inspect def log_agent_actions(state: Dict[str, Any]) -> None: """记录Agent的思考过程和行动 这个函数用于在控制台打印Agent的思考过程、工具调用和工具返回结果, 便于观察和调试Agent的行为。 Args: state: 包含消息历史的状态字典 """ print("\n" + "=" * 50) print("当前状态:") # 打印最新消息 if state.get("messages") and len(state["messages"]) > 0: latest_message = state["messages"][-1] if isinstance(latest_message, AIMessage): print(f"\nAI思考过程:") print(latest_message.content) # 如果有工具调用,打印工具调用信息 if latest_message.tool_calls: print(f"\n工具调用:") for tool_call in latest_message.tool_calls: print(f"- 工具: {tool_call['name']}") print(f"- 参数: {tool_call['args']}") elif isinstance(latest_message, ToolMessage): print(f"\n工具返回结果:") print(f"- 工具: {latest_message.name}") # 只打印结果的前500个字符,避免输出过长 content = latest_message.content if len(content) > 500: content = content[:500] + "... (更多内容省略)" print(f"- 结果: {content}") print("=" * 50) def save_agent_graph( agent, caller_file_path: Optional[str] = None, output_format: Literal["png", "svg", "mermaid"] = "png", custom_filename: Optional[str] = None, output_dir: Optional[str] = None ) -> str: """保存Agent的图表到指定目录 这个函数用于生成Agent的图表并保存到指定目录, 默认情况下文件名与调用者的文件名保持一致(不含扩展名)。 Args: agent: Agent对象,必须有get_graph方法 caller_file_path: 调用者的文件路径,如果为None则使用调用栈获取 output_format: 输出格式,可选"png"、"svg"或"mermaid" custom_filename: 自定义文件名(不含扩展名),如果提供则使用此名称 output_dir: 自定义输出目录,如果提供则使用此目录 Returns: str: 保存的图表路径 """ # 如果没有提供调用者文件路径,则从调用栈获取 if caller_file_path is None: # 获取调用者的栈帧 frame = inspect.currentframe().f_back caller_file_path = frame.f_code.co_filename try: # 获取图对象 graph = agent.get_graph() except AttributeError: raise ValueError("提供的agent对象没有get_graph方法") except Exception as e: raise RuntimeError(f"获取图表时出错: {str(e)}") # 确定文件名 if custom_filename: file_name_without_ext = custom_filename else: # 获取当前文件名(不含路径和扩展名) current_file = os.path.basename(caller_file_path) file_name_without_ext = os.path.splitext(current_file)[0] # 确定输出目录 if output_dir: graph_dir = output_dir else: # 如果调用者在examples目录下,则使用examples/graphs # 否则在调用者所在目录创建graphs子目录 if 'examples' in caller_file_path: base_dir = os.path.dirname(os.path.dirname(caller_file_path)) graph_dir = os.path.join(base_dir, "examples", "graphs") else: graph_dir = os.path.join(os.path.dirname(caller_file_path), "graphs") # 确保graphs目录存在 os.makedirs(graph_dir, exist_ok=True) # 根据输出格式生成相应文件 try: if output_format == "png": image_data = graph.draw_mermaid_png() graph_path = os.path.join(graph_dir, f"{file_name_without_ext}.png") with open(graph_path, "wb") as f: f.write(image_data) elif output_format == "svg": image_data = graph.draw_mermaid_svg() graph_path = os.path.join(graph_dir, f"{file_name_without_ext}.svg") with open(graph_path, "wb") as f: f.write(image_data) elif output_format == "mermaid": mermaid_code = graph.get_mermaid() graph_path = os.path.join(graph_dir, f"{file_name_without_ext}.mmd") with open(graph_path, "w") as f: f.write(mermaid_code) else: raise ValueError(f"不支持的输出格式: {output_format}") except Exception as e: raise RuntimeError(f"保存图表时出错: {str(e)}") print(f"图表已保存为 {graph_path}") return graph_path def visualize_agent(agent, **kwargs): """可视化Agent的快捷方法 这是save_agent_graph的简便包装,用于快速可视化Agent Args: agent: Agent对象 **kwargs: 传递给save_agent_graph的其他参数 Returns: str: 保存的图表路径 """ # 获取调用者的栈帧 frame = inspect.currentframe().f_back caller_file_path = frame.f_code.co_filename return save_agent_graph(agent, caller_file_path=caller_file_path, **kwargs) ================================================ FILE: core/utils/timezone.py ================================================ from datetime import datetime import os from typing import Optional from zoneinfo import ZoneInfo def get_timezone() -> str: """Get timezone from environment variable or use default. Returns: str: Timezone string (e.g. 'Asia/Shanghai') """ return os.getenv('TZ', 'UTC') def get_formatted_date(timezone: Optional[str] = None) -> str: """Get formatted date string with timezone awareness. Args: timezone: Optional timezone string. If not provided, uses TZ from env or UTC. Returns: str: Formatted date string (e.g. 'Today's Date: Mon, Jan 01, 2024') """ tz = ZoneInfo(timezone or get_timezone()) now = datetime.now(tz) return f"Today's Date: {now.strftime('%a, %b %d, %Y')}" def get_current_time(timezone: Optional[str] = None) -> datetime: """Get current time with timezone awareness. Args: timezone: Optional timezone string. If not provided, uses TZ from env or UTC. Returns: datetime: Current time with timezone information """ tz = ZoneInfo(timezone or get_timezone()) return datetime.now(tz) ================================================ FILE: examples/01_supervisor_test.py ================================================ from langgraph.prebuilt import create_react_agent from core.agents.supervisor import create_supervisor from langchain_openai import ChatOpenAI from langgraph.func import entrypoint, task from langgraph.graph import add_messages from dotenv import load_dotenv from core.utils.agent_utils import visualize_agent load_dotenv() # 自动加载 .env 文件 # 1. 初始化大模型 model = ChatOpenAI(model="gpt-4o-mini") ############################################################################## # Agent 1: Joke Generator (Functional API) ############################################################################## @task def generate_joke(messages): """Generate a short joke (no tool calls).""" system_message = { "role": "system", "content": "You are a witty comedian. Write a short joke." } # 直接调用 model.invoke,拼接 system_message + 用户消息 msg = model.invoke([system_message] + messages) return msg @entrypoint() def joke_agent(state): # 调用上面的函数型任务 joke = generate_joke(state['messages']).result() # 将产物插入消息列表 messages = add_messages(state["messages"], [joke]) return {"messages": messages} joke_agent.name = "joke_agent" ############################################################################## # Agent 2: Research Expert (Graph API) ############################################################################## def web_search(query: str) -> str: """Search the web for information. (Mocked data here)""" return ( "Here are the headcounts for each of the FAANG companies in 2024:\n" "1. **Facebook (Meta)**: 67,317 employees.\n" "2. **Apple**: 164,000 employees.\n" "3. **Amazon**: 1,551,000 employees.\n" "4. **Netflix**: 14,000 employees.\n" "5. **Google (Alphabet)**: 181,269 employees." ) research_agent = create_react_agent( model=model, tools=[web_search], name="research_expert", # Prompt 告诉它是一个研究型 Agent,可调用 web_search prompt=( "You are a world-class researcher. You have access to a 'web_search(query: str)' tool. " "Do not do any complicated math, just provide factual info from the web_search if needed." ), ) ############################################################################## # Supervisor Workflow ############################################################################## # 让 Supervisor 在一次对话中可以多轮调用 joke_agent 和 research_expert # 这里的 prompt 告诉它:如果用户要“先讲笑话再查信息”,请先调用 joke_agent,再调用 research_expert, # 这样可以在同一个用户请求下顺序执行两个 Agent。 # 这是最简单的示例,只是为了演示 create_supervisor 的基本用法,该方法没有被封装成一个 Agent # 也不具备 Planning 能力 workflow = create_supervisor( [research_agent, joke_agent], model=model, prompt=( "You are the overall supervisor. You manage two specialized agents:\n" "1) joke_agent: for telling jokes.\n" "2) research_expert: for factual or data-related questions.\n\n" "If the user wants a joke AND some research data in the same query, " "you MUST call joke_agent first, get the joke, then call research_expert for the data. " "After both calls, provide a final combined response. " "Do not call more than one agent in a single LLM message; do it step by step." ), ) # 编译得到一个可调用的"App" agent = workflow.compile() # 保存为一个可视化的图 # visualize_agent(agent) ############################################################################## # 测试:单个用户请求想要 "先讲笑话,再查Apple的2024年人数" 并合并结果 ############################################################################## result = agent.invoke({ "messages": [ { "role": "user", "content": ( "Hi! I'd like to start with a short joke to lighten the mood, " "then please check Apple's headcount in 2024. Summarize both." ) } ] }) ############################################################################## # 打印最终对话消息 ############################################################################## for m in result["messages"]: m.pretty_print() ================================================ FILE: examples/02_supervisor_agent_test.py ================================================ from langgraph.prebuilt import create_react_agent from core.agents.base.react_agent import ReactAgent from core.agents.react_supervisor_agent import SupervisorAgent from langchain_openai import ChatOpenAI from langgraph.func import entrypoint, task from langgraph.graph import add_messages from dotenv import load_dotenv load_dotenv() # 自动加载 .env 文件 # 1. 初始化大模型 model = ChatOpenAI(model="gpt-4o-mini") ############################################################################## # Agent 1: Joke Generator (Functional API) ############################################################################## @task def generate_joke(messages): """Generate a short joke (no tool calls).""" system_message = { "role": "system", "content": "You are a witty comedian. Write a short joke." } # 直接调用 model.invoke,拼接 system_message + 用户消息 msg = model.invoke([system_message] + messages) return msg @entrypoint() def joke_agent(state): # 调用上面的函数型任务 joke = generate_joke(state['messages']).result() # 将产物插入消息列表 messages = add_messages(state["messages"], [joke]) return {"messages": messages} joke_agent.name = "joke_agent" ############################################################################## # Agent 2: Research Expert (Graph API) ############################################################################## def web_search(query: str) -> str: """Search the web for information. (Mocked data here)""" return ( "Here are the headcounts for each of the FAANG companies in 2024:\n" "1. **Facebook (Meta)**: 67,317 employees.\n" "2. **Apple**: 164,000 employees.\n" "3. **Amazon**: 1,551,000 employees.\n" "4. **Netflix**: 14,000 employees.\n" "5. **Google (Alphabet)**: 181,269 employees." ) # research_agent = create_react_agent( # model=model, # tools=[web_search], # name="research_expert", # # Prompt 告诉它是一个研究型 Agent,可调用 web_search # prompt=( # "You are a world-class researcher. You have access to a 'web_search(query: str)' tool. " # "Do not do any complicated math, just provide factual info from the web_search if needed." # ), # ) research_agent = ReactAgent( model=model, tools=[web_search], name="research_expert", # Prompt 告诉它是一个研究型 Agent,可调用 web_search prompt=( "You are a world-class researcher. You have access to a 'web_search(query: str)' tool. " "Do not do any complicated math, just provide factual info from the web_search if needed." ), ) ############################################################################## # 使用 SupervisorAgent 类替代直接调用 create_supervisor 函数 ############################################################################## # 创建 SupervisorAgent 实例 supervisor = SupervisorAgent( agents=[research_agent], model=model, # prompt=( # "You are the overall supervisor. You manage two specialized agents:\n" # "1) joke_agent: for telling jokes.\n" # "2) research_expert: for factual or data-related questions.\n\n" # "If the user wants a joke AND some research data in the same query, " # "you MUST call joke_agent first, get the joke, then call research_expert for the data. " # "After both calls, provide a final combined response. " # "Do not call more than one agent in a single LLM message; do it step by step." # ), ) ############################################################################## # 测试:单个用户请求想要 "先讲笑话,再查Apple的2024年人数" 并合并结果 ############################################################################## result = supervisor.invoke({ "messages": [ { "role": "user", "content": ( "Hi! I'd like to start with a short joke to lighten the mood, " "then please check Apple's headcount in 2024. Summarize both." ) } ] }) ############################################################################## # 打印最终对话消息 ############################################################################## for m in result["messages"]: m.pretty_print() ================================================ FILE: examples/03_tavily_tools_test.py ================================================ import os from langgraph.prebuilt import create_react_agent from core.agents.react_supervisor_agent import SupervisorAgent from langchain_openai import ChatOpenAI from langgraph.func import entrypoint, task from langgraph.graph import add_messages from langchain_community.tools import TavilySearchResults from dotenv import load_dotenv load_dotenv() # 自动加载 .env 文件 # 1. 初始化大模型 model = ChatOpenAI(model="gpt-4o-mini") ############################################################################## # Agent 1: Joke Generator (Functional API) ############################################################################## @task def generate_joke(messages): """Generate a short joke (no tool calls).""" system_message = { "role": "system", "content": "You are a witty comedian. Write a short joke." } # 直接调用 model.invoke,拼接 system_message + 用户消息 msg = model.invoke([system_message] + messages) return msg @entrypoint() def joke_agent(state): # 调用上面的函数型任务 joke = generate_joke(state['messages']).result() # 将产物插入消息列表 messages = add_messages(state["messages"], [joke]) return {"messages": messages} joke_agent.name = "joke_agent" ############################################################################## # Agent 2: Research Expert with Tavily Search (Graph API) ############################################################################## # 创建Tavily搜索工具 tavily_search = TavilySearchResults( max_results=3, include_answer=True, include_raw_content=False, include_images=False, search_depth="advanced" ) research_agent = create_react_agent( model=model, tools=[tavily_search], name="research_expert", # Prompt 告诉它是一个研究型 Agent,可调用 tavily_search prompt=( "You are a world-class researcher. You have access to the 'tavily_search_results_json' tool " "which can search the web for real-time information. " "When asked a question, use this tool to find accurate and up-to-date information. " "Summarize the search results in a clear and concise manner. " "Always cite your sources by including the URLs from the search results." ), ) ############################################################################## # 使用 SupervisorAgent 类来协调多个智能体 ############################################################################## # 创建 SupervisorAgent 实例 supervisor = SupervisorAgent( agents=[research_agent, joke_agent], model=model, prompt=( "You are the overall supervisor. You manage two specialized agents:\n" "1) joke_agent: for telling jokes.\n" "2) research_expert: for factual or data-related questions using real-time web search.\n\n" "If the user wants a joke, call joke_agent.\n" "If the user wants factual information or research data, call research_expert.\n" "If the user wants a joke AND some research data in the same query, " "you MUST call joke_agent first, get the joke, then call research_expert for the data. " "After both calls, provide a final combined response. " "Do not call more than one agent in a single LLM message; do it step by step." ), ) # 编译得到一个可调用的"App" app = supervisor.compile() # # 获取当前文件名(不含路径和扩展名) # current_file = os.path.basename(__file__) # file_name_without_ext = os.path.splitext(current_file)[0] # graph_dir = os.path.join(os.path.dirname(__file__), "graphs") # # 确保 graphs 目录存在 # os.makedirs(graph_dir, exist_ok=True) # # 生成与文件名一致的图片名,并保存到 examples/graphs 目录 # image_data = app.get_graph().draw_mermaid_png() # graph_path = os.path.join(graph_dir, f"{file_name_without_ext}.png") # # 保存图片(如果已存在则覆盖) # with open(graph_path, "wb") as f: # f.write(image_data) # print(f"Image saved as {graph_path}") # 使用示例 if __name__ == "__main__": # 示例1:只询问笑话 result1 = app.invoke({"messages": [{"role": "user", "content": "讲个笑话"}]}) print("\n示例1 - 只询问笑话:") for message in result1["messages"]: message.pretty_print() # 示例2:只询问研究数据 result2 = app.invoke({"messages": [{"role": "user", "content": "谁是现任美国总统?"}]}) print("\n示例2 - 只询问研究数据:") for message in result2["messages"]: message.pretty_print() # 示例3:同时询问笑话和研究数据 result3 = app.invoke({"messages": [{"role": "user", "content": "讲个关于人工智能的笑话,然后告诉我什么是大型语言模型"}]}) print("\n示例3 - 同时询问笑话和研究数据:") for message in result3["messages"]: message.pretty_print() ================================================ FILE: examples/04_react_agent_test.py ================================================ import os import json from langgraph.prebuilt import create_react_agent from langchain_openai import ChatOpenAI from langchain_community.tools import TavilySearchResults from typing import Dict, Any from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from dotenv import load_dotenv from core.utils.agent_utils import log_agent_actions, save_agent_graph load_dotenv() # 自动加载 .env 文件 # 初始化大模型 model = ChatOpenAI(model="gpt-4o-mini") ############################################################################## # 创建Tavily搜索工具 - 配置为深度搜索模式 ############################################################################## tavily_search = TavilySearchResults( max_results=3, include_answer=True, include_raw_content=True, # 包含原始内容,便于分析 include_images=False, search_depth="advanced" # 使用高级搜索深度 ) ############################################################################## # 创建REACT Agent - 使用更详细的提示词引导多步思考 ############################################################################## react_agent = create_react_agent( model=model, tools=[tavily_search], name="tesla_research_expert", # 提示词强调分解问题、多步思考和综合信息 prompt=( "你是一位专业的研究分析师,擅长分析复杂问题并提供深入见解。\n" "你有一个强大的工具'tavily_search_results_json'可以搜索网络获取实时信息。\n\n" "当面对复杂问题时,请遵循以下REACT方法论:\n" "1. 分解问题:将复杂问题分解为更小的子问题\n" "2. 制定计划:确定需要搜索哪些信息,以及搜索的顺序\n" "3. 执行搜索:使用tavily_search_results_json工具执行搜索\n" "4. 分析结果:分析搜索结果,确定是否需要进一步搜索\n" "5. 综合信息:将所有搜索结果综合成一个连贯的回答\n\n" "重要提示:\n" "- 不要一次性搜索过于宽泛的问题\n" "- 对于复杂问题,进行多次有针对性的搜索\n" "- 每次搜索后评估结果,决定下一步行动\n" "- 在最终回答中引用来源,包括搜索结果中的URL\n" "- 清晰地展示你的思考过程,包括问题分解和计划制定\n" ), ) # 保存Agent图表 # save_agent_graph(react_agent) ############################################################################## # 测试:查询"特斯拉2025年的发展预期" ############################################################################## if __name__ == "__main__": # 复杂查询测试 print("\n开始测试REACT Agent处理复杂查询的能力...\n") print("查询: 特斯拉2025年的发展预期") # 定义输入 inputs = { "messages": [ {"role": "user", "content": "分析特斯拉2025年的发展预期,包括新车型计划、销量目标、技术创新和市场扩张战略。"} ] } # 使用stream方法逐步获取中间状态 final_state = None for partial_state in react_agent.stream(inputs, stream_mode="values"): # 保存最终状态 final_state = partial_state # 获取消息列表 messages = partial_state.get("messages", []) if not messages: continue # 获取最新消息 latest_message = messages[-1] # 使用原有的log_agent_actions函数记录状态 log_agent_actions({"messages": [latest_message]}) # 打印最终回答 print("\n最终回答:") if final_state and final_state.get("messages"): for message in final_state["messages"]: if isinstance(message, AIMessage) and not message.tool_calls: message.pretty_print() ================================================ FILE: examples/05_react_agent_user_input.py ================================================ import asyncio import os from typing import Dict, Any from langchain_openai import ChatOpenAI from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from core.agents.base.react_agent import ReactAgent from langchain_community.tools import TavilySearchResults from dotenv import load_dotenv load_dotenv() # 自动加载 .env 文件 # 初始化大模型 model = ChatOpenAI(model="gpt-4o-mini") ############################################################################## # 创建一个记录Agent思考过程的函数 ############################################################################## def log_agent_actions(state: Dict[str, Any]) -> None: """记录Agent的思考过程和行动""" print("\n" + "=" * 50) print("当前状态:") # 打印最新消息 if state.get("messages") and len(state["messages"]) > 0: latest_message = state["messages"][-1] if isinstance(latest_message, AIMessage): print(f"\nAI思考过程:") print(latest_message.content) # 如果有工具调用,打印工具调用信息 if latest_message.tool_calls: print(f"\n工具调用:") for tool_call in latest_message.tool_calls: print(f"- 工具: {tool_call['name']}") print(f"- 参数: {tool_call['args']}") elif isinstance(latest_message, ToolMessage): print(f"\n工具返回结果:") print(f"- 工具: {latest_message.name}") # 只打印结果的前200个字符,避免输出过长 content = latest_message.content if len(content) > 200: content = content[:200] + "... (更多内容省略)" print(f"- 结果: {content}") print("=" * 50) ############################################################################## # 创建Tavily搜索工具 - 配置为深度搜索模式 ############################################################################## tavily_search = TavilySearchResults( max_results=3, include_answer=True, include_raw_content=True, # 包含原始内容,便于分析 include_images=False, search_depth="advanced" # 使用高级搜索深度 ) ############################################################################## # 创建ReactAgent实例 ############################################################################## def create_react_agent_instance(): """创建并返回ReactAgent实例""" react_agent = ReactAgent( model=model, tools=[tavily_search], name="research_assistant", # 提示词强调分解问题、多步思考和综合信息 prompt=( "你是一位专业的研究分析师,擅长分析复杂问题并提供深入见解。\n" "你有一个强大的工具'tavily_search_results_json'可以搜索网络获取实时信息。\n\n" "当面对复杂问题时,请遵循以下REACT方法论:\n" "1. 分解问题:将复杂问题分解为更小的子问题\n" "2. 制定计划:确定需要搜索哪些信息,以及搜索的顺序\n" "3. 执行搜索:使用tavily_search_results_json工具执行搜索\n" "4. 分析结果:分析搜索结果,确定是否需要进一步搜索\n" "5. 综合信息:将所有搜索结果综合成一个连贯的回答\n\n" "重要提示:\n" "- 不要一次性搜索过于宽泛的问题\n" "- 对于复杂问题,进行多次有针对性的搜索\n" "- 每次搜索后评估结果,决定下一步行动\n" "- 在最终回答中引用来源,包括搜索结果中的URL\n" "- 清晰地展示你的思考过程,包括问题分解和计划制定\n" ), ) # 获取图对象并保存 agent = react_agent.compile() return agent ############################################################################## # 主函数 - 处理用户输入 ############################################################################## async def main(): # 创建ReactAgent实例 react_agent = create_react_agent_instance() while True: # 获取用户输入 user_input = await asyncio.to_thread(input, "\n请输入您想了解的问题 (输入'退出'结束): ") # 检查是否退出 if user_input.lower() in ['退出', 'exit', 'quit']: print("\n感谢使用,再见!") break # 准备初始状态 initial_state = { "messages": [HumanMessage(content=user_input)] } try: print("\n=== 🔍 开始研究 ===\n") # 使用stream方法逐步获取中间状态 final_state = None for partial_state in react_agent.stream(initial_state, stream_mode="values"): # 保存最终状态 final_state = partial_state # 获取消息列表 messages = partial_state.get("messages", []) if not messages: continue # 获取最新消息 latest_message = messages[-1] # 使用log_agent_actions函数记录状态 log_agent_actions({"messages": [latest_message]}) # 打印最终回答 print("\n最终回答:") if final_state and final_state.get("messages"): for message in final_state["messages"]: if isinstance(message, AIMessage) and not message.tool_calls: print("\n" + "=" * 80) print(message.content) print("=" * 80 + "\n") except Exception as e: print(f"\n处理查询时出错: {e}") ############################################################################## # 程序入口 ############################################################################## if __name__ == "__main__": print("\n欢迎使用ReactAgent研究助手!") print("这个助手可以帮助您研究各种问题,使用Tavily搜索工具获取最新信息。") print("您可以输入任何问题,助手将使用REACT方法论进行分析和回答。") # 运行主函数 asyncio.run(main()) ================================================ FILE: examples/06_web_extraction_tools_test.py ================================================ import os import sys from langgraph.prebuilt import create_react_agent from langchain_openai import ChatOpenAI import json from typing import Dict, Any from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from dotenv import load_dotenv from langchain_community.tools import JinaSearch from core.tools.firecrawl_tool import FireCrawlTool load_dotenv() # 自动加载 .env 文件 # 初始化大模型 model = ChatOpenAI(model="gpt-4o-mini") ############################################################################## # 创建一个记录Agent思考过程的函数 ############################################################################## def log_agent_actions(state: Dict[str, Any]) -> None: """记录Agent的思考过程和行动""" print("\n" + "=" * 50) print("当前状态:") # 打印最新消息 if state.get("messages") and len(state["messages"]) > 0: latest_message = state["messages"][-1] if isinstance(latest_message, AIMessage): print(f"\nAI思考过程:") print(latest_message.content) # 如果有工具调用,打印工具调用信息 if latest_message.tool_calls: print(f"\n工具调用:") for tool_call in latest_message.tool_calls: print(f"- 工具: {tool_call['name']}") print(f"- 参数: {tool_call['args']}") elif isinstance(latest_message, ToolMessage): print(f"\n工具返回结果:") print(f"- 工具: {latest_message.name}") # 只打印结果的前200个字符,避免输出过长 content = latest_message.content if len(content) > 300: content = content[:300] + "... (更多内容省略)" print(f"- 结果: {content}") print("=" * 50) ############################################################################## # 创建Web提取工具 - FireCrawl用于网站结构,Jina用于内容提取 ############################################################################## # 创建FireCrawl工具 - 用于网站结构分析 firecrawl_tool = FireCrawlTool( mode="crawl", # 使用爬取模式 params={"max_pages": 10} # 限制爬取页面数量 ) # 创建Jina Reader工具 - 用于内容提取 jina_reader_tool = JinaSearch() ############################################################################## # 创建REACT Agent - 使用更详细的提示词引导多步思考 ############################################################################## react_agent = create_react_agent( model=model, tools=[firecrawl_tool, jina_reader_tool], name="web_extraction_expert", # 提示词强调分解问题、多步思考和综合信息 prompt=( "你是一位专业的网页内容分析专家,擅长提取和分析网站结构与内容。\n" "你有两个强大的工具:\n" "1. 'firecrawl_tool': 用于爬取网站结构和下级页面\n" "2. 'jina_reader_tool': 用于从特定URL提取结构化内容,获取干净可读的内容\n\n" "当面对网站分析任务时,请遵循以下方法论:\n" "1. 分析任务: 明确需要从网站获取什么信息\n" "2. 网站结构分析: 使用firecrawl_tool爬取网站结构,了解可用页面\n" "3. 内容提取: 根据网站结构,使用jina_reader_tool从关键页面提取内容\n" "4. 信息整合: 将提取的内容整合成有条理的分析结果\n\n" "重要提示:\n" "- 先使用firecrawl_tool了解网站结构,再使用jina_reader_tool提取具体内容\n" "- 对于大型网站,先分析网站结构,再有针对性地选择重要页面进行内容提取\n" "- 每次工具使用后评估结果,决定下一步行动\n" "- 在最终回答中提供结构化的分析,包括网站组织方式和关键内容摘要\n" "- 清晰地展示你的思考过程,包括为什么选择特定页面进行分析\n" ), ) ############################################################################## # 测试:分析LangGraph文档网站 ############################################################################## if __name__ == "__main__": # 测试网站分析 print("\n开始测试Web提取Agent分析网站的能力...\n") print("分析目标: LangGraph文档网站") # 定义输入 inputs = { "messages": [ {"role": "user", "content": "爬取LangGraph文档网站的每个章节的内容(https://langchain-ai.github.io/langgraph/how-tos/) "} ] } # 使用stream方法逐步获取中间状态 final_state = None for partial_state in react_agent.stream(inputs, stream_mode="values"): # 保存最终状态 final_state = partial_state # 获取消息列表 messages = partial_state.get("messages", []) if not messages: continue # 获取最新消息 latest_message = messages[-1] # 使用原有的log_agent_actions函数记录状态 log_agent_actions({"messages": [latest_message]}) # 打印最终回答 print("\n最终分析结果:") if final_state and final_state.get("messages"): for message in final_state["messages"]: if isinstance(message, AIMessage) and not message.tool_calls: message.pretty_print() ================================================ FILE: examples/07_web_extraction_with_filesystem.py ================================================ import os import sys import json import asyncio from datetime import datetime from typing import Dict, Any, List from langchain_openai import ChatOpenAI from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_community.agent_toolkits import FileManagementToolkit from langgraph.prebuilt import create_react_agent from langgraph.checkpoint.memory import MemorySaver from dotenv import load_dotenv from langchain_community.tools import TavilySearchResults from core.agents.react_supervisor_agent import SupervisorAgent load_dotenv() # 自动加载 .env 文件 # 初始化大模型 model = ChatOpenAI(model="gpt-4o-mini") ############################################################################## # 创建一个记录Agent思考过程的函数 ############################################################################## def log_agent_actions(state: Dict[str, Any]) -> None: """记录Agent的思考过程和行动""" print("\n" + "=" * 50) print("当前状态:") # 打印最新消息 if state.get("messages") and len(state["messages"]) > 0: latest_message = state["messages"][-1] if isinstance(latest_message, AIMessage): print(f"\nAI思考过程:") # 限制内容长度,避免过长输出 content = latest_message.content if len(content) > 500: content = content[:250] + "\n... (内容过长,已截断) ...\n" + content[-250:] print(content) # 如果有工具调用,打印工具调用信息 if latest_message.tool_calls: print(f"\n工具调用:") for tool_call in latest_message.tool_calls: print(f"- 工具: {tool_call['name']}") # 限制参数输出长度 args = str(tool_call['args']) if len(args) > 100: args = args[:100] + "... (参数过长,已截断)" print(f"- 参数: {args}") elif isinstance(latest_message, ToolMessage): print(f"\n工具返回结果:") print(f"- 工具: {latest_message.name}") # 只打印结果的前200个字符,避免输出过长 content = latest_message.content if len(content) > 200: content = content[:100] + "\n... (更多内容省略) ...\n" + content[-100:] print(f"- 结果: {content}") print("=" * 50) ############################################################################## # 创建Web提取工具 ############################################################################## # 创建Tavily搜索工具 tavily_search = TavilySearchResults( max_results=3, include_answer=True, include_raw_content=False, include_images=False, search_depth="advanced" ) ############################################################################## # 创建文件系统工具 - 用于保存提取的内容 ############################################################################## # 设置文件系统工具的根目录为examples/output output_dir = os.path.join(os.path.dirname(__file__), "output") os.makedirs(output_dir, exist_ok=True) # 创建文件系统工具集 filesystem_toolkit = FileManagementToolkit( root_dir=output_dir, selected_tools=["write_file", "read_file", "list_directory"] ) # 获取文件系统工具 filesystem_tools = filesystem_toolkit.get_tools() ############################################################################## # 创建Research Agent - 用于网站内容提取 ############################################################################## research_agent = create_react_agent( model=model, tools=[tavily_search], name="research_agent", # 提示词强调分解问题、多步思考和综合信息 prompt=( "You are a world-class researcher. You have access to the 'tavily_search_results_json' tool " "which can search the web for real-time information. " "When asked a question, use this tool to find accurate and up-to-date information. " "Summarize the search results in a clear and concise manner. " "Always cite your sources by including the URLs from the search results." ), debug=False) ############################################################################## # 创建FileSystem Agent - 用于保存提取的内容 ############################################################################## filesystem_agent = create_react_agent( model=model, tools=filesystem_tools, name="filesystem_agent", # 提示词强调文件操作和内容保存 prompt=( "你是一位专业的文件系统管理专家,负责将网页内容保存到本地文件系统。\n" "你有以下工具可以使用:\n" "1. 'write_file': 用于将内容写入文件\n" "2. 'read_file': 用于读取文件内容\n" "3. 'list_directory': 用于列出目录内容\n\n" "当接收到保存内容的请求时,请遵循以下方法论:\n" "1. 分析内容: 确定内容的类型和结构\n" "2. 确定文件名: 根据内容类型和来源创建合适的文件名\n" "3. 保存内容: 使用write_file工具将内容保存到文件中\n" "4. 验证保存: 使用read_file或list_directory工具验证内容已正确保存\n\n" "重要提示:\n" "- 为文件创建有意义的名称,包含日期和内容描述\n" "- 对于结构化数据,优先使用JSON格式保存\n" "- 对于文本内容,使用TXT或MD格式保存\n" "- 确保文件名不包含非法字符\n" "- 在保存前,检查是否已存在同名文件,避免覆盖重要内容\n" ), ) ############################################################################## # 创建Supervisor Agent - 协调Research Agent和FileSystem Agent ############################################################################## # 创建内存存储器用于保存对话状态 memory_saver = MemorySaver() supervisor = SupervisorAgent( agents=[research_agent, filesystem_agent], model=model, prompt=( "你是一个智能助手的总协调者,负责管理两个专业智能体:\n" "1) research_agent: 网页内容分析专家,可以爬取和分析网站内容\n" "2) filesystem_agent: 文件系统管理专家,可以将内容保存到本地文件系统\n\n" "你的工作流程如下:\n" "1. 分析用户请求,确定是需要网页内容提取还是文件操作,或两者都需要\n" "2. 如果需要网页内容提取,调用research_agent获取网页内容\n" "3. 如果需要将提取的内容保存到文件,调用filesystem_agent进行保存\n" "4. 如果用户同时需要提取内容并保存,先调用research_agent获取内容,再调用filesystem_agent保存内容\n\n" "重要规则:\n" "- 不要在一个消息中同时调用多个智能体,必须一步一步来\n" "- 当调用filesystem_agent保存内容时,必须提供完整的内容和建议的文件名\n" "- 确保在最终回复中告知用户内容已成功提取和/或保存\n" "- 如果用户只想提取内容而不保存,只调用research_agent\n" "- 如果用户只想操作文件而不提取新内容,只调用filesystem_agent\n\n" "上下文管理指南:\n" "- 当处理大型网站或多个页面时,指导research_agent采用分批处理策略\n" "- 对于大型内容提取任务,先让research_agent获取网站结构,再逐步处理各个页面\n" "- 当发现research_agent返回的内容过大时,指导它进行内容摘要或分批处理\n" "- 如果research_agent一次性尝试处理过多页面导致上下文超限,指导它减少并行处理的页面数量\n" "- 对于需要保存的大型内容,考虑将其分割成多个小文件,而不是一个大文件\n" "- 在处理多页面内容时,可以采用先保存再处理的策略,减轻上下文负担\n" ), checkpointer=memory_saver ) # 编译得到一个可调用的"App",添加checkpointer实现记忆功能 app = supervisor.compile() # # 获取当前文件名(不含路径和扩展名) # current_file = os.path.basename(__file__) # file_name_without_ext = os.path.splitext(current_file)[0] # graph_dir = os.path.join(os.path.dirname(__file__), "graphs") # # 确保 graphs 目录存在 # os.makedirs(graph_dir, exist_ok=True) # # 生成与文件名一致的图片名,并保存到 examples/graphs 目录 # image_data = app.get_graph().draw_mermaid_png() # graph_path = os.path.join(graph_dir, f"{file_name_without_ext}.png") # # 保存图片(如果已存在则覆盖) # with open(graph_path, "wb") as f: # f.write(image_data) # print(f"图表已保存为 {graph_path}") ############################################################################## # 主函数 - 处理用户输入 ############################################################################## async def main(): # 创建一个固定的thread_id用于保持对话上下文 thread_id = "user_session_1" # 创建配置对象,包含thread_id config = {"configurable": {"thread_id": thread_id}} print("\n当前会话ID:", thread_id) print("(所有对话将在同一会话中保持上下文记忆)") while True: # 获取用户输入 user_input = await asyncio.to_thread(input, "\n请输入您想了解的问题 (输入'退出'结束): ") # 检查是否退出 if user_input.lower() in ['退出', 'exit', 'quit']: print("\n感谢使用,再见!") break # 准备初始状态 - 只包含当前用户消息 initial_state = { "messages": [HumanMessage(content=user_input)] } try: print("\n=== 🔍 开始研究 ===\n") # 使用stream方法逐步获取中间状态,传入config以使用相同的thread_id for partial_state in app.stream(initial_state, config, stream_mode="values"): # 保存最终状态 final_state = partial_state # 获取消息列表 messages = partial_state.get("messages", []) if not messages: continue # 获取最新消息 latest_message = messages[-1] # 使用log_agent_actions函数记录状态 log_agent_actions({"messages": [latest_message]}) except Exception as e: print(f"\n处理查询时出错: {e}") print("可能是由于上下文长度超出限制,请尝试减少查询范围或使用'批处理大小设置为X'命令调整批处理大小(1-5之间)") ############################################################################## # 程序入口 ############################################################################## if __name__ == "__main__": print("\n欢迎使用具有记忆功能的网页爬取助手!") print("本助手可以记住您之前的对话内容,实现连续对话体验。") print("您可以询问之前提到过的内容,助手会根据上下文理解您的问题。") # 运行主函数 asyncio.run(main()) ================================================ FILE: examples/08_react_agent_tool_registry_test.py ================================================ import os import sys import json from typing import Dict, Any, List from langchain_openai import ChatOpenAI from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_community.tools import JinaSearch, WikipediaQueryRun from langchain_community.utilities import WikipediaAPIWrapper from dotenv import load_dotenv from core.agents.base.react_agent import ReactAgent from core.tools import register_direct_tool from core.tools.registry import get_registered_tools, ToolCategory from core.tools.firecrawl_tool import FireCrawlTool load_dotenv() # 自动加载 .env 文件 ############################################################################## # 工具注册和ReactAgent测试 - 美联储研究任务 ############################################################################## def print_separator(title): """打印分隔符""" print("\n" + "=" * 80) print(f" {title} ".center(80, "=")) print("=" * 80) ############################################################################## # 创建一个记录Agent思考过程的函数 ############################################################################## def log_agent_actions(state: Dict[str, Any]) -> None: """记录Agent的思考过程和行动""" print("\n" + "-" * 50) print("当前状态:") # 打印最新消息 if state.get("messages") and len(state["messages"]) > 0: latest_message = state["messages"][-1] if isinstance(latest_message, AIMessage): print(f"\nAI思考过程:") print(latest_message.content) # 如果有工具调用,打印工具调用信息 if latest_message.tool_calls: print(f"\n工具调用:") for tool_call in latest_message.tool_calls: print(f"- 工具: {tool_call['name']}") print(f"- 参数: {tool_call['args']}") elif isinstance(latest_message, ToolMessage): print(f"\n工具返回结果:") print(f"- 工具: {latest_message.name}") # 只打印结果的前200个字符,避免输出过长 content = latest_message.content if len(content) > 200: content = content[:200] + "... (更多内容省略)" print(f"- 结果: {content}") print("-" * 50) ############################################################################## # 注册工具 ############################################################################## print_separator("注册搜索工具") # 创建JinaSearch工具实例 jina_search = JinaSearch() # 创建Wikipedia工具实例 # wiki_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()) firecrawl_tool = FireCrawlTool() # 使用register_direct_tool函数注册工具 register_direct_tool(jina_search) register_direct_tool(firecrawl_tool) # 注册自定义工具 - FireCrawlTool # 获取所有已注册的工具(以字典格式) registered_tools = get_registered_tools(as_dict=True) # 打印所有已注册的工具 print("\n已注册的工具:") for name, info in registered_tools.items(): print(f"- {name} (类别: {info['category'].value})") ############################################################################## # 创建ReactAgent实例 ############################################################################## print_separator("创建ReactAgent实例") # 初始化大模型 model = ChatOpenAI(model="gpt-4o-mini") # 从注册表中只获取搜索类工具列表 from core.tools.registry import get_tools_by_category, ToolCategory tools_list = get_tools_by_category(ToolCategory.SEARCH) # 创建ReactAgent实例 react_agent = ReactAgent( model=model, tools=tools_list, name="fed_research_agent", # 提示词强调分解问题、多步思考和综合信息 prompt=( "你是一位专业的经济研究分析师,擅长分析复杂的经济问题并提供深入见解。\n" "你有多个强大的工具可以搜索网络获取实时信息:\n" "当面对复杂问题时,请遵循以下方法论:\n" "1. 分解问题:将复杂问题分解为更小的子问题\n" "2. 制定计划:确定需要搜索哪些信息,以及使用哪些工具\n" "3. 执行搜索:使用适当的工具执行搜索\n" "4. 分析结果:分析搜索结果,确定是否需要进一步搜索\n" "5. 综合信息:将所有搜索结果综合成一个连贯的回答\n\n" "重要提示:\n" "- 每次搜索后评估结果,决定下一步行动\n" "- 在最终回答中引用来源\n" "- 清晰地展示你的思考过程,包括问题分解和计划制定\n" ), ) # agent = react_agent.compile() # 获取图对象 # graph = agent.get_graph() # # 获取当前文件名(不含路径和扩展名) # current_file = os.path.basename(__file__) # file_name_without_ext = os.path.splitext(current_file)[0] # graph_dir = os.path.join(os.path.dirname(__file__), "graphs") # # 确保 graphs 目录存在 # os.makedirs(graph_dir, exist_ok=True) # # 生成与文件名一致的图片名,并保存到 examples/graphs 目录 # image_data = graph.draw_mermaid_png() # graph_path = os.path.join(graph_dir, f"{file_name_without_ext}.png") # # 保存图片(如果已存在则覆盖) # with open(graph_path, "wb") as f: # f.write(image_data) # print(f"工作流图已保存为 {graph_path}") ############################################################################## # 测试:查询"美联储的详细介绍和它如何影响全球经济" ############################################################################## if __name__ == "__main__": print_separator("开始测试ReactAgent处理美联储研究任务") print("\n查询: 美联储的详细介绍和它如何影响全球经济") # 定义输入 inputs = { "messages": [ HumanMessage(content="请提供2025年美联储(Federal Reserve)的详细介绍,包括其历史、结构、职能,以及它如何通过货币政策影响全球经济。") ] } result = react_agent.run(inputs) ############################################################################## # 打印最终对话消息 ############################################################################## for m in result["messages"]: m.pretty_print() ================================================ FILE: examples/09_e2b_code_interpreter_test.py ================================================ import os import sys import json from typing import Dict, Any, List from langchain_openai import ChatOpenAI from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from dotenv import load_dotenv from core.agents.base.react_agent import ReactAgent from core.tools.registry import get_registered_tools, ToolCategory, get_tools_by_category from core.tools.e2b_tool import E2BCodeInterpreterTool load_dotenv() # 自动加载 .env 文件 ############################################################################## # E2B代码解释器工具测试 ############################################################################## def print_separator(title): """打印分隔符""" print("\n" + "=" * 80) print(f" {title} ".center(80, "=")) print("=" * 80) ############################################################################## # 检查E2B代码解释器工具是否已注册 ############################################################################## print_separator("检查E2B代码解释器工具是否已注册") # 获取所有已注册的工具(以字典格式) registered_tools = get_registered_tools(as_dict=True) # 打印所有已注册的工具 print("\n已注册的工具:") for name, info in registered_tools.items(): print(f"- {name} (类别: {info['category'].value})") # 检查E2B代码解释器工具是否已注册 e2b_tool_name = "e2b_code_interpreter" if e2b_tool_name in registered_tools: print(f"\nE2B代码解释器工具已成功注册: {e2b_tool_name}") else: print(f"\n警告: E2B代码解释器工具未注册") # 手动注册E2B代码解释器工具 print("尝试手动注册E2B代码解释器工具...") try: from core.tools.registry import register_tool e2b_tool = E2BCodeInterpreterTool() register_tool(e2b_tool, ToolCategory.CODE_INTERPRETER) print(f"已手动注册工具: {e2b_tool.name}") except Exception as e: print(f"手动注册E2B代码解释器工具失败: {e}") ############################################################################## # 创建ReactAgent实例 ############################################################################## print_separator("创建ReactAgent实例") # 初始化大模型 model = ChatOpenAI(model="gpt-4o-mini") # 从注册表中只获取代码解释器类工具列表 tools_list = get_tools_by_category(ToolCategory.CODE_INTERPRETER) # 打印获取到的代码解释器工具 print("\n获取到的代码解释器工具:") for tool in tools_list: print(f"- {tool.name}: {tool.description}") # 创建ReactAgent实例 react_agent = ReactAgent( model=model, tools=tools_list, name="code_interpreter_agent", # 提示词强调使用代码解释器工具进行数据分析和可视化 prompt=( "你是一位专业的数据分析师和编程助手,擅长使用Python进行数据分析和可视化。\n" "你有多个强大的代码执行工具可以使用:\n" "- e2b_code_interpreter: 用于执行Python代码,支持数据分析和可视化\n" "当面对编程和数据分析问题时,请遵循以下方法论:\n" "1. 分析问题:理解用户的需求和问题本质\n" "2. 制定计划:确定解决方案和需要使用的工具\n" "3. 编写代码:使用适当的工具编写和执行代码\n" "4. 分析结果:解释代码执行结果,提供见解\n" "5. 优化方案:如有必要,优化代码或提供改进建议\n\n" "重要提示:\n" "- 优先使用e2b_code_interpreter工具执行Python代码\n" "- 对于数据分析和可视化任务,确保导入必要的库(如pandas, matplotlib, numpy等)\n" "- 对于不存在的库,工具会自动尝试使用pip install进行安装\n" "- 在代码中添加详细注释,解释关键步骤\n" "- 执行代码后,解释结果含义和见解\n" ), ) # 编译Agent agent = react_agent.compile() # # 获取图对象 # graph = agent.get_graph() # # 获取当前文件名(不含路径和扩展名) # current_file = os.path.basename(__file__) # file_name_without_ext = os.path.splitext(current_file)[0] # graph_dir = os.path.join(os.path.dirname(__file__), "graphs") # # 确保 graphs 目录存在 # os.makedirs(graph_dir, exist_ok=True) # # 生成与文件名一致的图片名,并保存到 examples/graphs 目录 # image_data = graph.draw_mermaid_png() # graph_path = os.path.join(graph_dir, f"{file_name_without_ext}.png") # # 保存图片(如果已存在则覆盖) # with open(graph_path, "wb") as f: # f.write(image_data) # print(f"工作流图已保存为 {graph_path}") ############################################################################## # 测试:使用E2B代码解释器执行简单的数据分析任务 ############################################################################## if __name__ == "__main__": print_separator("开始测试ReactAgent使用E2B代码解释器") print("\n查询: 使用Python生成一个简单的正弦波图形") # 定义输入 inputs = { "messages": [ HumanMessage(content="使用Python生成一个简单的正弦波图形,如果有找不到的模块,需要自动安装") ] } result = agent.invoke(inputs) for m in result["messages"]: m.pretty_print() ================================================ FILE: examples/10_financial_data_analysis.py ================================================ import os import sys import json from typing import Dict, Any, List from langchain_openai import ChatOpenAI from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from dotenv import load_dotenv from core.agents.base.react_agent import ReactAgent from core.tools.registry import get_registered_tools, ToolCategory, get_tools_by_category from core.tools.e2b_tool import E2BCodeInterpreterTool load_dotenv() # 自动加载 .env 文件 ############################################################################## # 财务数据分析报表生成示例 ############################################################################## def print_separator(title): """打印分隔符""" print("\n" + "=" * 80) print(f" {title} ".center(80, "=")) print("=" * 80) ############################################################################## # 创建一个记录Agent思考过程的函数 ############################################################################## def log_agent_actions(state: Dict[str, Any]) -> None: """记录Agent的思考过程和行动""" print("\n" + "-" * 50) print("当前状态:") # 打印最新消息 if state.get("messages") and len(state["messages"]) > 0: latest_message = state["messages"][-1] if isinstance(latest_message, AIMessage): print(f"\nAI思考过程:") print(latest_message.content) # 如果有工具调用,打印工具调用信息 if latest_message.tool_calls: print(f"\n工具调用:") for tool_call in latest_message.tool_calls: print(f"- 工具: {tool_call['name']}") print(f"- 参数: {tool_call['args']}") elif isinstance(latest_message, ToolMessage): print(f"\n工具返回结果:") print(f"- 工具: {latest_message.name}") content = latest_message.content print(f"- 结果: {content}") print("-" * 50) ############################################################################## # 检查E2B代码解释器工具是否已注册 ############################################################################## print_separator("检查E2B代码解释器工具是否已注册") # 获取所有已注册的工具(以字典格式) registered_tools = get_registered_tools(as_dict=True) # 打印所有已注册的工具 print("\n已注册的工具:") for name, info in registered_tools.items(): print(f"- {name} (类别: {info['category'].value})") # 检查E2B代码解释器工具是否已注册 e2b_tool_name = "e2b_code_interpreter" if e2b_tool_name in registered_tools: print(f"\nE2B代码解释器工具已成功注册: {e2b_tool_name}") else: print(f"\n警告: E2B代码解释器工具未注册") # 手动注册E2B代码解释器工具 print("尝试手动注册E2B代码解释器工具...") try: from core.tools.registry import register_tool e2b_tool = E2BCodeInterpreterTool() register_tool(e2b_tool, ToolCategory.CODE_INTERPRETER) print(f"已手动注册工具: {e2b_tool.name}") except Exception as e: print(f"手动注册E2B代码解释器工具失败: {e}") ############################################################################## # 创建ReactAgent实例 ############################################################################## print_separator("创建ReactAgent实例") # 初始化大模型 model = ChatOpenAI(model="gpt-4o-mini") # 从注册表中只获取代码解释器类工具列表 tools_list = get_tools_by_category(ToolCategory.CODE_INTERPRETER) # 打印获取到的代码解释器工具 print("\n获取到的代码解释器工具:") for tool in tools_list: print(f"- {tool.name}: {tool.description}") # 创建ReactAgent实例 react_agent = ReactAgent( model=model, tools=tools_list, name="financial_data_analyst", # 提示词强调使用代码解释器工具进行财务数据分析和可视化 prompt=( "你是一位专业的财务数据分析师,擅长使用Python进行财务数据分析和可视化。\n" "你有强大的代码执行工具可以使用:\n" "- e2b_code_interpreter: 用于执行Python代码,支持数据分析和可视化\n\n" "当面对财务数据分析问题时,请遵循以下方法论:\n" "1. 分析问题:理解用户的需求和问题本质\n" "2. 制定计划:确定解决方案和需要使用的工具\n" "3. 编写代码:使用适当的工具编写和执行代码\n" "4. 分析结果:解释代码执行结果,提供财务见解\n" "5. 优化方案:如有必要,优化代码或提供改进建议\n\n" "重要提示:\n" "- 优先使用e2b_code_interpreter工具执行Python代码\n" "- 对于财务数据分析和可视化任务,确保导入必要的库(如pandas, matplotlib, numpy等)\n" "- 对于不存在的库,工具会自动尝试使用pip install进行安装\n" "- 在代码中添加详细注释,解释关键步骤\n" "- 执行代码后,解释结果含义和财务见解\n" ), ) # # 编译Agent # agent = react_agent.compile() # # 获取图对象 # graph = agent.get_graph() # # 获取当前文件名(不含路径和扩展名) # current_file = os.path.basename(__file__) # file_name_without_ext = os.path.splitext(current_file)[0] # graph_dir = os.path.join(os.path.dirname(__file__), "graphs") # # 确保 graphs 目录存在 # os.makedirs(graph_dir, exist_ok=True) # # 生成与文件名一致的图片名,并保存到 examples/graphs 目录 # image_data = graph.draw_mermaid_png() # graph_path = os.path.join(graph_dir, f"{file_name_without_ext}.png") # # 保存图片(如果已存在则覆盖) # with open(graph_path, "wb") as f: # f.write(image_data) # print(f"工作流图已保存为 {graph_path}") ############################################################################## # 从沙箱下载文件到本地的函数 ############################################################################## import os def download_file_from_sandbox(sandbox, sandbox_path, local_path): """从 e2b 沙箱中下载文件并保存到本地,自动区分文本和二进制文件""" try: print(f"读取文件: {sandbox_path}") # 判断是否为常见二进制文件类型(可自行扩展) binary_extensions = ( '.png', '.jpg', '.jpeg', '.gif', '.pdf', '.svg', '.xlsx', '.xls', '.zip', '.bin', '.pyc', '.pyd', '.pptx', '.docx', '.mp3', '.mp4', '.avi', '.mov', ) is_binary = sandbox_path.lower().endswith(binary_extensions) # 创建目录 os.makedirs(os.path.dirname(local_path), exist_ok=True) if is_binary: print("📦 识别为二进制文件,使用 sandbox.download_file()") content = sandbox.files.read(sandbox_path) # 返回 bytes with open(local_path, 'wb') as f: f.write(content) else: print("📄 识别为文本文件,使用 sandbox.files.read()") content = sandbox.files.read(sandbox_path) # 返回 str with open(local_path, 'w', encoding='utf-8') as f: f.write(content) print(f"✅ 文件已保存到本地: {local_path}") return True except Exception as e: print(f"❌ 下载失败: {e}") return False def download_directory_from_sandbox(sandbox, sandbox_dir_path, local_dir_path): """从沙箱下载整个目录内容到本地 Args: sandbox: 沙箱实例 sandbox_dir_path: 沙箱中的目录路径 local_dir_path: 本地保存目录路径 Returns: bool: 是否成功下载所有文件 """ try: print(f"尝试下载目录: {sandbox_dir_path} -> {local_dir_path}") # 确保本地目录存在 os.makedirs(local_dir_path, exist_ok=True) # 列出沙箱中指定目录下的所有文件 try: files = sandbox.files.list(sandbox_dir_path) # print(f"获取到文件列表: {sandbox_dir_path}, 类型: {type(files)}") # if files and len(files) > 0: # print(f"第一个文件类型: {type(files[0])}, 内容: {files[0]}") # # 检查对象属性 # print(f"文件对象可用属性: {dir(files[0])}") except Exception as e: print(f"列出文件时出错: {sandbox_dir_path}, 错误: {str(e)}") return False if not files: print(f"沙箱中目录 {sandbox_dir_path} 为空或不存在") return False downloaded_count = 0 # 定义需要跳过的系统文件 skip_files = {'.bashrc', '.bash_logout', '.profile'} # 遍历并下载每个文件 for file_info in files: try: # 使用dir()查看对象有哪些属性 print(f"文件信息对象属性: {dir(file_info)}") # 尝试安全获取name和type属性 file_name = getattr(file_info, "name", None) if file_name is None: print(f"警告: 无法获取文件名, 跳过此文件") continue file_type = getattr(file_info, "type", "file") # 默认为文件类型 # 如果 file_type 是枚举, 使用其 value 进行判断 type_value = file_type.value if hasattr(file_type, "value") else file_type # 跳过不需要的系统文件或系统目录(隐藏文件/目录) if file_name in skip_files or (file_name.startswith('.') and type_value == 'dir'): print(f"跳过系统文件或目录: {file_name}") continue print(f"处理文件: {file_name}, 类型: {type_value}") sandbox_file_path = f"{sandbox_dir_path}/{file_name}" local_file_path = os.path.join(local_dir_path, file_name) if type_value == 'dir': # 递归下载子目录 print(f"发现子目录: {sandbox_file_path}") if download_directory_from_sandbox(sandbox, sandbox_file_path, local_file_path): downloaded_count += 1 else: # 下载文件 print(f"下载文件: {sandbox_file_path} -> {local_file_path}") if download_file_from_sandbox(sandbox, sandbox_file_path, local_file_path): downloaded_count += 1 except Exception as e: print(f"处理文件时出错: {str(e)}") import traceback print(f"详细错误跟踪: {traceback.format_exc()}") continue if downloaded_count > 0: print(f"从 {sandbox_dir_path} 下载了 {downloaded_count} 个文件/目录到 {local_dir_path}") return True return False except Exception as e: print(f"从沙箱下载目录时出错: {str(e)}") import traceback print(f"详细错误跟踪: {traceback.format_exc()}") return False ############################################################################## # 测试:使用E2B代码解释器生成财务数据分析报表 ############################################################################## if __name__ == "__main__": print_separator("开始测试ReactAgent使用E2B代码解释器进行财务数据分析") print("\n查询: 生成模拟财务数据并进行分析,生成财务报表") # 定义输入 inputs = { "messages": [ HumanMessage(content="请生成一组模拟的公司财务数据(包括收入、支出、利润等),对数据进行分析,将处理过程(代码)和最终生成的结果保存到本地。") ] } result = react_agent.run(inputs) for m in result["messages"]: m.pretty_print() print("\n下载沙盒里的文件") try: # 遍历 react_agent.tools 以查找 E2B 相关工具 sandbox = None for tool in react_agent.tools: if hasattr(tool, "sandbox"): sandbox = tool.sandbox break # 找到后就退出循环 if sandbox: # 设定输出目录 output_dir = os.path.join(os.getcwd(), "examples/output/sandbox_files") os.makedirs(output_dir, exist_ok=True) # 直接下载主要工作目录 print("\n从沙箱下载文件到本地...") download_directory_from_sandbox(sandbox, "/home/user", output_dir) # 下载临时目录中可能的图表和数据文件 # download_directory_from_sandbox(sandbox, "/tmp", output_dir) print(f"\n文件已保存到目录: {output_dir}") sandbox.close() except Exception as e: print(f"从沙箱下载文件时出错: {str(e)}") ================================================ FILE: examples/11_e2b_sandbox_test.py ================================================ import os import sys import json from typing import Dict, Any, List from datetime import datetime from langchain_openai import ChatOpenAI from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from dotenv import load_dotenv from core.agents.base.react_agent import ReactAgent from core.tools.registry import get_registered_tools, ToolCategory, get_tools_by_category from core.tools.e2b_tool import E2BCodeInterpreterTool load_dotenv() # 自动加载 .env 文件 ############################################################################## # E2B沙盒环境测试程序 ############################################################################## def print_separator(title): """打印分隔符""" print("\n" + "=" * 80) print(f" {title} ".center(80, "=")) print("=" * 80) ############################################################################## # 创建一个记录Agent思考过程的函数 ############################################################################## def log_agent_actions(state: Dict[str, Any]) -> None: """记录Agent的思考过程和行动""" print("\n" + "-" * 50) print("当前状态:") # 打印最新消息 if state.get("messages") and len(state["messages"]) > 0: latest_message = state["messages"][-1] if isinstance(latest_message, AIMessage): print(f"\nAI思考过程:") print(latest_message.content) # 如果有工具调用,打印工具调用信息 if latest_message.tool_calls: print(f"\n工具调用:") for tool_call in latest_message.tool_calls: print(f"- 工具: {tool_call['name']}") print(f"- 参数: {tool_call['args']}") elif isinstance(latest_message, ToolMessage): print(f"\n工具返回结果:") print(f"- 工具: {latest_message.name}") content = latest_message.content if len(content) > 500: content = content[:250] + "\n... (内容过长,已截断) ...\n" + content[-250:] print(f"- 结果: {content}") print("-" * 50) ############################################################################## # 从沙箱下载文件到本地的函数 ############################################################################## def download_file_from_sandbox(sandbox, sandbox_path, local_path): """从 e2b 沙箱中下载文件并保存到本地,自动区分文本和二进制文件""" try: print(f"读取文件: {sandbox_path}") # 判断是否为常见二进制文件类型(可自行扩展) binary_extensions = ( '.png', '.jpg', '.jpeg', '.gif', '.pdf', '.svg', '.xlsx', '.xls', '.zip', '.bin', '.pyc', '.pyd', '.pptx', '.docx', '.mp3', '.mp4', '.avi', '.mov', ) is_binary = sandbox_path.lower().endswith(binary_extensions) # 创建目录 os.makedirs(os.path.dirname(local_path), exist_ok=True) if is_binary: print("📦 识别为二进制文件,使用 sandbox.download_file()") content = sandbox.files.read(sandbox_path) # 返回 bytes with open(local_path, 'wb') as f: f.write(content) else: print("📄 识别为文本文件,使用 sandbox.files.read()") content = sandbox.files.read(sandbox_path) # 返回 str with open(local_path, 'w', encoding='utf-8') as f: f.write(content) print(f"✅ 文件已保存到本地: {local_path}") return True except Exception as e: print(f"❌ 下载失败: {e}") return False def run_ai_generated_code(sandbox, code: str, save_results_dir=None): """在 E2B 沙箱中执行 AI 生成的代码 Args: sandbox: 沙箱实例 code: AI 生成的代码字符串 save_results_dir: 用于保存结果文件的本地目录路径(可选) Returns: dict: 包含执行结果的字典 """ try: print("在沙箱中执行 AI 生成的代码...") # 确保代码是字符串类型 if not isinstance(code, str): code = str(code) # 执行代码 execution = sandbox.run_code(code) print("代码执行完成!") # 准备结果字典 result = { "success": True, "stdout": "", "results": [] } # 提取标准输出 if hasattr(execution, "stdout"): result["stdout"] = execution.stdout # 检查代码是否执行成功 if hasattr(execution, "error") and execution.error: error_name = getattr(execution.error, "name", "Unknown") error_value = getattr(execution.error, "value", "Unknown error") error_traceback = getattr(execution.error, "traceback", "") print("AI 生成的代码执行出错:") print(f"错误类型: {error_name}") print(f"错误信息: {error_value}") if error_traceback: print(f"错误追踪: {error_traceback}") result["success"] = False result["error"] = { "name": error_name, "value": error_value, "traceback": error_traceback } return result # 处理执行结果 if hasattr(execution, "results") and execution.results: import base64 result_idx = 0 for res in execution.results: # 默认为文本结果 result_data = {"type": "text", "value": str(res)} # 检查是否有PNG图像 if hasattr(res, "png") and res.png: result_data["type"] = "png" result_data["value"] = res.png # base64编码的字符串 # 如果指定了保存目录,保存图像到本地 if save_results_dir: try: os.makedirs(save_results_dir, exist_ok=True) image_path = os.path.join(save_results_dir, f"result-{result_idx}.png") # 解码并保存图像 with open(image_path, 'wb') as f: f.write(base64.b64decode(res.png)) print(f"图像已保存到: {image_path}") result_data["local_path"] = image_path except Exception as img_err: print(f"保存图像时出错: {str(img_err)}") result["results"].append(result_data) result_idx += 1 return result except Exception as e: print(f"执行AI生成的代码时出错: {str(e)}") import traceback print(f"详细错误: {traceback.format_exc()}") return { "success": False, "error": { "name": type(e).__name__, "value": str(e), "traceback": traceback.format_exc() } } def download_directory_from_sandbox(sandbox, sandbox_dir_path, local_dir_path): """从沙箱下载整个目录内容到本地 Args: sandbox: 沙箱实例 sandbox_dir_path: 沙箱中的目录路径 local_dir_path: 本地保存目录路径 Returns: bool: 是否成功下载所有文件 """ try: print(f"尝试下载目录: {sandbox_dir_path} -> {local_dir_path}") # 确保本地目录存在 os.makedirs(local_dir_path, exist_ok=True) # 列出沙箱中指定目录下的所有文件 try: files = sandbox.files.list(sandbox_dir_path) # print(f"获取到文件列表: {sandbox_dir_path}, 类型: {type(files)}") # if files and len(files) > 0: # print(f"第一个文件类型: {type(files[0])}, 内容: {files[0]}") # # 检查对象属性 # print(f"文件对象可用属性: {dir(files[0])}") except Exception as e: print(f"列出文件时出错: {sandbox_dir_path}, 错误: {str(e)}") return False if not files: print(f"沙箱中目录 {sandbox_dir_path} 为空或不存在") return False downloaded_count = 0 # 定义需要跳过的系统文件 skip_files = {'.bashrc', '.bash_logout', '.profile'} # 遍历并下载每个文件 for file_info in files: try: # 使用dir()查看对象有哪些属性 print(f"文件信息对象属性: {dir(file_info)}") # 尝试安全获取name和type属性 file_name = getattr(file_info, "name", None) if file_name is None: print(f"警告: 无法获取文件名, 跳过此文件") continue file_type = getattr(file_info, "type", "file") # 默认为文件类型 # 如果 file_type 是枚举, 使用其 value 进行判断 type_value = file_type.value if hasattr(file_type, "value") else file_type # 跳过不需要的系统文件或系统目录(隐藏文件/目录) if file_name in skip_files or (file_name.startswith('.') and type_value == 'dir'): print(f"跳过系统文件或目录: {file_name}") continue print(f"处理文件: {file_name}, 类型: {type_value}") sandbox_file_path = f"{sandbox_dir_path}/{file_name}" local_file_path = os.path.join(local_dir_path, file_name) if type_value == 'dir': # 递归下载子目录 print(f"发现子目录: {sandbox_file_path}") if download_directory_from_sandbox(sandbox, sandbox_file_path, local_file_path): downloaded_count += 1 else: # 下载文件 print(f"下载文件: {sandbox_file_path} -> {local_file_path}") if download_file_from_sandbox(sandbox, sandbox_file_path, local_file_path): downloaded_count += 1 except Exception as e: print(f"处理文件时出错: {str(e)}") import traceback print(f"详细错误跟踪: {traceback.format_exc()}") continue if downloaded_count > 0: print(f"从 {sandbox_dir_path} 下载了 {downloaded_count} 个文件/目录到 {local_dir_path}") return True return False except Exception as e: print(f"下载整个目录时出错: {str(e)}") import traceback print(f"详细错误跟踪: {traceback.format_exc()}") ############################################################################## # 检查E2B代码解释器工具是否已注册 ############################################################################## print_separator("检查E2B代码解释器工具是否已注册") # 获取所有已注册的工具(以字典格式) registered_tools = get_registered_tools(as_dict=True) # 打印所有已注册的工具 print("\n已注册的工具:") for name, info in registered_tools.items(): print(f"- {name} (类别: {info['category'].value})") # 检查E2B代码解释器工具是否已注册 e2b_tool_name = "e2b_code_interpreter" if e2b_tool_name in registered_tools: print(f"\nE2B代码解释器工具已成功注册: {e2b_tool_name}") else: print(f"\n警告: E2B代码解释器工具未注册") # 手动注册E2B代码解释器工具 print("尝试手动注册E2B代码解释器工具...") try: from core.tools.registry import register_tool e2b_tool = E2BCodeInterpreterTool() register_tool(e2b_tool, ToolCategory.CODE_INTERPRETER) print(f"已手动注册工具: {e2b_tool.name}") except Exception as e: print(f"手动注册E2B代码解释器工具失败: {e}") ############################################################################## # 创建ReactAgent实例 ############################################################################## print_separator("创建ReactAgent实例") # 初始化大模型 model = ChatOpenAI(model="gpt-4o-mini") # 从注册表中只获取代码解释器类工具列表 tools_list = get_tools_by_category(ToolCategory.CODE_INTERPRETER) # 打印获取到的代码解释器工具 print("\n获取到的代码解释器工具:") for tool in tools_list: print(f"- {tool.name}: {tool.description}") # 创建ReactAgent实例 react_agent = ReactAgent( model=model, tools=tools_list, name="sandbox_test_agent", # 提示词强调测试沙箱环境的各种功能 prompt=( "你是一位专业的沙箱环境测试专家,负责测试E2B代码解释器沙箱环境的各种功能。\n" "你有强大的代码执行工具可以使用:\n" "- e2b_code_interpreter: 用于在沙箱环境中执行Python代码\n\n" "当进行沙箱环境测试时,请遵循以下方法论:\n" "1. 分析测试需求:理解需要测试的沙箱功能\n" "2. 设计测试用例:针对特定功能设计测试代码\n" "3. 执行测试:使用e2b_code_interpreter工具执行测试代码\n" "4. 分析结果:解释测试结果,判断功能是否正常\n" "5. 记录问题:如有异常,记录问题并提供详细信息\n\n" "重要提示:\n" "- 优先使用e2b_code_interpreter工具执行Python代码\n" "- 测试代码应包含详细注释,解释测试目的和预期结果\n" "- 所有文件和图片必须保存在沙盒环境中的特定目录,不要直接返回图片\n" "- 图片不允许在回复中展示!Images are not allowed in the response!\n" "- 测试应覆盖沙箱的各种功能,包括但不限于:\n" " * 基本Python代码执行\n" " * 文件系统操作(创建、读取、写入文件)\n" " * 包管理(安装和使用第三方包)\n" " * 系统命令执行(使用!前缀执行shell命令)\n" " * 数据处理和可视化\n" " * 异常处理和错误恢复\n" ), ) # 添加调试信息,验证工具列表和沙箱实例的初始状态 print("\n验证ReactAgent工具列表和沙箱实例初始状态:") print(f"react_agent.tools类型: {type(react_agent.tools)}") print(f"react_agent.tools长度: {len(react_agent.tools)}") # 遍历所有工具,检查是否有sandbox属性 for i, tool in enumerate(react_agent.tools): print(f"\n工具[{i}]类型: {type(tool)}") print(f"工具[{i}]名称: {getattr(tool, 'name', '未知')}") print(f"工具[{i}]是否有sandbox属性: {'sandbox' in dir(tool)}") # 如果有sandbox属性,打印沙箱实例信息 if 'sandbox' in dir(tool): print(f"工具[{i}]的sandbox类型: {type(tool.sandbox)}") print(f"工具[{i}]的sandbox是否可用: {getattr(tool, '_is_available', False)}") print(f"工具[{i}]的初始化错误: {getattr(tool, '_init_error', None)}") # 编译Agent agent = react_agent.compile() # # 获取图对象 # graph = agent.get_graph() # # 获取当前文件名(不含路径和扩展名) # current_file = os.path.basename(__file__) # file_name_without_ext = os.path.splitext(current_file)[0] # graph_dir = os.path.join(os.path.dirname(__file__), "graphs") # # 确保 graphs 目录存在 # os.makedirs(graph_dir, exist_ok=True) # # 生成与文件名一致的图片名,并保存到 examples/graphs 目录 # image_data = graph.draw_mermaid_png() # graph_path = os.path.join(graph_dir, f"{file_name_without_ext}.png") # # 保存图片(如果已存在则覆盖) # with open(graph_path, "wb") as f: # f.write(image_data) # print(f"工作流图已保存为 {graph_path}") ############################################################################## # 测试用例1:基本Python代码执行和环境信息 ############################################################################## def run_test_case_1(): print_separator("测试用例1:基本Python代码执行和环境信息") print("\n查询: 测试基本Python代码执行和获取环境信息") # 定义输入 inputs = { "messages": [ HumanMessage(content="请执行一段Python代码,测试基本的数学运算、字符串操作,并获取沙箱环境的系统信息(Python版本、操作系统信息等)。") ] } # 使用stream方法逐步获取中间状态 final_state = None for partial_state in agent.stream(inputs, stream_mode="values"): # 保存最终状态 final_state = partial_state # 获取消息列表 messages = partial_state.get("messages", []) if not messages: continue # 获取最新消息 latest_message = messages[-1] # 使用log_agent_actions函数记录状态 log_agent_actions({"messages": [latest_message]}) # 打印最终回答 print_separator("测试用例1结果") if final_state and final_state.get("messages"): for message in final_state["messages"]: if isinstance(message, AIMessage) and not message.tool_calls: print(message.content) ############################################################################## # 测试用例2:文件系统操作 ############################################################################## def run_test_case_2(): print_separator("测试用例2:文件系统操作") print("\n查询: 测试沙箱环境的文件系统操作") # 定义输入 inputs = { "messages": [ HumanMessage(content="请测试沙箱环境的文件系统操作,包括创建目录、创建文件、写入内容、读取内容、列出目录内容等。创建一个测试目录结构,并将操作结果保存到文件中。文件保存到 /home/user/test_dir") ] } # 使用stream方法逐步获取中间状态 final_state = None for partial_state in agent.stream(inputs, stream_mode="values"): # 保存最终状态 final_state = partial_state # 获取消息列表 messages = partial_state.get("messages", []) if not messages: continue # 获取最新消息 latest_message = messages[-1] # 使用log_agent_actions函数记录状态 log_agent_actions({"messages": [latest_message]}) # 打印最终回答 print_separator("测试用例2结果") if final_state and final_state.get("messages"): for message in final_state["messages"]: if isinstance(message, AIMessage) and not message.tool_calls: print(message.content) # 检查是否有E2B沙箱实例,尝试下载生成的文件 for msg in final_state["messages"]: if isinstance(msg, ToolMessage) and msg.name == "e2b_code_interpreter": try: # 尝试解析工具消息内容 tool_output = json.loads(msg.content) print(f"\n工具消息内容解析成功: {type(tool_output)}") # 检查是否有原始输出 if hasattr(msg, 'raw_output') and msg.raw_output: print(f"\n消息包含raw_output属性: {type(msg.raw_output)}") # 打印react_agent.tools的信息 print(f"\nreact_agent.tools类型: {type(react_agent.tools)}") print(f"react_agent.tools长度: {len(react_agent.tools)}") # 遍历所有工具,检查是否有sandbox属性 for i, tool in enumerate(react_agent.tools): print(f"\n工具[{i}]类型: {type(tool)}") print(f"工具[{i}]名称: {getattr(tool, 'name', '未知')}") print(f"工具[{i}]是否有sandbox属性: {'sandbox' in dir(tool)}") if 'sandbox' in dir(tool): print(f"工具[{i}]的sandbox类型: {type(tool.sandbox)}") # 遍历 react_agent.tools 以查找 E2B 相关工具 sandbox = None for tool in react_agent.tools: if hasattr(tool, "sandbox"): sandbox = tool.sandbox break # 找到后就退出循环 if sandbox: print("\n成功获取沙箱实例!") print(f"沙箱实例类型: {type(sandbox)}") # 从沙箱下载生成的文件 output_dir = os.path.join(os.path.dirname(__file__), "output", "sandbox_test") os.makedirs(output_dir, exist_ok=True) print(f"输出目录已创建: {output_dir}") # 尝试下载测试目录,路径和提示中保持一致 sandbox_test_path = "/home/user/test_dir" print(f"尝试从沙箱下载目录: {sandbox_test_path}") download_directory_from_sandbox(sandbox, sandbox_test_path, os.path.join(output_dir, "test_dir")) else: print("\n错误: 无法获取沙箱实例,没有找到具有sandbox属性的工具") else: print("\n错误: 消息没有raw_output属性") except Exception as e: print(f"处理工具消息时出错: {str(e)}") ############################################################################## # 测试用例3:包管理和第三方库使用 ############################################################################## def run_test_case_3(): print_separator("测试用例3:包管理和第三方库使用") print("\n查询: 测试沙箱环境的包管理和第三方库使用") # 定义输入 inputs = { "messages": [ HumanMessage(content="请测试沙箱环境的包管理功能,安装一个不常见的第三方库(如wordcloud、pycountry等),并使用该库编写一个简单的示例程序。验证包安装和使用是否正常。") ] } # 使用stream方法逐步获取中间状态 final_state = None for partial_state in agent.stream(inputs, stream_mode="values"): # 保存最终状态 final_state = partial_state # 获取消息列表 messages = partial_state.get("messages", []) if not messages: continue # 获取最新消息 latest_message = messages[-1] # 使用log_agent_actions函数记录状态 log_agent_actions({"messages": [latest_message]}) # 打印最终回答 print_separator("测试用例3结果") if final_state and final_state.get("messages"): for message in final_state["messages"]: if isinstance(message, AIMessage) and not message.tool_calls: print(message.content) ############################################################################## # 测试用例4:Shell命令执行 ############################################################################## def run_test_case_4(): print_separator("测试用例4:Shell命令执行") print("\n查询: 测试沙箱环境的Shell命令执行") # 定义输入 inputs = { "messages": [ HumanMessage(content="请测试沙箱环境中执行Shell命令的功能,使用!前缀执行一系列Linux命令,包括系统信息查询、目录操作、文件查找等。将命令执行结果保存到文件(/home/user/shell_commands_results.txt)中。") ] } # 使用stream方法逐步获取中间状态 final_state = None for partial_state in agent.stream(inputs, stream_mode="values"): # 保存最终状态 final_state = partial_state # 获取消息列表 messages = partial_state.get("messages", []) if not messages: continue # 获取最新消息 latest_message = messages[-1] # 使用log_agent_actions函数记录状态 log_agent_actions({"messages": [latest_message]}) # 打印最终回答 print_separator("测试用例4结果") if final_state and final_state.get("messages"): for message in final_state["messages"]: if isinstance(message, AIMessage) and not message.tool_calls: print(message.content) # 尝试下载生成的文件 for msg in final_state["messages"]: if isinstance(msg, ToolMessage) and msg.name == "e2b_code_interpreter": try: print(f"\n测试用例4: 检查工具消息类型: {type(msg)}") print(f"测试用例4: 工具消息名称: {msg.name}") # 检查react_agent.tools的信息 print(f"\n测试用例4: react_agent.tools类型: {type(react_agent.tools)}") print(f"测试用例4: react_agent.tools长度: {len(react_agent.tools)}") # 遍历 react_agent.tools 以查找 E2B 相关工具 sandbox = None for tool in react_agent.tools: if hasattr(tool, "sandbox"): sandbox = tool.sandbox break # 找到后就退出循环 if sandbox: print("\n测试用例4: 成功获取沙箱实例!") print(f"测试用例4: 沙箱实例类型: {type(sandbox)}") print(f"测试用例4: 沙箱实例属性: {dir(sandbox)[:10]}...") output_dir = os.path.join(os.path.dirname(__file__), "output", "sandbox_test") os.makedirs(output_dir, exist_ok=True) print(f"测试用例4: 输出目录已创建: {output_dir}") # 尝试下载shell命令结果文件,路径和提示中保持一致 sandbox_file_path = "/home/user/shell_commands_results.txt" local_file_path = os.path.join(output_dir, "shell_commands_results.txt") print(f"测试用例4: 尝试下载文件: {sandbox_file_path} -> {local_file_path}") download_file_from_sandbox(sandbox, sandbox_file_path, local_file_path) else: print("\n测试用例4: 错误: 无法获取沙箱实例,没有找到具有sandbox属性的工具") print(f"测试用例4: react_agent.tools的类型和长度: {type(react_agent.tools)}, {len(react_agent.tools)}") except Exception as e: print(f"下载文件时出错: {str(e)}") ############################################################################## # 测试用例5:数据处理和可视化 ############################################################################## def run_test_case_5(): print_separator("测试用例5:数据处理和可视化") print("\n查询: 测试沙箱环境的数据处理和可视化功能") # 定义输入 inputs = { "messages": [ HumanMessage(content=( "请测试沙箱环境的数据处理和可视化功能,生成一些随机数据,使用pandas进行数据处理," "然后使用matplotlib创建多种类型的图表(折线图、柱状图、散点图等)。\n" "严格按照以下要求:\n" "1. 将所有图表保存到 /home/user/visualizations 目录\n" "2. 不要在回复中包含图片 - 图片直接保存到上述目录即可\n" "3. Images are not allowed in the response!\n" "4. 只需描述你做了什么,创建了哪些图表,并说明它们保存在哪里\n" "5. 请确保目录存在后再保存图片\n" )) ] } # 使用stream方法逐步获取中间状态 final_state = None for partial_state in agent.stream(inputs, stream_mode="values"): # 保存最终状态 final_state = partial_state # 获取消息列表 messages = partial_state.get("messages", []) if not messages: continue # 获取最新消息 latest_message = messages[-1] # 使用log_agent_actions函数记录状态 log_agent_actions({"messages": [latest_message]}) # 打印最终回答 print_separator("测试用例5结果") if final_state and final_state.get("messages"): for message in final_state["messages"]: if isinstance(message, AIMessage) and not message.tool_calls: print(message.content) # 尝试下载生成的图表文件 for msg in final_state["messages"]: if isinstance(msg, ToolMessage) and msg.name == "e2b_code_interpreter": try: # 遍历 react_agent.tools 以查找 E2B 相关工具 sandbox = None for tool in react_agent.tools: if hasattr(tool, "sandbox"): sandbox = tool.sandbox break # 找到后就退出循环 if sandbox: output_dir = os.path.join(os.path.dirname(__file__), "output", "sandbox_test") os.makedirs(output_dir, exist_ok=True) # 针对性地下载可视化目录中的图表 vis_dir = "/home/user/visualizations" local_vis_dir = os.path.join(output_dir, "visualizations") os.makedirs(local_vis_dir, exist_ok=True) print(f"测试用例5: 下载可视化图表目录: {vis_dir} -> {local_vis_dir}") # 尝试列出可视化目录中的文件 try: files = sandbox.files.list(vis_dir) if files: print(f"找到图表文件:") for file_info in files: file_name = getattr(file_info, "name", "未知文件") print(f"- {file_name}") else: print(f"警告: 可视化目录为空或不存在") except Exception as e: print(f"列出可视化目录文件时出错: {str(e)}") # 执行下载 success = download_directory_from_sandbox(sandbox, vis_dir, local_vis_dir) if success: print(f"✅ 成功下载可视化图表") else: print(f"⚠️ 下载可视化图表失败,尝试下载整个用户目录作为备份") download_directory_from_sandbox(sandbox, "/home/user", output_dir) else: print("\n错误: 无法获取沙箱实例,没有找到具有sandbox属性的工具") except Exception as e: print(f"下载文件时出错: {str(e)}") import traceback print(f"错误详情: {traceback.format_exc()}") ############################################################################## # 测试用例6:异常处理和错误恢复 ############################################################################## def run_test_case_6(): print_separator("测试用例6:异常处理和错误恢复") print("\n查询: 测试沙箱环境的异常处理和错误恢复能力") # 定义输入 inputs = { "messages": [ HumanMessage(content="请测试沙箱环境的异常处理和错误恢复能力。编写一段包含各种常见错误的Python代码(如语法错误、除零错误、文件不存在错误等),然后展示如何捕获和处理这些异常。验证沙箱环境是否能正确报告错误并继续执行后续代码。") ] } # 使用stream方法逐步获取中间状态 final_state = None for partial_state in agent.stream(inputs, stream_mode="values"): # 保存最终状态 final_state = partial_state # 获取消息列表 messages = partial_state.get("messages", []) if not messages: continue # 获取最新消息 latest_message = messages[-1] # 使用log_agent_actions函数记录状态 log_agent_actions({"messages": [latest_message]}) # 打印最终回答 print_separator("测试用例6结果") if final_state and final_state.get("messages"): for message in final_state["messages"]: if isinstance(message, AIMessage) and not message.tool_calls: print(message.content) ############################################################################## # 主函数 - 运行所有测试用例 ############################################################################## if __name__ == "__main__": print_separator("开始测试E2B沙箱环境") try: # 确保输出目录存在 output_dir = os.path.join(os.path.dirname(__file__), "output", "sandbox_test") os.makedirs(output_dir, exist_ok=True) print(f"创建输出目录: {output_dir}") # 确保可视化输出目录存在 vis_output_dir = os.path.join(output_dir, "visualizations") os.makedirs(vis_output_dir, exist_ok=True) print(f"创建可视化输出目录: {vis_output_dir}") # # 运行测试用例 # # 运行测试用例1:基本Python代码执行和环境信息 # run_test_case_1() # # 运行测试用例2:文件系统操作 # run_test_case_2() # # 运行测试用例3:包管理和第三方库使用 # run_test_case_3() # # 运行测试用例4:Shell命令执行 # run_test_case_4() # 运行测试用例5:数据处理和可视化 run_test_case_5() # # 运行测试用例6:异常处理和错误恢复 # run_test_case_6() print_separator("E2B沙箱环境测试完成") print("测试结果已保存到 examples/output/sandbox_test 目录") except Exception as e: print(f"测试过程中出错: {str(e)}") finally: # 关闭E2B沙箱 print("\n正在关闭E2B沙箱...") for tool in react_agent.tools: if hasattr(tool, 'close'): tool.close() ================================================ FILE: examples/12_planning_supervisor_test.py ================================================ from langgraph.prebuilt import create_react_agent from core.agents.react_supervisor_agent import SupervisorAgent from core.agents.research_agent import ResearchAgent from core.agents.base.react_agent import ReactAgent from langchain_openai import ChatOpenAI from langgraph.func import entrypoint, task from langgraph.graph import add_messages from dotenv import load_dotenv from langchain_community.tools import TavilySearchResults load_dotenv() # 自动加载 .env 文件 # 1. 初始化大模型 model = ChatOpenAI(model="gpt-4o-mini") ############################################################################## # Agent 1: Joke Generator (Functional API) ############################################################################## @task def generate_joke(messages): """Generate a short joke (no tool calls).""" system_message = { "role": "system", "content": "You are a witty comedian. Write a short joke." } # 直接调用 model.invoke,拼接 system_message + 用户消息 msg = model.invoke([system_message] + messages) return msg @entrypoint() def joke_agent(state): # 调用上面的函数型任务 joke = generate_joke(state['messages']).result() # 将产物插入消息列表 messages = add_messages(state["messages"], [joke]) return {"messages": messages} joke_agent.name = "joke_agent" ############################################################################## # Agent 2: Research Expert with Tavily Search (Graph API) ############################################################################## # 创建Tavily搜索工具 tavily_search = TavilySearchResults( max_results=3, include_answer=True, include_raw_content=False, include_images=False, search_depth="advanced" ) # 使用我们自定义的ResearchAgent替代create_react_agent创建的agent research_agent = ResearchAgent( name="research_expert", model=model, max_iterations=5, cache_enabled=True, debug=False ) research_agent_2 = ReactAgent( name="research_expert", model=model, tools=[tavily_search]) ############################################################################## # 使用带有Planning功能的SupervisorAgent ############################################################################## # 创建 SupervisorAgent 实例,启用Planning功能 supervisor = SupervisorAgent( agents=[joke_agent,research_agent_2], model=model, ) ############################################################################## # 测试:复杂请求需要规划和多个步骤 ############################################################################## result = supervisor.run({ "messages": [ { "role": "user", "content": ( "I'm preparing a presentation about tech companies. I need three things: " "1) A joke about tech companies to start with, " "2) The employee count for FANNG, and " "3) A comparison of which company has more employees." ) } ] }) ############################################################################## # 打印最终对话消息 ############################################################################## for m in result["messages"]: m.pretty_print() # 打印任务列表 print("\n##############################################################################") print("# 最终任务列表") print("##############################################################################") if "plan" in result and result["plan"] and "tasks" in result["plan"]: tasks = result["plan"]["tasks"] print(f"总共 {len(tasks)} 个任务:") for i, task in enumerate(tasks): print(f"\n任务 {i+1}: {task['description']}") print(f" 状态: {task['status']}") print(f" 代理: {task['agent'] if task['agent'] else '未分配'}") print(f" 创建时间: {task['created_at']}") print(f" 完成时间: {task['completed_at'] if task['completed_at'] else '未完成'}") else: print("没有任务列表信息") # 打印原始任务列表(如果存在) if "tasks" in result: print("\n原始任务列表:") for t in result["tasks"]: t.pretty_print() ================================================ FILE: examples/13_multi_agent_roles_test.py ================================================ from langgraph.prebuilt import create_react_agent from core.agents.react_supervisor_agent import SupervisorAgent from core.agents.sub_agents.research_agent import ResearchAgent from core.agents.sub_agents.coder_agent import CoderAgent from core.agents.sub_agents.reporter_agent import ReporterAgent from core.agents.sub_agents.designer_agent import DesignerAgent from core.agents.sub_agents.data_analyst_agent import DataAnalystAgent from langchain_openai import ChatOpenAI from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langgraph.func import entrypoint, task from langgraph.graph import add_messages from dotenv import load_dotenv from langchain_community.tools import TavilySearchResults import os import logging import sys import io import json from contextlib import redirect_stdout, redirect_stderr load_dotenv() # 自动加载 .env 文件 # 1. 初始化大模型 model = ChatOpenAI(model="gpt-4o-mini") # 设置日志捕获 class LogCapture: def __init__(self): self.log_buffer = io.StringIO() self.log_content = [] def start_capture(self): self.log_buffer = io.StringIO() return self.log_buffer def stop_capture(self): output = self.log_buffer.getvalue() self.log_content.append(output) return output def get_content(self): return "\n".join(self.log_content) log_capture = LogCapture() ############################################################################## # 从沙箱下载文件到本地的函数 ############################################################################## def download_file_from_sandbox(sandbox, sandbox_path, local_path): """从 e2b 沙箱中下载文件并保存到本地,自动区分文本和二进制文件""" try: print(f"读取文件: {sandbox_path}") # 判断是否为常见二进制文件类型(可自行扩展) binary_extensions = ( '.png', '.jpg', '.jpeg', '.gif', '.pdf', '.svg', '.xlsx', '.xls', '.zip', '.bin', '.pyc', '.pyd', '.pptx', '.docx', '.mp3', '.mp4', '.avi', '.mov', ) is_binary = sandbox_path.lower().endswith(binary_extensions) # 创建目录 os.makedirs(os.path.dirname(local_path), exist_ok=True) if is_binary: print("📦 识别为二进制文件,使用 sandbox.download_file()") content = sandbox.files.read(sandbox_path) # 返回 bytes with open(local_path, 'wb') as f: f.write(content) else: print("📄 识别为文本文件,使用 sandbox.files.read()") content = sandbox.files.read(sandbox_path) # 返回 str with open(local_path, 'w', encoding='utf-8') as f: f.write(content) print(f"✅ 文件已保存到本地: {local_path}") return True except Exception as e: print(f"❌ 下载失败: {e}") return False def download_directory_from_sandbox(sandbox, sandbox_dir_path, local_dir_path): """从沙箱下载整个目录内容到本地 Args: sandbox: 沙箱实例 sandbox_dir_path: 沙箱中的目录路径 local_dir_path: 本地保存目录路径 Returns: bool: 是否成功下载所有文件 """ try: print(f"尝试下载目录: {sandbox_dir_path} -> {local_dir_path}") # 确保本地目录存在 os.makedirs(local_dir_path, exist_ok=True) # 列出沙箱中指定目录下的所有文件 try: files = sandbox.files.list(sandbox_dir_path) except Exception as e: print(f"列出文件时出错: {sandbox_dir_path}, 错误: {str(e)}") return False if not files: print(f"沙箱中目录 {sandbox_dir_path} 为空或不存在") return False downloaded_count = 0 # 定义需要跳过的系统文件 skip_files = {'.bashrc', '.bash_logout', '.profile'} # 遍历并下载每个文件 for file_info in files: try: # 尝试安全获取name和type属性 file_name = getattr(file_info, "name", None) if file_name is None: print(f"警告: 无法获取文件名, 跳过此文件") continue file_type = getattr(file_info, "type", "file") # 默认为文件类型 # 如果 file_type 是枚举, 使用其 value 进行判断 type_value = file_type.value if hasattr(file_type, "value") else file_type # 跳过不需要的系统文件或系统目录(隐藏文件/目录) if file_name in skip_files or (file_name.startswith('.') and type_value == 'dir'): print(f"跳过系统文件或目录: {file_name}") continue print(f"处理文件: {file_name}, 类型: {type_value}") sandbox_file_path = f"{sandbox_dir_path}/{file_name}" local_file_path = os.path.join(local_dir_path, file_name) if type_value == 'dir': # 递归下载子目录 print(f"发现子目录: {sandbox_file_path}") if download_directory_from_sandbox(sandbox, sandbox_file_path, local_file_path): downloaded_count += 1 else: # 下载文件 print(f"下载文件: {sandbox_file_path} -> {local_file_path}") if download_file_from_sandbox(sandbox, sandbox_file_path, local_file_path): downloaded_count += 1 except Exception as e: print(f"处理文件时出错: {str(e)}") import traceback print(f"详细错误跟踪: {traceback.format_exc()}") continue if downloaded_count > 0: print(f"从 {sandbox_dir_path} 下载了 {downloaded_count} 个文件/目录到 {local_dir_path}") return True return False except Exception as e: print(f"下载整个目录时出错: {str(e)}") import traceback print(f"详细错误跟踪: {traceback.format_exc()}") ############################################################################## # Agent 2: Research Expert - 使用自定义的ResearchAgent ############################################################################## research_agent = ResearchAgent( name="research_expert", model=model, max_iterations=5, cache_enabled=True, debug=True ) ############################################################################## # Agent 3: Coder - 使用自定义的CoderAgent ############################################################################## from core.tools.e2b_tool import E2BCodeInterpreterTool e2b_tool = E2BCodeInterpreterTool() coder_agent = CoderAgent( name="coder_expert", model=model, tools=[e2b_tool], max_iterations=5, cache_enabled=True, debug=True ) ############################################################################## # Agent 4: Reporter - 使用自定义的ReporterAgent ############################################################################## reporter_agent = ReporterAgent( name="reporter_expert", model=model, max_iterations=5, cache_enabled=True, ) ############################################################################## # Agent 5: Designer - 使用自定义的DesignerAgent ############################################################################## designer_agent = DesignerAgent( name="designer_expert", model=model, max_iterations=5, cache_enabled=True, ) ############################################################################## # Agent 6: Data Analyst - 使用自定义的DataAnalystAgent ############################################################################## data_analyst_agent = DataAnalystAgent( name="data_analyst_expert", model=model, max_iterations=5, cache_enabled=True, ) ############################################################################## # 使用带有Planning功能的SupervisorAgent协调所有角色 ############################################################################## # 创建 SupervisorAgent 实例,启用Planning功能 supervisor = SupervisorAgent( agents=[ research_agent, coder_agent, reporter_agent, designer_agent, data_analyst_agent, ], model=model, enable_planning=True, output_mode="last_message" ) # 获取当前文件名(不含路径和扩展名) current_file = os.path.basename(__file__) file_name_without_ext = os.path.splitext(current_file)[0] logs_dir = os.path.join(os.path.dirname(__file__), "logs") # 创建图表输出文件路径 os.makedirs(logs_dir, exist_ok=True) # 创建Markdown输出文件路径 markdown_path = os.path.join(logs_dir, f"{file_name_without_ext}.md") ############################################################################## # 测试:复杂请求需要规划和多个步骤 ############################################################################## def save_markdown_log(): """将执行结果保存为Markdown文件""" with open(markdown_path, "w", encoding="utf-8") as f: f.write(f"# 执行结果: {file_name_without_ext}\n\n") f.write("## 图表\n\n") f.write("## 执行日志\n\n") f.write("```\n") f.write(log_capture.get_content()) f.write("\n```\n") print(f"执行日志已保存到 {markdown_path}") if __name__ == "__main__": try: # 开始捕获输出 log_buffer = log_capture.start_capture() with redirect_stdout(log_buffer), redirect_stderr(log_buffer): print(f"开始执行 {current_file} 测试...") # 测试1:需要研究和编码的任务 print("\n## 测试1:需要研究和编码的任务") final_state = supervisor.run({ "messages": [ { "role": "user", "content": ( "我需要一个Python爬虫来获取 https://www.paulgraham.com/articles.html 所有articles列表,并将结果保存为CSV文件,放在/home/user下面。" "并将你测试通过的爬虫代码返回给我。" "请确保你的代码能够正常运行。" "如果遇到问题,请重试。" ) } ] }) print("\n测试1结果:") for m in final_state["messages"]: m.pretty_print() # 遍历 react_agent.tools 以查找 E2B 相关工具 try: # 遍历 react_agent.tools 以查找 E2B 相关工具 sandbox = None for tool in coder_agent.tools: if hasattr(tool, "sandbox"): sandbox = tool.sandbox break # 找到后就退出循环 if sandbox: # 设定输出目录 output_dir = os.path.join(os.getcwd(), "examples/output/sandbox_files") os.makedirs(output_dir, exist_ok=True) # 直接下载主要工作目录 print("\n从沙箱下载文件到本地...") download_directory_from_sandbox(sandbox, "/home/user", output_dir) # 下载临时目录中可能的图表和数据文件 # download_directory_from_sandbox(sandbox, "/tmp", output_dir) print(f"\n文件已保存到目录: {output_dir}") sandbox.close() except Exception as e: print(f"从沙箱下载文件时出错: {str(e)}") finally: # 停止捕获并保存结果 log_capture.stop_capture() save_markdown_log() ================================================ FILE: examples/14_mcp_client_fetch_test.py ================================================ import os import sys import asyncio import traceback from typing import Dict, Optional, Type from dotenv import load_dotenv # 在这里添加项目根目录到路径,方便导入 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) load_dotenv() from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent from langchain_core.tools import BaseTool from langchain_core.messages import HumanMessage try: from pydantic.v1 import BaseModel, Field except ImportError: from pydantic import BaseModel, Field # type: ignore from core.mcp.client import MCPClient from core.mcp.config_loader import load_config, MCPConfig, StdioConfig from core.llm.llm_manager import LLMManager try: from mcp.types import CallToolRequest CALL_TOOL_REQ_AVAILABLE = True except ImportError: CallToolRequest = None CALL_TOOL_REQ_AVAILABLE = False # 这是唯一保留的 fetch schema try: class FetchInputSchema(BaseModel): url: str = Field(..., description="URL to fetch") max_length: Optional[int] = Field(default=5000) start_index: Optional[int] = Field(default=0) raw: Optional[bool] = Field(default=False) FETCH_SCHEMA_AVAILABLE = True except Exception: FetchInputSchema = None FETCH_SCHEMA_AVAILABLE = False CENTRAL_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "..", "core", "mcp", "mcp_server_config.json") LLM_ID_FOR_TESTING = "openai_gpt4o_mini" llm_manager = LLMManager() class MCPToolRunner(BaseTool): name: str = "needs_override" description: str = "needs_override" args_schema: Optional[Type[BaseModel]] = None client: MCPClient = Field(exclude=True) class Config: arbitrary_types_allowed = True async def _arun(self, **kwargs) -> str: if not self.client or not self.client.session: return f"ERROR: MCP Client session inactive for {self.name}." if not CALL_TOOL_REQ_AVAILABLE: return "ERROR: CallToolRequest unavailable." try: print(f" [_arun:{self.name}] Sending MCP request with args: {kwargs}") result_message = await asyncio.wait_for( self.client.session.call_tool(self.name, kwargs), timeout=120.0 ) # 简化: 只检查 result 和 error if hasattr(result_message, "result"): return str(result_message.result) elif hasattr(result_message, "error"): return f"Tool Error: {result_message.error.message}" else: return "Unknown response" except asyncio.TimeoutError: return "Error: Timeout." except Exception as e: return f"Error: {e}\n{traceback.format_exc()}" def _run(self, **kwargs) -> str: print(f" [_run:{self.name}] Running async method via asyncio.run()...") try: return asyncio.run(self._arun(**kwargs)) except Exception as e: return f"Error in sync wrapper: {e}" async def run_fetch_test(server_config_key: str, all_configs: Dict[str, MCPConfig]): print(f"\n=== Running STDIO BaseTool Test for Server '{server_config_key}' (Tool: 'fetch') ===") if not FETCH_SCHEMA_AVAILABLE: print("ERROR: Fetch Schema missing.") return False if not CALL_TOOL_REQ_AVAILABLE: print("ERROR: CallToolRequest unavailable.") return False server_config = all_configs.get(server_config_key) if not server_config: print(f"ERROR: Config for '{server_config_key}' not found.") return False if not isinstance(server_config.connection, StdioConfig): print(f"ERROR: Config '{server_config_key}' not STDIO.") return False try: model = llm_manager.get_model(LLM_ID_FOR_TESTING) print(f"Using LLM: {getattr(model, 'model_name', LLM_ID_FOR_TESTING)}") except ValueError as e: print(f"获取 LLM 出错: {e}.") return False test_success = False async with MCPClient(server_config) as client: if not client.session: print("ERROR: MCP session not established!") return False try: runner = MCPToolRunner( client=client, name="fetch", description="Fetches URL content as markdown.", args_schema=FetchInputSchema ) tools = [runner] except Exception as e_inst: print(f"ERROR: Failed to instantiate MCPToolRunner: {e_inst}") return False agent = create_react_agent(model, tools) query = ( "Use the fetch tool to get the content of https://www.google.com " "and tell me its title (first 50 chars)." ) print(f"\nQuery: {query}") try: response = await asyncio.wait_for( agent.ainvoke({"messages": [{"role": "user", "content": query}]}), timeout=180.0 ) print(f"\nAgent Final Response:") if response and "messages" in response and response["messages"]: response_content = response["messages"][-1].content print(response_content) if "google" in response_content.lower(): print("\n✅ Test PASS") test_success = True else: print("\n❌ Test FAIL (title not found)") test_success = False else: print("No valid response from agent.") test_success = False except Exception as e: print(f"Exception: {e}") test_success = False return test_success async def main(): print("Starting a simplified MCP Integration Test for 'fetch_via_uvx' only...") try: all_configs = load_config(CENTRAL_CONFIG_PATH) print(f"Loaded {len(all_configs)} server configs.") except Exception as e: print(f"Error loading config: {e}") return # 只测试 fetch_via_uvx result = await run_fetch_test("fetch_via_uvx", all_configs) if result: print("\nALL GOOD: 'fetch' test passed.") else: print("\nTEST FAILED: 'fetch' test didn't pass.") print("Done.") if __name__ == "__main__": asyncio.run(main()) ================================================ FILE: examples/15_mcp_agent_test.py ================================================ # examples/14_mcp_fetch_basetool_test.py (最终版 - BaseTool 子类) import os import sys import asyncio import json from dotenv import load_dotenv import traceback from typing import List, Dict, Any, Optional, Type # --- 前置要求 --- # 1. 确保 core/mcp/client.py 和 core/mcp/config_loader.py 是最新版本 (含 AsyncExitStack 和导入修复)。 # 2. 确保 core/mcp/config.json 文件存在,并包含 "fetch_via_uvx" 配置 (使用 uvx + stdio)。 # 3. 确保已安装 uv (`pip install uv`) 和 mcp-server-fetch。 # 4. 确保 OpenAI API Key (或其他 LLM Key) 在 .env 或环境变量中设置。 # 5. 推荐设置 LangSmith 环境变量用于详细追踪 Agent 行为。 # --- # 添加项目根目录到路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) load_dotenv() # --- 核心依赖导入 --- # LangChain from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent from langchain_core.tools import BaseTool from langchain_core.messages import HumanMessage try: # 尝试导入 Pydantic v1 (LangChain 常用的版本) from langchain_core.pydantic_v1 import BaseModel, Field except ImportError: try: # 如果 V1 不可用,尝试导入 V2 from pydantic import BaseModel, Field # type: ignore except ImportError: print("CRITICAL ERROR: Pydantic (v1 or v2) not found.") sys.exit(1) # MCP Client/Config try: from core.mcp.client import MCPClient except ImportError: print("CRITICAL ERROR: Cannot import MCPClient."); sys.exit(1) try: from core.mcp.config_loader import load_config, MCPConfig, StdioConfig except ImportError: print("CRITICAL ERROR: Cannot import config loader."); sys.exit(1) # LLM from core.llm.llm_manager import LLMManager # MCP Types try: from mcp.types import CallToolRequest; CALL_TOOL_REQ_AVAILABLE = True except ImportError: CallToolRequest = None; CALL_TOOL_REQ_AVAILABLE = False # --- # --- Fetch Tool Schema 定义 --- FETCH_SCHEMA_AVAILABLE = False FetchInputSchema = None try: class FetchInputSchema(BaseModel): # 使用导入的 BaseModel url: str = Field(..., description="URL to fetch") max_length: Optional[int] = Field(default=5000, description="Maximum number of characters to return") start_index: Optional[int] = Field(default=0, description="Start content from this character index") raw: Optional[bool] = Field(default=False, description="Get raw content without markdown conversion") FETCH_SCHEMA_AVAILABLE = True except Exception as e_pyd_fetch: print(f"ERROR defining FetchInputSchema: {e_pyd_fetch}") # --- # --- 全局设置 --- # **重要**: 确认此路径指向你的中央配置文件 CENTRAL_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "..", "core", "mcp", "mcp_server_config.json") # 使用 OpenAI 模型通常更稳定 LLM_ID_FOR_TESTING = "openai_gpt4o_mini" # 要测试的服务器在 config.json 中的 key SERVER_KEY_TO_TEST = "fetch_via_uvx" # 要测试的工具名称 TOOL_NAME_TO_TEST = "fetch" # 要测试的工具的正确 Schema CORRECT_SCHEMA_FOR_TOOL = FetchInputSchema # 要测试的工具的描述 TOOL_DESCRIPTION = "Fetches web content as markdown. Input requires 'url' (string) and optional 'max_length', 'start_index', 'raw'." # --- Everything MCP 服务器设置 --- EVERYTHING_SERVER_KEY = "everything" EVERYTHING_ECHO_TOOL = "echo" EVERYTHING_ADD_TOOL = "add" # --- Everything MCP 工具 Schema 定义 --- ECHO_SCHEMA_AVAILABLE = False EchoInputSchema = None try: class EchoInputSchema(BaseModel): message: str = Field(..., description="Message to echo back") ECHO_SCHEMA_AVAILABLE = True except Exception as e_pyd_echo: print(f"ERROR defining EchoInputSchema: {e_pyd_echo}") ADD_SCHEMA_AVAILABLE = False AddInputSchema = None try: class AddInputSchema(BaseModel): a: float = Field(..., description="First number") b: float = Field(..., description="Second number") ADD_SCHEMA_AVAILABLE = True except Exception as e_pyd_add: print(f"ERROR defining AddInputSchema: {e_pyd_add}") llm_manager = LLMManager() # --- 标准 BaseTool 子类定义,用于桥接 MCP 调用 --- class MCPToolRunner(BaseTool): """ 通过 MCP 调用服务器上工具的标准 BaseTool 实现。 """ # --- 类属性 (将在实例化时被覆盖) --- name: str = "mcp_tool_runner" # Default name description: str = "Runs a tool via MCP" args_schema: Optional[Type[BaseModel]] = None # --- 实例属性 --- client: MCPClient = Field(exclude=True) # 存储客户端引用 # Pydantic 配置 (根据你使用的 BaseModel 版本) class Config: arbitrary_types_allowed = True async def _arun(self, **kwargs) -> str: """异步执行:构造 MCP 请求并调用 client.session.call_tool""" if not self.client or not self.client.session: return f"ERROR: MCP Client session inactive for {self.name}." if not CALL_TOOL_REQ_AVAILABLE: return "ERROR: CallToolRequest unavailable." try: # kwargs 应该是 LangChain 根据 args_schema 验证和准备好的参数 print(f" [_arun:{self.name}] Preparing MCP request with args: {kwargs}") # 不再需要构造CallToolRequest对象,直接传递工具名称和参数 print(f" [_arun:{self.name}] Calling tool '{self.name}' with args: {kwargs}") # 调用 MCP session - 直接传递工具名称和参数 result_message = await asyncio.wait_for( self.client.session.call_tool(self.name, kwargs), timeout=120.0 # 给予足够的网络和执行超时 ) # 处理结果 - 简化处理逻辑,直接检查content属性 print(f" [_arun:{self.name}] MCP Response received, type: {type(result_message)}") # 直接检查是否有content属性(根据日志显示的响应结构) if hasattr(result_message, 'content'): content = result_message.content print(f" [_arun:{self.name}] Found content attribute, type: {type(content)}") # 如果content是列表且不为空 if isinstance(content, list) and len(content) > 0: first_item = content[0] print(f" [_arun:{self.name}] Content is a list, first item type: {type(first_item)}") # 尝试获取text属性 if hasattr(first_item, 'text'): print(f" [_arun:{self.name}] First item has text attribute, returning text") return first_item.text else: print(f" [_arun:{self.name}] First item has no text attribute, converting to string") return str(first_item) elif hasattr(content, 'text'): print(f" [_arun:{self.name}] Content has text attribute, returning text") return content.text else: print(f" [_arun:{self.name}] Content has no text attribute, converting to string") return str(content) # 如果没有content属性,回退到检查result属性 elif hasattr(result_message, 'result'): res_val = result_message.result print(f" [_arun:{self.name}] Found result attribute: {str(res_val)[:500]}...") return str(res_val) if not isinstance(res_val, str) else res_val elif hasattr(result_message, 'error'): err_msg = result_message.error.message print(f" [_arun:{self.name}] MCP Tool Error: {err_msg}") # 对于 Agent,返回错误通常比抛出异常更好处理 return f"Tool Error: {err_msg}" else: # 打印完整的响应对象,帮助诊断问题 print(f" [_arun:{self.name}] Unknown MCP response format. Full response object: {result_message}") print(f" [_arun:{self.name}] Response type: {type(result_message)}") print(f" [_arun:{self.name}] Response dir: {dir(result_message)}") # 尝试处理特殊的响应格式 if hasattr(result_message, 'content'): content = result_message.content print(f" [_arun:{self.name}] Found content attribute in response") # 处理content是列表的情况 if isinstance(content, list) and len(content) > 0: print(f" [_arun:{self.name}] Content is a list with {len(content)} items") first_item = content[0] if hasattr(first_item, 'text'): print(f" [_arun:{self.name}] First item has text attribute, returning text") return first_item.text elif hasattr(first_item, 'type') and hasattr(first_item, 'text'): print(f" [_arun:{self.name}] First item has type and text attributes, returning text") return first_item.text else: print(f" [_arun:{self.name}] First item has no text attribute, converting to string") return str(first_item) # 处理content是单个对象的情况 elif hasattr(content, 'text'): print(f" [_arun:{self.name}] Content has text attribute, returning text") return content.text else: print(f" [_arun:{self.name}] Content has no text attribute, converting to string") return str(content) # 尝试提取更多信息 response_details = "" for attr in dir(result_message): if not attr.startswith('_'): try: value = getattr(result_message, attr) if not callable(value): response_details += f"\n - {attr}: {value}" except Exception as attr_err: response_details += f"\n - {attr}: [Error accessing: {attr_err}]" print(f" [_arun:{self.name}] Response details: {response_details}") return f"Unknown response from MCP tool {self.name}. Details: {response_details}" except asyncio.TimeoutError: print(f" [_arun:{self.name}] MCP call timeout.") return f"Error: Timeout calling MCP tool {self.name}." except Exception as e: print(f" [_arun:{self.name}] Unexpected error during MCP call: {e}") print(traceback.format_exc()) # 返回包含 Traceback 的错误,方便调试 return f"Unexpected Error calling {self.name}: {e}\n{traceback.format_exc()}" def _run(self, **kwargs) -> str: """同步执行 (简单实现,通过运行异步方法)""" print(f" [_run:{self.name}] Running async method via asyncio.run()...") try: # 注意: 在已运行的事件循环中调用 asyncio.run 会报错 # 更好的方法是检查当前循环或使用 anyio/nest_asyncio # 但为了满足 BaseTool 要求,先用简单方式,如果 Agent 只用 async 就没问题 # 如果 Agent 强制用 sync,可能需要更复杂的处理 # return asyncio.run(self._arun(**kwargs)) # 更安全的方式是提示不支持或使用更复杂的同步转异步 return "Synchronous execution not fully supported, please use async." except Exception as e: print(f" [_run:{self.name}] Error: {e}") return f"Error in sync wrapper: {e}" # --- # --- 主要测试逻辑 --- async def run_fetch_test(): """运行 Fetch Server 测试 (使用 BaseTool 子类)""" print(f"\n=== Running Fetch Server Test (BaseTool Subclass Method) ===") # 检查依赖和 Schema 定义 if not FETCH_SCHEMA_AVAILABLE: print("ERROR: FetchInputSchema not available."); return False if not CALL_TOOL_REQ_AVAILABLE: print("ERROR: CallToolRequest unavailable."); return False # 加载配置 config: Optional[MCPConfig] = None try: all_configs = load_config(CENTRAL_CONFIG_PATH) config = all_configs.get(SERVER_KEY_TO_TEST) if not config: print(f"ERROR: Config key '{SERVER_KEY_TO_TEST}' not found in '{CENTRAL_CONFIG_PATH}'."); return False if not isinstance(config.connection, StdioConfig): print("ERROR: Config connection is not STDIO."); return False print(f"Successfully loaded config for '{SERVER_KEY_TO_TEST}'.") except Exception as e_load: print(f"ERROR loading config: {e_load}"); return False # 获取 LLM try: model = llm_manager.get_model(LLM_ID_FOR_TESTING); print(f"Using LLM: {getattr(model, 'model_name', LLM_ID_FOR_TESTING)}") except ValueError as e: print(f"获取 LLM 出错: {e}."); return False test_success = False # 使用 MCPClient 连接 (它会根据 config 启动服务器) async with MCPClient(config) as client: print("\nMCPClient context entered.") if not client.session: print("ERROR: MCP session not established!"); return False # --- 实例化我们定义的 MCPToolRunner --- try: print(f"Instantiating MCPToolRunner for '{TOOL_NAME_TO_TEST}'...") mcp_tool_instance = MCPToolRunner( client=client, # 注入 client name=TOOL_NAME_TO_TEST, description=TOOL_DESCRIPTION, args_schema=CORRECT_SCHEMA_FOR_TOOL ) tools = [mcp_tool_instance] print(f"Tool instance created successfully.") except Exception as e_inst: print(f"ERROR instantiating MCPToolRunner: {e_inst}"); return False # --- # --- Agent 执行 --- agent = create_react_agent(model, tools) # Agent 使用这个标准工具 query = "Use the fetch tool to get the main content (first 2000 chars) from https://developer.mozilla.org/en-US/docs/Web/HTML" print(f"\nRunning Agent Query...") print(f"Query: {query}") print("--- NOTE: Enable LangSmith for detailed tracing! ---") try: response = await asyncio.wait_for( agent.ainvoke({"messages": [{"role": "user","content": query}]}), timeout=180.0 ) print(f"\nAgent Final Response:") if response and "messages" in response and response["messages"]: response_content = response["messages"][-1].content; print(response_content) # 检查是否成功获取内容且无报错 contains_error = "error" in response_content.lower() or "fail" in response_content.lower() or "issue" in response_content.lower() or "apologi" in response_content.lower() or "unable" in response_content.lower() or "tool error" in response_content.lower() contains_expected = "HTML" in response_content if not contains_error and contains_expected: print(f"\n✅ Test PASS: Agent successfully used tool and got expected content.") test_success = True else: print(f"\n❌ Test FAIL: Agent reported error or didn't get expected content."); test_success = False else: print("Agent returned no valid response."); test_success = False except asyncio.TimeoutError: print(f"Agent execution timed out"); test_success = False except Exception as e: print(f"Agent execution failed: {e}"); print(f"Traceback:\n{traceback.format_exc()}"); test_success = False # --- # async with 会自动调用 client.close() print(f"\n--- Fetch Server Test Result: {'PASS' if test_success else 'FAIL'} ---") return test_success async def run_everything_test(): """运行 Everything MCP Server 测试 (使用 BaseTool 子类)""" print(f"\n=== Running Everything MCP Server Test (BaseTool Subclass Method) ===") # 检查依赖和 Schema 定义 if not ECHO_SCHEMA_AVAILABLE: print("ERROR: EchoInputSchema not available."); return False if not ADD_SCHEMA_AVAILABLE: print("ERROR: AddInputSchema not available."); return False if not CALL_TOOL_REQ_AVAILABLE: print("ERROR: CallToolRequest unavailable."); return False # 加载配置 config: Optional[MCPConfig] = None try: all_configs = load_config(CENTRAL_CONFIG_PATH) config = all_configs.get(EVERYTHING_SERVER_KEY) if not config: print(f"ERROR: Config key '{EVERYTHING_SERVER_KEY}' not found in '{CENTRAL_CONFIG_PATH}'."); return False if not isinstance(config.connection, StdioConfig): print("ERROR: Config connection is not STDIO."); return False print(f"Successfully loaded config for '{EVERYTHING_SERVER_KEY}'.") except Exception as e_load: print(f"ERROR loading config: {e_load}"); return False # 获取 LLM try: model = llm_manager.get_model(LLM_ID_FOR_TESTING); print(f"Using LLM: {getattr(model, 'model_name', LLM_ID_FOR_TESTING)}") except ValueError as e: print(f"获取 LLM 出错: {e}."); return False test_success = False # 使用 MCPClient 连接 (它会根据 config 启动服务器) async with MCPClient(config) as client: print("\nMCPClient context entered for Everything MCP.") if not client.session: print("ERROR: MCP session not established!"); return False # --- 实例化我们定义的 MCPToolRunner 用于 echo 工具 --- try: print(f"Instantiating MCPToolRunner for '{EVERYTHING_ECHO_TOOL}'...") echo_tool = MCPToolRunner( client=client, # 注入 client name=EVERYTHING_ECHO_TOOL, description="Echoes back the input message", args_schema=EchoInputSchema ) print(f"Instantiating MCPToolRunner for '{EVERYTHING_ADD_TOOL}'...") add_tool = MCPToolRunner( client=client, # 注入 client name=EVERYTHING_ADD_TOOL, description="Adds two numbers together", args_schema=AddInputSchema ) tools = [echo_tool, add_tool] print(f"Tool instances created successfully.") except Exception as e_inst: print(f"ERROR instantiating MCPToolRunner: {e_inst}"); return False # --- # --- Agent 执行 --- agent = create_react_agent(model, tools) # Agent 使用这些工具 query = "First, use the echo tool to echo back the message 'Hello from Everything MCP!'. Then, use the add tool to calculate 42 + 58." print(f"\nRunning Agent Query...") print(f"Query: {query}") print("--- NOTE: Enable LangSmith for detailed tracing! ---") try: response = await asyncio.wait_for(agent.ainvoke({"messages": [{"role": "user","content": query}]}), timeout=180.0) print(f"\nAgent Final Response:") if response and "messages" in response and response["messages"]: response_content = response["messages"][-1].content; print(response_content) # 检查是否成功获取内容且无报错 contains_error = "error" in response_content.lower() or "fail" in response_content.lower() or "issue" in response_content.lower() or "apologi" in response_content.lower() or "unable" in response_content.lower() or "tool error" in response_content.lower() contains_echo = "Hello from Everything MCP!" in response_content contains_add = "100" in response_content if not contains_error and contains_echo and contains_add: print(f"\n✅ Test PASS: Agent successfully used both tools and got expected content.") test_success = True else: print(f"\n❌ Test FAIL: Agent reported error or didn't get expected content.") print(f" - Contains error: {contains_error}") print(f" - Contains echo response: {contains_echo}") print(f" - Contains add result: {contains_add}") test_success = False else: print("Agent returned no valid response."); test_success = False except asyncio.TimeoutError: print(f"Agent execution timed out"); test_success = False except Exception as e: print(f"Agent execution failed: {e}"); print(f"Traceback:\n{traceback.format_exc()}"); test_success = False # --- # async with 会自动调用 client.close() print(f"\n--- Everything MCP Server Test Result: {'PASS' if test_success else 'FAIL'} ---") return test_success async def main(): """主函数 - 运行所有测试""" print("Starting MCP Integration Tests...") # 运行 Fetch 测试 fetch_success = await run_fetch_test() # 运行 Everything MCP 测试 everything_success = await run_everything_test() print("\n" + "="*20 + " FINAL TEST SUMMARY " + "="*20); print(f" Fetch Server Test: {'PASS' if fetch_success else 'FAIL'}") print(f" Everything MCP Test: {'PASS' if everything_success else 'FAIL'}") print("="*20 + " MCP Integration Test Finished " + "="*20) if __name__ == "__main__": # 简化依赖检查 print("--- Dependency Check ---") deps_ok = True try: import mcp; print("mcp available: True") except ImportError: print("mcp available: False"); deps_ok = False if CALL_TOOL_REQ_AVAILABLE: print("CallToolRequest available: True") else: print("CallToolRequest available: False"); deps_ok = False # 需要它 try: import langgraph; print("langgraph available: True") except ImportError: print("langgraph available: False"); deps_ok = False try: import langchain_openai; print("langchain_openai available: True") except ImportError: print("langchain_openai available: False"); deps_ok = False try: import dotenv; print("dotenv available: True") except ImportError: print("dotenv available: False"); deps_ok = False try: import pydantic; print("pydantic available: True") except ImportError: print("pydantic available: False"); deps_ok = False try: from core.mcp.client import MCPClient; print("MCPClient available: True") except ImportError: print("MCPClient available: False"); deps_ok = False try: from core.mcp.config_loader import load_config; print("config_loader available: True") except ImportError: print("config_loader available: False"); deps_ok = False if not FETCH_SCHEMA_AVAILABLE: print("FetchInputSchema available: False"); deps_ok=False else: print("FetchInputSchema available: True") if not ECHO_SCHEMA_AVAILABLE: print("EchoInputSchema available: False"); deps_ok=False else: print("EchoInputSchema available: True") if not ADD_SCHEMA_AVAILABLE: print("AddInputSchema available: False"); deps_ok=False else: print("AddInputSchema available: True") print(f"------------------------") if not deps_ok: print("CRITICAL ERROR: Necessary dependencies missing.") sys.exit(1) asyncio.run(main()) ================================================ FILE: examples/16_google_a2a/README.md ================================================ # LangGraph Agent 与 A2A 协议集成框架 ## 概述 本项目提供了一个将 **LangGraph Agent**(特别是基于 ReAct 模式并能调用工具的 Agent)与 **A2A (Agent-to-Agent) 协议** 相集成的框架和示例。目标是展示如何将一个用 LangGraph 构建的复杂 Agent 能力,通过标准化的 A2A 接口暴露给外部客户端或其他 Agent。 此框架的核心在于 `AgentTaskManager`,它充当了 A2A 协议层与具体 Agent 实现之间的桥梁。项目包含了一个完整的端到端示例,其中 `CurrencyAgent`(使用 `create_react_agent` 构建,并带有计算器和搜索工具)通过 `A2AServer` 提供服务,并提供了两个不同的客户端示例 (`client_example.py` 和 `currency_agent_test.py`) 来演示如何与之交互。 关键技术栈包括: * **A2A 协议:** 定义交互规范。 * **LangGraph:** 用于构建具备状态管理和工具调用能力的 Agent。 * **`create_react_agent`:** LangGraph 提供的预构建 ReAct Agent 实现(作为示例)。 * **Pydantic:** 用于定义和验证 A2A 协议中的数据结构 (`core/a2a/types.py`)。 * **Starlette/Uvicorn:** 作为底层 Web 框架运行 A2A 服务器 (`core/a2a/server/server.py`)。 * **OpenAI API:** 作为 LangGraph Agent 使用的后端大语言模型(可替换)。 ## 特性 * **A2A 协议兼容:** 提供符合 A2A 规范的服务端点 (`/.well-known/agent.json` 和主任务端点)。 * **LangGraph Agent 集成:** 可将任意(满足特定接口要求的)LangGraph Agent 作为 A2A 服务的核心处理逻辑。 * **工具使用:** 集成的 Agent 能够根据需要调用外部工具(示例中为计算器和搜索)。 * **同步任务处理:** 支持客户端发送任务并等待最终结果。 * **流式基础:** 包含了处理流式请求和响应的框架(Agent 端流式逻辑需开发者实现)。 * **类型安全:** 使用 Pydantic 进行严格的数据校验。 * **环境配置:** 支持通过 `.env` 文件配置 API 密钥等敏感信息。 * **客户端示例:** 提供了基础和场景化的客户端示例代码。 ## 目录结构 ``` . ├── core/ # 核心 A2A 协议实现 │ └── a2a/ │ ├── client/ │ │ └── client.py # A2AClient 客户端库实现 │ ├── server/ │ │ ├── server.py # A2AServer HTTP 服务器实现 │ │ └── task_manager.py # TaskManager 基础接口 (被 AgentTaskManager 使用) │ ├── agent_task_manager.py # AgentTaskManager 实现 (连接 A2A 与 LangGraph) │ └── types.py # A2A 协议的 Pydantic 模型定义 ├── examples/ # 示例代码 │ └── a2a/ │ ├── langgraph_integration.py # 服务端设置和示例 LangGraph Agent (CurrencyAgent) 定义 │ ├── client_example.py # 基础 A2A 客户端使用示例脚本 │ └── currency_agent_test.py # 场景化 A2A 客户端测试脚本 ├── .env # 存储环境变量 (例如 OPENAI_API_KEY) - *需要自行创建* ├── requirements.txt # Python 依赖项列表 (假设存在) └── README.md # 本文档 ``` ## 核心组件说明 * **`core/a2a/types.py`:** 定义所有 A2A 数据结构,是协议的基础和校验依据。 * **`core/a2a/server/server.py` (`A2AServer`):** 基于 Starlette 的 HTTP 服务器,处理 A2A JSON-RPC 请求路由,将请求交给 `AgentTaskManager`。通过 `.start()` 方法启动。 * **`core/a2a/agent_task_manager.py` (`AgentTaskManager`):** **核心适配器**。连接 A2A 层和 Agent 层。它接收来自 `A2AServer` 的请求,管理任务状态,并调用注入的 Agent 实例的 `invoke` 或 `stream` 方法。 * **`examples/a2a/langgraph_integration.py`:** 包含 `CurrencyAgent` (使用 `create_react_agent` 的示例 Agent) 的定义,以及如何配置和启动 `A2AServer` 来运行这个 Agent 的完整脚本。 * **`core/a2a/client/client.py` (`A2AClient`):** 基础 A2A 客户端库。 * **`examples/a2a/client_example.py`:** 一个简单的脚本,演示如何使用 `A2AClient` 发送基本请求。 * **`examples/a2a/currency_agent_test.py`:** 一个更复杂的客户端脚本,包含多个测试场景,用于测试服务器端 Agent 的不同交互模式。 ## 先决条件 * Python (推荐 3.10 或更高版本) * `pip` (Python 包安装器) * 虚拟环境 (强烈推荐) * 大语言模型 API Key (例如 OpenAI API Key) ## 安装与设置 1. **克隆仓库:** ```bash git clone cd ``` 2. **创建并激活虚拟环境:** ```bash uv venv source .venv/bin/activate ``` 3. **安装依赖项:** ```bash uv sync ``` 4. **设置环境变量:** * 在项目根目录下创建 `.env` 文件。 * 添加所需的 API Key,例如: ```dotenv OPENAI_API_KEY="sk-..." ``` ## 运行示例 1. **启动 A2A 服务器:** * 在终端中,激活虚拟环境后运行: ```bash python -m examples.a2a.langgraph_integration ``` * 服务器将在 `http://127.0.0.1:8000` 启动并监听。 2. **运行 A2A 客户端:** * 打开**新的**终端,激活虚拟环境。 * 你可以选择运行任一客户端示例: * **基础示例:** ```bash python -m examples.a2a.client_example ``` * **场景化测试:** ```bash python -m examples.a2a.currency_agent_test ``` 3. **预期输出:** * **服务器终端**会显示接收请求、调用 LLM 和工具(如果被触发)的日志。 * **客户端终端**会显示发送任务、轮询状态(对于同步任务)、接收结果或(模拟的)流式事件的输出。`currency_agent_test.py` 会按场景输出结果。 --- ## **重要:集成新的 LangGraph Agent 指南** 如果你创建了一个新的基于 LangGraph 的 Agent,并希望将其接入到这个 A2A 框架中,你需要遵循以下步骤和约定: ### 1. Agent 类必须实现的接口 你的新 Agent 类(例如 `MyNewAgent`)需要被 `AgentTaskManager` 调用。为此,它**必须**实现以下方法和属性: * **`__init__(self, llm, ...)`:** * 构造函数,用于初始化 Agent 所需的资源,例如 LLM 实例、工具列表等。 * **关键:** 在这里构建或获取你的 LangGraph **Runnable** 实例(例如通过 `create_react_agent` 或手动构建 `StateGraph().compile()`),并将其存储为类的成员(例如 `self.agent_runnable`)。 * **`invoke(self, query: str, session_id: Optional[str] = None) -> str:`** * 处理 A2A 的**同步** `tasks/send` 请求。 * 接收从 `AgentTaskManager` 传递过来的纯文本用户查询 `query` 和可选的 `session_id`。 * **内部逻辑:** * 将 `query` 包装成你的 LangGraph Runnable 所需的输入格式。对于基于 `create_react_agent` 或类似使用消息列表的 Agent,通常是 `{"messages": [("user", query)]}`。如果需要 `session_id`,也应包含在内。 * 调用 LangGraph Runnable 的 `.invoke()` 方法,传入构造好的输入字典。 * 处理 Runnable 返回的结果字典。对于 ReAct Agent,最终的文本答案通常位于结果字典内 `messages` 列表的最后一条消息的内容中。你需要编写逻辑来提取这个最终答案。 * **返回值:** **必须**返回一个包含最终答案的**字符串**。 * **`stream(self, query: str, session_id: Optional[str] = None) -> AsyncIterable[Dict[str, Any]]:`** * 处理 A2A 的**流式** `tasks/sendSubscribe` 请求。 * 接收 `query` 和 `session_id`。 * **必须**是一个**异步生成器** (`async def` 包含 `yield`)。 * **内部逻辑:** * 准备 LangGraph Runnable 流式调用所需的输入(通常与 `invoke` 类似,例如 `{"messages": [("user", query)]}`)。 * 调用 LangGraph Runnable 的流式方法,例如 `self.agent_runnable.astream(...)` 或 `self.agent_runnable.astream_log(...)`。 * 使用 `async for chunk in ...:` 迭代 LangGraph Runnable 返回的流式数据块 (`chunk`)。 * **解析 `chunk`**: LangGraph 流式输出的 `chunk` 格式取决于你调用的方法(`astream` vs `astream_log`)和图的结构。你需要解析这些 `chunk`(可能是状态变更、日志补丁等)来获取有意义的中间或最终内容。 * **`yield` 符合格式的字典**: 对于每个希望发送给客户端的更新,你需要 `yield` 一个字典。这个字典**必须**包含以下键(供 `AgentTaskManager._run_streaming_agent` 使用): * `"content"`: `str` - 当前步骤生成的文本内容。 * `"is_task_complete"`: `bool` - 指示这是否是任务的最终产物/结束信号。 * `"require_user_input"`: `bool` - 指示任务是否暂停并需要用户输入。 * **返回值:** 返回一个异步可迭代对象(由 `async def` + `yield` 自动创建)。 * **`SUPPORTED_CONTENT_TYPES: List[str]` (类属性):** * 一个包含 Agent 支持的输出内容类型的列表。对于主要处理文本的 Agent,通常是 `["text"]`。`AgentTaskManager` 会用它来验证客户端请求的 `acceptedOutputModes`。 ### 2. `AgentState` 的一致性 如果你手动构建 LangGraph 图,你定义的 `AgentState`(传递给 `StateGraph`)需要与你的 `invoke` 和 `stream` 方法处理输入/输出的方式保持一致。特别是,如果你依赖 `messages` 列表来管理对话历史或传递输入/输出,`AgentState` 中需要正确定义它。 ### 3. 集成步骤 1. **创建 Agent 类:** * 在你的项目中创建一个新的 Python 文件(例如 `my_new_agent.py`)。 * 定义你的 Agent 类(例如 `MyNewAgent`),确保它实现了上面描述的 `__init__`, `invoke`, `stream` 方法和 `SUPPORTED_CONTENT_TYPES` 属性。 * 在 `__init__` 中构建或加载你的 LangGraph Runnable。 2. **修改服务器启动脚本 (例如 `examples/a2a/langgraph_integration.py`):** * **导入**你的新 Agent 类:`from my_new_agent import MyNewAgent`。 * **实例化**你的新 Agent:`my_agent = MyNewAgent(llm)` (确保传递了所需的依赖,如 `llm`)。 * **更新 `AgentCard`**: 修改 `name`, `description` 和 `skills` 列表以反映新 Agent 的信息。确保 `AgentSkill` 具有唯一的 `id` 和正确的 `name`。 * **实例化 `AgentTaskManager`**: 使用你的新 Agent 实例:`task_manager = AgentTaskManager(my_agent)`。 * **实例化 `A2AServer`**: 使用更新后的 `agent_card` 和 `task_manager`。 3. **运行服务器:** * 启动修改后的服务器脚本:`python -m examples.a2a.your_server_script`。 4. **测试:** * 使用 `client_example.py` 或 `currency_agent_test.py`(可能需要修改发送的查询或 `metadata` 中的 `skill_name`)来向新启动的服务器发送请求,验证你的新 Agent 是否能通过 A2A 协议正常工作。 ### 示例 Agent 骨架 ```python # my_new_agent.py import logging from typing import List, Optional, AsyncIterable, Dict, Any, Tuple from langchain_core.language_models import BaseChatModel # 示例 LLM 类型 from langgraph.graph.state import StateGraph # 如果手动构建图 # from langgraph.prebuilt import create_some_agent # 如果使用预构建 from typing import TypedDict logger = logging.getLogger(__name__) # 1. 定义你 Agent 使用的 State (如果需要) class MyAgentState(TypedDict): messages: List[Tuple[str, str]] # ... 其他状态字段 class MyNewAgent: SUPPORTED_CONTENT_TYPES: List[str] = ["text"] def __init__(self, llm: BaseChatModel): self.llm = llm # TODO: 在这里构建或加载你的 LangGraph Runnable # 例如: self.agent_runnable = self._build_my_graph() # 或者: self.agent_runnable = create_some_agent(llm, tools) self.agent_runnable = self._get_placeholder_runnable() # 示例 logger.info("MyNewAgent initialized.") def _get_placeholder_runnable(self): # 这是一个模拟的 Runnable,你需要替换成真实的 LangGraph Runnable class PlaceholderRunnable: def invoke(self, input_dict): logger.info(f"PlaceholderRunnable received invoke: {input_dict}") query = input_dict.get("messages", [("", "")])[-1][1] return {"messages": [("assistant", f"模拟回应 '{query}'")]} async def astream(self, input_dict): logger.info(f"PlaceholderRunnable received astream: {input_dict}") query = input_dict.get("messages", [("", "")])[-1][1] yield {"messages": [("assistant", f"模拟流式回应1 '{query}' ...")]} await asyncio.sleep(0.5) yield {"messages": [("assistant", f"模拟流式回应2 '{query}' 完毕。")]} return PlaceholderRunnable() # def _build_my_graph(self): # # 如果你手动构建图,在这里实现 # # workflow = StateGraph(MyAgentState) # # ... add nodes, edges ... # # return workflow.compile() # pass def invoke(self, query: str, session_id: Optional[str] = None) -> str: logger.debug(f"[MyNewAgent.invoke] query: '{query}', session_id: '{session_id}'") # 1. 准备输入 invoke_input = {"messages": [("user", query)]} # 2. 调用 Runnable try: result = self.agent_runnable.invoke(invoke_input) logger.debug(f"[MyNewAgent.invoke] Runnable result: {result}") # 3. 解析结果 final_output = "错误:未能解析 Agent 响应。" if isinstance(result, dict) and isinstance(result.get("messages"), list) and result["messages"]: last_message = result["messages"][-1] if isinstance(last_message, tuple) and len(last_message) == 2: final_output = last_message[1] elif hasattr(last_message, 'content'): final_output = last_message.content return str(final_output) except Exception as e: logger.error(f"[MyNewAgent.invoke] Error: {e}", exc_info=True) raise # 重新抛出异常,让 TaskManager 处理 async def stream(self, query: str, session_id: Optional[str] = None) -> AsyncIterable[Dict[str, Any]]: logger.debug(f"[MyNewAgent.stream] query: '{query}', session_id: '{session_id}'") # 1. 准备输入 stream_input = {"messages": [("user", query)]} # 2. 调用 Runnable 的流式方法 try: # 使用 astream 或 astream_log async for chunk in self.agent_runnable.astream(stream_input): logger.debug(f"[MyNewAgent.stream] Received chunk: {chunk}") # 3. 解析 chunk 并 yield 符合格式的字典 # 这里的解析逻辑高度依赖于你的图和使用的流式方法 # 你需要根据实际的 chunk 内容提取 content, is_task_complete, require_user_input # --- 这是一个 **高度简化** 的示例解析 --- content_to_yield = "" is_complete = False # 你需要根据 chunk 判断任务是否真的结束 is_input_required = False # 你需要根据 chunk 判断是否需要输入 # 尝试从 chunk 中提取 'messages' 的最新内容作为 content if isinstance(chunk, dict) and isinstance(chunk.get("messages"), list) and chunk["messages"]: last_message = chunk["messages"][-1] if isinstance(last_message, tuple) and len(last_message) == 2: content_to_yield = last_message[1] elif hasattr(last_message, 'content'): content_to_yield = last_message.content if content_to_yield: # 只在有内容时 yield # 在实际应用中,你需要更复杂的逻辑判断 is_task_complete # 例如,检查 LangGraph 图是否到达了 END 节点,或者某个特定的最终节点状态 # is_complete = ??? yield { "content": content_to_yield, "is_task_complete": is_complete, # 需要正确设置 "require_user_input": is_input_required # 需要正确设置 } # --- 简化示例结束 --- # **重要**: 在循环结束后,如果任务确实完成了,需要再 yield 一个最终状态 # (除非上面的循环中最后一个 yield 的 is_task_complete 已经是 True) # 例如: # final_result = await self.agent_runnable.ainvoke(stream_input) # 可能需要再调用一次 invoke 获取最终确认状态 # final_text = ... # 解析最终文本 # yield {"content": final_text, "is_task_complete": True, "require_user_input": False} except Exception as e: logger.error(f"[MyNewAgent.stream] Error: {e}", exc_info=True) # 在流中抛出异常可能会中断 SSE 连接,或者你可以 yield 一个错误信息 yield { "content": f"处理流式请求时出错: {e}", "is_task_complete": True, # 标记任务失败并结束 "require_user_input": False } ``` ## 当前状态与限制 * 同步任务执行,包括 LangGraph Agent 调用 LLM 和工具,已成功实现并验证。 * A2A 协议的服务端和客户端基础结构已建立。 * **Agent 端的流式处理 (`CurrencyAgent.stream`) 目前是模拟的**,并未真正调用 LangGraph 的流式接口。真实的流式更新尚未实现。 * 当前 Agent 实现 (`CurrencyAgent`) 不支持需要跨请求保持状态的多轮对话澄清。 * 错误处理可以进一步增强。 * 任务存储仅在内存中 (`InMemoryTaskManager`)。 ## 未来方向 * **实现真实流式输出:** 按照上述指南,在 Agent 类中实现 `stream` 方法,调用 LangGraph 的 `astream` 或 `astream_log`,并正确解析和 `yield` A2A 所需格式的字典。 * **支持多轮对话:** 修改 `AgentState` 以包含可累加的消息历史 (例如使用 `Annotated[List[BaseMessage], operator.add]`),并调整 Agent 的 `invoke` 和 `stream` 方法以处理和利用这个历史记录。可能还需要 Agent 能返回 `input-required` 状态。 * **增强错误处理:** 为网络问题、Agent 执行错误、工具调用失败、类型验证错误等提供更详细、用户友好的错误报告。 * **持久化任务存储:** 替换 `InMemoryTaskManager`。 * **配置管理:** 外部化配置。 * **多技能支持:** 添加路由逻辑。 ================================================ FILE: examples/16_google_a2a/__init__.py ================================================ # examples/a2a/__init__.py """ A2A协议与LangGraph集成示例 本目录包含了A2A协议与LangGraph Agent集成的示例和文档。 """ ================================================ FILE: examples/16_google_a2a/agent_task_manager_test.py ================================================ # examples/a2a/agent_task_manager_test.py import os import sys import asyncio import logging from typing import TypedDict, Any, List, Optional,Tuple # 添加项目根目录到路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) # 导入环境变量 from dotenv import load_dotenv load_dotenv() # 导入A2A相关组件 from core.a2a.types import ( TaskState, TaskStatus, Task, Artifact, Message, SendTaskRequest, SendTaskResponse, SendTaskStreamingRequest, TaskSendParams, JSONRPCResponse ) from core.a2a.agent_task_manager import AgentTaskManager # 导入LangChain和LLM相关组件 from langchain_core.tools import tool from langchain_openai import ChatOpenAI from langgraph.graph import END, StateGraph from langgraph.prebuilt import create_react_agent # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 定义一个简单的工具 @tool def search(query: str) -> str: """搜索互联网获取信息""" return f"这是关于 '{query}' 的搜索结果。" @tool def calculator(expression: str) -> str: """计算数学表达式""" try: result = eval(expression) return f"计算结果: {result}" except Exception as e: return f"计算错误: {e}" # 定义一个简单的LangGraph Agent class AgentState(TypedDict): messages: List[Tuple[str, str]] session_id: Optional[str] # 保留 session_id class TestAgent: """测试用Agent""" # 支持的内容类型 SUPPORTED_CONTENT_TYPES = ["text"] def __init__(self, llm=None): if llm is None: try: llm = ChatOpenAI(model="gpt-4o-mini") except Exception as e: print(f"警告: 无法创建OpenAI LLM ({e}),使用模拟模式") from langchain.llms.fake import FakeListLLM llm = FakeListLLM(responses=["这是一个模拟的LLM响应"]) self.tools = [search, calculator] self.agent = create_react_agent(llm, self.tools) self.graph = self._build_graph() def _build_graph(self): """构建Agent的工作流图""" workflow = StateGraph(AgentState) workflow.add_node("agent", self.agent) workflow.set_entry_point("agent") workflow.add_edge("agent", END) return workflow.compile() def invoke(self, query: str, session_id: str = None) -> str: """同步调用Agent""" result = self.graph.invoke({"input": query, "session_id": session_id}) return result["output"] async def stream(self, query: str, session_id: str = None): """流式调用Agent""" # 模拟流式输出 chunks = [ "正在处理您的请求...", "正在搜索相关信息...", "找到了一些结果,正在整理...", f"关于 '{query}' 的信息如下:这是一个模拟的流式响应。" ] for i, chunk in enumerate(chunks): is_last = i == len(chunks) - 1 yield { "content": chunk, "is_task_complete": is_last, "require_user_input": False } await asyncio.sleep(0.5) # 模拟延迟 # 测试AgentTaskManager的同步任务处理 async def test_sync_task(): print("\n=== 测试同步任务处理 ===\n") # 创建Agent和AgentTaskManager agent = TestAgent() task_manager = AgentTaskManager(agent) # 创建任务请求 task_id = "test_sync_task_1" session_id = "test_session_1" content = [{"type": "text", "text": "计算 123 + 456 的结果"}] task_params = TaskSendParams( id=task_id, sessionId=session_id, message=Message(role="user", parts=content), acceptedOutputModes=["text"], historyLength=10 ) request = SendTaskRequest(id="req1", params=task_params) # 发送任务 response = await task_manager.on_send_task(request) # 打印结果 print(f"任务ID: {task_id}") print(f"响应类型: {type(response)}") if hasattr(response, "error") and response.error: print(f"错误: {response.error}") else: print("任务成功完成") # 获取任务 task = task_manager.tasks.get(task_id) if task: print(f"任务状态: {task.status.state}") if task.artifacts: for artifact in task.artifacts: for part in artifact.parts: if part.get("type") == "text": print(f"任务结果: {part.get('text')}") # 测试AgentTaskManager的流式任务处理 async def test_streaming_task(): print("\n=== 测试流式任务处理 ===\n") # 创建Agent和AgentTaskManager agent = TestAgent() task_manager = AgentTaskManager(agent) # 创建任务请求 task_id = "test_stream_task_1" session_id = "test_session_1" content = [{"type": "text", "text": "搜索关于人工智能的信息"}] task_params = TaskSendParams( id=task_id, sessionId=session_id, message=Message(role="user", parts=content), acceptedOutputModes=["text"], historyLength=10 ) request = SendTaskStreamingRequest(id="req2", params=task_params) # 发送流式任务 response_generator = await task_manager.on_send_task_subscribe(request) # 检查响应类型 if isinstance(response_generator, JSONRPCResponse): print(f"错误: {response_generator.error}") return # 处理流式响应 print("开始接收流式响应:") async for response in response_generator: if hasattr(response, "error") and response.error: print(f"流式响应错误: {response.error}") else: result = response.result if hasattr(result, "status") and result.status and result.status.message: for part in result.status.message.parts: # --- 修改开始 --- # 直接访问对象的属性 type 和 text if hasattr(part, 'type') and part.type == "text": text_content = getattr(part, 'text', '') # 安全获取 text print(f"流式更新: {text_content}") # --- 修改结束 --- if hasattr(result, "artifact") and result.artifact: for part in result.artifact.parts: # --- 修改开始 --- # 直接访问对象的属性 type 和 text if hasattr(part, 'type') and part.type == "text": text_content = getattr(part, 'text', '') # 安全获取 text print(f"流式结果: {text_content}") # --- 修改结束 --- if hasattr(result, "final") and result.final: print("流式响应结束") # 主函数 async def main(): print("=== AgentTaskManager 测试 ===\n") # 测试同步任务 await test_sync_task() # 测试流式任务 await test_streaming_task() # 运行测试 if __name__ == "__main__": asyncio.run(main()) ================================================ FILE: examples/16_google_a2a/client_example.py ================================================ # examples/a2a/client_example.py import os import sys import asyncio import json import logging # 添加 logging from typing import Dict, Any, List, Optional from uuid import uuid4 # 用于生成示例 Task ID # 添加项目根目录到路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) # 导入环境变量 from dotenv import load_dotenv load_dotenv() # 导入A2A客户端和类型 from core.a2a.client.client import A2AClient # 导入 Message 和 TextPart 以构建请求,导入响应类型以进行类型提示 from core.a2a.types import ( Part, TextPart, Message, TaskState, # 添加 TaskState SendTaskResponse, GetTaskResponse, SendTaskStreamingResponse, Task, # 添加 Task JSONRPCError # 添加 JSONRPCError ) # 配置日志 logging.basicConfig(level=logging.INFO) # 可以改为 DEBUG 获取更详细客户端日志 logger = logging.getLogger(__name__) # 示例: 使用A2A客户端连接到A2A服务器 async def run_a2a_client(): print("\n=== 运行A2A客户端示例 ===\n") # 创建A2A客户端 client = A2AClient(url="http://127.0.0.1:8000") # 指向你的服务器地址 # 发送同步任务 await send_sync_task(client) # 发送流式任务 await send_streaming_task(client) # --- 修正发送同步任务 --- async def send_sync_task(client: A2AClient): print("\n=== 发送同步任务 ===\n") query = "请计算 123 + 456 的结果" task_id = "client_sync_" + uuid4().hex # 生成一个唯一的任务 ID try: # 1. 构建 Message 对象 message = Message(role="user", parts=[TextPart(text=query)]) # 2. 构建 TaskSendParams 对应的 payload 字典 (添加 id) payload_dict = { "id": task_id, # --- 添加必需的 id 字段 --- "sessionId": "client_session_sync_1", "message": message.model_dump(), "acceptedOutputModes": ["text"], "metadata": {"skill_name": "react_query"} } logger.debug(f"Sending sync task with payload: {payload_dict}") # 3. 调用 send_task,传入 payload 字典 response: SendTaskResponse = await client.send_task(payload=payload_dict) logger.debug(f"Send task response: {response}") # 4. 处理响应 if response.error: # 类型提示帮助访问属性 error: JSONRPCError = response.error print(f"发送任务时出错: Code={error.code}, Message={error.message}") return # SendTaskResponse 的 result 是 Task 对象或 None if not response.result: print(f"发送任务成功,但响应中未包含任务详情: {response}") # 我们可以继续使用我们发送的 task_id 来查询状态 elif response.result.id != task_id: # 理论上服务器应该使用或确认客户端提供的 ID logger.warning(f"服务器返回的任务ID '{response.result.id}' 与客户端发送的ID '{task_id}' 不匹配。") task_id = response.result.id # 以服务器返回的为准(如果存在) print(f"任务已发送,ID: {task_id}") # --- 轮询等待任务完成 --- print("等待任务完成...") task_result: Optional[Task] = None # 用于存储最终的任务对象 for attempt in range(10): # 最多尝试 10 次 await asyncio.sleep(2) # 等待 2 秒 # 5. 构建 get_task 的 payload get_payload = {"id": task_id} logger.debug(f"Getting task with payload: {get_payload} (Attempt {attempt+1})") # 6. 获取任务结果 (传入 payload 字典) get_response: GetTaskResponse = await client.get_task(payload=get_payload) logger.debug(f"Get task response: {get_response}") if get_response.error: error: JSONRPCError = get_response.error print(f"获取任务时出错: Code={error.code}, Message={error.message}") return # 出错则停止轮询 if not get_response.result: print(f"获取任务成功,但未收到任务详情: {get_response}") continue # 继续轮询 task_result = get_response.result # 获取任务对象 print(f" 当前任务状态: {task_result.status.state}") # 检查任务是否完成或失败 if task_result.status.state in [TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED, TaskState.INPUT_REQUIRED]: break else: print("任务在限定时间内未完成。") return # 7. 处理最终任务结果 (使用属性访问) if task_result.status.state == TaskState.COMPLETED and task_result.artifacts: print("任务成功完成。结果:") for artifact in task_result.artifacts: if artifact.parts: for part in artifact.parts: if isinstance(part, TextPart): print(f" - {part.text}") elif task_result.status.state == TaskState.FAILED: error_msg = "未知错误" if task_result.status.message and task_result.status.message.parts: # 假设错误信息在第一个 TextPart if isinstance(task_result.status.message.parts[0], TextPart): error_msg = task_result.status.message.parts[0].text print(f"任务失败: {error_msg}") else: print(f"任务最终状态为: {task_result.status.state}") except Exception as e: logger.error(f"发送或处理同步任务时发生异常: {e}", exc_info=True) print(f"发送同步任务失败: {e}") # --- 修正发送流式任务 --- async def send_streaming_task(client: A2AClient): print("\n=== 发送流式任务 ===\n") query = "请搜索关于人工智能的最新进展" task_id = "client_stream_" + uuid4().hex # 为流式任务生成 ID try: # 1. 构建 Message 对象 message = Message(role="user", parts=[TextPart(text=query)]) # 2. 构建 TaskSendParams 对应的 payload 字典 (添加 id) payload_dict = { "id": task_id, # --- 添加必需的 id 字段 --- "sessionId": "client_session_stream_1", "message": message.model_dump(), "acceptedOutputModes": ["text"], "metadata": {"skill_name": "react_query"} } logger.debug(f"Sending streaming task with payload: {payload_dict}") print(f"任务已发送,ID: {task_id}") # 流式任务 ID 在发送时就已知 # 3. 调用 send_task_streaming (不再使用 await) # 它返回一个异步生成器 event_stream_generator = client.send_task_streaming(payload=payload_dict) # 4. 使用 async for 处理流式事件 print("开始接收流式响应:") async for event_response in event_stream_generator: # 正确迭代异步生成器 logger.debug(f"Received stream event: {event_response}") # 检查整个响应是否有错误 if event_response.error: error: JSONRPCError = event_response.error print(f"流式传输中出错: Code={error.code}, Message={error.message}") continue # 或 break # 获取事件具体内容 event = event_response.result if not event: logger.warning("Received stream response with empty result.") continue # 处理状态更新事件中的消息部分 if hasattr(event, "status") and event.status and event.status.message: if event.status.message.parts: for part in event.status.message.parts: if isinstance(part, TextPart): print(f" 流式更新: {part.text}") # 处理制品更新事件 if hasattr(event, "artifact") and event.artifact: print(" 收到 Artifact:") if event.artifact.parts: for part in event.artifact.parts: if isinstance(part, TextPart): print(f" 流式结果 (TextPart): {part.text}") # 检查流结束标志 if hasattr(event, "final") and event.final: print("流式响应结束标志收到。") print("流式任务处理完成。") except Exception as e: logger.error(f"发送或处理流式任务时发生异常: {e}", exc_info=True) print(f"发送流式任务失败: {e}") # 主函数 if __name__ == "__main__": # 使用 asyncio.run 运行顶层异步函数 asyncio.run(run_a2a_client()) ================================================ FILE: examples/16_google_a2a/currency_agent_test.py ================================================ # examples/a2a/currency_agent_test.py import os import sys import asyncio import json import logging from typing import Dict, Any, List, Optional from uuid import uuid4 # Import uuid # 添加项目根目录到路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) # 导入环境变量 from dotenv import load_dotenv load_dotenv() # 导入A2A客户端和所需类型 from core.a2a.client.client import A2AClient # 导入 Message, TextPart, TaskState, SendTaskResponse, GetTaskResponse, Task, JSONRPCError from core.a2a.types import ( Part, TextPart, Message, TaskState, SendTaskResponse, GetTaskResponse, Task, JSONRPCError, SendTaskStreamingResponse # 导入流式响应类型 ) # 配置日志 logging.basicConfig(level=logging.INFO) # 可以改为 DEBUG 获取详细日志 logger = logging.getLogger(__name__) # 测试场景1: 同步请求 - 货币转换查询 (修正) async def test_sync_currency_conversion(client: A2AClient): print("\n=== 测试场景1: 同步请求 - Agent 调用 (计算器) ===") # query = "How much is the exchange rate for 1 USD to INR?" # 这个查询可能需要搜索工具 query = "计算 58 * 34 的结果" # 使用计算器工具确保能得到结果 task_id = "test_sync_" + uuid4().hex # 客户端生成任务ID try: # 1. 构建 Message 对象 message = Message(role="user", parts=[TextPart(text=query)]) # 2. 构建 TaskSendParams 对应的 payload 字典 payload_dict = { "id": task_id, "sessionId": "test_session_sync_1", "message": message.model_dump(), # 序列化为字典 "acceptedOutputModes": ["text"], "metadata": {"skill_name": "react_query"} # 与 AgentCard 中的 skill name/id 对应 } logger.debug(f"Sending sync task with payload: {payload_dict}") # 3. 调用 send_task,传入 payload 字典 response: SendTaskResponse = await client.send_task(payload=payload_dict) logger.debug(f"Send task response: {response}") # 4. 处理响应 if response.error: error: JSONRPCError = response.error print(f"发送任务时出错: Code={error.code}, Message={error.message}") return None if not response.result: print(f"发送任务成功,但未收到任务详情: {response}") # 继续使用我们发送的 task_id 查询 elif response.result.id != task_id: logger.warning(f"服务器返回的任务ID '{response.result.id}' 与客户端发送的ID '{task_id}' 不匹配。") task_id = response.result.id # 以服务器返回的为准 print(f"任务已发送,ID: {task_id}") # 5. 轮询等待任务完成 print("等待任务完成...") task_result: Optional[Task] = None for attempt in range(10): await asyncio.sleep(2) get_payload = {"id": task_id} logger.debug(f"Getting task with payload: {get_payload} (Attempt {attempt+1})") get_response: GetTaskResponse = await client.get_task(payload=get_payload) logger.debug(f"Get task response: {get_response}") if get_response.error: error: JSONRPCError = get_response.error print(f"获取任务时出错: Code={error.code}, Message={error.message}") return None if not get_response.result: print(f"获取任务成功,但未收到任务详情: {get_response}") continue task_result = get_response.result print(f" 当前任务状态: {task_result.status.state.value}") # 使用 .value 获取枚举值 if task_result.status.state in [TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED]: break else: print("任务在限定时间内未完成。") return None # 6. 处理最终任务结果 (使用属性访问) if task_result.status.state == TaskState.COMPLETED and task_result.artifacts: print("任务成功完成。结果:") for artifact in task_result.artifacts: if artifact.parts: for part in artifact.parts: if isinstance(part, TextPart): # 检查类型 print(f" - {part.text}") # 访问属性 elif task_result.status.state == TaskState.FAILED: error_msg = "未知错误" if task_result.status.message and task_result.status.message.parts: if isinstance(task_result.status.message.parts[0], TextPart): error_msg = task_result.status.message.parts[0].text print(f"任务失败: {error_msg}") else: print(f"任务最终状态为: {task_result.status.state.value}") return task_result except Exception as e: logger.error(f"处理同步任务时发生异常: {e}", exc_info=True) print(f"发送同步任务失败: {e}") return None # 测试场景2: 多轮对话 - 不完整信息 (修正,但有局限性) async def test_multi_turn_conversation(client: A2AClient): print("\n=== 测试场景2: 多轮对话 (Agent 可能不支持) ===") print("注意:当前服务器端的 Agent 实现可能不支持真正的多轮状态保持。") # --- 第一轮对话 --- session_id = "test_session_multi_" + uuid4().hex # 为多轮对话创建唯一 session ID query1 = "100美元等于多少" # 故意不指定目标货币 task_id_1 = "test_multi_1_" + uuid4().hex try: print(f"\n第一轮对话 (Session: {session_id}): 发送 '{query1}'") # 1a. 构建 Message 和 Payload message1 = Message(role="user", parts=[TextPart(text=query1)]) payload_dict1 = { "id": task_id_1, "sessionId": session_id, # 传递 session ID "message": message1.model_dump(), "acceptedOutputModes": ["text"], "metadata": {"skill_name": "react_query"} } logger.debug(f"Sending multi-turn task 1 with payload: {payload_dict1}") # 1b. 发送任务 response1: SendTaskResponse = await client.send_task(payload=payload_dict1) logger.debug(f"Send task 1 response: {response1}") if response1.error: error: JSONRPCError = response1.error print(f"发送第一轮任务时出错: Code={error.code}, Message={error.message}") return None if response1.result: task_id_1 = response1.result.id # Use server-confirmed ID print(f"第一轮任务已发送,ID: {task_id_1}") # 1c. 轮询获取结果 print("等待第一轮任务响应...") task1_result: Optional[Task] = None for attempt in range(5): # 减少轮询次数 await asyncio.sleep(2) get_payload1 = {"id": task_id_1} get_response1: GetTaskResponse = await client.get_task(payload=get_payload1) if get_response1.result: task1_result = get_response1.result print(f" 当前任务状态: {task1_result.status.state.value}") if task1_result.status.state != TaskState.WORKING: break else: print("第一轮任务在限定时间内未完成或未开始。") return None # 1d. 检查 Agent 是否要求输入 (当前 Agent 可能直接完成或失败) if task1_result.status.state == TaskState.INPUT_REQUIRED and task1_result.status.message: print("Agent 要求更多信息:") for part in task1_result.status.message.parts: if isinstance(part, TextPart): print(f" Agent: {part.text}") # --- 第二轮对话 --- query2 = "日元" # 提供目标货币 task_id_2 = "test_multi_2_" + uuid4().hex print(f"\n第二轮对话 (Session: {session_id}): 发送 '{query2}'") # 2a. 构建 Message 和 Payload message2 = Message(role="user", parts=[TextPart(text=query2)]) payload_dict2 = { "id": task_id_2, "sessionId": session_id, # 必须使用相同的 session ID "message": message2.model_dump(), "acceptedOutputModes": ["text"], "metadata": {"skill_name": "react_query"} } logger.debug(f"Sending multi-turn task 2 with payload: {payload_dict2}") # 2b. 发送任务 response2: SendTaskResponse = await client.send_task(payload=payload_dict2) logger.debug(f"Send task 2 response: {response2}") if response2.error: error: JSONRPCError = response2.error print(f"发送第二轮任务时出错: Code={error.code}, Message={error.message}") return None if response2.result: task_id_2 = response2.result.id print(f"第二轮任务已发送,ID: {task_id_2}") # 2c. 轮询获取最终结果 print("等待第二轮任务完成...") task2_result: Optional[Task] = None for attempt in range(10): await asyncio.sleep(2) get_payload2 = {"id": task_id_2} get_response2: GetTaskResponse = await client.get_task(payload=get_payload2) if get_response2.result: task2_result = get_response2.result print(f" 当前任务状态: {task2_result.status.state.value}") if task2_result.status.state != TaskState.WORKING: break else: print("第二轮任务在限定时间内未完成。") return None # 2d. 处理最终结果 if task2_result.status.state == TaskState.COMPLETED and task2_result.artifacts: print("多轮任务成功完成。最终结果:") for artifact in task2_result.artifacts: if artifact.parts: for part in artifact.parts: if isinstance(part, TextPart): print(f" - {part.text}") else: print(f"第二轮任务最终状态为: {task2_result.status.state.value}") return task2_result elif task1_result.status.state == TaskState.COMPLETED: print("Agent 在第一轮就已完成任务 (可能直接使用了默认货币或无法处理):") if task1_result.artifacts: for artifact in task1_result.artifacts: if artifact.parts: for part in artifact.parts: if isinstance(part, TextPart): print(f" - {part.text}") return task1_result else: print(f"第一轮任务未要求输入,最终状态为: {task1_result.status.state.value}") return task1_result except Exception as e: logger.error(f"处理多轮对话时发生异常: {e}", exc_info=True) print(f"多轮对话测试失败: {e}") return None # 测试场景3: 流式响应 (修正) async def test_streaming_response(client: A2AClient): print("\n=== 测试场景3: 流式响应 (Agent 端为模拟) ===") # query = "What are the current exchange rates between USD, EUR, and JPY?" query = "用中文写一首关于春天的短诗" # 更适合流式输出的查询 task_id = "test_stream_" + uuid4().hex try: # 1. 构建 Message 和 Payload message = Message(role="user", parts=[TextPart(text=query)]) payload_dict = { "id": task_id, "sessionId": "test_session_stream_1", "message": message.model_dump(), "acceptedOutputModes": ["text"], "metadata": {"skill_name": "react_query"} } logger.debug(f"Sending streaming task with payload: {payload_dict}") print(f"任务已发送,ID: {task_id}") # 2. 调用 send_task_streaming (不使用 await) 并使用 async for 迭代 event_stream_generator = client.send_task_streaming(payload=payload_dict) print("开始接收流式响应:") async for event_response in event_stream_generator: logger.debug(f"Received stream event: {event_response}") if event_response.error: error: JSONRPCError = event_response.error print(f"流式传输中出错: Code={error.code}, Message={error.message}") continue event = event_response.result if not event: logger.warning("Received stream response with empty result.") continue # 处理状态更新事件 if hasattr(event, "status") and event.status and event.status.message: if event.status.message.parts: for part in event.status.message.parts: if isinstance(part, TextPart): print(f" 流式更新: {part.text}") # 处理 Artifact 事件 if hasattr(event, "artifact") and event.artifact: # print(" 收到 Artifact:") # 打印多次可能比较干扰,注释掉 if event.artifact.parts: for part in event.artifact.parts: if isinstance(part, TextPart): print(f" 流式结果: {part.text}") # 检查结束标志 if hasattr(event, "final") and event.final: print("流式响应结束标志收到。") print("流式任务处理完成。") return True except Exception as e: logger.error(f"处理流式任务时发生异常: {e}", exc_info=True) print(f"发送流式任务失败: {e}") return False # 主函数 (修正) async def main(): print("=== LangGraph Agent A2A协议测试 ===\n") # print("此测试脚本将测试LangGraph Currency Agent通过A2A协议的三种交互场景:") # print("1. 同步请求 - Agent 调用 (计算器)") # print("2. 多轮对话 - 处理不完整信息 (Agent 可能不支持)") # print("3. 流式响应 - 实时状态更新 (Agent 端为模拟)") # 创建A2A客户端 client = A2AClient(url="http://127.0.0.1:8000") # --- 移除了 get_agent_info 调用 --- # (如果需要验证服务器是否在线,可以尝试发送一个简单的任务) print("尝试连接到服务器并运行测试...") print("-" * 30) # 执行测试场景 await test_sync_currency_conversion(client) print("-" * 30) # 注意:多轮对话测试依赖于 Agent 对话状态的处理能力 await test_multi_turn_conversation(client) print("-" * 30) await test_streaming_response(client) print("-" * 30) print("所有测试场景执行完毕。") # 运行测试 if __name__ == "__main__": asyncio.run(main()) ================================================ FILE: examples/16_google_a2a/currency_agent_test_README.md ================================================ # LangGraph Agent A2A协议交互测试 ## 概述 本测试脚本 (`examples/a2a/currency_agent_test.py`) 旨在通过具体的交互场景,测试和演示如何使用 A2A 客户端与先前通过 `langgraph_integration.py` 启动的 LangGraph Agent 服务进行通信。它覆盖了同步请求(涉及工具调用)、尝试进行多轮对话以及接收(模拟的)流式响应等场景。 ## 测试场景说明 此脚本包含以下三个主要测试场景: 1. **场景 1: 同步请求 - Agent 调用 (涉及工具)** * **目的:** 测试发送一个需要 Agent 调用内部工具(如此示例中的计算器)才能完成的请求。 * **流程:** 客户端发送一个计算任务 -> 服务器端 Agent (LangGraph ReAct) 解析任务 -> 调用 `calculator` 工具 -> 获取结果 -> LLM 整合答案 -> 服务器返回最终结果 -> 客户端轮询获取并显示结果。 * **预期:** 客户端能成功获取到 Agent 计算后的准确结果。 2. **场景 2: 多轮对话尝试 (Agent 当前实现有限)** * **目的:** 测试客户端在需要多步交互时的请求发送方式(使用 `sessionId`),并观察当前 Agent 的响应行为。 * **流程:** * 第一轮:客户端发送一个信息不明确的查询(例如 "100美元等于多少",缺少目标货币),并附带 `sessionId`。 * 客户端轮询获取结果。 * **注意:** *根据我们当前的 Agent 实现 (`CurrencyAgent` 使用 `create_react_agent` 且 `invoke` 未特殊处理对话历史),Agent 很可能不会返回 `input-required` 状态来请求更多信息,而是会直接尝试处理或告知无法处理,然后将任务标记为 `completed` 或 `failed`。* * (理想流程中,如果 Agent 返回 `input-required`,客户端会发送第二轮请求补充信息,使用相同的 `sessionId`。) * **预期:** 客户端能够正确发送带 `sessionId` 的请求,并能处理 Agent 的最终响应(即使它没有按预期进入多轮澄清状态)。此测试主要验证客户端的多轮请求发送能力和对 Agent 当前行为的观察。 3. **场景 3: 流式响应 (Agent 端模拟)** * **目的:** 测试客户端接收 A2A 流式响应 (Server-Sent Events) 的能力。 * **流程:** 客户端发送一个适合流式输出的查询 -> 服务器端的 `AgentTaskManager` 调用 `CurrencyAgent.stream` 方法 -> **注意:** *`CurrencyAgent.stream` 当前是一个模拟实现,它会发送预设的文本块,而不是真正调用 LangGraph 的流式接口。* -> 客户端接收并打印这些模拟的流式事件。 * **预期:** 客户端能够成功连接 SSE 端点,并接收、打印服务器发送的(模拟)流式事件。 ## 运行测试 ### 前提条件 * Python (推荐 3.10 或更高版本) * 已根据项目 `requirements.txt` 安装所有必需的 Python 依赖库。 * 在项目根目录下的 `.env` 文件中配置了有效的 `OPENAI_API_KEY` (或其他所需的 LLM API 密钥)。 ### 步骤 1. **启动 A2A 服务器:** * 确保你位于项目的根目录。 * 在终端中运行 (如果尚未运行): ```bash python -m examples.a2a.langgraph_integration ``` * 服务器应成功启动并监听在 `http://127.0.0.1:8000`。 2. **运行本测试脚本:** * 打开 **另一个** 终端。 * 确保你位于项目的根目录并激活了相同的虚拟环境。 * 运行测试脚本: ```bash python -m examples.a2a.currency_agent_test ``` ## 测试输出示例 (基于实际运行结果) 以下是运行此测试脚本时预期的输出格式,反映了当前 Agent 的实际行为: ### 同步请求示例 (计算器调用) ``` === 测试场景1: 同步请求 - Agent 调用 (计算器) === 任务已发送,ID: test_sync_... 等待任务完成... 当前任务状态: completed 任务成功完成。结果: - 58 * 34 的结果是 1972。 ``` ### 多轮对话示例 (Agent 第一轮即完成) ``` === 测试场景2: 多轮对话 (Agent 可能不支持) === 注意:当前服务器端的 Agent 实现可能不支持真正的多轮状态保持。 第一轮对话 (Session: test_session_multi_...): 发送 '100美元等于多少' 第一轮任务已发送,ID: test_multi_1_... 等待第一轮任务响应... 当前任务状态: completed Agent 在第一轮就已完成任务 (可能直接使用了默认货币或无法处理): - 目前无法提供100美元等于多少人民币的具体信息。你可以查阅最新的汇率数据或使用汇率转换工具来获取准确的结果。 ``` *(注意:Agent 的具体回复可能因 LLM 的不同调用而略有差异)* ### 流式响应示例 (Agent 端模拟) ``` === 测试场景3: 流式响应 (Agent 端为模拟) === 任务已发送,ID: test_stream_... 开始接收流式响应: 流式更新: 正在处理您的请求... 流式结果: 关于 '用中文写一首关于春天的短诗' 的信息如下:这是一个模拟的回应,因为真实流未实现。 流式响应结束标志收到。 流式任务处理完成。 ``` ## 注意事项 * 确保在运行测试前已正确设置 `.env` 文件中的环境变量。 * 测试脚本默认连接 `http://127.0.0.1:8000`。如果服务器地址或端口不同,请修改脚本中的 `A2AClient` 初始化 URL。 * 如果连接失败或测试出错,请优先检查 A2A 服务器是否已正确启动且正在运行,并查看服务器端的日志输出。 --- ## 两个客户端示例的命名与区别 你项目中有两个客户端示例文件,我们可以为它们命名并说明其侧重点: 1. **`examples/a2a/client_example.py` -> "基础客户端示例 (Basic Client Example)"** * **目的:** 这个脚本更侧重于**基础演示**,展示了调用 `A2AClient` 库中几个核心方法(`send_task`, `get_task`, `send_task_streaming`)的最基本用法。 * **特点:** 代码相对简洁,逻辑直接,主要目的是让使用者快速了解如何发起不同类型的 A2A 请求并处理最简单的成功响应。它包含了一个简单的轮询逻辑。 2. **`examples/a2a/currency_agent_test.py` -> "场景化测试客户端 (Scenario-based Test Client)"** * **目的:** 这个脚本的定位是**功能测试和场景演示**。它针对我们集成的 LangGraph Agent 设计了几个具体的交互场景(同步工具调用、尝试多轮对话、流式接收),以验证端到端的流程和观察 Agent 在特定情况下的行为。 * **特点:** 结构更清晰地划分为不同的测试函数,包含了更具体的业务逻辑查询(尽管有些是模拟的或揭示了 Agent 的局限性),并且其输出更侧重于展示每个测试场景的结果。它也使用了轮询,并尝试了多轮交互的状态传递(通过 `sessionId`)。 **主要区别总结:** | 特性 | `client_example.py` (基础示例) | `currency_agent_test.py` (场景化测试) | | :----------- | :--------------------------------------------- | :---------------------------------------------------- | | **目标** | 演示 Client API 基本用法 | 测试/演示特定交互场景 | | **结构** | 简单的顺序调用 | 按测试场景划分函数 | | **复杂度** | 较低,核心 API 调用 | 略高,包含场景逻辑(如尝试多轮) | | **查询内容** | 通用示例(计算、搜索) | 针对场景设计(计算、不完整查询、适合流式的查询) | | **侧重点** | 如何调用 API | Agent 在特定场景下的行为和端到端流程验证 | ================================================ FILE: examples/16_google_a2a/langgraph_integration.py ================================================ # examples/a2a/langgraph_integration.py import os import sys import asyncio # asyncio 仍然可能被依赖库使用,保留导入 import logging # 确保导入了 List, Tuple, Optional, TypedDict from typing import Dict, Any, List, Optional, AsyncIterable, Union, TypedDict, Tuple # 添加项目根目录到路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) # 导入环境变量 from dotenv import load_dotenv load_dotenv() # 导入A2A相关组件 # 从你的项目结构导入 from core.a2a.types import ( AgentCard, AgentCapabilities, AgentSkill, Task, TaskState, TaskStatus, Artifact, Message, TextPart, # TextPart 可能不再直接使用 JSONRPCResponse, InvalidParamsError, InternalError, SendTaskRequest, SendTaskResponse, TaskSendParams ) from core.a2a.server.server import A2AServer from core.a2a.agent_task_manager import AgentTaskManager # 导入LangChain和LLM相关组件 from langchain_openai import ChatOpenAI from langchain_core.tools import tool # StateGraph 和 END 不再直接使用,但保留导入 from langgraph.graph import END, StateGraph from langgraph.prebuilt import create_react_agent # 配置日志 logging.basicConfig(level=logging.INFO) # 可以改为 DEBUG 获取更详细日志 logger = logging.getLogger(__name__) # --- 定义工具 (保持不变) --- @tool def search(query: str) -> str: """搜索互联网获取信息""" # 实际应用中应调用真实搜索引擎 API logger.info(f"Tool 'search' called with query: {query}") return f"这是关于 '{query}' 的模拟搜索结果。" @tool def calculator(expression: str) -> str: """计算数学表达式""" logger.info(f"Tool 'calculator' called with expression: {expression}") try: # 注意:生产环境中使用 eval 非常危险,这里仅作示例 # 限制 eval 的能力,只允许简单的数学运算 allowed_names = { k: v for k, v in __import__("math").__dict__.items() if not k.startswith("_") } allowed_names.update({"abs": abs, "int": int, "float": float}) # 添加常用函数 code = compile(expression, "", "eval") for name in code.co_names: if name not in allowed_names: raise NameError(f"Use of name '{name}' not allowed") result = eval(code, {"__builtins__": {}}, allowed_names) return f"计算结果: {result}" except NameError as e: logger.error(f"Calculation error (NameError): {e} in expression '{expression}'") return f"计算错误: 不允许的名称 '{e.name}'" except Exception as e: logger.error(f"Calculation error: {e} in expression '{expression}'") return f"计算错误: {e}" # --- 修正 AgentState 定义 --- class AgentState(TypedDict): # 使用 'messages' 字段来传递对话内容 # 格式为 (角色, 内容) 的元组列表 messages: List[Tuple[str, str]] # session_id 可以保留,如果Agent内部逻辑需要的话 (create_react_agent 通常不需要) # session_id: Optional[str] # 注意: ReAct Agent 运行时可能会在状态中添加其他键 (例如 intermediate_steps) # --- 修正 CurrencyAgent 类 --- class CurrencyAgent: """一个简单的货币转换和信息查询Agent (已修正)""" # 支持的内容类型 (保持不变) SUPPORTED_CONTENT_TYPES = ["text"] def __init__(self, llm): """初始化Agent,直接使用 create_react_agent 创建的 Runnable""" self.tools = [search, calculator] # create_react_agent 返回一个可直接调用的 Runnable (图) self.agent_runnable = create_react_agent(llm, self.tools) logger.info("CurrencyAgent initialized with ReAct runnable.") def invoke(self, query: str, session_id: str = None) -> str: """同步调用Agent Runnable""" # (session_id 在此实现中未传递给 agent_runnable,如果需要可以添加) logger.debug(f"[CurrencyAgent.invoke] Received query: '{query}', session_id: '{session_id}'") if not query: logger.error("[CurrencyAgent.invoke] Query is empty!") return "错误:输入查询为空。" # 准备 ReAct Agent Runnable 所需的输入 invoke_input = {"messages": [("user", query)]} logger.debug(f"[CurrencyAgent.invoke] Invoking agent runnable with input: {invoke_input}") try: # 直接调用 create_react_agent 返回的 runnable result = self.agent_runnable.invoke(invoke_input) logger.debug(f"[CurrencyAgent.invoke] Agent runnable result: {result}") # 提取最终响应 final_output = "错误:未能从Agent获取有效响应。" if isinstance(result, dict) and isinstance(result.get("messages"), list) and result["messages"]: last_message = result["messages"][-1] if isinstance(last_message, tuple) and len(last_message) == 2: final_output = last_message[1] elif hasattr(last_message, 'content'): final_output = last_message.content else: logger.warning(f"[CurrencyAgent.invoke] Last message format unexpected: {last_message!r}") else: logger.warning(f"[CurrencyAgent.invoke] Could not find 'messages' list in result: {result}") logger.debug(f"[CurrencyAgent.invoke] Returning output: {final_output}") return str(final_output) except Exception as e: logger.error(f"[CurrencyAgent.invoke] Exception during agent invocation: {e}", exc_info=True) raise async def ainvoke(self, inputs: dict) -> dict: """异步调用Agent Runnable (输入格式也需调整)""" # TODO: 确认这里的输入格式是否也需要转换为 {"messages": [...]} logger.debug(f"[CurrencyAgent.ainvoke] Invoking agent runnable async with input: {inputs}") # 假设输入字典已经包含了正确的 "messages" 键 return await self.agent_runnable.ainvoke(inputs) async def stream(self, query: str, session_id: str = None): """流式调用Agent (当前为模拟)""" # TODO: 实现真实的流式调用 logger.warning("[CurrencyAgent.stream] Stream method is currently mocked.") # --- 模拟实现 --- yield { "content": "正在处理您的请求...", "is_task_complete": False, "require_user_input": False } await asyncio.sleep(0.5) final_simulated_answer = f"关于 '{query}' 的信息如下:这是一个模拟的回应,因为真实流未实现。" yield { "content": final_simulated_answer, "is_task_complete": True, "require_user_input": False } # --- 模拟结束 --- # --- A2A 服务器设置 (修正函数定义和 AgentCard) --- # 将函数改为同步定义 (def 而不是 async def) def setup_a2a_server(): """设置并返回 A2A 服务器实例 (同步函数)""" print("\n=== 配置 LangGraph A2A 服务器 ===\n") # 创建LLM try: llm = ChatOpenAI(model="gpt-4o-mini") logger.info("Using OpenAI LLM: gpt-4o-mini") except Exception as e: print(f"警告: 无法创建OpenAI LLM ({e}),将使用模拟模式") from langchain.llms.fake import FakeListLLM llm = FakeListLLM(responses=["这是一个模拟的LLM响应"]) logger.info("Using FakeListLLM (simulation mode)") # 创建 Agent 实例 agent = CurrencyAgent(llm) # 创建 Agent 卡片 (添加缺失字段) agent_card = AgentCard( name="LangGraph ReAct Agent", description="一个使用LangGraph ReAct处理查询并调用工具的Agent", url="http://127.0.0.1:8000/agent", # Agent 的访问 URL (示例) version="0.1.0", # Agent 的版本号 capabilities=AgentCapabilities( # 设置 Agent 的能力 streaming=False, # 当前 stream 是模拟的,设为 False pushNotifications=False # 假设不支持推送 ), skills=[ # skills 列表在 AgentCard 顶层 AgentSkill( id="react_query_skill", # 技能的唯一 ID name="react_query", description="处理自然语言查询,可使用搜索和计算器工具", inputModes=["text"], outputModes=["text"] ) ] # 其他可选字段可以按需添加 ) # 创建 AgentTaskManager task_manager = AgentTaskManager(agent) # 创建A2A服务器实例 (不在此处设置 host/port) server = A2AServer(agent_card=agent_card, task_manager=task_manager) print("A2A服务器实例已创建。") return server # 返回实例 # --- 主函数入口 (修正启动逻辑) --- if __name__ == "__main__": try: # 调用同步函数来设置服务器 server_instance = setup_a2a_server() # 定义 HOST 和 PORT HOST = "127.0.0.1" PORT = 8000 print(f"准备启动A2A服务器,监听地址 http://{HOST}:{PORT}") # 在调用 start 前设置 host 和 port # (或者修改 A2AServer 的 __init__ 让其接受 host/port) server_instance.host = HOST server_instance.port = PORT # 启动服务器 (调用同步的 start 方法) server_instance.start() except KeyboardInterrupt: print("\n服务器已手动停止。") except Exception as e: # 捕获设置或启动过程中的其他异常 logger.error(f"启动服务器时发生未处理的异常: {e}", exc_info=True) ================================================ FILE: examples/TODO_computer_tool_demo.py ================================================ from typing import Annotated, Literal from langchain_core.messages import HumanMessage, AIMessage from langchain.agents import AgentExecutor, create_openai_tools_agent from langchain_core.runnables import Runnable, RunnablePassthrough from langchain_core.runnables.graph import StateGraph, END, START from langchain.tools.render import render_text_description from langchain_openai import ChatOpenAI from langchain_core.tools import Tool from langchain.agents.format_scratchpad import format_to_openai_tool_messages from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables.config import RunnableConfig from langgraph.graph import END, StateGraph from langgraph.prebuilt import ToolNode from langgraph.graph.message import Command, InjectedState # Import our custom computer tool # TODO: MarinaBox - Import our custom computer tool from marinabox import mb_start_computer, mb_stop_computer, mb_use_computer_tool # Set up model with tools model = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) tools = [mt_use_computer_tool()] model_with_tools = model.bind_tools(tools) # Define workflow nodes def should_continue(state: Annotated[dict, InjectedState()]): messages = state["messages"] if len(messages) > 0: last_message = messages[-1] if last_message.tool_calls: return Command(goto="tool_node") else: return Command(goto="stop_computer") return Command(goto="stop_computer") def call_model(state: Annotated[dict, InjectedState()]): input_message = input("Enter your message: ") if input_message != "stop_computer": messages = [HumanMessage(content=input_message)] response = model_with_tools.invoke(messages) return {"messages": [response], "session_id": state.get("session_id")} else: return {"messages": [], "session_id": state.get("session_id")} # Set up workflow workflow = StateGraph(dict) workflow.add_node("start_computer", mt_start_computer) workflow.add_node("agent", call_model) workflow.add_node("tool_node", ToolNode(tools=tools)) workflow.add_node("stop_computer", mt_stop_computer) workflow.add_node("should_continue", should_continue) # Define workflow edges workflow.add_edge(START, "start_computer") workflow.add_edge("start_computer", "agent") workflow.add_edge("tool_node", "agent") workflow.add_edge("agent", "should_continue") workflow.add_edge("stop_computer", END) # Compile and run workflow app = workflow.compile() if __name__ == "__main__": app.invoke({"messages": ""}) ================================================ FILE: examples/__init__.py ================================================ ================================================ FILE: examples/state_based_supervisor_examples/01_simple.py ================================================ import asyncio import json import os import re import time from datetime import datetime from typing import Literal, List, Dict, Any, Optional, cast # --- LangChain / LangGraph --- try: # 使用 langchain_openai (或你选择的模型提供商) from langchain_openai import ChatOpenAI except ImportError: ChatOpenAI = None print("Warning: langchain_openai not installed.") # 核心消息类型 from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, ToolMessage # LangChain 工具相关 from langchain_core.tools import tool, BaseTool # --- OpenAI 错误处理 --- try: from openai import RateLimitError except ImportError: class RateLimitError(Exception): pass # --- 内部模块导入 (请确保路径正确) --- try: # 假设这些是你当前的路径 from core.agents.sb_supervisor_agent import SupervisorAgent from core.agents.supervisor.state_schema import PlanningAgentState from core.agents.base.react_agent import ReactAgent # 导入 ReactAgent # 导入 StreamUpdate (如果需要在最终状态中检查它,但这里主要关注消息) # from core.agents.supervisor.schemas import StreamUpdate except ImportError as e: print(f"Error importing agent components: {e}") print("Please ensure paths like 'core.agents.sb_supervisor_agent' are correct.") import traceback # --- 定义 Web Search 工具 --- # 使用 @tool 装饰器明确这是一个工具 @tool def web_search(query: str) -> str: """Search the web for current information about a given query. Use this for recent events, data, or facts.""" print(f"--- TOOL CALLED: web_search(query='{query}') ---") # 添加日志确认工具被调用 # Mocked data - 实际使用时会调用 Tavily 或其他搜索引擎 if "apple" in query.lower() and "headcount" in query.lower() and "2024" in query: return ( "According to recent (mocked) reports for 2024, Apple's headcount is approximately 164,000 employees globally." ) elif "joke" in query.lower(): # 这个工具不适合讲笑话 return "I am a web search tool, I cannot tell jokes." else: return f"Mock search results for query: '{query}'. Found relevant information on various websites." # --- 主执行逻辑 --- async def main(): # --- 初始化 LLM (确保 API Key 在环境中) --- try: model_name = os.getenv("LLM_MODEL_NAME", "gpt-4o") print(f"Using LLM: {model_name}") if not ChatOpenAI: raise ImportError("ChatOpenAI not available.") # 使用温度稍高的模型可能有助于 ReAct 思考和调用工具 model = ChatOpenAI(model=model_name, temperature=0.2) except Exception as e: print(f"Failed to initialize ChatOpenAI model: {e}") return # --- 实例化 Agents --- try: # research_agent 现在有了一个明确定义的 web_search 工具 research_agent = ReactAgent( name="research_expert", tools=[web_search], # <--- 传入工具列表 model=model, # 添加明确的 Prompt 引导工具使用 prompt=( "You are a research expert. Use available tools to find information. " "You have access to 'web_search'. Use it for questions about current data, facts, or events." ), max_context_tokens=8000 ) all_agents = [research_agent] # --- 实例化 Supervisor --- supervisor = SupervisorAgent( agents=all_agents, model=model, # Supervisor 使用相同的模型 state_schema=PlanningAgentState, include_agent_name="inline" # checkpointer=... ) except Exception as e: print(f"Failed to initialize agents or supervisor: {e}") traceback.print_exc() return # --- 准备初始请求 --- # 用户请求包含两个意图:讲笑话 + 查信息 user_request = ( "Hi! I'd like to start with a short joke to lighten the mood, " "then please check Apple's headcount in 2024. Summarize both." ) print(f"Initial Request: '{user_request}'") # --- 准备初始状态 --- initial_graph_state: PlanningAgentState = { "messages": [HumanMessage(content=user_request)], # 使用 HumanMessage "plan": None, "error": None } # --- 执行 Supervisor (使用 ainvoke) --- final_state: Optional[Dict[str, Any]] = None error_occurred: Optional[Exception] = None config = {"recursion_limit": 100} try: print("\n--- Invoking Supervisor Agent (ainvoke) ---") final_state = await supervisor.ainvoke(initial_graph_state, config=config) print("\n--- Supervisor Agent Invocation Complete ---") # --- 错误处理 --- except RateLimitError as e: error_occurred = e; print(f"\n!!! OpenAI Quota Error: {e}") except Exception as e: error_occurred = e; print(f"\n!!! Error during graph execution: {e}"); traceback.print_exc() # --- 处理并打印最终结果 --- if error_occurred: print("\n--- Graph Execution INTERRUPTED or FAILED ---") else: print("\n--- Graph Execution Finished ---") if not final_state: print("Error: No final state available.") return print("\n--- FINAL STATE ---") # 打印错误(如果在状态中记录了) if final_state.get("error"): print(f"\nERROR RECORDED IN STATE: {final_state['error']}") # 打印计划 final_plan = final_state.get('plan') if final_plan: print("\nFinal Plan State:", json.dumps(final_plan, indent=2, default=str)) else: print("\nFinal Plan State: Not available.") # 打印消息历史 final_messages = final_state.get("messages", []) if final_messages: print("\nFinal Message History (Last 10):") for m in final_messages[-10:]: try: if hasattr(m, 'pretty_print'): m.pretty_print() else: print(json.dumps(m, indent=2, default=str)) print("-" * 10) except Exception as print_err: print(f"Error printing final message: {print_err}") else: print("\nFinal Message History: Empty.") print("\n--- END OF TEST ---") if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: print("\nExecution interrupted by user.") except Exception as e: print(f"\nAn unexpected top-level error occurred: {e}") traceback.print_exc() ================================================ FILE: examples/state_based_supervisor_examples/02_tavily.py ================================================ # main.py (用于测试 State-Based Supervisor 和 ReactAgent) import asyncio import json import os from typing import Dict, Any, Optional from langchain_community.tools import TavilySearchResults # --- LangChain / LangGraph --- # 假设模型直接在此初始化或从别处导入 from dotenv import load_dotenv load_dotenv() # 自动加载 .env 文件 try: from langchain_openai import ChatOpenAI # 或者你使用的其他模型 except ImportError: ChatOpenAI = None print("Warning: langchain_openai not installed.") # 核心消息类型 from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, ToolMessage # --- OpenAI 错误处理 --- try: from openai import RateLimitError except ImportError: class RateLimitError(Exception): pass # --- 内部模块导入 (请确保路径正确) --- try: # 从你提供的 core.agents... 路径导入 from core.agents.sb_supervisor_agent import SupervisorAgent # 你的 Supervisor 实现 from core.agents.state_based_supervisor.state_schema import PlanningAgentState # 包含 Plan 的状态 from core.agents.base.react_agent import ReactAgent # 你的 ReactAgent 实现 from core.llm.llm_manager import LLMManager # LLM 管理器 # (如果你的子 Agent 有更具体的类,在这里导入它们) # 例如: # from core.agents.researcher import ResearchAgent # from core.agents.coder import CoderAgent # --- 如果没有具体子 Agent 类,使用 ReactAgent 作为示例 --- # (确保 ReactAgent 可以被直接实例化用于测试) if not issubclass(ReactAgent, object): # 简单检查 ReactAgent 是否有效 raise ImportError("ReactAgent class not found or invalid.") except ImportError as e: print(f"Error importing agent components: {e}") print("Please ensure paths like 'core.agents.sb_supervisor_agent' are correct relative to your execution path.") import traceback # --- 主执行函数 (简化版,只关注最终结果) --- async def run_supervisor_test(supervisor_agent: SupervisorAgent, initial_state: Dict[str, Any]): """Executes the supervisor graph using ainvoke and prints the final state.""" print("--- Starting Supervisor Graph Test ---") # 获取初始消息列表,检查是否为空 messages_list = initial_state.get("messages", []) initial_query = "N/A" # 默认值 if messages_list: first_message = messages_list[0] # 检查第一个消息是否有 content 属性 (更健壮) if hasattr(first_message, 'content'): initial_query = first_message.content else: # 如果第一个元素不是预期的消息对象,记录一下 print(f"Warning: First item in initial messages is not a standard message object: {type(first_message)}") initial_query = str(first_message) # 尝试转换为字符串 print(f"Initial Query: '{initial_query}'") print("-" * 30) config = {"recursion_limit": 100} # 使用较高的递归限制 final_state: Optional[Dict[str, Any]] = None error_occurred: Optional[Exception] = None try: print("--- Invoking Supervisor Agent (ainvoke) ---") # 直接调用 ainvoke 获取最终状态 final_state = await supervisor_agent.ainvoke(initial_state, config=config) print("--- Supervisor Agent Invocation Complete ---") # --- 错误处理 --- except RateLimitError as e: error_occurred = e print("\n" + "="*40 + "\n!!! OpenAI API Error: Insufficient Quota !!!\n" + "="*40) print("Execution stopped. Check OpenAI plan/billing.") print(f"Original error: {e}") except TypeError as e: error_occurred = e print("\n" + "="*40 + "\n!!! TypeError During Graph Execution !!!\n" + "="*40) print(f"Error details: {e}") if "synchronous function provided" in str(e): print("Hint: Ensure all graph nodes support async or run the graph synchronously if needed.") traceback.print_exc() except Exception as e: error_occurred = e print("\n" + "="*40 + "\n!!! An Unexpected Error Occurred !!!\n" + "="*40) print(f"Error type: {type(e).__name__}\nError details: {e}") traceback.print_exc() # --- Process Final State --- if error_occurred: print("\n--- Graph Execution INTERRUPTED or FAILED ---") else: print("\n--- Graph Execution Finished ---") if not final_state: # 如果 ainvoke 返回 None 或在出错前未赋值 (理论上 ainvoke 会抛错或返回字典) print("Error: No final state available (Execution might have failed early).") # 尝试从 supervisor agent 获取最后状态 (如果 checkpointer 可用且实现了 get_state) if hasattr(supervisor_agent, 'checkpointer') and supervisor_agent.checkpointer and hasattr(supervisor_agent.checkpointer, 'get'): try: # 需要知道配置中的 thread_id (这里假设是 'test_thread') last_checkpoint = supervisor_agent.checkpointer.get({"configurable": {"thread_id": "test_thread"}}) if last_checkpoint: print("Attempting to display last known checkpoint state:") final_state = last_checkpoint.get('channel_values', {}) else: print("Could not retrieve last checkpoint state.") except Exception as cp_err: print(f"Error retrieving checkpoint state: {cp_err}") # 即使出错,也尝试打印 final_state (可能是包含错误信息的状态) if final_state and isinstance(final_state, dict): print("\n--- FINAL STATE ---") # 1. 打印错误信息 (如果存在) if final_state.get("error"): print(f"\nERROR RECORDED IN STATE: {final_state['error']}") # 2. 打印最终消息历史 (尝试 pretty_print) final_messages = final_state.get("messages", []) if final_messages and isinstance(final_messages, list): print("\nFinal Message History (Last ~10):") for m in final_messages[-10:]: # 只打印最后一部分 try: if hasattr(m, 'pretty_print'): m.pretty_print() else: # Fallback for dict or other types print(json.dumps(m, indent=2, default=str)) print("-" * 10) except Exception as print_err: print(f"Error printing final message: {print_err}") else: print("\nFinal Message History: Not available or empty.") # 3. 打印最终计划状态 final_plan = final_state.get('plan') if final_plan and isinstance(final_plan, dict): print("\nFinal Plan State:") print(json.dumps(final_plan, indent=2, default=str)) else: print("\nFinal Plan State: Not available or not generated.") else: print("\n--- No Final State Could Be Displayed ---") print("\n--- END OF TEST ---") return final_state # --- Main Execution Block --- async def main(): # --- 1. 初始化 LLM 管理器 (它会自动注册配置好的模型) --- try: model_manager = LLMManager() # 可以选择打印一下注册了哪些模型 print("Registered Models:", json.dumps(model_manager.list_models(), indent=2)) print("Capability Mapping:", model_manager.list_capabilities()) except Exception as e: print(f"Failed to initialize LLMManager: {e}") return # --- 2. 实例化 Agents (使用 ModelManager 获取模型) --- try: # 获取默认模型用于基础任务 grok = model_manager.get_model("xai_grok") # 获取 ID 由 config 或第一个注册的决定 deepseek_v3 = model_manager.get_model("deepseek_v3") # 获取 DeepSeek 模型 # 创建Tavily搜索工具 tavily_search = TavilySearchResults( max_results=3, include_answer=True, include_raw_content=False, include_images=False, search_depth="advanced" ) # 确保 ReactAgent 使用与 Supervisor 兼容的状态 (例如 BasicAgentState) # 或者 Supervisor 能够处理不同类型的子 Agent 状态返回 researcher_system_prompt = """You are a research expert. Use available tools to find the most up-to-date information to answer the user's query. You have access to a 'tavily_search_results_json' tool.""" research_agent = ReactAgent( name="research_expert", tools=[tavily_search], description="Research expert with access to Tavily search.", model=deepseek_v3, prompt=researcher_system_prompt, ) all_agents = [research_agent] # 只包含一个子 Agent 用于测试 # --- 实例化 Supervisor (使用 PlanningAgentState) --- supervisor = SupervisorAgent( agents=all_agents, model=deepseek_v3, # Supervisor 使用的 LLM state_schema=PlanningAgentState, # 明确 Supervisor 使用 Planning 状态 # enable_planning=True, # 不再需要此参数,因为 state_schema 暗示了规划 include_agent_name="inline" # 推荐 # checkpointer=... # 添加 Checkpointer 以测试持久化 ) except Exception as e: print(f"Failed to initialize agents or supervisor: {e}") import traceback traceback.print_exc() return # --- 获取用户输入 --- topic = input("Please enter the initial request for the supervisor: ") if not topic: print("No request entered. Exiting.") return # --- 准备初始状态 (使用 PlanningAgentState) --- initial_graph_state: PlanningAgentState = { "messages": [HumanMessage(content=topic)], # 确保是 HumanMessage 对象 "plan": None, # 初始没有计划 "error": None } # --- 运行测试 --- await run_supervisor_test(supervisor, initial_graph_state) if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: print("\nExecution interrupted by user.") except Exception as e: print(f"\nAn unexpected top-level error occurred: {e}") traceback.print_exc() ================================================ FILE: examples/state_based_supervisor_examples/03_multi_agents.py ================================================ # main.py (Multi-Agent Test with State-Based Supervisor) import asyncio import json import os import re import time import traceback # 导入 traceback from datetime import datetime from typing import Dict, Any, Optional, List, Literal, cast # --- LangChain / LangGraph / OpenAI Imports --- from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, ToolMessage # --- Agent 和工具导入 (确保路径正确) --- try: from core.agents.sb_supervisor_agent import SupervisorAgent # 替换为你的 SupervisorAgent 类路径 from core.agents.state_based_supervisor.state_schema import PlanningAgentState from core.agents.base.react_agent import ReactAgent # 导入 ReactAgent 基类 # 导入所有重构后的子 Agent 类 from core.agents.sub_agents.research_agent import ResearchAgent # 假设路径 from core.agents.sub_agents.coder_agent import CoderAgent # 假设路径 from core.agents.sub_agents.reporter_agent import ReporterAgent # 假设路径 from core.agents.sub_agents.designer_agent import DesignerAgent # 假设路径 from core.agents.sub_agents.data_analyst_agent import DataAnalystAgent # 假设路径 # 导入工具注册表函数和枚举 from core.tools.registry import get_tools_by_category, ToolCategory, register_tool # 导入 register_tool from core.llm.llm_manager import LLMManager # LLM 管理器 # 导入特定工具实例或类 (如果 Registry 没有预注册所有工具) from langchain_community.tools.tavily_search import TavilySearchResults # 示例 # from core.tools.e2b_tool import E2BCodeInterpreterTool # 示例 # from core.tools.replicate_flux_tool import ReplicateFluxImageTool # 示例 # --- 确保工具已注册 --- # 运行 registry 初始化 (通常在 core/tools/__init__.py 中完成) try: import core.tools # 尝试导入以触发 __init__.py 中的注册 print("Tool registry potentially initialized.") except ImportError: print("Warning: Could not import 'core.tools' to initialize registry.") except Exception as reg_err: print(f"Error during tool registry initialization: {reg_err}") # (可选) 在这里可以检查或手动注册缺失的核心工具 # Example: Check and register Tavily if not present if not any(getattr(t, 'name', '') == 'tavily_search_results_json' for t in get_tools_by_category(ToolCategory.SEARCH)): try: print("Attempting to register TavilySearchResults...") tavily_tool = TavilySearchResults(max_results=3) register_tool(tavily_tool, ToolCategory.SEARCH) except Exception as e: print(f"Warning: Failed to register TavilySearchResults manually: {e}") # ... 检查并注册其他必要的工具 ... except ImportError as e: print(f"Error importing agent/tool components: {e}") print("Please ensure all agent/tool class paths and registry setup are correct.") exit(1) # --- 助手函数 --- def slugify(text: str) -> str: """Converts text to a safe filename part.""" # ... (保持不变) ... if not text: return "no_topic" text = text.lower(); text = re.sub(r'\s+', '_', text) text = re.sub(r'[^\w\-]+', '', text); text = text.strip('_') return text[:100] if text else "sanitized_topic" # --- 主研究/测试函数 --- async def run_supervisor_test(supervisor_agent: SupervisorAgent, initial_state: Dict[str, Any]): """Executes the supervisor graph using ainvoke and prints the final state.""" print("\n--- Starting Supervisor Graph Execution ---") initial_query = initial_state.get("messages", [{}])[0].content if initial_state.get("messages") and hasattr(initial_state.get("messages")[0], 'content') else "N/A" print(f"Initial Query: '{initial_query}'") print("-" * 30) config = {"recursion_limit": 100} # 保持较高的递归限制 final_state: Optional[Dict[str, Any]] = None error_occurred: Optional[Exception] = None try: print("--- Invoking Supervisor Agent (ainvoke) ---") # 直接调用 ainvoke 获取最终状态 final_state = await supervisor_agent.ainvoke(initial_state, config=config) print("--- Supervisor Agent Invocation Complete ---") # --- 错误处理 --- except Exception as e: error_occurred = e; print(f"\n!!! Error during graph execution: {e}"); traceback.print_exc() # --- 处理最终状态 --- if error_occurred: print("\n--- Graph Execution INTERRUPTED or FAILED ---") else: print("\n--- Graph Execution Finished ---") if not final_state: print("Error: No final state available (Execution might have failed early).") # 尝试从 checkpointer 获取最后状态 (如果配置了) # ... (checkpoint retrieval logic - optional) ... return None print("\n--- FINAL STATE ---") # 打印错误 (如果在状态中记录了) if final_state.get("error"): print(f"\nERROR RECORDED IN STATE: {final_state['error']}") # 打印计划 final_plan = final_state.get('plan') if final_plan and isinstance(final_plan, dict): print("\nFinal Plan State:") print(json.dumps(final_plan, indent=2, default=str)) else: print("\nFinal Plan State: Not available or not generated.") # 打印最终消息历史 final_messages = final_state.get("messages", []) if final_messages and isinstance(final_messages, list): print("\nFinal Message History (Last 10):") for m in final_messages[-10:]: try: if hasattr(m, 'pretty_print'): m.pretty_print() else: print(json.dumps(m, indent=2, default=str)) # Fallback print("-" * 10) except Exception as print_err: print(f"Error printing final message: {print_err}") else: print("\nFinal Message History: Empty.") # --- 保存最终报告 (如果 Reporter Agent 被调用且成功) --- # 检查最后一条消息是否来自 Reporter final_report_content = None if final_messages and isinstance(final_messages[-1], AIMessage) and final_messages[-1].name == "reporter_expert": final_report_content = final_messages[-1].content print("\n--- Final Report Found from Reporter Agent ---") if not error_occurred and final_report_content and isinstance(final_report_content, str) and "Failed" not in final_report_content: print("\n--- Saving Final Output to Markdown ---") try: markdown_content = final_report_content # 获取原始请求作为文件名基础 initial_query_text = final_state.get('messages', [{}])[0].content if final_state.get('messages') and hasattr(final_state.get('messages')[0], 'content') else 'unknown_request' topic_slug = slugify(initial_query_text) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"multi_agent_report_{topic_slug}_{timestamp}.md" script_dir = os.path.dirname(os.path.abspath(__file__)) output_dir = os.path.join(script_dir, "Output") os.makedirs(output_dir, exist_ok=True) filepath = os.path.join(output_dir, filename) with open(filepath, "w", encoding="utf-8") as f: f.write(markdown_content) print(f"Successfully saved output to: {filepath}") except Exception as e: print(f"Error saving output to Markdown: {e}") elif error_occurred: print("\nFinal Report: Not saved due to execution error.") else: print("\nFinal Report: Not generated or not found.") print("\n--- END OF TEST ---") return final_state # --- Main Execution Block --- async def main(): # --- 1. 初始化 LLM 管理器 --- try: model_manager = LLMManager() print("Registered Models:", json.dumps(model_manager.list_models(), indent=2)) except Exception as e: print(f"Failed to initialize LLMManager: {e}") return # --- 2. 实例化所有 Agents --- try: # 获取模型实例 # 确保 'deepseek_v3' 和 'gpt-4o' 是你 LLMManager 中有效的 ID deepseek_model = model_manager.get_model("deepseek_v3") gpt4o_model = model_manager.get_model("openai_gpt4o") # 多模态模型 # 实例化 ResearchAgent research_agent = ResearchAgent( model=deepseek_model, ) # 实例化 CoderAgent coder_agent = CoderAgent( model=deepseek_model, ) # 实例化 ReporterAgent reporter_agent = ReporterAgent( model=deepseek_model ) # 实例化 DesignerAgent designer_agent = DesignerAgent( model=gpt4o_model, ) # 实例化 DataAnalystAgent data_analyst_agent = DataAnalystAgent( model=deepseek_model, ) # --- 3. 组合 Agent 列表 --- all_agents = [ research_agent, coder_agent, reporter_agent, designer_agent, data_analyst_agent, ] # --- 4. 实例化 Supervisor --- supervisor = SupervisorAgent( agents=all_agents, model=deepseek_model, # Supervisor 自身使用的模型 # model = gpt4o_model, state_schema=PlanningAgentState, include_agent_name="inline" # checkpointer=... # 可选: 添加 Checkpointer 实现持久化 ) except Exception as e: print(f"Failed to initialize agents or supervisor: {e}") traceback.print_exc() return # --- 5. 获取用户输入 --- topic = input("Please enter the initial request for the supervisor: ") if not topic: print("No request entered. Using default topic.") topic = """我需要获取法国巴黎当前的实时气温。请按以下步骤操作: 1. 首先,帮我调研一个可以免费获取巴黎当前天气数据的 API (例如 Open-Meteo, WeatherAPI.com 或其他类似的),重点是找到获取当前气温的 API 端点(endpoint URL)以及如何构造请求(如果可能,选择不需要 API key 的)。 2. 然后,编写一个 Python 脚本,使用 'requests' 库来调用上一步找到的 API 端点,并从中提取出巴黎当前的温度(摄氏度)。 3. 使用你的代码执行工具来运行这个 Python 脚本。 4. 最后,告诉我你找到的当前巴黎温度是多少。""" # --- 6. 准备初始状态 --- initial_graph_state: PlanningAgentState = { "messages": [HumanMessage(content=topic)], "plan": None, "error": None } # --- 7. 运行测试 --- await run_supervisor_test(supervisor, initial_graph_state) if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: print("\nExecution interrupted by user.") except Exception as e: print(f"\nAn unexpected top-level error occurred: {e}") traceback.print_exc() ================================================ FILE: examples/web_agents/README.md ================================================ # Web Agents 这个目录包含可以通过web界面加载的代理示例。每个子目录代表一个独立的代理实现,可以被server.py动态加载。 ## 目录结构 每个代理应遵循以下结构: ``` agent_name/ __init__.py # 包含get_graph()函数,返回编译好的LangGraph README.md # 代理的说明文档 ``` ## 接口规范 每个代理必须实现以下接口: ```python def get_graph(): """返回编译好的LangGraph实例""" # 构建并返回图 return compiled_graph ``` ================================================ FILE: examples/web_agents/README_SPEC.md ================================================ # Web Agent 开发规范 ## 1. 概述 本规范旨在统一Web Agent的开发流程和命名约定,确保前后端协同工作,避免出现前端组件无法正确显示后端数据的问题。本文档基于实际开发经验,特别强调前后端节点命名一致性的重要性。 ## 2. 节点命名规范 ## 2. 前后端交互核心机制 ### 2.1 关键概念 - **节点名称匹配**: 前端渲染组件时,会根据后端节点的名称来选择对应的组件进行渲染 - **状态数据结构**: 后端节点生成的状态数据必须符合前端组件期望的结构 - **渲染函数**: 前端的`renderNode`函数是连接后端节点和前端组件的关键桥梁 ### 2.2 渲染流程 1. 后端节点执行并生成状态数据 2. 前端通过`useLangGraphAgent`钩子接收节点数据 3. 前端的`renderNode`函数根据节点名称选择对应组件 4. 组件根据状态数据进行渲染 ## 3. 节点命名规范 ### 3.1 关键节点命名 所有Web Agent必须在图结构中包含处理消息的节点,这些节点名称必须与前端`renderNode`函数中的case语句匹配: ```python # 后端节点命名 - 必须与前端renderNode函数中的case匹配 builder.add_node("agent", agent_function) # 或其他在前端已注册的节点名称 ``` **重要提示**: 前端`page.tsx`中的`renderNode`函数定义了可识别的节点名称。目前支持的节点名称有: - `__start__` - `agent` (替代了原来的`chatbot`) - `weather` - `reminder` - `research` - `search` - `report` 如果后端使用了其他节点名称,必须在前端的`renderNode`函数中添加对应的case语句。 ### 3.2 状态字段命名 - 状态字段名称应与前端组件期望的字段名称保持一致 - 使用蛇形命名法(snake_case)命名状态字段 - 复杂数据结构应使用数组形式,即使只有一个元素 ### 3.3 必要的状态字段 每个Web Agent必须在`agent-types.ts`文件中定义其状态接口,并确保后端发送的状态与此接口匹配: ```typescript export interface AgentState extends WithMessages { // 定义Agent特有的状态字段 weather_forecast?: WeatherForecast[]; research_status?: ResearchStatus[]; // 其他状态字段 } ``` ## 4. 前端组件规范 ### 4.1 组件结构 - 主组件应根据节点名称渲染不同的子组件 - 子组件应检查所需状态字段是否存在,并提供合理的默认行为 ```typescript export default function renderNode(checkpoint, node) { switch (node.name) { case '__start__': case 'agent': // 注意:这里使用'agent'替代了原来的'chatbot' return ; case 'weather': return ; // 其他节点类型 default: return null; } } ``` ### 4.2 组件注册 所有Web Agent的节点组件必须在`page.tsx`的`renderNode`函数中正确注册: ```typescript const renderNode = (checkpoint, node) => { switch (node.name) { // 确保这里的节点名称与后端图定义中的节点名称一致 case '__start__': case 'agent': // 注意:这里使用'agent'替代了原来的'chatbot' return ; case 'weather': return ; case 'reminder': return ; case 'research': case 'search': case 'report': return ; // 添加新节点类型的渲染逻辑 default: return null; } } ``` ## 5. 后端图结构规范 ### 5.1 节点函数 - 节点函数应使用适当的参数来处理状态 - 消息处理必须在与前端匹配的节点中进行 ```python async def agent(state): # 注意:这里使用'agent'替代了原来的'chatbot' # 处理消息并返回结果 return {"messages": [...]} # 必须包含messages字段 ``` ### 5.2 图构建 - 图必须包含与前端匹配的节点,用于处理消息 - 必须实现`get_graph()`函数返回编译好的图实例 ```python def get_graph(): """返回编译好的LangGraph实例""" builder = StateGraph() builder.add_node("agent", agent) # 注意:这里使用'agent'替代了原来的'chatbot' # 添加边和其他节点 graph = builder.compile(checkpointer=MemorySaver()) return graph ``` ## 6. 开发流程 ### 6.1 新建Web Agent流程 1. 在`examples/web_agents/`下创建新的Agent目录 2. 创建`graph.py`文件,实现Agent的图结构,确保节点名称与前端`renderNode`函数中的case语句匹配 3. 在`web/app/chat/[id]/agent-types.ts`中添加Agent所需的状态接口 4. 在`web/app/chat/[id]/components/`下创建Agent的组件 5. 在`web/app/chat/[id]/page.tsx`的`renderNode`函数中注册Agent的节点组件(如果使用新的节点名称) ### 6.2 测试验证 在提交代码前,必须进行以下测试: 1. 确认后端图结构中的节点名称与前端`renderNode`函数中的case语句匹配 2. 验证前端组件能正确渲染不同类型的节点 3. 检查状态字段名称与前端组件期望的字段名称一致 ## 7. 常见问题与解决方案 ### 7.1 前端不显示消息问题 如果前端不显示消息内容,请检查: 1. 后端图结构中的节点名称是否与前端`renderNode`函数中的case语句匹配 2. 前端`renderNode`函数是否正确处理了对应的节点名称 3. 消息是否正确包含在state的messages字段中 ### 7.2 状态更新不生效 确保状态更新时,字段名称与前端期望的字段名称一致,并且数据结构符合前端组件的预期。 ### 7.3 添加新节点类型 如果需要添加新的节点类型,必须: 1. 在后端图结构中定义新节点 2. 在前端`page.tsx`的`renderNode`函数中添加对应的case语句 3. 创建新节点对应的前端组件 4. 在`agent-types.ts`中添加新节点所需的状态接口 --- 遵循本规范可以有效避免前后端不一致导致的显示问题,提高Web Agent的开发效率和质量。 ================================================ FILE: examples/web_agents/__init__.py ================================================ # Web Agents Package # This package contains web agents that can be loaded by the server ================================================ FILE: examples/web_agents/research_assistant/README.md ================================================ # 研究助手 这是一个强大的研究助手代理,可以帮助用户进行在线研究、信息收集和报告生成。 ## 功能 - 在线搜索信息 - 提取和总结网页内容 - 生成研究报告 - 实时显示研究进度 ## 使用方法 用户可以通过自然语言与代理交互,例如: - "帮我研究人工智能在医疗领域的应用" - "查找关于气候变化的最新研究" - "总结量子计算的基本原理" ## 技术实现 该代理使用LangGraph构建,结合了Supervisor和React模式,包含以下节点: - supervisor: 协调整个研究流程 - search: 执行在线搜索 - extract: 提取网页内容 - analyze: 分析收集的信息 - report: 生成研究报告 研究过程中会实时更新状态,让用户了解当前进度。 ================================================ FILE: examples/web_agents/research_assistant/__init__.py ================================================ # Research Assistant Agent # This module provides a research assistant agent that can crawl websites and extract content from .graph import get_graph __all__ = ["get_graph"] ================================================ FILE: examples/web_agents/research_assistant/graph.py ================================================ from langgraph.prebuilt import create_react_agent from langchain_openai import ChatOpenAI from typing import Dict, Any from dotenv import load_dotenv from langchain_community.tools import TavilySearchResults from langgraph.checkpoint.memory import MemorySaver from core.tools.e2b_tool import E2BCodeInterpreterTool from core.tools.registry import register_tool, ToolCategory from core.llm.llm_manager import LLMManager load_dotenv() # 自动加载 .env 文件 # 初始化大模型 model = LLMManager().get_model("deepseek_v3") # 创建Tavily搜索工具 tavily_search = TavilySearchResults( max_results=3, include_answer=True, include_raw_content=False, include_images=False, search_depth="advanced" ) # 创建E2B代码解释器工具 e2b_code_interpreter = E2BCodeInterpreterTool() research_agent = create_react_agent( model=model, tools=[tavily_search, e2b_code_interpreter], name="research_expert", # Prompt 告诉它是一个研究型 Agent,可调用 tavily_search 和 e2b_code_interpreter prompt=( "你是一位世界级的研究专家和数据分析师,擅长信息检索和数据分析。你有两个强大的工具可以使用:\n" "1. 'tavily_search_results_json':用于搜索网络获取实时信息\n" "2. 'e2b_code_interpreter':用于执行Python代码,支持数据分析和可视化\n\n" "当面对问题时,请遵循以下方法论:\n" "1. 分析问题:理解用户的需求和问题本质\n" "2. 制定计划:确定需要搜索哪些信息,以及是否需要进行数据分析\n" "3. 执行搜索:使用tavily_search_results_json工具获取最新信息\n" "4. 数据分析:如果需要,使用e2b_code_interpreter工具编写和执行Python代码进行数据分析和可视化\n" "5. 综合信息:将搜索结果和数据分析结果综合成一个连贯的回答\n\n" "重要提示:\n" "- 对于信息检索任务,使用tavily_search_results_json工具,并在回答中引用来源URL\n" "- 对于数据分析和可视化任务,使用e2b_code_interpreter工具执行Python代码\n" "- 在使用代码解释器时,确保导入必要的库(如pandas, matplotlib, numpy等)\n" "- 在代码中添加详细注释,解释关键步骤\n" "- 执行代码后,解释结果含义和见解" ), checkpointer=MemorySaver(), ) def get_graph(): return research_agent ================================================ FILE: examples/web_agents/weather_agent/README.md ================================================ # 天气代理 这是一个简单的天气查询代理,可以回答用户关于天气的问题,并提供天气预报信息。 ## 功能 - 查询当前天气 - 创建提醒 ## 使用方法 用户可以通过自然语言与代理交互,例如: - "今天北京的天气怎么样?" - "帮我设置一个提醒,明天早上8点去开会" ## 技术实现 该代理使用LangGraph构建,包含以下节点: - chatbot: 处理用户输入并生成回复 - weather: 处理天气查询请求 - reminder: 处理提醒创建请求 ================================================ FILE: examples/web_agents/weather_agent/__init__.py ================================================ # Weather Agent Example # This is a simple weather agent that can be loaded by the server import operator from typing import Literal, TypedDict, Any, Annotated from dotenv import load_dotenv from langchain_openai import ChatOpenAI from langgraph.graph import StateGraph, MessagesState, START, END from langgraph.checkpoint.memory import MemorySaver from langgraph.types import StreamWriter, interrupt, Send from langchain_core.messages import ToolMessage from langchain_core.tools import tool import random import asyncio load_dotenv() class Weather(TypedDict): location: str search_status: str result: str class State(MessagesState): weather_forecast: Annotated[list[Weather], operator.add] class WeatherInput(TypedDict): location: str tool_call_id: str class ToolNodeArgs(TypedDict): name: str args: dict[str, Any] id: str @tool async def weather_tool(query: str) -> str: """Call to get current weather""" return "Sunny" @tool async def create_reminder_tool(reminder_text: str) -> str: """Call to create a reminder""" return "Reminder created" async def weather(input: WeatherInput, writer: StreamWriter): location = input["location"] tool_call_id = input["tool_call_id"] # Send custom event to the client. It will update the state of the last checkpoint and all child nodes. # Note: if there are multiple child nodes (e.g. parallel nodes), the state will be updated for all of them. writer({"weather_forecast": [ {"location": location, "search_status": f"Checking weather in {location}"}]}) await asyncio.sleep(2) weather = random.choice(["Sunny", "Cloudy", "Rainy", "Snowy"]) return {"messages": [ToolMessage(content=weather, tool_call_id=tool_call_id)], "weather_forecast": [{"location": location, "search_status": "", "result": weather}]} async def reminder(input: ToolNodeArgs): res = interrupt(input['args']['reminder_text']) tool_answer = "Reminder created." if res == 'approve' else "Reminder creation cancelled by user." return {"messages": [ToolMessage(content=tool_answer, tool_call_id=input["id"])]} async def chatbot(state: State): llm = ChatOpenAI( model="gpt-4o-mini").bind_tools([weather_tool, create_reminder_tool]) response = await llm.ainvoke(state["messages"]) return {"messages": [response]} def tool_router(state: State) -> Literal["weather", "reminder", "__end__"]: messages = state["messages"] last_message = messages[-1] if last_message.tool_calls: if last_message.tool_calls[0]["name"] == "weather_tool": return "weather" elif last_message.tool_calls[0]["name"] == "create_reminder_tool": return "reminder" return "__end__" # Chatbot node router. Based on tool calls, creates the list of the next parallel nodes. def assign_tool(state: State) -> Literal["weather", "reminder", "__end__"]: messages = state["messages"] last_message = messages[-1] if last_message.tool_calls: send_list = [] for tool in last_message.tool_calls: if tool["name"] == 'weather_tool': send_list.append( Send('weather', {'location': tool['args']['query'], 'tool_call_id': tool['id']})) elif tool["name"] == 'create_reminder_tool': send_list.append(Send('reminder', tool)) return send_list if len(send_list) > 0 else "__end__" return "__end__" def get_graph(): """Return the compiled graph for this agent""" builder = StateGraph(State) builder.add_node("chatbot", chatbot) builder.add_node("weather", weather) builder.add_node("reminder", reminder) builder.add_edge(START, "chatbot") builder.add_conditional_edges("chatbot", assign_tool) builder.add_edge("weather", "chatbot") builder.add_edge("reminder", "chatbot") builder.add_edge("chatbot", END) memory = MemorySaver() return builder.compile(checkpointer=memory) ================================================ FILE: instructions/00.Langgraph 和 React Agent.md ================================================ # 一、LangGraph 的核心思想 LangGraph 是一个可以让开发者以**图(Graph)**的方式来编排对话式AI流程的库,提供了以下能力: 1. **状态驱动**:在传统的对话模型中,我们经常需要维护对话上下文、剩余步骤等各种内部变量。LangGraph 将这些变量统一到一个“状态(State)”里,并约定任何节点的输入/输出都以“状态(State)”的形式表示。 2. **可视化执行流**:LangGraph 将对话/工具调用/自定义逻辑封装成“节点(Node)”与“边(Edge)”。当图被编译后,执行流会在节点之间穿梭,处理对话消息、调用工具、终止或转向某些分支。 3. **可组合**:你可以把一个复杂的对话逻辑拆分为多个可复用的子图,每个子图都可以独立进行单元测试或复用在更大的图中。 4. **多步思考 + 工具调用**:通过 ReAct Agent(一个经典的多步推理+工具调用范式),LangGraph 可以帮助你自动管理**多次**调用语言模型及其衍生工具的过程——只要你把“工具”注册到图里。 在使用时,你基本会经历如下步骤: 1. **定义状态模式(state schema)**:说明 state 中必须包含哪些字段(如:对话消息 `messages`,剩余可用步骤 `remaining_steps`,等等)。 2. **定义节点(Node)**:比如一个负责调用LLM的节点、一个负责执行特定工具的节点、或者你自定义的Python逻辑节点。 3. **连接边(Edges)**:决定每个节点之后,下一步走到哪个节点;也可以做条件分支或循环。 4. **编译图(Compile)**:LangGraph 会把你的“编排逻辑”转换为一个 LangChain 兼容的“可调用对象(CompiledGraph)”。 5. **执行或流式执行**:可以直接一次性 `graph.invoke(...)` 得到最终结果,也可以使用 `graph.stream(...)` 流式获取每个“阶段性状态(partial state)” 。 --- # 二、LangGraph 核心概念详解 LangGraph 构建的是一个"流程图",每个智能体(agent)或功能模块(tool调用、分支逻辑)都是这个流程图的一个节点(node)。让我们深入理解其中的核心概念: ## 2.1 Graph:有状态的数据流图 Graph 是整个 Agent 系统的执行框架,定义了哪些模块怎么串联、怎么流转。你构建的 graph 是一个有向图: ```python graph = StateGraph(state_schema=MyState) # 添加节点 graph.add_node("supervisor", supervisor_runnable) graph.add_node("writer", writer_runnable) # 添加边来连接节点 graph.add_edge("supervisor", "writer") graph.add_edge("writer", "supervisor") ``` LangGraph 根据这些连接关系来控制执行流程,决定在某个节点执行完后下一步应该去哪里。 ## 2.2 Node:图中的"一个执行单元" 每个 node 是图中的一个处理模块(通常就是一个智能体)。它接受一个输入 state,做点事情,然后返回一个新的 state: ```python def my_node(state: dict) -> dict: # 处理 state 中的数据 new_state = state.copy() # 修改状态内容 new_state["some_key"] = "new_value" return new_state ``` 节点可以是: - 函数(同步或异步) - LLM Agent(如 create_react_agent 返回的) - 包装后的 Agent(如 MemorySlidingReactAgent) ## 2.3 State:每一轮节点处理的输入/输出 每轮调用,LangGraph 会传递一个 "state"(字典类型)给当前节点。这个 state 中可以包含: - `messages`: 当前对话历史(主上下文)【默认】 - `memory`: 你自定义的长期记忆(可以注入系统提示) - `todo_list`, `current_task`: 其他任务状态 - 任何你自定义的字段 每个节点执行后,返回新的 state: ```python def writer(state): new_msg = generate_chapter(state["current_task"]) state["messages"].append({"role": "assistant", "content": new_msg}) return state ``` ## 2.4 Runnable:Node 的运行接口 LangGraph 要求,每个节点(node)必须是可以运行的,也就是说:你交给 `add_node()` 的对象必须有 `.invoke(state)` 或 `.ainvoke(state)` 方法。 比如: - 函数本身(它会自动包装成 Runnable) - Agent(React agent 本身就支持 `.invoke`) - `RunnableCallable(...)` 是 LangGraph 用来显式包装函数的工具 举个例子: ```python def my_function(state: dict) -> dict: # 处理逻辑 return state runnable = RunnableCallable(my_function, async_version) graph.add_node("writer", runnable) ``` ## 2.5 执行流程 LangGraph 的执行流程大致如下: ``` LangGraph Graph: [START] ↓ [Supervisor Node] ↓ [Writer Node] ↓ [Supervisor Node] ↓ [END] ``` 每次节点执行时: 1. 传入当前 state 2. `.invoke(state)` 被调用 3. 返回更新后的 state 4. 下一节点接着执行 ## 2.6 概念类比 | LangGraph 概念 | 类比 | |----------------|------| | Graph | 工作流程图/数据流图 | | Node | 每个处理步骤/智能体 | | State | 当前上下文与执行状态(黑匣子) | | Runnable | 每个节点"能被执行"的接口定义 | --- # 三、ReAct Agent 与 create_react_agent 概念 ## 3.1 什么是 ReAct Agent “ReAct” 是一种典型的LLM多步推理与工具调用策略。它主要包含两部分: 1. **Reasoning**:先让语言模型(LLM)进行一步推理,产出一个潜在的思考过程以及可能的工具调用。 2. **Acting**:如果模型说“我要调用某个工具”,则执行该工具,得到结果,再把结果加入对话,然后让模型再次 Reason,看看是否还需要再调用工具,或输出最后的答案。 这个循环可以**多次往返**,直到模型不再调用工具,输出最终结果。 ## 3.2 create_react_agent 做了什么 `create_react_agent(...)` 是 LangGraph 中的一个快捷方法,用于**快速创建**一个可执行的“ReAct风格”图(Graph): - **自动添加“agent节点”**:用来调用你的语言模型(并在对话中发出可能的工具调用)。 - **自动添加“tools节点”**:如果 agent 的输出中含有工具调用(tool_calls),则会交给 tools 节点逐个执行,并把执行结果以 `ToolMessage` 的形式返回到对话中。 - **自动在 agent ↔ tools 之间连线**:只要 agent 产生了工具调用,就进入 tools;tools 执行完返回消息后,再回到 agent;直到不再有工具调用为止。 - **可选 structured output**:如果你传入了 `response_format` 参数,LangGraph 会在最后一步生成一个结构化输出(“JSON Schema”、“Pydantic”、“OpenAI function schema”等),以便你获取可解析的最终结果。 - **控制“剩余步骤”**:Agent 每次回答后会检查是否还可以继续调用工具,或者是否需要中止并返回错误(“抱歉,需要更多步骤”)。 因此,调用 `create_react_agent(...)` 得到的结果,是一个**已经配置好**的 “CompiledGraph”。这个图中带有 “agent” 节点(LLM) 和 “tools” 节点(调用工具),以及检查**是否还有工具要调**的逻辑。你可以直接拿这个对象执行,获得一个 ReAct 流程的多轮对话+工具使用。 --- # 四、执行流程:从输入到输出 创建好 ReAct 图后,你给它一个输入状态(最少包含 `"messages"`,如 `{"messages": [("user", "Hello!")]}`)。执行过程大体是: 1. **entry point: "agent"** 进入 agent 节点,它会从 state["messages"] 中取出消息,交给 LLM 生成一个 AIMessage。如果 AIMessage 包含 tool_calls,那么 state 会更新多一些字段,比如 `messages` 后面多了这个 AIMessage。 2. **检查是否要调用工具** - 如果 `tool_calls` 不为空,则顺着 edges 进入 "tools" 节点。 - 如果没有 tool_calls,则表示 agent 没有想调用任何工具 -> 流程会判断是否要去 “generate_structured_response” 或 “END”。 3. **tools 节点执行** "tools" 节点会去匹配 agent 要调用的工具,比如: ```json { "name": "search_tool", "args": {"query": "something"}, "id": "call_abc123" } ``` 然后运行相应的 Python 函数,得到结果后,包装成 `ToolMessage`,附加回 state["messages"] 列表里。 - 如果 agent 一次性请求了多个工具,在 v1 版本中则会并行执行,再把返回结果依次追加到 messages 里。 - 在 v2 版本中,LangGraph 会拆分 tool_calls 分批执行。 4. **回到 agent** 现在 agent 再次拿到新的 state["messages"](多了“ToolMessage”),就会针对最新的对话上下文重新进行思考——是否要再调用别的工具、或者是否直接产出最终回答? 5. **循环,直到不再调用工具** 只要 AIMessage 继续发出 tool_calls,就进入 Tools 节点;Tools 执行完再回到 Agent 节点。这一过程可能多次往返。(如果你设定了 `remaining_steps`,LangGraph 在每一轮都会减少1,直到不足时终止或报错,避免死循环。) 6. **可选:结构化输出** 在最后如果 `response_format` 存在,图会跳到“generate_structured_response”节点,再次对(几乎)所有对话做一次 LLM 调用,要求 LLM 产出符合**你给定schema**的 JSON,并存入 `structured_response` 字段中。然后再返回 END。 7. **结束** 整个 ReAct 流程完成后,图会返回一个最终状态,如: ```python { "messages": [ # 所有对话消息(包含了Human/AI/Tools等), ..., AIMessage(content="Here is the final answer", tool_calls=[]) ], "remaining_steps": 2, "structured_response": { ... } # 如果使用了response_format } ``` 你可以从中拿到想要的最终 AI 回答。 --- # 五、如何查看“中间推理”或“工具调用”? 从 **langgraph 0.3** 开始,`create_react_agent` 及其返回的 Graph 已经**不再支持** `graph.add_state_change_listener` 或在函数参数里传入 `callbacks`。如果你想**监听**或**打印** Agent 的中间思考、工具调用等过程,最好的方式是 **使用 `graph.stream(...)`**——它会在每一小步执行结束后产出一个“部分状态( partial_state )”,你可以在循环里进行日志记录、可视化或其他操作。示例: ```python graph = create_react_agent(model, tools=[...], prompt="...") inputs = { "messages": [ ("user", "请分析特斯拉2025年的发展预期,包括新车型计划、销量目标、技术创新和市场扩张战略。") ] } for partial_state in graph.stream(inputs, stream_mode="values"): messages = partial_state["messages"] last_msg = messages[-1] if last_msg.type == "ai": print("[AIMessage] => ", last_msg.content) if last_msg.tool_calls: print("AI wants to call tools:", last_msg.tool_calls) elif last_msg.type == "tool": print("[ToolMessage] => Name:", last_msg.name, "Content:", last_msg.content) elif last_msg.type == "human": print("[User] => ", last_msg.content) # 最后一次迭代时,partial_state 就是最终结果 final_answer = partial_state["messages"][-1].content print("最终回答:", final_answer) ``` 这样就能够**在每一次** Agent 或 Tools 完成后都获取状态,不需要“回调监听器”。 --- # 六、关于一些进阶用法 1. **`interrupt_before` / `interrupt_after`** 如果你希望在“agent”节点**执行前**或者**后**打断,可以设置这两个可选参数,比如: ```python create_react_agent( model, tools=[...], interrupt_before=["tools"], interrupt_after=["agent"], ... ) ``` 当执行流程跑到 agent 或 tools 时,会先/后给你一个“交互点”机会,你可以在**流式**执行中察觉到这个点,或者抛出异常提前终止等。但是它比较适合做“用户确认”或“调试介入”,而不是实时日志。 2. **`checkpointer` / `store`** - `checkpointer` 主要用来将单个“线程”(单条对话)的状态进行保存、恢复,可以在多回合对话里保留上下文。 - `store` 提供了更跨线程或跨用户的持久化能力。 通过把 `store` 绑定到 Graph,工具调用里还可以使用 `InjectedStore`,把数据写入或读取到 store 中(如相当于“全局数据库”)。 3. **`response_format`** 如果你想让最终输出符合某种 JSON Schema 或 Pydantic 验证,可以这样写: ```python from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, Field class TeslaPlan(BaseModel): new_models: list[str] = Field(..., description="新车型列表") sales_target: int = Field(..., description="预计销量") technology_innovations: str market_strategy: str my_response_format = TeslaPlan graph = create_react_agent( model, tools, prompt="你是一个专业汽车分析师。", response_format=my_response_format ) ``` 当 ReAct 流程结束后,LangGraph 会调用一次 LLM 并要求它返回符合 `TeslaPlan` 的 JSON。最终的 `state["structured_response"]` 就是一个 Python 字典或 Pydantic 实例。 4. **`version="v1" / "v2"`** - **v1**: 工具调用是“把当前 AIMessage 中的所有 tool_calls 一次性并行执行” → tools → 再回到 agent。 - **v2**: 更细粒度地把每个 tool_call 拆开,每个都进入一个独立的 ToolNode 实例。如果一个 AIMessage 里有 3 个 tool_calls,就会做 3 次独立的“tools执行→回到agent”循环**(通过 Send API)**。这种方式可以在多工具协作里更灵活,也可以插入更多自定义逻辑,但要做好相应的结构化处理。 --- # 七、常见问题与答疑 1. **Q: 我在旧版本使用 `graph.add_state_change_listener` 或 `callbacks`,现在为什么报错?** A: 因为新的 LangGraph 0.3 取消了这种回调API,推荐使用 `graph.stream(...)` 在每一步迭代中自行处理日志或监听逻辑。 2. **Q: 如果不想每次都多轮循环,而只想 LLMC 接受一次输入就结束,怎么做?** A: 你可以传递 `tools=[]`(空)到 `create_react_agent`,这样它就生成一个不支持工具调用的图;agent 只会输出一次,然后就结束。此时相当于纯LLM调用。 3. **Q: 要怎么限制调用工具的次数?** A: 你可以在输入的 `state` 里设置 `remaining_steps`,或自定义 `AgentState` 包含 `remaining_steps=3` 一类初始值,每次 agent节点执行后,LangGraph 会自动减少1。用完就不会再允许工具调用了。 4. **Q: ReAct 会在同一个消息里多次请求调用工具吗?** A: 是可能的。尤其是当 LLM 在回答中生成多个 tool_calls,就会全部执行。你可以在 `v1` 模式下并行运行它们,也可以在 `v2` 模式下逐个执行。 5. **Q: structured response 里的提示是如何工作的?** A: 当 `response_format` 是 `(system_prompt, schema)` 这种 tuple 时,LangGraph 会在最后的 LLM 调用里给一个额外的 system_prompt,引导 LLM 返回符合 schema 的 JSON。这样可以做更严格的结构化要求。 --- # 八、总结 - **LangGraph** 是一个以“图”来编排对话和工具调用的框架; - **create_react_agent** 是“快捷构造 ReAct 风格图”的核心函数,一次性帮你搭建“agent(LLM) ↔ tools(工具节点) ↔ agent”循环; - 执行时默认从 `agent` 开始,如果 `AIMessage` 包含 `tool_calls` 就调用 `tools` 并注入结果,直到不再有工具调用; - 可以**流式**(`graph.stream(...)`) 或**一次性**(`graph.invoke(...)`)获取结果; - 要想查看中间推理和调用日志,使用 `stream` 在每一步循环里记录; - 可选地,你能通过 `interrupt_before` / `interrupt_after` 或 `checkpointer` / `store` 等更高级特性进一步定制执行流程或存储/恢复状态。 这就是从**原理**到**源码**再到**执行流程**的完整解析。希望能帮助你在实际项目里更好地运用 `create_react_agent` 和 LangGraph! ================================================ FILE: instructions/01.supervisor_pattern.md ================================================ # Supervisor 模式:多智能体协作的核心实现 ## 1. 引言 在人工智能领域,多智能体系统(Multi-Agent System)是一种将复杂任务分解为多个专业智能体协同完成的架构模式。本文将详细介绍我们在 Mentis 项目中实现的 Supervisor(监督者)模式,这是一种高效组织和协调多个智能体的方法。 ## 2. 多智能体系统的基本概念 多智能体系统由多个具有不同专业能力的智能体组成,每个智能体负责特定的任务领域。在这种系统中,智能体之间需要有效地协作和通信,以完成复杂的任务。 在我们的实现中,主要包含以下角色: - **Supervisor(监督者)**:负责任务分发、协调和结果整合的中央控制智能体 - **Specialized Agents(专业智能体)**:具有特定领域专长的执行智能体 ## 3. Supervisor 模式的工作流程 ### 3.1 基本工作流程 Supervisor 模式的工作流程如下: 1. 用户向系统提交请求 2. Supervisor 接收请求并进行任务分析 3. Supervisor 决定调用哪个专业智能体处理任务 4. 专业智能体执行任务并返回结果 5. Supervisor 接收结果,可能进一步调用其他智能体 6. Supervisor 整合所有结果并返回给用户 ### 3.2 控制权转移机制 Supervisor 模式的核心是控制权的转移机制。在我们的实现中,这通过 `handoff` 工具实现: 1. Supervisor 通过调用特定的 `handoff` 工具将控制权转移给目标智能体 2. 目标智能体完成任务后,通过 `handoff_back_messages` 将控制权返回给 Supervisor 3. 这种机制确保了在任何时刻只有一个智能体在处理任务,避免了冲突 ## 4. Supervisor 的核心实现 ### 4.1 核心代码分析 在 `supervisor.py` 中,`create_supervisor` 函数是实现 Supervisor 模式的核心: ```python def create_supervisor( agents: list[Pregel], *, model: LanguageModelLike, tools: list[BaseTool | Callable] | None = None, prompt: Prompt | None = None, # ... 其他参数 ... ) -> StateGraph: # 检查智能体名称唯一性 agent_names = set() for agent in agents: if agent.name is None or agent.name == "LangGraph": raise ValueError("Please specify a name when you create your agent...") if agent.name in agent_names: raise ValueError(f"Agent with name '{agent.name}' already exists...") agent_names.add(agent.name) # 为每个智能体创建 handoff 工具 handoff_tools = [create_handoff_tool(agent_name=agent.name) for agent in agents] all_tools = (tools or []) + handoff_tools # 绑定工具到模型 model = model.bind_tools(all_tools) # 创建 supervisor 智能体 supervisor_agent = create_react_agent( name=supervisor_name, model=model, tools=all_tools, prompt=prompt, # ... 其他参数 ... ) # 构建状态图 builder = StateGraph(state_schema, config_schema=config_schema) builder.add_node(supervisor_agent, destinations=tuple(agent_names) + (END,)) builder.add_edge(START, supervisor_agent.name) # 添加智能体节点和边 for agent in agents: builder.add_node( agent.name, _make_call_agent( agent, output_mode, add_handoff_back_messages, supervisor_name, ), ) builder.add_edge(agent.name, supervisor_agent.name) return builder ``` ### 4.2 智能体调用机制 `_make_call_agent` 函数负责创建智能体调用的包装函数: ```python def _make_call_agent( agent: Pregel, output_mode: OutputMode, add_handoff_back_messages: bool, supervisor_name: str, ) -> Callable[[dict], dict] | RunnableCallable: # ... 参数验证 ... def _process_output(output: dict) -> dict: messages = output["messages"] # 根据输出模式处理消息 if output_mode == "full_history": pass elif output_mode == "last_message": messages = messages[-1:] # 添加控制权返回消息 if add_handoff_back_messages: messages.extend(create_handoff_back_messages(agent.name, supervisor_name)) return { **output, "messages": messages, } def call_agent(state: dict) -> dict: output = agent.invoke(state) return _process_output(output) # ... 异步版本 ... return RunnableCallable(call_agent, acall_agent) ``` ### 4.3 设计亮点与最佳实践 Supervisor 模式的实现包含了多个多智能体系统设计的黄金经验,以下是关键设计亮点: #### 4.3.1 自动控制权回传机制 `_make_call_agent` 中的自动 handoff back 机制非常巧妙: ```python if add_handoff_back_messages: messages.extend(create_handoff_back_messages(agent.name, supervisor_name)) ``` 这种设计的优势在于: - **隐式交接**:专业智能体无需知道 supervisor 的存在 - **自动转发**:智能体完成任务后,系统自动将结果打包并转交回 supervisor - **消息插入**:在消息历史中自动插入 AIMessage 和 ToolMessage,表明控制权已转移 - **零侵入性**:对智能体代码没有任何侵入,实现了完全的关注点分离 #### 4.3.2 智能的上下文管理策略 `output_mode` 参数提供了对消息历史的精确控制: ```python if output_mode == "last_message": messages = messages[-1:] ``` 这允许开发者灵活选择: - **全量历史模式**(`full_history`):保留智能体输出的完整历史,提供完整上下文 - **最后消息模式**(`last_message`):仅保留最后一条消息,有效节省 token 消耗 这种灵活的上下文压缩策略,在长对话或多轮智能体调用场景中尤为重要,可以有效防止上下文爆炸。 #### 4.3.3 动态工具生成与绑定 系统会自动为每个智能体创建对应的 handoff 工具: ```python handoff_tools = [create_handoff_tool(agent_name=agent.name) for agent in agents] ``` 这些工具允许 supervisor 通过类似 `transfer_to_writer()` 或 `transfer_to_researcher()` 的函数调用来转移控制权,实现了: - **声明式调度**:调度逻辑由 LLM 决定,而非硬编码规则 - **可解释性**:每次转移都有明确的工具调用,便于追踪和调试 - **灵活性**:可以根据当前状态动态决定下一步调用哪个智能体 #### 4.3.4 统一的 Runnable 接口封装 每个智能体都被统一封装为 `RunnableCallable`: ```python builder.add_node(agent.name, _make_call_agent(...)) ``` 这种封装提供了多种优势: - **统一接口**:所有智能体都遵循相同的调用接口 - **状态管理**:状态由 LangGraph 自动管理,无需手动处理 - **异步支持**:同时支持同步和异步调用,适应不同场景 - **自动处理**:输入/输出状态转换自动完成,减少样板代码 #### 4.3.5 灵活的配置选项 系统支持多种配置选项,适应不同需求: - **多种提示格式**:支持字符串、SystemMessage 或可调用函数作为提示 - **结构化输出**:支持 JSON schema、TypedDict 或 Pydantic 类作为输出格式 - **状态模式**:可自定义状态结构,支持复杂的状态追踪和管理 - **并行工具调用控制**:可以针对不同模型配置是否支持并行工具调用 ## 5. 实践案例:笑话生成与研究专家 在 `01_supervisor_test.py` 中,我们实现了一个包含两个专业智能体的系统: ### 5.1 智能体创建 我们使用了两种不同的方式创建智能体: #### 5.1.1 功能型 API(Functional API) 笑话生成器使用功能型 API 创建: ```python @task def generate_joke(messages): """Generate a short joke (no tool calls).""" system_message = { "role": "system", "content": "You are a witty comedian. Write a short joke." } msg = model.invoke([system_message] + messages) return msg @entrypoint() def joke_agent(state): joke = generate_joke(state['messages']).result() messages = add_messages(state["messages"], [joke]) return {"messages": messages} joke_agent.name = "joke_agent" ``` #### 5.1.2 图形 API(Graph API) 研究专家使用图形 API 创建: ```python def web_search(query: str) -> str: """Search the web for information. (Mocked data here)""" return ( "Here are the headcounts for each of the FAANG companies in 2024:\n" # ... 模拟数据 ... ) research_agent = create_react_agent( model=model, tools=[web_search], name="research_expert", prompt=( "You are a world-class researcher. You have access to a 'web_search(query: str)' tool. " "Do not do any complicated math, just provide factual info from the web_search if needed." ), ) ``` ### 5.2 Supervisor 配置 我们创建了一个 Supervisor 来协调这两个智能体: ```python workflow = create_supervisor( [research_agent, joke_agent], model=model, prompt=( "You are the overall supervisor. You manage two specialized agents:\n" "1) joke_agent: for telling jokes.\n" "2) research_expert: for factual or data-related questions.\n\n" "If the user wants a joke AND some research data in the same query, " "you MUST call joke_agent first, get the joke, then call research_expert for the data. " "After both calls, provide a final combined response. " "Do not call more than one agent in a single LLM message; do it step by step." ), ) ``` ### 5.3 执行流程 当用户请求同时需要笑话和研究数据时,执行流程如下: 1. Supervisor 接收用户请求 2. Supervisor 分析请求,决定先调用 joke_agent 3. joke_agent 生成笑话并返回结果 4. Supervisor 接收笑话,然后调用 research_expert 5. research_expert 查询数据并返回结果 6. Supervisor 整合两个结果,生成最终回复 ## 6. 可视化与调试 我们使用 LangGraph 的可视化功能生成了工作流图表,保存在 `examples/graphs/1_supervisor_test_01.png`,这有助于理解和调试多智能体系统的工作流程。 ## 7. 总结 Supervisor 模式是一种高效组织多智能体系统的方法,它通过中央控制智能体协调专业智能体的工作,实现复杂任务的分解与协作。在我们的实现中,通过精心设计的 handoff 机制实现了智能体之间的控制权转移,确保系统的有序运行。 这种模式的优势在于: 1. **模块化**:每个智能体专注于特定领域,便于开发和维护 2. **可扩展性**:可以方便地添加新的专业智能体 3. **灵活性**:Supervisor 可以根据任务需求动态调用不同的智能体 4. **结果整合**:Supervisor 负责整合各个智能体的结果,提供一致的用户体验 5. **低耦合**:智能体之间通过消息传递交互,减少直接依赖 6. **可追踪性**:每次控制权转移都有明确的工具调用记录,便于调试和监控 7. **资源优化**:通过上下文管理策略,有效控制 token 消耗 8. **开发便捷**:统一的接口和自动化的状态管理,减少样板代码 通过本文的实践案例和深入分析,我们不仅展示了如何使用 LangGraph 和 LangChain 框架实现 Supervisor 模式,更揭示了背后的设计思想和最佳实践,为构建复杂的多智能体系统提供了宝贵参考。这些设计模式和技巧可以帮助开发者构建更加健壮、可维护和高效的智能体系统。 ================================================ FILE: instructions/02.supervisor_pattern_agent.md ================================================ # Supervisor 模式:多智能体协作的核心实现 (Agent 封装模式) ## 1. 引言 在人工智能领域,多智能体系统(Multi-Agent System)是一种将复杂任务分解为多个专业智能体协同完成的架构模式。本文将详细介绍我们在 Mentis 项目中实现的 Supervisor(监督者)模式,这是一种高效组织和协调多个智能体的方法。 ## 2. 多智能体系统的基本概念 多智能体系统由多个具有不同专业能力的智能体组成,每个智能体负责特定的任务领域。在这种系统中,智能体之间需要有效地协作和通信,以完成复杂的任务。 在我们的实现中,主要包含以下角色: - **Supervisor(监督者)**:负责任务分发、协调和结果整合的中央控制智能体 - **Specialized Agents(专业智能体)**:具有特定领域专长的执行智能体 ## 3. Supervisor 模式的工作流程 ### 3.1 基本工作流程 Supervisor 模式的工作流程如下: 1. 用户向系统提交请求 2. Supervisor 接收请求并进行任务分析 3. Supervisor 决定调用哪个专业智能体处理任务 4. 专业智能体执行任务并返回结果 5. Supervisor 接收结果,可能进一步调用其他智能体 6. Supervisor 整合所有结果并返回给用户 ### 3.2 控制权转移机制 Supervisor 模式的核心是控制权的转移机制。在我们的实现中,这通过 `handoff` 工具实现: 1. Supervisor 通过调用特定的 `handoff` 工具将控制权转移给目标智能体 2. 目标智能体完成任务后,通过 `handoff_back_messages` 将控制权返回给 Supervisor 3. 这种机制确保了在任何时刻只有一个智能体在处理任务,避免了冲突 ## 4. 基础架构:BaseAgent 类 在我们的重构中,我们引入了 `BaseAgent` 基类,作为所有智能体的基础。这种设计使得不同类型的智能体可以共享通用功能,同时保持各自的特性。 ### 4.1 BaseAgent 核心实现 ```python class BaseAgent: _PROMPT_TEMPLATE = """ You have access to the following tools: {tools} Use the above tools to answer the question at the end. """ def __init__( self, name: str, model: Union[BaseChatModel, LanguageModelLike], tools: Optional[List[Union[BaseTool, Callable]]] = None, prompt: Optional[Union[str, SystemMessage, Callable]] = None, checkpointer: Optional[Checkpointer] = None, max_context_messages: Optional[int] = None, # 限制最近消息数量 max_context_tokens: Optional[int] = None, # 限制总估计token数 model_name: Optional[str] = "gpt-4o-mini", # 用于未来token估计改进 ): # 初始化基本属性 self.name = name self.model = model self.tools = tools or [] self.prompt = prompt self.checkpointer = checkpointer self.max_context_messages = max_context_messages self.max_context_tokens = max_context_tokens self.model_name = model_name self._workflow = None self._agent = None ``` ### 4.2 上下文管理机制 `BaseAgent` 提供了智能的上下文管理机制,可以根据配置自动截断消息历史: ```python def _inject_context(self, state: Dict[str, Any]) -> Dict[str, Any]: """注入记忆并根据配置截断消息。""" memory = state.get("memory") or [] messages = state.get("messages", []) messages = self._truncate_messages(messages) memory_messages = [SystemMessage(content=chunk) for chunk in memory] state["messages"] = memory_messages + messages return state ``` ### 4.3 通用方法接口 `BaseAgent` 定义了所有智能体共享的核心方法接口: ```python def build(self) -> StateGraph: """构建工作流。""" def compile(self) -> CompiledStateGraph: """编译工作流。""" def invoke(self, state: Dict[str, Any]) -> Dict[str, Any]: """同步调用工作流。""" async def ainvoke(self, state: Dict[str, Any]) -> Dict[str, Any]: """异步调用工作流。""" ``` ## 5. ReactAgent 类实现 `ReactAgent` 是我们实现的基于 ReAct(Reasoning and Acting)模式的智能体,它继承自 `BaseAgent`,专注于推理和工具调用。 ### 5.1 ReactAgent 类设计 ```python class ReactAgent(BaseAgent): """ReAct Agent class for reasoning and acting with tools. This class provides a high-level interface for creating a ReAct agent workflow that can perform multi-step reasoning and tool calling. """ def __init__( self, model: LanguageModelLike, tools: Optional[List[Union[BaseTool, Callable]]] = None, prompt: Optional[str] = None, response_format: Optional[ Union[StructuredResponseSchema, tuple[str, StructuredResponseSchema]] ] = None, state_schema: StateSchemaType = AgentState, config_schema: Type[Any] = None, checkpointer: Optional[Checkpointer] = None, store: Optional[BaseStore] = None, interrupt_before: Optional[List[str]] = None, interrupt_after: Optional[List[str]] = None, debug: bool = False, version: Literal["v1", "v2"] = "v1", name: str = "react_agent", max_context_messages: Optional[int] = None, max_context_tokens: Optional[int] = None, model_name: Optional[str] = "gpt-4o-mini", ): # 调用父类初始化 super().__init__( name=name, model=model, tools=tools or [], prompt=prompt, checkpointer=checkpointer, max_context_messages=max_context_messages, max_context_tokens=max_context_tokens, model_name=model_name ) # 初始化ReactAgent特有属性 self.response_format = response_format self.state_schema = state_schema self.config_schema = config_schema self.store = store self.interrupt_before = interrupt_before self.interrupt_after = interrupt_after self.debug = debug self.version = version self._agent = None ``` ### 5.2 核心方法实现 #### 5.2.1 compile 方法 `compile` 方法负责编译 ReactAgent 工作流: ```python def compile(self) -> CompiledGraph: """构建 ReAct agent 工作流。 Returns: 编译后的 CompiledGraph """ # 如果_agent已经存在,直接返回,避免重复构建 if self._agent is not None: return self._agent _react_agent = create_react_agent( model=self.model, tools=self.tools, prompt=self.prompt, response_format=self.response_format, state_schema=self.state_schema, config_schema=self.config_schema, checkpointer=self.checkpointer, store=self.store, interrupt_before=self.interrupt_before, interrupt_after=self.interrupt_after, debug=self.debug, version=self.version, name=self.name, ) self._agent = CreateReactAgentWrapper(_react_agent, name=self.name, before_invoke=self.invoke, before_ainvoke=self.ainvoke) return self._agent ``` #### 5.2.2 invoke 和 ainvoke 方法 `invoke` 和 `ainvoke` 方法负责调用 ReactAgent 处理用户请求,并提供调试信息: ```python def invoke(self, state: Dict[str, Any]) -> Dict[str, Any]: """同步调用入口 (真正的 Agent 执行逻辑).""" # 打印调试信息 messages = state.get("messages", []) if messages: for i, msg in enumerate(messages, 1): type_str = type(msg).__name__ print(f"第 {i} 条消息 - {type_str} (Name: {msg.name}):") msg.pretty_print() # 上下文注入 state = self._inject_context(state) return state async def ainvoke(self, state: Dict[str, Any]) -> Dict[str, Any]: """异步调用入口.""" # 上下文注入 state = await self._inject_context(state) return state ``` ## 6. SupervisorAgent 类实现 `SupervisorAgent` 类继承自 `BaseAgent`,专注于协调多个智能体的工作。在重构后,它增加了规划功能,可以更有效地管理复杂任务。 ### 6.1 SupervisorAgent 类设计 ```python class SupervisorAgent(BaseAgent): """Supervisor class for managing multiple agents with planning capabilities. This class provides a high-level interface for creating a supervisor workflow that can manage and coordinate multiple agents. It also includes planning capabilities to create and manage a plan for complex tasks using a state-driven approach. The planning functionality is implemented using PlanningStateHandler and PlanningTool, which provide a more structured and flexible way to manage tasks compared to the previous TodolistTool approach. """ def __init__( self, agents: List[BaseAgent], model: LanguageModelLike, tools: Optional[List[Union[BaseTool, Callable]]] = None, prompt: Optional[str] = None, state_schema: StateSchemaType = AgentState, supervisor_name: str = "supervisor", checkpointer: Optional[Checkpointer] = None, output_mode: str = "last_message", # * full_history or last_message * enable_planning: bool = True, # * True or False * ): # 设置规划相关属性 self._enable_planning = enable_planning # 如果启用规划功能,设置状态模式为PlanningAgentState if self._enable_planning and state_schema == AgentState: state_schema = PlanningAgentState # 存储特定于智能体的属性 self.agents = agents self.output_mode = output_mode self.supervisor_name = supervisor_name self.state_schema = state_schema self.checkpointer = checkpointer self.tools = tools or [] self._workflow = None self._agent = None # 生成基础提示词 _final_prompt = self._PLANNING_PROMPT_TEMPLATE + "/n/n" + self._PLANNING_TOOL_TEMPLATE if self._enable_planning else self._PROMPT_TEMPLATE # 如果启用规划功能,添加规划工具 if self._enable_planning: tools = tools or [] tools.append(SimplePlanningTool()) # 初始化BaseAgent父类 super().__init__( name=supervisor_name, model=model, tools=tools, checkpointer=checkpointer, prompt=_final_prompt, ) ``` ### 6.2 核心方法实现 #### 6.2.1 build 方法 `build` 方法负责构建 Supervisor 工作流: ```python def build(self) -> StateGraph: """构建 supervisor 工作流。 Returns: 构建的 StateGraph """ if self._workflow is not None: return self._workflow self._workflow = create_supervisor( agents=self.agents, model=self.model, tools=self.tools, prompt=self.prompt, state_schema=self.state_schema, supervisor_name=self.supervisor_name, output_mode=self.output_mode, ) return self._workflow ``` ## 7. create_supervisor 函数实现 `create_supervisor` 函数是 SupervisorAgent 的核心依赖,它负责创建多智能体协作的工作流。 ```python def create_supervisor( agents: list[Pregel], *, model: LanguageModelLike, tools: list[BaseTool | Callable] | None = None, prompt: Prompt | None = None, response_format: Optional[ Union[StructuredResponseSchema, tuple[str, StructuredResponseSchema]] ] = None, state_schema: StateSchemaType = AgentState, config_schema: Type[Any] | None = None, output_mode: OutputMode = "last_message", add_handoff_back_messages: bool = True, supervisor_name: str = "supervisor", include_agent_name: AgentNameMode | None = None, ) -> StateGraph: # 检查智能体名称唯一性 agent_names = set() for agent in agents: if agent.name is None or agent.name == "LangGraph": raise ValueError( "Please specify a name when you create your agent..." ) if agent.name in agent_names: raise ValueError( f"Agent with name '{agent.name}' already exists. Agent names must be unique." ) agent_names.add(agent.name) # 为每个智能体创建 handoff 工具 handoff_tools = [create_handoff_tool(agent_name=agent.name) for agent in agents] all_tools = (tools or []) + handoff_tools # 绑定工具到模型 if _supports_disable_parallel_tool_calls(model): model = model.bind_tools(all_tools, parallel_tool_calls=False) else: model = model.bind_tools(all_tools) # 处理智能体名称显示方式 if include_agent_name: model = with_agent_name(model, include_agent_name) # 创建 supervisor 智能体 _react_agent = ReactAgent( name=supervisor_name, model=model, tools=all_tools, prompt=prompt, state_schema=state_schema, response_format=response_format, debug=False, ) supervisor_agent = _react_agent.compile() # 构建状态图 builder = StateGraph(state_schema, config_schema=config_schema) builder.add_node(supervisor_agent, destinations=tuple(agent_names) + (END,)) builder.add_edge(START, supervisor_agent.name) # 添加智能体节点和边 for agent in agents: # 如果智能体是 "ReactAgent" 或类似类型 if hasattr(agent, "get_agent") and callable(agent.get_agent): agent = agent.get_agent() # 获取编译后的子图 builder.add_node( agent.name, _make_call_agent( agent, output_mode, add_handoff_back_messages, supervisor_name, ), ) builder.add_edge(agent.name, supervisor_agent.name) return builder ``` ## 8. 实践案例 ### 8.1 使用 create_supervisor 函数(原始方式) 在 `01_supervisor_test.py` 中,我们使用原始的 `create_supervisor` 函数实现了一个包含两个专业智能体的系统: ```python workflow = create_supervisor( [research_agent, joke_agent], model=model, prompt=( "You are the overall supervisor. You manage two specialized agents:\n" "1) joke_agent: for telling jokes.\n" "2) research_expert: for factual or data-related questions.\n\n" "If the user wants a joke AND some research data in the same query, " "you MUST call joke_agent first, get the joke, then call research_expert for the data. " "After both calls, provide a final combined response. " "Do not call more than one agent in a single LLM message; do it step by step." ), ) # 编译得到一个可调用的"App" app = workflow.compile() ``` ### 8.2 使用 SupervisorAgent 类(封装方式) 在 `02_supervisor_agent_test.py` 中,我们使用封装的 `SupervisorAgent` 类实现了相同的功能,但增加了规划能力: ```python # 创建 SupervisorAgent 实例 supervisor = SupervisorAgent( agents=[research_agent, joke_agent], model=model, prompt=( "You are the overall supervisor. You manage two specialized agents:\n" "1) joke_agent: for telling jokes.\n" "2) research_expert: for factual or data-related questions.\n\n" "If the user wants a joke AND some research data in the same query, " "you MUST call joke_agent first, get the joke, then call research_expert for the data. " "After both calls, provide a final combined response. " "Do not call more than one agent in a single LLM message; do it step by step." ), enable_planning=True, # 启用规划功能 ) # 编译得到一个可调用的"App" app = supervisor.compile() ``` ### 8.3 两种方式的比较 两种实现方式在基本功能上相似,但使用 `SupervisorAgent` 类的方式有以下优势: 1. **更简洁的 API**:封装了复杂的参数和配置,提供了更简洁的接口 2. **更好的封装性**:将相关功能封装在一个类中,便于维护和扩展 3. **更好的可读性**:代码结构更清晰,意图更明确 4. **更好的可重用性**:可以方便地在不同项目中复用 5. **规划功能**:内置了任务规划能力,可以更有效地管理复杂任务 6. **上下文管理**:通过 BaseAgent 继承了智能的上下文管理机制 ## 9. 总结 在重构后的实现中,我们引入了以下关键改进: 1. **BaseAgent 基类**:提供了所有智能体共享的基础功能,如上下文管理、工作流构建等 2. **ReactAgent 重构**:现在继承自 BaseAgent,使用 CreateReactAgentWrapper 增强功能 3. **SupervisorAgent 重构**:现在继承自 BaseAgent,增加了规划功能 4. **统一的接口**:所有智能体类型现在共享相同的核心方法接口 5. **智能上下文管理**:可以根据配置自动截断消息历史,优化性能 Supervisor 模式是一种高效组织多智能体系统的方法,它通过中央控制智能体协调专业智能体的工作,实现复杂任务的分解与协作。在我们的重构实现中,通过引入 BaseAgent 基类和增强 SupervisorAgent 的规划能力,使得多智能体系统更加灵活、高效,同时保持了良好的可维护性和可扩展性。 这种模式特别适合以下场景: - 需要多种专业知识协作的复杂任务 - 需要动态决策调用不同专家的场景 - 需要结果整合和质量控制的任务流程 - 需要有计划地执行多步骤任务的场景 未来,我们将继续优化 Supervisor 模式的实现,增强其灵活性和可扩展性,并探索更多的应用场景。 ================================================ FILE: instructions/03.tavily_search_integration.md ================================================ # Tavily搜索工具集成:为多智能体系统提供实时信息能力 ## 1. 引言 在多智能体系统中,获取实时、准确的外部信息是提升系统实用性的关键因素。本文将详细介绍我们在 Mentis 项目中集成 Tavily 搜索工具的实现,这使得我们的智能体系统能够获取最新的网络信息,大幅提升了系统的实用价值。 ## 2. Tavily 搜索工具概述 Tavily 是一个专为 AI 应用设计的搜索 API,它提供了高质量、结构化的网络搜索结果。在我们的实现中,Tavily 工具具有以下特点: - **实时性**:能够获取最新的网络信息 - **结构化输出**:返回格式化的搜索结果,便于智能体处理 - **可配置性**:支持多种参数配置,如搜索深度、结果数量等 - **多媒体支持**:可选择性地包含图片等多媒体内容 ## 3. Tavily 工具的实现 ### 3.1 核心代码分析 在 `tavily_tools.py` 中,我们实现了 `TavilySearchResults` 类,它继承自 LangChain 的 `BaseTool`: ```python class TavilySearchResults(BaseTool): """Tool that queries the Tavily Search API and gets back json.""" name: str = "tavily_search_results_json" description: str = ( "A search engine optimized for comprehensive, accurate, and trusted results. " "Useful for when you need to answer questions about current events. " "Input should be a search query." ) args_schema: Type[BaseModel] = TavilyInput max_results: int = 5 """Max search results to return, default is 5""" search_depth: str = "advanced" """The depth of the search. It can be "basic" or "advanced"""" include_domains: List[str] = [] """A list of domains to specifically include in the search results.""" exclude_domains: List[str] = [] """A list of domains to specifically exclude from the search results.""" include_answer: bool = False """Include a short answer to original query in the search results.""" include_raw_content: bool = False """Include cleaned and parsed HTML of each site search results.""" include_images: bool = False """Include a list of query related images in the response.""" api_wrapper: TavilySearchAPIWrapper = Field(default_factory=TavilySearchAPIWrapper) response_format: Literal["content_and_artifact"] = "content_and_artifact" ``` ### 3.2 搜索执行方法 `TavilySearchResults` 类提供了同步和异步两种搜索方法: ```python def _run( self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> Tuple[Union[List[Dict[str, str]], str], Dict]: """Use the tool.""" try: raw_results = self.api_wrapper.raw_results( query, self.max_results, self.search_depth, self.include_domains, self.exclude_domains, self.include_answer, self.include_raw_content, self.include_images, ) except Exception as e: return repr(e), {} return self.api_wrapper.clean_results(raw_results["results"]), raw_results async def _arun( self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> Tuple[Union[List[Dict[str, str]], str], Dict]: """Use the tool asynchronously.""" # 异步实现... ``` ## 4. 在多智能体系统中集成 Tavily 工具 ### 4.1 创建研究型智能体 在我们的多智能体系统中,我们创建了一个专门的研究型智能体,它使用 Tavily 搜索工具获取实时信息: ```python # 创建Tavily搜索工具 tavily_search = TavilySearchResults( max_results=3, include_answer=True, include_raw_content=False, include_images=False, search_depth="advanced" ) research_agent = create_react_agent( model=model, tools=[tavily_search], name="research_expert", prompt=( "You are a world-class researcher. You have access to the 'tavily_search_results_json' tool " "which can search the web for real-time information. " "When asked a question, use this tool to find accurate and up-to-date information. " "Summarize the search results in a clear and concise manner. " "Always cite your sources by including the URLs from the search results." ), ) ``` ### 4.2 与 Supervisor 集成 研究型智能体作为专业智能体,被集成到 Supervisor 模式中: ```python # 创建 SupervisorAgent 实例 supervisor = SupervisorAgent( agents=[research_agent, joke_agent], model=model, prompt=( "You are the overall supervisor. You manage two specialized agents:\n" "1) joke_agent: for telling jokes.\n" "2) research_expert: for factual or data-related questions using real-time web search.\n\n" "If the user wants a joke, call joke_agent.\n" "If the user wants factual information or research data, call research_expert.\n" "If the user wants a joke AND some research data in the same query, " "you MUST call joke_agent first, get the joke, then call research_expert for the data. " "After both calls, provide a final combined response. " "Do not call more than one agent in a single LLM message; do it step by step." ), ) ``` ## 5. 实践案例 ### 5.1 只询问研究数据 当用户只询问研究数据时,Supervisor 会直接调用研究型智能体: ```python # 示例2:只询问研究数据 result2 = app.invoke({"messages": [{"role": "user", "content": "谁是现任美国总统?"}]}) ``` 在这种情况下,研究型智能体会使用 Tavily 搜索工具获取最新信息,并返回结构化的回答,包括引用的来源。 ### 5.2 混合查询 当用户同时需要笑话和研究数据时,Supervisor 会先调用笑话智能体,然后调用研究型智能体: ```python # 示例3:同时询问笑话和研究数据 result3 = app.invoke({"messages": [{"role": "user", "content": "讲个关于人工智能的笑话,然后告诉我什么是大型语言模型"}]}) ``` 这种情况下,Supervisor 会协调两个智能体的工作,并整合它们的结果。 ## 6. 可视化与调试 我们使用 LangGraph 的可视化功能生成了工作流图表,保存在 `examples/graphs/03_tavily_tools_test.png`。这个图表展示了包含 Tavily 搜索工具的多智能体系统的工作流程,有助于理解和调试系统。 ## 7. 总结 Tavily 搜索工具的集成为我们的多智能体系统带来了以下优势: 1. **实时信息获取**:系统能够获取最新的网络信息,不再局限于模型训练数据的时间范围 2. **信息准确性提升**:通过引用可靠的网络来源,提高了系统回答的准确性 3. **功能扩展**:使系统能够回答关于最新事件、数据和信息的问题 4. **灵活配置**:可以根据需要调整搜索参数,优化搜索结果 通过 Tavily 搜索工具的集成,我们的多智能体系统从一个封闭的知识系统转变为一个能够获取实时信息的开放系统,大大提升了系统的实用价值和应用范围。 未来,我们计划进一步优化搜索工具的使用策略,提高搜索效率和结果质量,并探索更多外部工具的集成,使系统能够处理更复杂的任务。 ================================================ FILE: instructions/04.react_agent.md ================================================ # ReactAgent:基于ReAct方法论的多步推理与工具调用框架 ## 1. 引言 ReactAgent是一个基于ReAct方法论的智能体框架,它能够通过多步推理和工具调用来解决复杂问题。本文将详细介绍ReactAgent的核心概念、工作原理、实现方式以及在实际应用中的使用方法。 ## 2. ReactAgent的核心概念 ### 2.1 什么是ReAct方法论 ReAct(Reasoning + Acting)是一种结合推理和行动的AI问题解决方法论,它包含两个核心步骤: 1. **推理(Reasoning)**:让语言模型进行思考,分析问题,并决定下一步行动。 2. **行动(Acting)**:执行具体的工具调用,获取外部信息或执行特定操作。 这两个步骤可以多次循环往复,直到问题被解决。ReAct方法论特别适合处理需要多步骤、多工具协作的复杂问题。 ### 2.2 ReactAgent与LangGraph的关系 ReactAgent是基于LangGraph框架实现的,它利用LangGraph的图结构来编排推理和行动的流程。在LangGraph中,ReactAgent被表示为一个包含多个节点和边的有向图: - **节点(Node)**:包括Agent节点(负责推理)和Tools节点(负责执行工具调用) - **边(Edge)**:定义节点之间的转换条件,例如当Agent生成工具调用时,流程转向Tools节点 ## 3. ReactAgent的实现 ### 3.1 ReactAgent类的设计 在我们的实现中,ReactAgent类继承自LangGraph的Pregel类,提供了一个高级接口来创建和管理ReAct工作流: ```python class ReactAgent(Pregel): """ReAct Agent class for reasoning and acting with tools. This class provides a high-level interface for creating a ReAct agent workflow that can perform multi-step reasoning and tool calling. """ def __init__( self, model: LanguageModelLike, tools: Optional[List[Union[BaseTool, Callable]]] = None, prompt: Optional[str] = None, response_format: Optional[ Union[StructuredResponseSchema, tuple[str, StructuredResponseSchema]] ] = None, state_schema: StateSchemaType = AgentState, config_schema: Type[Any] = None, interrupt_before: Optional[List[str]] = None, interrupt_after: Optional[List[str]] = None, debug: bool = False, version: Literal["v1", "v2"] = "v1", name: str = "react_agent", ): # 初始化代码... ``` ### 3.2 核心方法 ReactAgent类提供了以下核心方法: 1. **build()**: 构建ReAct工作流图 2. **compile()**: 编译工作流为可执行应用 3. **invoke()**: 同步执行ReAct工作流 4. **ainvoke()**: 异步执行ReAct工作流 5. **stream()**: 流式执行,可以获取中间状态 6. **get_graph()**: 获取底层图结构,用于可视化或调试 ### 3.3 与create_react_agent的关系 ReactAgent类内部使用了LangGraph提供的`create_react_agent`函数来构建工作流图。这个函数自动处理了: - 创建Agent节点(用于调用语言模型) - 创建Tools节点(用于执行工具调用) - 在节点之间建立连接 - 处理状态管理和流程控制 ## 4. 使用ReactAgent解决复杂问题 ### 4.1 基本使用流程 使用ReactAgent的基本流程如下: 1. **初始化ReactAgent**:提供语言模型和工具 2. **编译工作流**:调用compile()方法 3. **准备初始状态**:通常包含用户的问题 4. **执行或流式执行**:使用invoke()或stream()方法 5. **处理结果**:分析最终状态或中间状态 ### 4.2 集成Tavily搜索工具 在实际应用中,我们经常将ReactAgent与Tavily搜索工具集成,使其能够获取实时网络信息: ```python # 创建Tavily搜索工具 tavily_search = TavilySearchResults( max_results=3, include_answer=True, include_raw_content=True, include_images=False, search_depth="advanced" ) # 创建ReactAgent实例 react_agent = ReactAgent( model=model, tools=[tavily_search], prompt=( "你是一位专业的研究分析师,擅长分析复杂问题并提供深入见解。\n" "当面对复杂问题时,请遵循以下REACT方法论:\n" "1. 分解问题:将复杂问题分解为更小的子问题\n" "2. 制定计划:确定需要搜索哪些信息,以及搜索的顺序\n" "3. 执行搜索:使用tavily_search_results_json工具执行搜索\n" "4. 分析结果:分析搜索结果,确定是否需要进一步搜索\n" "5. 综合信息:将所有搜索结果综合成一个连贯的回答\n" ), ) # 编译工作流 agent = react_agent.compile() ``` ### 4.3 处理用户输入 以下是处理用户输入的示例代码: ```python # 准备初始状态 initial_state = { "messages": [HumanMessage(content=user_input)] } # 流式执行并获取中间状态 for partial_state in react_agent.stream(initial_state, stream_mode="values"): # 处理中间状态 messages = partial_state.get("messages", []) if messages: latest_message = messages[-1] # 记录或显示最新消息 log_agent_actions({"messages": [latest_message]}) # 处理最终结果 final_state = partial_state # 最后一个状态就是最终状态 ``` ## 5. ReactAgent的优势与应用场景 ### 5.1 优势 - **多步推理**:能够分解复杂问题,逐步解决 - **工具调用**:可以集成各种外部工具,扩展能力边界 - **状态管理**:自动管理对话状态和中间结果 - **可视化**:支持工作流可视化,便于调试和理解 - **流式执行**:可以获取中间状态,实现更好的用户体验 ### 5.2 应用场景 - **研究助手**:帮助用户研究复杂问题,获取最新信息 - **数据分析**:分步骤处理数据分析任务 - **决策支持**:通过多步推理和信息收集辅助决策 - **教育辅导**:分解复杂概念,逐步引导学习 ## 6. 实际案例:研究特斯拉2025年发展预期 以下是使用ReactAgent研究特斯拉2025年发展预期的实际案例: 1. **问题分解**:将问题分解为新车型计划、销量目标、技术创新和市场扩张战略 2. **执行搜索**:针对每个子问题执行Tavily搜索 3. **分析结果**:分析每个搜索的结果,提取关键信息 4. **综合信息**:将所有信息整合为一个全面的分析报告 通过这种方式,ReactAgent能够提供比单次查询更全面、更深入的分析结果。 ## 7. 总结 ReactAgent是一个强大的基于ReAct方法论的智能体框架,它通过多步推理和工具调用来解决复杂问题。在实际应用中,ReactAgent特别适合需要分步骤思考、收集信息和综合分析的任务。通过与Tavily等工具的集成,ReactAgent能够获取实时信息,大幅提升其实用价值。 在未来的开发中,我们将继续优化ReactAgent的性能,增强其推理能力,并集成更多实用工具,使其能够应对更广泛的应用场景。 ================================================ FILE: instructions/05.react_agent_user_input.md ================================================ # ReactAgent与用户交互:构建交互式研究助手 ## 1. 引言 本文将介绍如何使用ReactAgent构建一个能够与用户进行交互的研究助手,该助手能够接收用户输入,使用搜索工具获取信息,并提供深入的分析结果。这种交互式助手特别适合需要实时信息和多轮对话的场景。 ## 2. 交互式研究助手的核心概念 ### 2.1 用户输入处理 交互式研究助手需要能够处理用户的自然语言输入,理解用户的意图,并将其转化为可执行的搜索查询或其他操作。这涉及到: 1. **输入解析**:分析用户输入,提取关键信息和查询意图 2. **查询重构**:将用户的自然语言问题转化为更有效的搜索查询 3. **上下文维护**:在多轮对话中保持对话上下文的连贯性 ### 2.2 搜索工具集成 研究助手的核心功能是能够获取和分析信息,这通常通过集成各种搜索工具来实现: 1. **Tavily搜索**:提供高质量的网络搜索结果,支持深度搜索模式 2. **结果处理**:对搜索结果进行过滤、排序和整合,提取最相关的信息 3. **多次搜索策略**:对复杂问题进行分解,执行多次有针对性的搜索 ## 3. 实现交互式研究助手 ### 3.1 基本架构 交互式研究助手的基本架构包括: ``` 用户输入 → ReactAgent → 搜索工具 → 结果分析 → 回复生成 → 用户 ``` 这个流程可以多次循环,形成多轮对话。 ### 3.2 ReactAgent配置 以下是创建交互式研究助手的核心代码: ```python def create_react_agent_instance(): """创建并返回ReactAgent实例""" react_agent = ReactAgent( model=model, tools=[tavily_search], name="research_assistant", # 提示词强调分解问题、多步思考和综合信息 prompt=( "你是一位专业的研究分析师,擅长分析复杂问题并提供深入见解。\n" "你有一个强大的工具'tavily_search_results_json'可以搜索网络获取实时信息。\n\n" "当面对复杂问题时,请遵循以下REACT方法论:\n" "1. 分解问题:将复杂问题分解为更小的子问题\n" "2. 制定计划:确定需要搜索哪些信息,以及搜索的顺序\n" "3. 执行搜索:使用tavily_search_results_json工具执行搜索\n" "4. 分析结果:分析搜索结果,确定是否需要进一步搜索\n" "5. 综合信息:将所有搜索结果综合成一个连贯的回答\n\n" "重要提示:\n" "- 不要一次性搜索过于宽泛的问题\n" "- 对于复杂问题,进行多次有针对性的搜索\n" "- 每次搜索后评估结果,决定下一步行动\n" "- 在最终回答中引用来源,包括搜索结果中的URL\n" "- 清晰地展示你的思考过程,包括问题分解和计划制定\n" ), ) return react_agent ``` ### 3.3 Tavily搜索工具配置 ```python tavily_search = TavilySearchResults( max_results=3, include_answer=True, include_raw_content=True, # 包含原始内容,便于分析 include_images=False, search_depth="advanced" # 使用高级搜索深度 ) ``` ### 3.4 用户交互循环 用户交互循环的核心是通过`stream`方法获取中间状态,并实时显示Agent的思考过程: ```python def process_user_query(query): # 创建ReactAgent实例 react_agent = create_react_agent_instance() agent = react_agent.compile() # 准备输入 inputs = { "messages": [HumanMessage(content=query)] } # 使用stream方法逐步获取中间状态 final_state = None for partial_state in react_agent.stream(inputs, stream_mode="values"): # 保存最终状态 final_state = partial_state # 获取最新消息并记录 messages = partial_state.get("messages", []) if messages: latest_message = messages[-1] log_agent_actions({"messages": [latest_message]}) # 返回最终回答 return final_state ``` ## 4. 最佳实践与优化策略 ### 4.1 提示词优化 提示词对研究助手的性能至关重要,应包含以下要素: 1. **角色定义**:明确助手的专业身份和能力 2. **方法论指导**:提供结构化的问题解决方法 3. **工具使用指南**:说明如何有效使用搜索工具 4. **输出格式要求**:规定回答的结构和引用方式 ### 4.2 搜索策略优化 为提高搜索效率和结果质量,可采用以下策略: 1. **渐进式搜索**:从一般到具体,逐步缩小搜索范围 2. **多角度查询**:使用不同的关键词和表述方式进行搜索 3. **结果验证**:通过交叉检查多个来源验证信息的准确性 4. **深度参数调整**:根据问题复杂度调整搜索深度参数 ### 4.3 用户体验优化 提升用户体验的关键点包括: 1. **透明的思考过程**:展示Agent的推理过程,增强可信度 2. **实时反馈**:通过流式输出提供即时反馈 3. **引用来源**:清晰标注信息来源,便于用户进一步探索 4. **交互式引导**:在复杂问题上引导用户提供更多上下文或澄清问题 ## 5. 应用场景 交互式研究助手适用于多种场景: 1. **学术研究**:帮助研究人员快速获取和分析相关文献 2. **市场分析**:收集和整合市场趋势、竞争对手信息 3. **新闻摘要**:汇总和分析最新新闻事件 4. **技术调研**:探索新技术、框架或工具的特性和评价 5. **教育辅助**:为学生提供学习资料和解答问题 ## 6. 总结 ReactAgent结合用户交互和搜索工具,可以构建功能强大的研究助手,能够处理复杂查询并提供深入分析。通过优化提示词、搜索策略和用户体验,可以进一步提升助手的性能和实用性。未来的发展方向包括集成更多专业数据源、增强多模态能力,以及提供更个性化的信息服务。 ================================================ FILE: instructions/06.web_extraction_tools.md ================================================ # 网页提取工具:FireCrawl与Jina的集成与应用 ## 1. 引言 网页内容提取是智能体系统中的重要能力,它使智能体能够从互联网获取、分析和处理结构化和非结构化的网页内容。本文将详细介绍如何在Mentis框架中集成和使用FireCrawl和Jina两种强大的网页提取工具,以实现高效的网站结构分析和内容提取。 ## 2. 网页提取工具的核心概念 ### 2.1 网页提取的两个关键步骤 高效的网页内容提取通常包含两个关键步骤: 1. **网站结构分析**:了解网站的组织结构、页面之间的链接关系,以及重要页面的位置。 2. **内容提取**:从特定页面中提取有价值的文本、图像或其他结构化信息。 ### 2.2 FireCrawl与Jina的角色分工 在Mentis框架中,我们使用两种工具来分别处理这两个步骤: 1. **FireCrawl**:专注于网站结构分析,能够爬取网站的页面结构和链接关系。 2. **Jina**:专注于内容提取,能够从特定URL获取干净、结构化的内容。 ## 3. FireCrawlTool的实现与使用 ### 3.1 FireCrawlTool的基本结构 FireCrawlTool是对FireCrawl API的封装,提供了网站爬取和内容分析的能力: ```python class FireCrawlTool(BaseTool): """Tool that uses FireCrawl API to crawl or scrape web content.""" name: str = "firecrawl_tool" description: str = ( "A web crawler and scraper that extracts content from websites. " "Useful for when you need to analyze the content of a specific website or webpage. " "Input should be a URL to crawl or scrape." ) args_schema: Type[BaseModel] = FireCrawlInput api_key: Optional[str] = None api_url: Optional[str] = None mode: str = "crawl" params: Dict[str, Any] = Field(default_factory=dict) ``` ### 3.2 FireCrawlTool的配置选项 FireCrawlTool提供了多种配置选项: 1. **mode**:工作模式,可选值包括: - `crawl`:爬取网站结构和链接 - `scrape`:提取特定页面的内容 - `map`:生成网站地图 2. **params**:额外参数,常用的包括: - `max_pages`:限制爬取的最大页面数量 - `max_depth`:限制爬取的最大深度 - `follow_links`:是否跟踪页面中的链接 ### 3.3 使用FireCrawlTool爬取网站结构 以下是使用FireCrawlTool爬取网站结构的示例代码: ```python # 创建FireCrawl工具 - 用于网站结构分析 firecrawl_tool = FireCrawlTool( mode="crawl", # 使用爬取模式 params={ "max_pages": 5, # 限制爬取页面数量 } ) # 在Agent中使用该工具 react_agent = create_react_agent( model=model, tools=[firecrawl_tool], name="web_crawler", prompt="你是一位网站结构分析专家..." ) ``` ## 4. JinaSearch的实现与使用 ### 4.1 JinaSearch的基本功能 JinaSearch是LangChain提供的一个工具,能够从网页中提取干净、可读的内容,去除广告、导航栏等干扰元素: ```python from langchain_community.tools import JinaSearch # 创建Jina Reader工具 - 用于内容提取 jina_reader_tool = JinaSearch() ``` ### 4.2 使用JinaSearch提取网页内容 JinaSearch特别适合在确定了目标页面后,提取其中的核心内容: ```python # 在Agent中结合FireCrawl和Jina react_agent = create_react_agent( model=model, tools=[firecrawl_tool, jina_reader_tool], name="web_extraction_expert", prompt="你是一位网页内容分析专家..." ) ``` ## 5. 网页提取的最佳实践 ### 5.1 两阶段提取策略 为了高效地提取网页内容,建议采用两阶段策略: 1. **第一阶段**:使用FireCrawlTool爬取网站结构,了解网站的组织方式和重要页面。 2. **第二阶段**:根据第一阶段的结果,使用JinaSearch有针对性地提取重要页面的内容。 ### 5.2 提示词优化 为了引导Agent正确使用这两个工具,提示词应该明确指出工具的使用顺序和方法: ```python prompt = ( "你是一位专业的网页内容分析专家,擅长提取和分析网站结构与内容。\n" "你有两个强大的工具:\n" "1. 'firecrawl_tool': 用于爬取网站结构和下级页面\n" "2. 'jina_reader_tool': 用于从特定URL提取结构化内容\n\n" "当面对网站分析任务时,请遵循以下方法论:\n" "1. 先使用firecrawl_tool了解网站结构\n" "2. 再使用jina_reader_tool提取关键页面内容\n" "3. 最后整合信息提供分析结果" ) ``` ### 5.3 处理大型网站的策略 对于大型网站,可以采用以下策略: 1. **限制爬取范围**:设置合理的`max_pages`和`max_depth`参数 2. **分批处理**:先获取网站结构,然后每次只处理1-3个重要页面 3. **内容摘要**:对提取的内容进行摘要,减少token消耗 ## 6. 实际应用案例 ### 6.1 分析LangGraph文档网站 以下是使用FireCrawl和Jina分析LangGraph文档网站的示例: ```python # 定义输入 inputs = { "messages": [ {"role": "user", "content": "爬取LangGraph文档网站的每个章节的内容(https://langchain-ai.github.io/langgraph/how-tos/) "} ] } # 使用stream方法逐步获取中间状态 final_state = None for partial_state in react_agent.stream(inputs, stream_mode="values"): # 处理中间状态... pass ``` ### 6.2 结果分析与处理 Agent会首先使用FireCrawl获取网站结构,然后使用Jina提取重要页面的内容,最后整合信息提供分析结果: 1. **网站结构分析**:识别主要章节和子页面 2. **内容提取**:获取每个章节的详细内容 3. **信息整合**:将内容组织成结构化的文档或摘要 ## 7. 总结 FireCrawl和Jina的结合为智能体提供了强大的网页内容提取能力。通过两阶段提取策略,可以高效地分析网站结构并提取有价值的内容。这种能力使智能体能够从互联网获取实时信息,为用户提供更加全面和准确的回答。 未来的发展方向包括增强对JavaScript渲染页面的支持、提高内容提取的准确性,以及集成更多专业领域的内容分析能力。 ================================================ FILE: instructions/07.web_extraction_with_filesystem.md ================================================ # 网页提取与文件系统集成:构建内容采集与存储系统 ## 1. 引言 在智能体系统中,网页内容提取通常需要与文件系统操作相结合,以便将提取的内容持久化存储。本文将详细介绍如何在Mentis框架中集成网页提取工具和文件系统工具,并使用SupervisorAgent协调多个专业智能体,构建一个完整的内容采集与存储系统。 ## 2. 系统架构设计 ### 2.1 三层架构模式 我们采用三层架构设计,包括: 1. **Supervisor层**:负责协调和管理其他智能体,接收用户指令并分配任务 2. **Research层**:负责网页内容提取,包括网站结构分析和内容提取 3. **FileSystem层**:负责文件操作,包括内容保存、读取和目录管理 ### 2.2 智能体角色分工 系统中的三个智能体各自承担不同的职责: 1. **SupervisorAgent**:总协调者,负责理解用户需求,并将任务分配给适当的专业智能体 2. **Research Agent**:网页内容分析专家,负责使用FireCrawl和Jina工具提取网页内容 3. **FileSystem Agent**:文件系统管理专家,负责将提取的内容保存到本地文件系统 ## 3. 组件实现 ### 3.1 Research Agent实现 Research Agent负责网页内容提取,使用FireCrawl和Jina工具: ```python # 创建FireCrawl工具 - 用于网站结构分析 firecrawl_tool = FireCrawlTool( mode="crawl", # 使用爬取模式 params={ "max_pages": 5, # 限制爬取页面数量 } ) # 创建Jina Reader工具 - 用于内容提取 jina_reader_tool = JinaSearch() # 创建Research Agent research_agent = create_react_agent( model=model, tools=[firecrawl_tool, jina_reader_tool], name="research_agent", prompt=( "你是一位专业的网页内容分析专家,擅长提取和分析网站结构与内容。\n" "你有两个强大的工具...\n" # 提示词内容 ), ) ``` ### 3.2 FileSystem Agent实现 FileSystem Agent负责文件操作,使用LangChain的FileManagementToolkit: ```python # 设置文件系统工具的根目录 output_dir = os.path.join(os.path.dirname(__file__), "output") os.makedirs(output_dir, exist_ok=True) # 创建文件系统工具集 filesystem_toolkit = FileManagementToolkit( root_dir=output_dir, selected_tools=["write_file", "read_file", "list_directory"] ) # 获取文件系统工具 filesystem_tools = filesystem_toolkit.get_tools() # 创建FileSystem Agent filesystem_agent = create_react_agent( model=model, tools=filesystem_tools, name="filesystem_agent", prompt=( "你是一位专业的文件系统管理专家,负责将网页内容保存到本地文件系统。\n" "你有以下工具可以使用...\n" # 提示词内容 ), ) ``` ### 3.3 SupervisorAgent实现 SupervisorAgent负责协调Research Agent和FileSystem Agent: ```python # 创建Supervisor Agent supervisor = SupervisorAgent( agents=[research_agent, filesystem_agent], model=model, prompt=( "你是一个智能助手的总协调者,负责管理两个专业智能体:\n" "1) research_agent: 网页内容分析专家,可以爬取和分析网站内容\n" "2) filesystem_agent: 文件系统管理专家,可以将内容保存到本地文件系统\n\n" # 提示词内容 ), ) # 创建内存存储器用于保存对话状态 memory_saver = MemorySaver() # 编译得到一个可调用的"App",添加checkpointer实现记忆功能 app = supervisor.compile(checkpointer=memory_saver) ``` ## 4. 工作流程 ### 4.1 基本工作流程 系统的基本工作流程如下: 1. **用户请求**:用户提出网页内容提取和保存的请求 2. **Supervisor分析**:SupervisorAgent分析用户请求,确定需要调用哪个专业智能体 3. **内容提取**:如果需要提取网页内容,SupervisorAgent调用Research Agent 4. **内容保存**:如果需要保存内容,SupervisorAgent将Research Agent的结果传递给FileSystem Agent 5. **结果返回**:SupervisorAgent将最终结果返回给用户 ### 4.2 上下文管理策略 为了有效管理上下文长度,系统采用以下策略: 1. **分批处理**:对于大型网站,采用分批处理策略,每次只处理少量页面 2. **内容摘要**:对于大型内容,进行摘要处理,减少传递的token数量 3. **先保存再处理**:对于多页面内容,采用先保存再处理的策略,减轻上下文负担 ## 5. 提示词设计 ### 5.1 SupervisorAgent提示词 SupervisorAgent的提示词强调任务分配和协调: ``` 你是一个智能助手的总协调者,负责管理两个专业智能体: 1) research_agent: 网页内容分析专家,可以爬取和分析网站内容 2) filesystem_agent: 文件系统管理专家,可以将内容保存到本地文件系统 你的工作流程如下: 1. 分析用户请求,确定是需要网页内容提取还是文件操作,或两者都需要 2. 如果需要网页内容提取,调用research_agent获取网页内容 3. 如果需要将提取的内容保存到文件,调用filesystem_agent进行保存 4. 如果用户同时需要提取内容并保存,先调用research_agent获取内容,再调用filesystem_agent保存内容 重要规则: - 不要在一个消息中同时调用多个智能体,必须一步一步来 - 当调用filesystem_agent保存内容时,必须提供完整的内容和建议的文件名 - 确保在最终回复中告知用户内容已成功提取和/或保存 ``` ### 5.2 Research Agent提示词 Research Agent的提示词强调网页内容提取的方法论: ``` 你是一位专业的网页内容分析专家,擅长提取和分析网站结构与内容。 你有两个强大的工具: 1. 'firecrawl_tool': 用于爬取网站结构和下级页面 2. 'jina_reader_tool': 用于从特定URL提取结构化内容,获取干净可读的内容 当面对网站分析任务时,请遵循以下方法论: 1. 分析任务: 明确需要从网站获取什么信息 2. 网站结构分析: 使用firecrawl_tool爬取网站结构,了解可用页面 3. 内容提取: 根据网站结构,使用jina_reader_tool从关键页面提取内容 4. 信息整合: 将提取的内容整合成有条理的分析结果 ``` ### 5.3 FileSystem Agent提示词 FileSystem Agent的提示词强调文件操作和内容保存: ``` 你是一位专业的文件系统管理专家,负责将网页内容保存到本地文件系统。 你有以下工具可以使用: 1. 'write_file': 用于将内容写入文件 2. 'read_file': 用于读取文件内容 3. 'list_directory': 用于列出目录内容 当接收到保存内容的请求时,请遵循以下方法论: 1. 分析内容: 确定内容的类型和结构 2. 确定文件名: 根据内容类型和来源创建合适的文件名 3. 保存内容: 使用write_file工具将内容保存到文件中 4. 验证保存: 使用read_file或list_directory工具验证内容已正确保存 ``` ## 6. 记忆功能实现 ### 6.1 使用MemorySaver实现记忆 系统使用LangGraph的MemorySaver实现对话状态的持久化: ```python # 创建内存存储器用于保存对话状态 memory_saver = MemorySaver() # 编译得到一个可调用的"App",添加checkpointer实现记忆功能 app = supervisor.compile(checkpointer=memory_saver) ``` ### 6.2 记忆功能的应用场景 记忆功能在以下场景中特别有用: 1. **多轮对话**:在多轮对话中保持上下文连贯性 2. **长时间任务**:对于需要长时间处理的任务,可以保存中间状态 3. **断点续传**:支持任务的暂停和恢复 ## 7. 应用案例 ### 7.1 提取并保存LangGraph文档 以下是一个完整的应用案例,提取并保存LangGraph文档: ```python # 用户请求 inputs = { "messages": [ HumanMessage(content="请爬取LangGraph文档网站(https://langchain-ai.github.io/langgraph/how-tos/)的内容,并保存为Markdown文件") ] } # 执行工作流 final_state = None for partial_state in app.stream(inputs, stream_mode="values"): # 处理中间状态... final_state = partial_state # 记录状态 log_agent_actions(partial_state) # 最终结果 print("\n最终结果:") if final_state and final_state.get("messages"): for message in final_state["messages"]: if isinstance(message, AIMessage) and not message.tool_calls: print(message.content) ``` ## 8. 总结 网页提取与文件系统集成是构建完整内容采集系统的关键。通过SupervisorAgent协调Research Agent和FileSystem Agent,我们可以实现网页内容的提取、分析和持久化存储。这种多智能体协作模式不仅提高了系统的模块化程度,也使得每个智能体可以专注于自己的专业领域,从而提高整体系统的效率和质量。 未来的发展方向包括增强对复杂网站的处理能力、支持更多文件格式的存储和处理,以及集成数据库存储以支持更大规模的内容管理。 ================================================ FILE: instructions/08.react_agent_tool_registry.md ================================================ # 工具注册机制与ReactAgent集成:构建可扩展的智能体系统 ## 1. 引言 工具注册机制是构建可扩展智能体系统的关键组件,它允许我们以统一的方式管理和使用各种工具,并将这些工具与ReactAgent集成。本文将详细介绍Mentis框架中的工具注册机制,包括工具注册、分类管理以及与ReactAgent的集成方式。 ## 2. 工具注册机制的核心概念 ### 2.1 工具注册的意义 工具注册机制提供了以下优势: 1. **统一管理**:集中管理所有可用工具,避免重复创建和配置 2. **分类组织**:按功能和用途对工具进行分类,便于查找和使用 3. **动态加载**:支持动态注册和加载工具,提高系统的灵活性 4. **简化集成**:简化工具与Agent的集成过程,只需从注册表中获取工具列表 ### 2.2 工具分类体系 在Mentis框架中,我们使用`ToolCategory`枚举定义了工具的分类体系: ```python class ToolCategory(Enum): SEARCH = "Search" CODE_INTERPRETER = "Code Interpreter" WEB_BROWSING = "Web Browsing" DATABASE = "Database" FILE_SYSTEM = "FileSystem" OTHER = "Other" ``` 这种分类体系使我们能够根据任务需求选择特定类别的工具,提高工具使用的针对性和效率。 ## 3. 工具注册机制的实现 ### 3.1 全局工具注册表 工具注册机制的核心是一个全局工具注册表,它是一个字典,用于存储所有已注册的工具及其分类信息: ```python # 全局工具注册表 _registered_tools = {} ``` ### 3.2 工具注册函数 `register_tool`函数用于将工具注册到全局注册表中: ```python def register_tool(tool: Tool, category: ToolCategory) -> None: """注册一个工具到全局字典中,带有分类信息""" if tool.name in _registered_tools: raise ValueError(f"工具名 {tool.name} 已存在,请确保工具名唯一") _registered_tools[tool.name] = { "tool": tool, "category": category } ``` ### 3.3 工具获取函数 框架提供了多种函数来获取已注册的工具: ```python def get_registered_tools(as_dict: bool = False) -> Union[List[Tool], Dict[str, Dict]]: """返回所有已注册的工具""" if as_dict: return _registered_tools return [info["tool"] for info in _registered_tools.values()] def get_tools_by_category(category: ToolCategory, return_instances: bool = True) -> List[Union[str, Tool]]: """返回指定分类的工具列表""" if return_instances: return [info["tool"] for name, info in _registered_tools.items() if info["category"] == category] return [name for name, info in _registered_tools.items() if info["category"] == category] ``` ## 4. 简化工具注册的辅助函数 ### 4.1 直接注册工具的函数 为了简化工具注册过程,框架提供了`register_direct_tool`函数,它可以根据工具类名自动判断工具类别: ```python def register_direct_tool(tool_instance: BaseTool, category: ToolCategory = None) -> None: """注册直接从langchain_community.tools导入的工具""" if not category: # 获取工具类名 tool_class_name = tool_instance.__class__.__name__ # 根据工具类名自动判断类别 category = tool_category_mapping.get(tool_class_name, tool_category_mapping["default"]) # 注册工具 register_tool(tool_instance, category) print(f"已注册工具: {tool_instance.name} (类别: {category.value})") ``` ### 4.2 自动注册自定义工具 框架还支持自动扫描和注册自定义工具。在`__init__.py`中,我们使用以下代码自动注册自定义工具: ```python # 遍历目录中的所有文件,注册自定义工具 for filename in os.listdir(tools_dir): # 只处理 .py 文件,且排除 __init__.py 和 registry.py if filename.endswith('.py') and filename not in ['__init__.py', 'registry.py']: # 提取模块名(去掉 .py 后缀) module_name = filename[:-3] try: # 动态导入模块 module = importlib.import_module(f'.{module_name}', package='core.tools') # 查找模块中的工具类(继承自BaseTool的类) for name, obj in inspect.getmembers(module): # 检查是否是类且是BaseTool的子类 if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: # 检查该类是否已经被实例化并注册 tool_name = getattr(obj, 'name', None) if tool_name and tool_name not in [info['tool'].name for info in get_registered_tools().values()]: # 确定工具类别 category = getattr(module, 'category', ToolCategory.OTHER) # 实例化并注册工具 try: tool_instance = obj() register_tool(tool_instance, category) print(f"已注册工具类: {name} (工具名: {tool_instance.name}, 类别: {category.value})") except Exception as e: print(f"实例化工具类 {name} 时出错: {e}") except Exception as e: print(f"导入 {module_name} 时出错: {e}") ``` 这段代码会自动扫描`core/tools`目录中的所有Python文件,查找继承自`BaseTool`的类,并自动实例化和注册这些工具。 ## 5. 与ReactAgent的集成 ### 5.1 从注册表获取工具列表 在创建ReactAgent实例时,我们可以从注册表中获取工具列表: ```python # 从注册表中获取工具列表 tools_list = [info["tool"] for info in registered_tools.values()] # 创建ReactAgent实例 react_agent = ReactAgent( model=model, tools=tools_list, name="fed_research_agent", prompt=( "你是一位专业的经济研究分析师,擅长分析复杂的经济问题并提供深入见解。\n" "你有多个强大的工具可以搜索网络获取实时信息:\n" "- jina_search: 用于进行网络搜索获取最新信息\n" "- wikipedia_query_run: 用于查询维基百科获取基础知识\n" "- firecrawl_tool: 用于抓取和分析特定网页内容\n\n" # 提示词内容 ), ) ``` ### 5.2 按类别选择工具 在某些场景下,我们可能只需要特定类别的工具。这时,可以使用`get_tools_by_category`函数: ```python # 获取所有搜索类工具 search_tools = get_tools_by_category(ToolCategory.SEARCH) # 创建专注于搜索的ReactAgent search_agent = ReactAgent( model=model, tools=search_tools, name="search_agent", prompt="你是一位专业的信息搜索专家..." ) ``` ## 6. 实际应用案例 ### 6.1 美联储研究任务 以下是一个完整的应用案例,使用工具注册机制和ReactAgent进行美联储研究: ```python # 注册搜索工具 jina_search = JinaSearch() wiki_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()) # 使用register_direct_tool函数注册工具 register_direct_tool(jina_search) register_direct_tool(wiki_tool) # 注意:FireCrawlTool已经在core/tools/__init__.py中被注册,这里不需要再次注册 # 获取所有已注册的工具(以字典格式) registered_tools = get_registered_tools(as_dict=True) # 从注册表中获取工具列表 tools_list = [info["tool"] for info in registered_tools.values()] # 创建ReactAgent实例 react_agent = ReactAgent( model=model, tools=tools_list, name="fed_research_agent", prompt=( "你是一位专业的经济研究分析师,擅长分析复杂的经济问题并提供深入见解。\n" # 提示词内容 ), ) # 编译Agent agent = react_agent.compile() # 定义输入 inputs = { "messages": [ HumanMessage(content="请提供美联储(Federal Reserve)的详细介绍,包括其历史、结构、职能,以及它如何通过货币政策影响全球经济。") ] } # 执行Agent final_state = None for partial_state in react_agent.stream(inputs, stream_mode="values"): # 处理中间状态... pass ``` ### 6.2 结果保存 执行完成后,我们可以将结果保存到文件: ```python # 打印最终回答 if final_state and final_state.get("messages"): for message in final_state["messages"]: if isinstance(message, AIMessage) and not message.tool_calls: print(message.content) # 将结果保存到文件 output_dir = os.path.join(os.path.dirname(__file__), "output") os.makedirs(output_dir, exist_ok=True) output_file = os.path.join(output_dir, "fed_research_report.md") with open(output_file, "w", encoding="utf-8") as f: f.write("# 美联储研究报告\n\n") f.write(message.content) print(f"\n研究报告已保存到: {output_file}") ``` ## 7. 最佳实践 ### 7.1 工具命名规范 为了避免工具名冲突,建议遵循以下命名规范: 1. 使用有意义的名称,反映工具的功能 2. 对于同一类别的工具,使用统一的前缀或后缀 3. 避免使用过于通用的名称,如`search`、`get`等 ### 7.2 工具分类策略 合理的工具分类策略可以提高工具使用的效率: 1. 根据工具的主要功能进行分类,而不是实现方式 2. 对于多功能工具,根据其主要功能进行分类 3. 只有在无法确定主要功能时,才将工具归类为`OTHER` ### 7.3 提示词优化 在提示词中明确说明可用工具及其用途,可以提高Agent的工具使用效率: ``` 你是一位专业的经济研究分析师,擅长分析复杂的经济问题并提供深入见解。 你有多个强大的工具可以搜索网络获取实时信息: - jina_search: 用于进行网络搜索获取最新信息 - wikipedia_query_run: 用于查询维基百科获取基础知识 - firecrawl_tool: 用于抓取和分析特定网页内容 当面对复杂问题时,请遵循以下方法论: 1. 分解问题:将复杂问题分解为更小的子问题 2. 制定计划:确定需要搜索哪些信息,以及使用哪些工具 3. 执行搜索:使用适当的工具执行搜索 4. 分析结果:分析搜索结果,确定是否需要进一步搜索 5. 综合信息:将所有搜索结果综合成一个连贯的回答 ``` ## 8. 总结 工具注册机制为Mentis框架提供了强大的可扩展性,使得智能体系统能够轻松集成各种工具,并根据任务需求灵活选择合适的工具组合。通过分类管理和自动注册,工具注册机制简化了工具的管理和使用流程,提高了开发效率。 结合ReactAgent,工具注册机制使得智能体能够访问丰富的外部功能,从而处理更复杂的任务。未来的发展方向包括支持更多类型的工具、增强工具的自动发现和选择能力,以及提供更细粒度的工具权限控制。 ================================================ FILE: instructions/09.e2b_sandbox_integration.md ================================================ # E2B沙箱环境与智能代理集成指南 ## 1. 引言 E2B沙箱环境是一个强大的代码执行工具,它提供了安全、隔离的环境来运行Python代码和Shell命令。将E2B沙箱与智能代理(如ReactAgent)集成,可以显著增强代理的能力,使其能够执行代码、处理数据、创建可视化,甚至与文件系统交互。本文将详细介绍E2B沙箱的核心概念、工作原理、实现方式以及在智能代理系统中的应用。 ## 2. E2B沙箱环境的核心概念 ### 2.1 什么是E2B沙箱 E2B(Execution Environment for Bots)是一个专为AI代理设计的代码执行环境,它提供以下核心功能: 1. **安全隔离**:在隔离的容器中执行代码,防止恶意代码影响宿主系统 2. **多语言支持**:主要支持Python,同时可通过Shell命令执行其他语言代码 3. **文件系统操作**:允许创建、读取、写入和管理文件 4. **包管理**:支持安装和使用第三方Python库 5. **持久化**:可以在会话之间保持状态和文件 ### 2.2 E2B沙箱与代码解释器的关系 E2B沙箱是一种特殊的代码解释器实现,它不仅能执行代码,还提供了完整的操作系统环境(基于Debian)。这使得它比简单的代码解释器功能更强大,能够: - 执行系统命令 - 管理文件和目录 - 安装和使用各种软件包 - 运行网络服务 - 处理复杂的数据分析和可视化任务 ## 3. E2B沙箱的实现 ### 3.1 E2BCodeInterpreterTool类的设计 在我们的实现中,`E2BCodeInterpreterTool`类继承自LangChain的`BaseTool`,提供了与E2B沙箱交互的接口: ```python class E2BCodeInterpreterTool(BaseTool): """使用E2B SDK执行Python代码的工具 该工具创建一个安全的沙箱环境,用于执行Python代码,并返回执行结果、 标准输出、标准错误和任何错误信息。 """ name: str = "e2b_code_interpreter" description: str = ( "在安全的 Debian 基础沙箱环境中执行 Python 代码或 shell 命令,并返回结果。" "适用于数据分析、可视化、复杂计算以及系统操作。" "输入应为有效的 Python 代码字符串,或以 '!' 开头的 shell 命令。" "常见 Python 库(如 numpy、pandas 和 matplotlib)已预装,若需其他库,可通过 pip 安装。" "沙箱环境充分利用 Debian 系统的强大功能,支持广泛的操作。" ) ``` ### 3.2 核心方法 `E2BCodeInterpreterTool`类提供了以下核心方法: 1. **_initialize_sandbox()**: 初始化沙箱环境 2. **_run()**: 在沙箱中执行代码并返回结果 3. **close()**: 关闭沙箱并释放资源 4. **format_to_tool_message()**: 将执行结果格式化为工具消息 ### 3.3 沙箱初始化与资源管理 沙箱初始化过程包括: 1. 检查是否安装了`e2b_code_interpreter`包 2. 验证是否设置了`E2B_API_KEY`环境变量 3. 创建`Sandbox`实例 4. 设置沙箱状态标志 资源管理方面,工具提供了`close()`方法来释放沙箱资源: ```python def close(self): """关闭沙箱,释放资源""" if hasattr(self, "sandbox") and self._is_available and self.sandbox is not None: try: print("正在关闭E2B沙箱并释放资源...") self.sandbox.kill() print("E2B沙箱已成功关闭") except Exception as e: print(f"关闭E2B沙箱时出错: {str(e)}") ``` ## 4. 将E2B沙箱与ReactAgent集成 ### 4.1 基本集成流程 将E2B沙箱与ReactAgent集成的基本流程如下: 1. **注册E2B工具**:将`E2BCodeInterpreterTool`注册到工具注册表中 2. **创建ReactAgent**:使用包含E2B工具的工具列表初始化ReactAgent 3. **设计提示词**:编写强调代码执行能力的提示词 4. **执行工作流**:让Agent使用E2B工具执行代码并处理结果 ### 4.2 代码示例 以下是一个基本的集成示例: ```python # 导入必要的库 from core.agents.react_agent import ReactAgent from core.tools.registry import get_tools_by_category, ToolCategory from langchain_openai import ChatOpenAI # 获取代码解释器工具 tools_list = get_tools_by_category(ToolCategory.CODE_INTERPRETER) # 创建ReactAgent实例 react_agent = ReactAgent( model=ChatOpenAI(model="gpt-4o-mini"), tools=tools_list, prompt=( "你是一位专业的数据分析师,可以使用Python代码解决问题。\n" "你有强大的代码执行工具可以使用:\n" "- e2b_code_interpreter: 用于执行Python代码和shell命令\n" ), ) # 编译Agent agent = react_agent.compile() # 执行任务 result = agent.invoke({"messages": [HumanMessage(content="分析以下数据并创建可视化...")]}) ``` ## 5. E2B沙箱的高级功能 ### 5.1 文件系统操作 E2B沙箱提供了完整的文件系统操作能力,可以: - 创建和管理目录结构 - 读写文本和二进制文件 - 列出目录内容 - 移动和删除文件 示例代码: ```python # 在沙箱中创建目录和文件 code = """ # 创建目录 import os os.makedirs('test_dir/subdir', exist_ok=True) # 创建并写入文件 with open('test_dir/example.txt', 'w') as f: f.write('Hello from E2B sandbox!') # 列出目录内容 print(os.listdir('test_dir')) # 读取文件内容 with open('test_dir/example.txt', 'r') as f: content = f.read() print(f'文件内容: {content}') """ # 执行代码 result = e2b_tool.invoke({"code": code}) ``` ### 5.2 包管理 E2B沙箱允许安装和使用第三方Python库: ```python # 安装并使用第三方库 code = """ # 安装pandas库 !pip install pandas matplotlib # 使用pandas进行数据分析 import pandas as pd import matplotlib.pyplot as plt # 创建示例数据 data = {'Category': ['A', 'B', 'C', 'D'], 'Values': [10, 25, 15, 30]} df = pd.DataFrame(data) # 打印数据 print(df) # 创建可视化 plt.figure(figsize=(8, 4)) plt.bar(df['Category'], df['Values']) plt.title('Sample Bar Chart') plt.savefig('chart.png') print('图表已保存为chart.png') """ # 执行代码 result = e2b_tool.invoke({"code": code}) ``` ### 5.3 从沙箱下载文件 可以将沙箱中生成的文件下载到本地系统: ```python def download_file_from_sandbox(sandbox, sandbox_path, local_path): """从沙箱下载文件到本地""" try: # 从沙箱读取文件内容 content = sandbox.files.read(sandbox_path) # 确保目标目录存在 os.makedirs(os.path.dirname(local_path), exist_ok=True) # 写入本地文件 with open(local_path, 'w', encoding='utf-8') as file: file.write(content) print(f"文件已从沙箱下载到本地: {local_path}") return True except Exception as e: print(f"从沙箱下载文件时出错: {str(e)}") return False ``` ## 6. 实际应用案例 ### 6.1 数据分析与可视化 E2B沙箱特别适合数据分析和可视化任务,可以: - 加载和处理各种格式的数据(CSV、JSON、Excel等) - 使用pandas进行数据清洗和转换 - 使用matplotlib、seaborn等创建可视化 - 生成分析报告 ### 6.2 文件处理与转换 E2B沙箱可以处理各种文件格式的转换和处理: - 文本文件处理(如日志分析) - 图像处理和转换 - 数据格式转换(如CSV到JSON) - 文档生成(如生成HTML或PDF报告) ### 6.3 Web爬虫与API调用 E2B沙箱可以执行网络相关任务: - 使用requests或BeautifulSoup进行网页爬取 - 调用各种API并处理响应 - 下载和处理网络资源 ## 7. 最佳实践与注意事项 ### 7.1 安全考虑 虽然E2B沙箱提供了隔离环境,但在使用时仍需注意: - 不要在沙箱中处理敏感数据 - 避免执行未经验证的用户输入代码 - 限制沙箱的网络访问权限 - 定期关闭和重新创建沙箱实例 ### 7.2 资源管理 E2B沙箱会消耗系统资源,因此: - 在不需要时关闭沙箱(使用`close()`方法) - 避免在单个沙箱中运行过多或过大的任务 - 监控沙箱的内存和CPU使用情况 ### 7.3 错误处理 在与E2B沙箱交互时,应当实施健壮的错误处理: - 捕获并处理代码执行异常 - 验证沙箱初始化是否成功 - 提供有意义的错误消息给用户 - 实现重试机制处理临时故障 ## 8. 总结 E2B沙箱为智能代理提供了强大的代码执行能力,使其能够处理各种复杂任务。通过将E2B沙箱与ReactAgent集成,我们可以创建能够执行代码、处理数据、创建可视化,甚至与文件系统交互的智能系统。 正确使用E2B沙箱需要理解其核心概念、实现方式和最佳实践。通过本文的指导,开发者应能够有效地将E2B沙箱集成到自己的智能代理系统中,并充分利用其强大功能。 ## 9. 参考资源 - [E2B官方文档](https://e2b.dev/docs) - [E2B Code Interpreter SDK](https://github.com/e2b-dev/code-interpreter) - [LangChain工具集成指南](https://python.langchain.com/docs/integrations/tools) - [ReactAgent文档](https://python.langchain.com/docs/modules/agents/agent_types/react) ================================================ FILE: log_analyzer.py ================================================ import re import sys import argparse from collections import defaultdict import json def parse_log_file(file_path): """Parse the execution log file and extract agent interactions.""" with open(file_path, 'r', encoding='utf-8') as f: content = f.read() # Extract different sections of the log sections = content.split("================================ Human Message =================================") if len(sections) > 1: main_content = sections[1] # Skip header else: main_content = content # Extract messages messages = [] # Pattern for AI messages ai_pattern = r"================================== Ai Message ==================================\nName: (\w+)\n\n(.*?)(?=(==================================|$))" ai_matches = re.finditer(ai_pattern, main_content, re.DOTALL) for match in ai_matches: agent_name = match.group(1) message_content = match.group(2).strip() # Check if message has tool calls tool_calls = [] tool_call_pattern = r"Tool Calls:\n(.*?)(?=\n==================================|$)" tool_call_match = re.search(tool_call_pattern, message_content, re.DOTALL) if tool_call_match: # Extract tool calls tool_calls_text = tool_call_match.group(1) tool_call_entries = re.findall(r" (\w+) \(([^)]+)\)", tool_calls_text) tool_calls = [{"name": name, "id": call_id} for name, call_id in tool_call_entries] # Remove tool calls from the message content message_content = re.sub(r"Tool Calls:.*?(?=\n==================================|$)", "", message_content, flags=re.DOTALL).strip() messages.append({ "role": "agent", "agent": agent_name, "content": message_content, "tool_calls": tool_calls }) # Pattern for Tool messages tool_pattern = r"================================= Tool Message =================================\nName: (\w+)\n\n(.*?)(?=(==================================|$))" tool_matches = re.finditer(tool_pattern, main_content, re.DOTALL) for match in tool_matches: tool_name = match.group(1) tool_content = match.group(2).strip() messages.append({ "role": "tool", "tool": tool_name, "content": tool_content }) # Sort messages by their position in the log messages.sort(key=lambda x: main_content.find(x["content"])) return messages def analyze_agent_interactions(messages): """Analyze the interactions between agents.""" interactions = [] current_sender = None tool_call_map = {} for i, msg in enumerate(messages): if msg["role"] == "agent": current_sender = msg["agent"] # Check if this agent is using tool calls for tool_call in msg.get("tool_calls", []): tool_name = tool_call["name"] tool_id = tool_call["id"] tool_call_map[tool_id] = { "sender": current_sender, "tool": tool_name } interactions.append({ "step": i, "from": current_sender, "to": f"SYSTEM ({tool_name})", "action": f"Called tool {tool_name}", "content": f"Tool call ID: {tool_id}" }) elif msg["role"] == "tool": # Find which agent invoked this tool for prev_msg in reversed(messages[:i]): if prev_msg["role"] == "agent" and any(tc["name"] == msg["tool"] for tc in prev_msg.get("tool_calls", [])): sender = prev_msg["agent"] break else: sender = "SYSTEM" interactions.append({ "step": i, "from": f"SYSTEM ({msg['tool']})", "to": sender, "action": f"Tool response", "content": msg["content"] }) return interactions def visualize_interactions(interactions): """Visualize the interactions between agents.""" print("\n" + "="*100) print(" "*40 + "AGENT INTERACTIONS SUMMARY") print("="*100 + "\n") for idx, interaction in enumerate(interactions): print(f"[{idx+1}] {interaction['from']} → {interaction['to']}") print(f" Action: {interaction['action']}") content = interaction['content'] if len(content) > 100: content = content[:97] + "..." print(f" Content: {content}\n") def visualize_conversation_flow(messages): """Visualize the conversation flow between agents.""" print("\n" + "="*100) print(" "*40 + "CONVERSATION FLOW") print("="*100 + "\n") for idx, message in enumerate(messages): if message["role"] == "agent": agent_name = message["agent"] print(f"[{idx+1}] Agent: {agent_name}") content = message["content"] if len(content) > 150: content = content[:147] + "..." print(f" Content: {content}") if message.get("tool_calls"): tools = ", ".join([tc["name"] for tc in message["tool_calls"]]) print(f" Tools Called: {tools}") else: print(f"[{idx+1}] Tool: {message['tool']}") content = message["content"] if len(content) > 100: content = content[:97] + "..." print(f" Response: {content}") print() def main(): parser = argparse.ArgumentParser(description='Analyze Mentis execution logs.') parser.add_argument('log_file', help='Path to the log file') parser.add_argument('--format', choices=['interactions', 'flow', 'all'], default='all', help='Output format: interactions, flow, or all') args = parser.parse_args() try: messages = parse_log_file(args.log_file) interactions = analyze_agent_interactions(messages) if args.format in ['interactions', 'all']: visualize_interactions(interactions) if args.format in ['flow', 'all']: visualize_conversation_flow(messages) except Exception as e: print(f"Error: {e}") sys.exit(1) if __name__ == "__main__": main() ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools>=42", "wheel"] build-backend = "setuptools.build_meta" readme = "README.md" requires-python = ">=3.11" [project] name = "mentis" version = "0.1.0" description = "A Multi-Agents project based on langgraph" requires-python = ">=3.11" dependencies = [ "dotenv>=0.9.9", "langchain-community>=0.3.19", "langchain-core>=0.3.45", "langchain-openai>=0.3.8", "langgraph>=0.3.11", "pydantic>=2.10.6", "typing-extensions>=4.12.2", "python-dotenv>=1.0.0", "firecrawl-py", "wikipedia>=1.4.0", "serpapi>=0.1.5", "google-search-results>=2.4.2", "duckduckgo-search>=7.5.2", "arxiv>=2.1.3", "rizaio>=0.9.0", "e2b-code-interpreter>=1.1.0", "fastapi>=0.115.11", "uvicorn>=0.34.0", "sse-starlette>=2.2.1", "exa-py>=1.9.1", "tavily-python>=0.5.1", "replicate>=1.0.4", "langchain-mcp-adapters>=0.0.7", "mcp>=1.6.0", "playwright>=1.51.0", "pillow>=11.2.1", "yfinance>=0.2.55", ] [tool.setuptools] packages = ["core"] ================================================ FILE: requirements.txt ================================================ # This file was autogenerated by uv via the following command: # uv pip compile pyproject.toml -o requirements.txt aiohappyeyeballs==2.6.1 # via aiohttp aiohttp==3.11.14 # via langchain-community aiosignal==1.3.2 # via aiohttp annotated-types==0.7.0 # via pydantic anyio==4.9.0 # via # httpx # openai # rizaio arxiv==2.1.3 # via mentis (pyproject.toml) attrs==25.3.0 # via # aiohttp # e2b # e2b-code-interpreter beautifulsoup4==4.13.3 # via wikipedia certifi==2025.1.31 # via # httpcore # httpx # requests charset-normalizer==3.4.1 # via requests click==8.1.8 # via duckduckgo-search dataclasses-json==0.6.7 # via langchain-community distro==1.9.0 # via # openai # rizaio dotenv==0.9.9 # via mentis (pyproject.toml) duckduckgo-search==7.5.2 # via mentis (pyproject.toml) e2b==1.1.0 # via e2b-code-interpreter e2b-code-interpreter==1.1.0 # via mentis (pyproject.toml) feedparser==6.0.11 # via arxiv firecrawl-py==1.14.1 # via mentis (pyproject.toml) frozenlist==1.5.0 # via # aiohttp # aiosignal google-search-results==2.4.2 # via mentis (pyproject.toml) greenlet==3.1.1 # via sqlalchemy h11==0.14.0 # via httpcore httpcore==1.0.7 # via # e2b # httpx httpx==0.28.1 # via # e2b # e2b-code-interpreter # langgraph-sdk # langsmith # openai # rizaio httpx-sse==0.4.0 # via langchain-community idna==3.10 # via # anyio # httpx # requests # yarl jiter==0.9.0 # via openai jsonpatch==1.33 # via langchain-core jsonpointer==3.0.0 # via jsonpatch langchain==0.3.20 # via langchain-community langchain-community==0.3.19 # via mentis (pyproject.toml) langchain-core==0.3.45 # via # mentis (pyproject.toml) # langchain # langchain-community # langchain-openai # langchain-text-splitters # langgraph # langgraph-checkpoint # langgraph-prebuilt langchain-openai==0.3.9 # via mentis (pyproject.toml) langchain-text-splitters==0.3.6 # via langchain langgraph==0.3.11 # via mentis (pyproject.toml) langgraph-checkpoint==2.0.20 # via # langgraph # langgraph-prebuilt langgraph-prebuilt==0.1.3 # via langgraph langgraph-sdk==0.1.57 # via langgraph langsmith==0.3.15 # via # langchain # langchain-community # langchain-core lxml==5.3.1 # via duckduckgo-search marshmallow==3.26.1 # via dataclasses-json msgpack==1.1.0 # via langgraph-checkpoint multidict==6.2.0 # via # aiohttp # yarl mypy-extensions==1.0.0 # via typing-inspect nest-asyncio==1.6.0 # via firecrawl-py numpy==2.2.4 # via langchain-community openai==1.66.3 # via langchain-openai orjson==3.10.15 # via # langgraph-sdk # langsmith packaging==24.2 # via # e2b # langchain-core # langsmith # marshmallow primp==0.14.0 # via duckduckgo-search propcache==0.3.0 # via # aiohttp # yarl protobuf==5.29.3 # via e2b pydantic==2.10.6 # via # mentis (pyproject.toml) # firecrawl-py # langchain # langchain-core # langsmith # openai # pydantic-settings # rizaio pydantic-core==2.27.2 # via pydantic pydantic-settings==2.8.1 # via langchain-community python-dateutil==2.9.0.post0 # via e2b python-dotenv==1.0.1 # via # mentis (pyproject.toml) # dotenv # firecrawl-py # pydantic-settings pyyaml==6.0.2 # via # langchain # langchain-community # langchain-core regex==2024.11.6 # via tiktoken requests==2.32.3 # via # arxiv # firecrawl-py # google-search-results # langchain # langchain-community # langsmith # requests-toolbelt # serpapi # tiktoken # wikipedia requests-toolbelt==1.0.0 # via langsmith rizaio==0.9.0 # via mentis (pyproject.toml) serpapi==0.1.5 # via mentis (pyproject.toml) sgmllib3k==1.0.0 # via feedparser six==1.17.0 # via python-dateutil sniffio==1.3.1 # via # anyio # openai # rizaio soupsieve==2.6 # via beautifulsoup4 sqlalchemy==2.0.39 # via # langchain # langchain-community tenacity==9.0.0 # via # langchain-community # langchain-core tiktoken==0.9.0 # via langchain-openai tqdm==4.67.1 # via openai typing-extensions==4.12.2 # via # mentis (pyproject.toml) # anyio # beautifulsoup4 # e2b # langchain-core # openai # pydantic # pydantic-core # rizaio # sqlalchemy # typing-inspect typing-inspect==0.9.0 # via dataclasses-json urllib3==2.3.0 # via requests websockets==15.0.1 # via firecrawl-py wikipedia==1.4.0 # via mentis (pyproject.toml) yarl==1.18.3 # via aiohttp zstandard==0.23.0 # via langsmith ================================================ FILE: setup.py ================================================ from setuptools import setup setup() ================================================ FILE: super_agents/__init__.py ================================================ ================================================ FILE: super_agents/browser_use/README.md ================================================ n# Browser Agent (基于 LangGraph) - super_agents/browser_use ## 概述 本项目实现了一个基于 LangGraph 框架的 Web 浏览和交互 Agent。其核心目标是让一个大型语言模型 (LLM) 能够像人一样理解任务指令,自主地控制浏览器(通过 Playwright)来访问网页、分析内容、与页面元素交互(点击、输入、滚动等),并最终完成用户指定的任务,例如信息提取、表单填写、在线搜索等。 该 Agent 采用了多模态感知的设计思路,结合了传统的 DOM/Accessibility Tree 分析和可选的视觉语言模型 (VLM) 分析,以期在复杂网页上获得更鲁棒的理解和定位能力。 ## 核心技术栈 * **流程编排:** LangGraph (LangChain 的状态图编排库) * **浏览器自动化:** Playwright (异步 Python 版本) * **模型调用:** LangChain ChatModels (`langchain-openai`, `langchain-community` 等) * **语言模型 (LLM/VLM):** * **规划/决策 LLM:** 可配置,支持 OpenAI, Groq, xAI (Grok), 及其他 OpenAI 兼容 API (通过 `llm.py` 和 `.env` 配置)。 * **视觉分析 VLM:** 可选,通过 OpenRouter 调用支持 Vision 的模型 (如 Qwen-VL, GPT-4o, Claude 3.5 Sonnet 等) (通过 `detector.py` 和 `.env` 配置)。 * **依赖管理:** `uv` (或 `pip`) * **配置:** `.env` 文件 ## 项目架构 项目主要文件和目录结构如下: ``` super_agents/ └── browser_use/ # Agent 根目录 ├── agent/ # LangGraph 核心实现 │ ├── __init__.py │ ├── graph.py # 定义 LangGraph 图结构、节点连接、条件边 │ ├── nodes.py # 定义图中各节点 (Node) 的具体执行逻辑 (AgentNodes 类) │ ├── state.py # 定义 Agent 在图中流转的状态 (AgentState) │ ├── schemas.py # 定义数据模型 (如动作指令 Action Schema, VLM 输出 Schema) │ └── prompts.py # 管理发送给规划 LLM 和 VLM 的 Prompt 模板 │ ├── browser/ # 浏览器交互底层实现 (基于原始项目代码) │ ├── __init__.py │ ├── browser.py # 核心 Browser 类,封装 Playwright 操作、感知方法 (get_content, update_state) │ ├── detector.py # 视觉检测器类,实现 VLM 调用逻辑 │ ├── models.py # 定义浏览器状态、元素等 Pydantic 模型 │ ├── utils.py # 浏览器相关的工具函数 │ └── findVisibleInteractiveElements.js # 用于 DOM 元素检测的 JS 脚本 │ ├── llm/ # LLM 相关实现 │ ├── __init__.py │ └── llm.py # 定义 ChatOpenRouter (VLM 调用), initialize_llms (规划 LLM 初始化), generate_structured_output │ ├── main.py # Agent 的主入口脚本 ├── requirements.txt # Python 依赖列表 ├── README.md # 本文件 └── .env # 环境变量配置文件 (需要手动创建) ``` ## 核心概念与设计 ### 1. LangGraph 状态机 Agent 的核心控制流由 LangGraph 管理。它被实现为一个状态机 (`StateGraph`): * **状态 (State):** `agent/state.py` 中的 `AgentState` (TypedDict) 定义了在节点间传递的数据,包含当前任务、浏览器内容/状态、LLM 解析出的动作、历史记录、错误信息等。 * **节点 (Nodes):** `agent/nodes.py` 中的 `AgentNodes` 类定义了主要的处理步骤,作为图的节点: * `get_browser_state`: 调用 `Browser` 类的感知方法 (当前是 `get_content`) 获取页面信息。 * `plan_action`: 将感知信息和任务包装成 Prompt,调用**规划 LLM** (通过 `llm.py` 的 `generate_structured_output`) 获取结构化的下一步动作 JSON。 * `execute_action`: 解析 `plan_action` 返回的动作 JSON,并调用 `Browser` 类中相应的交互方法 (如 `Maps_to`, `click`, `type`, `scroll`, `wait`) 执行操作。 * **边 (Edges):** `agent/graph.py` 定义了节点间的固定跳转(如 `get_browser_state` -> `plan_action`)和条件跳转(如 `execute_action` 后根据 `should_end` 函数判断是结束 `END` 还是回到 `get_browser_state`)。 ### 2. 感知 (Perception) Agent 通过 `browser.py` 中的 `Browser.get_content()` 方法(被 `get_browser_state` 节点调用)来理解当前网页状态。该方法整合了多种信息源,旨在为 LLM 提供丰富且相对简洁的页面表示: * **简化 DOM:** 通过注入并执行 `SIMPLIFY_PAGE_SCRIPT` JavaScript,移除无关标签(脚本、样式等),提取关键交互元素及其属性,并为这些元素添加 `x-pw-id` 唯一标识。结果以伪 HTML 字符串形式返回。 * **可访问性树 (AX Tree):** (当前实现中暂时禁用/存在错误) 理论上通过 `page.accessibility.snapshot()` 获取页面的语义结构信息(角色、名称等),以 JSON 字符串形式返回。 * **视觉元素 (VLM):** (可选,需配置) * 如果 `.env` 文件中配置了 VLM (`OPENROUTER_API_KEY`, `VLM_API_MODEL`),`get_content` 会调用 `Detector` 实例。 * `Detector` (在 `browser/detector.py` 中) 使用 LangChain 的 `ChatOpenRouter` (在 `llm.py` 中定义) 调用配置的 VLM API。 * 通过精心设计的 Prompt (`VLM_PROMPT_TEMPLATE`) 请求 VLM 返回页面交互元素的**描述、类型和边界框百分比坐标** (JSON 格式)。 * `Detector` 解析 VLM 返回的 JSON,创建 `InteractiveElement` 对象列表(目前坐标是占位符)。 * `get_content` 将这些视觉元素信息格式化为**文本摘要** (包含 VLM 分配的 ID 和边界框信息)。 * **合并与截断:** `get_content` 将 URL、简化 DOM、AX Tree (如果成功)、视觉元素摘要合并为一个长的文本字符串,并在超过 `max_length` 时进行截断,最后返回给 `plan_action` 节点。 ### 3. 规划 (Planning) * `plan_action` 节点接收 `get_content` 返回的**混合文本字符串**。 * `agent/prompts.py` 中的 `create_agent_prompt` 函数将任务描述、历史记录、错误信息(如果有)和这段混合文本整合成一个 Prompt。 * 该 Prompt 被发送给**规划 LLM**(通过 `llm.py` 中的 `generate_structured_output` 函数,该函数使用 LangChain 的 `.with_structured_output()` 功能)。 * LLM 被要求分析输入信息,决定下一步动作,并**严格按照 `agent/schemas.py` 中定义的 `LLMResponse` Pydantic 模型返回一个包含具体动作指令的 JSON**。Prompt 中包含了对生成**健壮 CSS 选择器**(优先使用稳定 ID、aria-label、文本内容,结合 `x-pw-id`)的明确指导。 ### 4. 行动 (Action Execution) * `execute_action` 节点接收规划 LLM 返回的结构化动作 JSON (存储在 `state['parsed_action']`)。 * 它解析出动作类型 (`type`) 和参数 (`selector`, `url`, `text`, `direction` 等)。 * 根据动作类型,调用 `browser/browser.py` 中 `Browser` 类对应的**简单交互方法**: * `Maps_to(url)` * `click(selector)` * `type(selector, text)` * `scroll(direction)` * `wait(milliseconds)` * 这些方法内部使用 Playwright 的 `page.goto`, `page.locator(...).click`, `page.locator(...).fill`, `page.evaluate(...)` 等函数执行实际的浏览器操作。 * 如果动作是 `finish` 或 `error`,图流程会根据 `graph.py` 中的 `should_end` 函数判断并终止。 ## 安装与配置 1. **环境:** 推荐使用 Python 3.10+。 2. **依赖安装:** * 克隆项目。 * 进入 `super_agents/browser_use/` 目录。 * 创建并激活虚拟环境 (使用 uv): ```bash uv venv source .venv/bin/activate # Linux/macOS # 或者 .venv\Scripts\activate # Windows ``` * 安装依赖项 (使用 uv): ```bash uv sync ``` 3. **Playwright 浏览器:** 运行 `playwright install` (至少需要 `playwright install chromium`) 来下载浏览器驱动。 4. **环境变量 (`.env` 文件):** * 在 `super_agents/browser_use/` 目录下创建一个名为 `.env` 的文件。 * 参考我们之前讨论的 `.env` 示例,**至少需要配置**: * **规划 LLM:** 选择一个 Provider (如 `openai`), 设置 `LLM_PROVIDER`, `LLM_MODEL_NAME`, 以及对应的 API Key (如 `OPENAI_API_KEY`)。 * **VLM (可选):** 如果要启用视觉分析,设置 `OPENROUTER_API_KEY` 和 `VLM_API_MODEL` (设置为 OpenRouter 上支持视觉的模型 ID,如 `openai/gpt-4.1`等)。 * 确保 `.env` 文件被正确加载(`main.py` 和 `llm.py` 中包含 `load_dotenv()`)。 ## 如何运行 1. 确保已完成安装和配置。 2. 激活虚拟环境。 3. 从 `super_agents/` 目录(即 `browser_use` 的**上级**目录)运行 `main.py`: ```bash # 基本运行 python -m browser_use.main "您的任务描述" # 示例:访问 Hacker News 并获取导航栏信息 python -m browser_use.main "访问 news.ycombinator.com,返回页面导航栏信息" # 示例:使用其他命令行参数(如果有定义,如下面的最大步骤数) python -m browser_use.main "您的任务描述" --max-steps 30 ``` ## 当前状态、局限性与未来工作 * **核心流程:** Agent 的基本 LangGraph 流程(感知-规划-行动循环)、浏览器操作(导航、点击、输入、滚动、等待)、规划 LLM 调用、可选的 VLM 调用**已经跑通**,能够完成一些多步骤的 Web 任务。 * **视觉集成 (部分):** VLM 调用流程已集成到 `Detector` 类并通过 `get_content` 触发(需配置 API Key 和 Model)。VLM 能够返回 JSON 格式的检测结果,并且可以被成功解析为内部数据结构 (`InteractiveElement`)。 * **局限性 & 待完善:** 1. **VLM 坐标处理:** VLM 返回的是百分比坐标,但在解析时 (`_parse_vlm_detections`) 目前使用的是**占位符像素坐标**。需要获取截图的实际尺寸,实现准确的百分比到像素的转换,才能真正利用视觉信息进行定位。 2. **动作执行方式:** 当前 `execute_action` 仍然**完全依赖规划 LLM 生成的 CSS 选择器**。尚未实现基于 VLM 的元素 ID 或坐标进行点击/输入的操作,这限制了视觉能力的实际应用,特别是在 CSS 选择器不可靠的复杂页面上。 3. **感知信息完整性:** * **内容截断:** `get_content` 方法返回的内容会因为 `max_length` 限制而被截断,影响需要完整页面信息的任务(如“摘录全文”)。需要增大 `max_length` 或实现更智能的内容提取/滚动策略。 * **AX Tree 缺失:** 获取 Accessibility Tree 的代码目前被注释或存在错误,导致缺少重要的语义信息。需要修复 `page.accessibility.snapshot()` 调用。 4. **滚动策略:** 当前依靠 Prompt 指示 LLM 进行滚动。可能需要更鲁棒的机制来处理长页面,例如 Agent 内部判断是否需要滚动,或者让 LLM 能获取滚动状态信息。 5. **Pydantic V1 警告:** 调用规划 LLM 的 `with_structured_output` 时仍然出现 Pydantic V1 警告,建议保持 LangChain 相关库和 Pydantic 为最新版本。 6. **错误处理:** 当前错误处理相对简单(例如 VLM 解析失败直接跳过,执行错误直接终止图),可以增加更复杂的重试、回退或用户介入机制。 7. **VLM 稳定性:** VLM 能否稳定、准确地返回所需的 JSON 格式和边界框,高度依赖所选模型和 Prompt,可能需要进一步调优。 * **未来工作:** * 修复 AX Tree 获取。 * 实现 VLM 百分比坐标到像素坐标的准确转换。 * 增强 `execute_action` 和 `Browser` 类以支持基于坐标的交互。 * 优化 Prompt,指导 LLM 输出 VLM 元素 ID 或在 CSS 选择器失败时提供坐标作为备选。 * 实现更智能的滚动策略以处理长页面和完整内容提取。 * 持续更新依赖库,解决 Pydantic 警告。 * 增强错误处理和恢复能力。 ================================================ FILE: super_agents/browser_use/__init__.py ================================================ ================================================ FILE: super_agents/browser_use/agent/__init__.py ================================================ # super_agents/browser_use/agent/__init__.py """ Browser agent module that handles browser automation using LLM guidance. """ ================================================ FILE: super_agents/browser_use/agent/graph.py ================================================ # super_agents/browser_use/agent/graph.py import logging from typing import Dict, Any from langchain_core.runnables.base import RunnableSerializable from langgraph.graph import StateGraph, END from .state import AgentState from .nodes import AgentNodes from ..browser.browser import Browser logger = logging.getLogger(__name__) NODE_GET_BROWSER_STATE = "get_browser_state" NODE_PLAN_ACTION = "plan_action" NODE_EXECUTE_ACTION = "execute_action" # --- UPDATED Conditional Edge Logic --- def should_end(state: AgentState) -> bool: """Determines if the graph should end.""" action = state.get("parsed_action", {}) action_type = action.get("type") error_occurred = state.get("error") is not None # Check if execute_action reported an error # End if the LLM planned action is 'finish' or 'error' if action_type == "finish": logger.info("Graph execution: 'finish' action planned. Ending.") return True if action_type == "error": # Log the error message from the action payload logger.error(f"Graph execution: 'error' action planned by LLM: {action.get('message', 'Unknown error')}. Ending.") return True # End if the execute_action node reported an error in the state # Note: Depending on desired behavior, you might want to retry instead of ending on execution errors # if error_occurred: # logger.error(f"Graph execution: Error occurred during execution: {state['error']}. Ending.") # return True # Uncomment this line if ANY execution error should terminate the graph return False # Continue otherwise def create_graph_app(browser: Browser, llm: RunnableSerializable): """ Creates the LangGraph application using class-based nodes. """ agent_nodes = AgentNodes(browser=browser, llm=llm) workflow = StateGraph(AgentState) workflow.add_node(NODE_GET_BROWSER_STATE, agent_nodes.get_browser_state) workflow.add_node(NODE_PLAN_ACTION, agent_nodes.plan_action) workflow.add_node(NODE_EXECUTE_ACTION, agent_nodes.execute_action) workflow.set_entry_point(NODE_GET_BROWSER_STATE) workflow.add_edge(NODE_GET_BROWSER_STATE, NODE_PLAN_ACTION) workflow.add_edge(NODE_PLAN_ACTION, NODE_EXECUTE_ACTION) # After executing action, decide whether to end or loop back workflow.add_conditional_edges( NODE_EXECUTE_ACTION, # Function to decide the next step based on the state *after* execution lambda state: END if should_end(state) else NODE_GET_BROWSER_STATE, { END: END, NODE_GET_BROWSER_STATE: NODE_GET_BROWSER_STATE } ) logger.info("Compiling LangGraph workflow...") app = workflow.compile() logger.info("LangGraph workflow compiled successfully.") return app ================================================ FILE: super_agents/browser_use/agent/nodes.py ================================================ # super_agents/browser_use/agent/nodes.py import asyncio import logging from typing import Dict, Any, Optional # --- LangChain Core Import for Type Hint --- from langchain_core.runnables.base import RunnableSerializable # <--- Import this from .state import AgentState from .schemas import ( BaseAction, LLMResponse ) from .prompts import create_agent_prompt # --- CORRECTED LLM IMPORT --- # Import only the necessary functions/classes that actually exist in llm.py from ..llm import generate_structured_output # Import the correct Browser from the browser subdirectory from ..browser.browser import Browser logger = logging.getLogger(__name__) # --- Class to hold nodes and dependencies --- class AgentNodes: """Encapsulates agent nodes and their dependencies (browser, llm).""" # --- CORRECTED TYPE HINT for llm --- def __init__(self, browser: Browser, llm: RunnableSerializable): # <--- Use RunnableSerializable if not isinstance(llm, RunnableSerializable): logger.warning(f"LLM instance provided to AgentNodes is not of type RunnableSerializable (actual type: {type(llm)}).") self.browser = browser self.llm = llm logger.info("AgentNodes initialized with browser and llm instances.") # --- Node method implementations remain the same --- async def get_browser_state(self, state: AgentState) -> Dict[str, Any]: """Node method to get the current state of the browser page.""" logger.info("Node: get_browser_state") try: content = await self.browser.get_content() return {"browser_content": content, "error": None} except Exception as e: logger.error(f"Error getting browser state: {e}", exc_info=True) return {"error": f"Failed to get browser state: {e}"} async def plan_action(self, state: AgentState) -> Dict[str, Any]: """Node method to decide the next action using the LLM's structured output.""" logger.info("Node: plan_action") if state.get("error"): logger.warning(f"Planning action with existing error: {state['error']}") prompt = create_agent_prompt( task=state["task"], current_browser_content=state["browser_content"], history=state.get("history", []), error_message=state.get("error") ) system_message = "You are an AI agent controlling a web browser. Respond with the single next action formatted as JSON matching the required schema." try: llm_response: Optional[LLMResponse] = await generate_structured_output( model=self.llm, # Pass the llm instance schema=LLMResponse, prompt=prompt, system_message=system_message ) if llm_response and isinstance(llm_response, LLMResponse): parsed_action_model: BaseAction = llm_response.action parsed_action_dict = parsed_action_model.dict() logger.info(f"LLM proposed action: {parsed_action_dict.get('type', 'unknown')}") return {"parsed_action": parsed_action_dict, "error": None} else: logger.error("Failed to get valid structured output from LLM.") error_action_dict = {"type": "error", "message": "Failed to get valid structured output from LLM."} return {"parsed_action": error_action_dict, "error": "LLM did not return valid structured output."} except Exception as e: logger.error(f"Error during structured action planning: {e}", exc_info=True) error_action_dict = {"type": "error", "message": f"LLM planning exception: {e}"} return {"parsed_action": error_action_dict, "error": f"LLM planning exception: {e}"} async def execute_action(self, state: AgentState) -> Dict[str, Any]: """Node method to execute the action dictionary from the state.""" logger.info("Node: execute_action") action_dict = state.get("parsed_action") history = state.get("history", []) if not action_dict or not isinstance(action_dict, dict) or "type" not in action_dict: error_msg = "No valid action dictionary provided to execute." logger.error(error_msg) return {"error": error_msg} action_type = action_dict.get("type") action_repr = f"Action: {action_type}, Details: { {k:v for k,v in action_dict.items() if k != 'type'} }" logger.info(f"Executing {action_repr}") new_history = history + [action_repr] try: if action_type == "navigate": await self.browser.navigate_to(action_dict["url"]) # Check if method name/args match Browser class elif action_type == "click": await self.browser.click(action_dict["selector"]) # Check Browser class for click method/args elif action_type == "type": await self.browser.type(action_dict["selector"], action_dict["text"]) # Check Browser class for type method/args elif action_type == "scroll": await self.browser.scroll(action_dict["direction"]) # Check Browser class for scroll method/args elif action_type == "wait": await self.browser.wait(action_dict["milliseconds"]) # Check Browser class for wait method/args elif action_type == "get_content": logger.info("Action 'get_content' requested (will be handled by next cycle)") pass elif action_type == "finish": logger.info(f"Action 'finish' received. Result: {action_dict.get('result')}") pass elif action_type == "error": error_msg = action_dict.get("message", "LLM signaled an error.") logger.error(f"Executing 'error' action from LLM: {error_msg}") return {"error": error_msg, "history": new_history} else: error_msg = f"Attempted to execute unknown/unhandled action type: {action_type}" logger.error(error_msg) return {"error": error_msg, "history": new_history} return {"error": None, "history": new_history} except Exception as e: logger.error(f"Error executing action '{action_type}': {e}", exc_info=True) return {"error": f"Failed to execute action '{action_type}': {e}", "history": new_history} ================================================ FILE: super_agents/browser_use/agent/prompts.py ================================================ from typing import List def create_agent_prompt( task: str, current_browser_content: str, # This string now potentially contains URL, DOM, AX Tree, and Visual Elements history: List[str], error_message: str = None ) -> str: """ Generates the prompt to be sent to the LLM based on the current state. Includes sections for Simplified DOM, Accessibility Tree, and Visual Elements. """ prompt_parts = [] prompt_parts.append("You are an AI agent controlling a web browser to complete a task.") prompt_parts.append(f"Your current task is: {task}") if error_message: prompt_parts.append(f"\nAn error occurred in the previous step: {error_message}") prompt_parts.append("Please analyze the error and the current browser state, then decide the next best action.") prompt_parts.append("\n\n# Current Browser Perception:") # The browser_content string now contains multiple sections, as generated by get_content prompt_parts.append(current_browser_content) if history: prompt_parts.append("\n\n# History of Previous Actions:") for i, item in enumerate(history[-5:], 1): prompt_parts.append(f"{i}. {item}") # --- Instructions with guidance on using all perception data --- instructions = """ # Instructions: Analyze the **Current Browser Perception** section above, which includes: 1. **Page URL:** The current web address. 2. **Simplified DOM:** A structural view of the page with interactive elements marked with `x-pw-id` attributes. 3. **Accessibility Tree:** Semantic information about elements (roles, names). 4. **Visual Elements:** Elements detected visually via Computer Vision (CV), including their bounding boxes `[L:left, T:top, R:right, B:bottom]` and IDs (e.g., `cv-0`, `cv-1`). Based on the task and ALL available perception information, decide the single next action to take. Your response MUST be a JSON object with a single top-level key named "action". The value of the "action" key MUST be an object matching one of the following action schemas: - Navigate: {{"type": "navigate", "url": ""}} - Click: {{"type": "click", "selector": "", "description": ""}} - Type: {{"type": "type", "selector": "", "text": "", "description": ""}} - Scroll: {{"type": "scroll", "direction": ""}} - Finish: {{"type": "finish", "result": ""}} - Error: {{"type": "error", "message": ""}} (Use if you detect an unrecoverable error or loop) - GetContent: {{"type": "get_content", "description": ""}} **Important Task Handling Guidance:** 1. **Identify elements** using the DOM, AX Tree (if available), and Visual Elements. Use robust selectors as previously guided. 2. **If the task requires reading or extracting content that might extend beyond the current view (e.g., '摘录全文', 'find all items', 'read the article'), and you haven't finished scrolling, your next action should likely be to SCROLL DOWN.** Use: `{{"action": {{"type": "scroll", "direction": "down"}}}}` 3. Only use `get_content` if you believe scrolling will not help or if you need to re-analyze after a non-scroll action. 4. Once you believe you have scrolled enough and have all necessary information visible in the content provided, proceed with the extraction or final action. 5. If the task is complete, use the 'finish' action. Example Response: ```json {{ "action": {{ "type": "click", "selector": "a[x-pw-id='pw-16']:has-text('new')", "description": "Click the 'new' link, corresponds to visual element cv-3" }} }} {{ "action": {{ "type": "scroll", "direction": "down" }} }} ``` Provide ONLY the JSON object containing the 'action' key in a ```json ... ``` block. Think step-by-step. Correlate information from the DOM, AX Tree, and Visual Elements if possible. Choose the most precise and stable selector. If the task is complete, use the 'finish' action. """ prompt_parts.append(instructions) # --- End Instructions --- final_prompt = "\n".join(prompt_parts) return final_prompt ================================================ FILE: super_agents/browser_use/agent/schemas.py ================================================ # super_agents/browser_use/agent/schemas.py from typing import Literal, Optional, Union, List, Dict, Any, Type # Use Pydantic V2+ if installed, otherwise V1 syntax try: from pydantic.v1 import BaseModel, Field except ImportError: from pydantic import BaseModel, Field # Fallback to V2 # --- Action Type --- ActionTypeLiteral = Literal[ "navigate", "click", "type", "scroll", "wait", "get_content", "finish", "error" ] # --- Pydantic Schemas for Actions --- # Using Pydantic allows for better validation and compatibility # with LangChain's structured output features. class BaseAction(BaseModel): """Base schema for all actions, containing the type.""" type: ActionTypeLiteral = Field(..., description="The type of action to perform.") class NavigateAction(BaseAction): type: Literal["navigate"] = "navigate" url: str = Field(..., description="The URL to navigate to.") class ClickAction(BaseAction): type: Literal["click"] = "click" selector: str = Field(..., description="CSS selector for the element to click.") description: Optional[str] = Field(None, description="Optional description of the element being clicked.") class TypeAction(BaseAction): type: Literal["type"] = "type" selector: str = Field(..., description="CSS selector for the input field.") text: str = Field(..., description="The text to type into the field.") description: Optional[str] = Field(None, description="Optional description of the element being typed into.") class ScrollAction(BaseAction): type: Literal["scroll"] = "scroll" direction: Literal["up", "down", "left", "right"] = Field(..., description="The direction to scroll the page.") # selector: Optional[str] = Field(None, description="Optional CSS selector of element to scroll within.") # Add if needed class WaitAction(BaseAction): type: Literal["wait"] = "wait" milliseconds: int = Field(..., description="Duration to wait in milliseconds.") class GetContentAction(BaseAction): type: Literal["get_content"] = "get_content" # No extra fields needed, just signifies intent to refresh state description: Optional[str] = Field("Requesting updated browser content", description="Reason for requesting content.") class FinishAction(BaseAction): type: Literal["finish"] = "finish" result: str = Field(..., description="The final answer or summary of the completed task.") class ErrorAction(BaseAction): type: Literal["error"] = "error" message: str = Field(..., description="Description of the error encountered or signaled by the LLM.") # --- Union for Parsing --- # LangChain's with_structured_output often works best when targeting a single Pydantic model # that uses discriminated unions (if available in your Pydantic version) or by prompting # the LLM clearly to only output ONE type of action JSON matching the base structure. # For simplicity here, we define the *expected output structure* the LLM should generate. # The parsing function might need refinement based on how the LLM structures the output. # Define the overall structure the LLM should output, which includes one of the actions. # This structure helps `with_structured_output`. class LLMResponse(BaseModel): action: Union[ NavigateAction, ClickAction, TypeAction, ScrollAction, WaitAction, GetContentAction, FinishAction, ErrorAction ] = Field(..., description="The specific action determined by the LLM.") # --- Parsing Function (Placeholder/Example) --- # The `generate_structured_output` function in llm.py now handles the parsing # directly into the Pydantic schema (LLMResponse). # So, we might not need a separate manual parsing function here if using that. # If you need manual parsing from raw text (less reliable): # def parse_llm_response_manual(response: str) -> Optional[BaseAction]: # # ... (complex logic using regex or JSON parsing as in previous example) # # This would return one of the action models (NavigateAction, ClickAction, etc.) # pass ================================================ FILE: super_agents/browser_use/agent/state.py ================================================ # super_agents/browser_use/agent/state.py from typing import Dict, List, Optional, Any, TypedDict # Define the state structure using TypedDict for type hinting class AgentState(TypedDict, total=False): """ TypedDict representing the state of the browser agent during execution. Attributes: task: The user task description browser_content: The current HTML content of the browser parsed_action: The last action parsed from LLM response history: List of previous actions taken error: Any error message from the last operation """ task: str browser_content: str parsed_action: Dict[str, Any] history: List[str] error: Optional[str] ================================================ FILE: super_agents/browser_use/agent/tools.py ================================================ ================================================ FILE: super_agents/browser_use/agent.py ================================================ # super_agents/browser_use/agent.py """ Agent API for browser-based task execution. Provides a simplified interface similar to the original implementation. """ import asyncio import logging from typing import Any, Dict, Optional from .agent.graph import create_graph_app from .agent.state import AgentState from .browser.browser import Browser from .browser.config import BrowserConfig from .llm import initialize_llms logger = logging.getLogger(__name__) class Agent: """ Agent class that provides a simple interface for browser automation with LLM. This implementation is similar to the original API but uses the current browser automation stack with LangGraph. """ def __init__( self, llm=None, browser_config: Optional[BrowserConfig] = None, max_steps: int = 50 ): """ Initialize the Agent with optional LLM and browser configuration. Args: llm: LLM instance to use (if None, will initialize from environment) browser_config: Browser configuration options max_steps: Maximum number of steps the agent can take """ self.browser_config = browser_config or BrowserConfig() self.llm = llm self.max_steps = max_steps self.browser = None self._app = None async def _initialize(self): """Initialize the browser and LLM if not already initialized.""" # Initialize LLM if not provided if self.llm is None: logger.info("Initializing LLM from environment variables") self.llm, _ = initialize_llms() if self.llm is None: raise ValueError("Failed to initialize LLM. Check API keys and .env settings.") # Initialize browser self.browser = Browser(config=self.browser_config) await self.browser.initialize() # Initialize LangGraph app self._app = create_graph_app(browser=self.browser, llm=self.llm) async def run(self, prompt: str) -> Dict[str, Any]: """ Run the agent with the given prompt/task. Args: prompt: The task description or prompt for the agent Returns: Dictionary containing the execution result """ # Ensure initialization if self.browser is None or self._app is None: await self._initialize() # Define the initial state initial_state = AgentState( task=prompt, browser_content="", parsed_action={}, history=[], error=None ) # Run the graph logger.info(f"Starting agent execution for task: {prompt}") try: final_state = await self._app.ainvoke( initial_state, config={"recursion_limit": self.max_steps} ) # Process result if final_state.get("error"): logger.error(f"Agent finished with error: {final_state['error']}") return {"result": f"Error: {final_state['error']}", "success": False} elif final_state.get("parsed_action", {}).get("type") == "finish": result = final_state["parsed_action"].get("result", "Task finished, but no result extracted.") logger.info(f"Agent finished successfully. Result: {result}") return {"result": result, "success": True} else: logger.warning("Agent finished without a 'finish' action or error.") return { "result": "Agent stopped without producing a final answer.", "success": False, "state": final_state } except Exception as e: logger.error(f"Agent execution failed: {e}", exc_info=True) return {"result": f"Error during execution: {str(e)}", "success": False} finally: # Clean up resources if self.browser: await self.browser.close() self.browser = None self._app = None def __del__(self): """Ensure resources are cleaned up.""" if self.browser: asyncio.create_task(self.browser.close()) # Provider classes for compatibility with original API class OpenAIProvider: """OpenAI provider compatible with the interface""" def __init__(self, model="gpt-4o-mini", api_key=None, temperature=0.1): """ Initialize OpenAI provider. Args: model: Model name to use api_key: OpenAI API key (if None, will use from environment) temperature: Temperature for generation """ self.model = model self.api_key = api_key self.temperature = temperature # These parameters will be used by initialize_llms() internally import os if api_key: os.environ["OPENAI_API_KEY"] = api_key os.environ["LLM_PROVIDER"] = "openai" os.environ["LLM_MODEL_NAME"] = model os.environ["LLM_TEMPERATURE"] = str(temperature) class AnthropicProvider: """Anthropic provider compatible with the interface""" def __init__(self, model="claude-3-opus-20240229", api_key=None, temperature=0.1, enable_thinking=False, thinking_token_budget=None): """ Initialize Anthropic provider. Args: model: Model name to use api_key: Anthropic API key (if None, will use from environment) temperature: Temperature for generation enable_thinking: Enable thinking step (not fully supported in current implementation) thinking_token_budget: Tokens for thinking (not fully supported) """ self.model = model self.api_key = api_key self.temperature = temperature self.enable_thinking = enable_thinking self.thinking_token_budget = thinking_token_budget # These parameters will be used by initialize_llms() internally import os if api_key: os.environ["ANTHROPIC_API_KEY"] = api_key os.environ["LLM_PROVIDER"] = "anthropic" os.environ["LLM_MODEL_NAME"] = model os.environ["LLM_TEMPERATURE"] = str(temperature) # Add convenience imports to __init__.py # This will allow: from super_agents.browser_use import Agent, OpenAIProvider, BrowserConfig ================================================ FILE: super_agents/browser_use/browser/browser.py ================================================ # super_agents/browser_use/browser/browser.py """ Streamlined Playwright browser implementation with integrated perception capabilities. Includes DOM/AX Tree/Visual analysis and basic interaction methods. """ import asyncio import json import logging import functools import base64 import os from dataclasses import dataclass, field # from importlib import resources # Not used from typing import Any, Optional, TypedDict, List, Dict # Added List, Dict # --- Local Imports (Ensure these files exist in the same directory) --- try: from .observe_helper import observe except ImportError: def observe(name, ignore_input=False, ignore_output=False): def decorator(func): return func return decorator logging.basicConfig(level=logging.WARNING) # Setup basic logging if needed logger_observe = logging.getLogger(__name__) logger_observe.warning("observe_helper not found, using dummy decorator.") try: from .detector import Detector from .models import ( BrowserError, BrowserState, InteractiveElementsData, TabInfo, InteractiveElement, ) from .utils import ( combine_and_filter_elements, put_highlight_elements_on_screenshot, ) except ImportError as e: logging.basicConfig(level=logging.ERROR) logger_import = logging.getLogger(__name__) logger_import.error(f"Failed to import local browser dependencies (detector, models, utils): {e}. Browser class may not function correctly.", exc_info=True) # Define dummy classes to allow file loading, but functionality will be broken class Detector: enabled=False class BrowserError(Exception): pass class BrowserState: pass class InteractiveElementsData: elements=[]; viewport={} class TabInfo: pass class InteractiveElement: pass def combine_and_filter_elements(a, b): return [] def put_highlight_elements_on_screenshot(a, b): return None # --- End Local Imports --- # --- Playwright Imports --- from playwright.async_api import ( Browser as PlaywrightBrowser, BrowserContext as PlaywrightBrowserContext, Page, Playwright, StorageState, async_playwright, Error as PlaywrightError ) # --- Tenacity Import --- from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) logger = logging.getLogger(__name__) # Ensure basic logging is configured if not done elsewhere if not logger.hasHandlers(): logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') # --- Load JavaScript Files --- INTERACTIVE_ELEMENTS_JS_CODE = "" SIMPLIFY_PAGE_SCRIPT = "" try: current_dir = os.path.dirname(os.path.abspath(__file__)) # JS for DOM-based interactive elements used in get_interactive_elements_data js_file_path_interactive = os.path.join(current_dir, 'findVisibleInteractiveElements.js') with open(js_file_path_interactive, 'r', encoding='utf-8') as js_file: INTERACTIVE_ELEMENTS_JS_CODE = js_file.read() # JS for DOM simplification used in get_content # (Re-paste the script here for completeness) SIMPLIFY_PAGE_SCRIPT = """ (() => { const MAX_ELEMENTS = 250; const MAX_TEXT_LENGTH = 200; const INTERACTIVE_TAGS = ['a', 'button', 'input', 'textarea', 'select', 'option', 'details', 'summary', 'label']; const EXCLUDED_TAGS = ['script', 'style', 'noscript', 'svg', 'link', 'meta', 'head', 'embed', 'object', 'path', 'canvas', 'iframe', 'video', 'audio']; let elementCount = 0; let uniqueIdCounter = 0; function isVisible(el) { if (!el || !el.checkVisibility) return false; return el.checkVisibility({checkOpacity: true, checkVisibilityCSS: true}); } function truncateText(text, maxLength = MAX_TEXT_LENGTH) { if (typeof text !== 'string') return text; return text.length > maxLength ? text.substring(0, maxLength) + '...' : text; } function getElementData(el) { const data = { tag: el.tagName.toLowerCase(), attributes: {}, text: '', children: [], pw_id: `pw-${uniqueIdCounter++}` }; try { if (document.body.contains(el)) el.setAttribute('x-pw-id', data.pw_id); } catch(e){} const attrsToKeep = ['id', 'class', 'role', 'aria-label', 'aria-labelledby', 'aria-describedby', 'aria-hidden', 'aria-invalid', 'aria-required', 'placeholder', 'title', 'alt', 'for', 'name', 'type', 'href', 'value', 'selected', 'checked', 'disabled', 'readonly', 'open']; for (const attr of attrsToKeep) { if (el.hasAttribute(attr)) { let value = el.getAttribute(attr); if (attr === 'class' && value) value = value.split(' ').filter(c => c && c.length > 1 && c.length < 30 && !/^[0-9]+$/.test(c)).slice(0, 5).join(' '); if (value !== null && value !== '') data.attributes[attr] = truncateText(String(value), 80); } } if (['button', 'a', 'label', 'summary'].includes(data.tag) && !data.attributes['aria-label'] && el.textContent) data.attributes['aria-label'] = truncateText(el.textContent.trim(), 80); try { if (el.tagName.toLowerCase() === 'input' && !data.attributes.value && el.value) data.attributes.value = truncateText(el.value); else if (el.tagName.toLowerCase() === 'textarea' && !data.attributes.value && el.value) data.attributes.value = truncateText(el.value); else if (el.tagName.toLowerCase() === 'select' && el.options && el.selectedIndex !== -1 && !data.attributes.value) data.attributes.value = truncateText(el.options[el.selectedIndex].text); } catch (e) {} try { const directText = Array.from(el.childNodes).filter(node => node.nodeType === Node.TEXT_NODE && node.textContent.trim().length > 0).map(node => node.textContent.trim()).join(' ').replace(/\s+/g, ' '); if (directText) data.text = truncateText(directText); } catch (e) {} return data; } function simplifyNode(node) { if (elementCount >= MAX_ELEMENTS) return null; if (node.nodeType !== Node.ELEMENT_NODE || EXCLUDED_TAGS.includes(node.tagName.toLowerCase())) { if(node.nodeType === Node.TEXT_NODE && node.textContent.trim().length === 0) return null; return null; } elementCount++; const elementData = getElementData(node); if (node.hasChildNodes()) { Array.from(node.childNodes).forEach(child => { if (INTERACTIVE_TAGS.includes(node.tagName.toLowerCase()) && child.nodeType === Node.ELEMENT_NODE) return; const simplifiedChild = simplifyNode(child); if (simplifiedChild) elementData.children.push(simplifiedChild); }); } const isInteractive = INTERACTIVE_TAGS.includes(elementData.tag); const hasMeaningfulAttrs = Object.keys(elementData.attributes).some(k => k !== 'x-pw-id'); if (!isInteractive && !hasMeaningfulAttrs && elementData.children.length === 0 && !elementData.text) { try { if (document.body.contains(node)) node.removeAttribute('x-pw-id'); } catch(e){} return null; } return elementData; } if (!document.body) return " element not found."; const simplifiedBody = simplifyNode(document.body); function convertToPseudoHTML(node) { if (!node) return ''; let attrs = `x-pw-id="${node.pw_id}"`; for (const [key, value] of Object.entries(node.attributes)) attrs += ` ${key}="${String(value).replace(/"/g, '"')}"`; let childrenHTML = node.children.map(convertToPseudoHTML).join(''); let textContent = node.text ? String(node.text).replace(//g, '>') : ''; if (['input', 'img', 'br', 'hr'].includes(node.tag)) return `<${node.tag} ${attrs} />`; else return `<${node.tag} ${attrs}>${textContent}${childrenHTML}`; } return convertToPseudoHTML(simplifiedBody); })() """ except FileNotFoundError: logger.error(f"JavaScript file 'findVisibleInteractiveElements.js' not found in {current_dir}. Interactive element detection (JS based) will fail.") INTERACTIVE_ELEMENTS_JS_CODE = "() => ({ viewport: { width: window.innerWidth, height: window.innerHeight }, elements: [] });" # Provide fallback except Exception as e: logger.error(f"Error loading JavaScript file(s): {e}", exc_info=True) INTERACTIVE_ELEMENTS_JS_CODE = "() => ({ viewport: { width: window.innerWidth, height: window.innerHeight }, elements: [] });" SIMPLIFY_PAGE_SCRIPT = "() => 'Error loading simplification script.';" # --- TypedDict for Viewport Size --- class ViewportSize(TypedDict): width: int height: int # --- BrowserConfig Dataclass (Corrected: No CV Endpoints) --- @dataclass class BrowserConfig: """ Configuration for the Browser. """ cdp_url: Optional[str] = None viewport_size: ViewportSize = field(default_factory=lambda: {"width": 1200, "height": 900}) storage_state: Optional[StorageState] = None # CV/Sheets Endpoints Removed # --- Main Browser Class --- class Browser: """ Unified Browser responsible for interacting with the browser via Playwright. Includes methods for navigation, simple actions, perception (DOM, AX Tree, optional VLM), and state management. Initializes its own VLM detector based on environment variables. """ def __init__(self, config: BrowserConfig = BrowserConfig(), close_context: bool = True): """ Initializes the Browser instance. """ logger.debug('Initializing browser') self.config = config self.close_context = close_context # Playwright attributes self.playwright: Optional[Playwright] = None self.playwright_browser: Optional[PlaywrightBrowser] = None self.context: Optional[PlaywrightBrowserContext] = None # Page and state management self.current_page: Optional[Page] = None self._state: Optional[BrowserState] = None # This holds the rich state from update_state self._cdp_session = None # Initialize Detector internally try: self.detector: Optional[Detector] = Detector() if not self.detector.enabled: self.detector = None logger.warning("Detector initialized but disabled due to missing config/errors.") else: logger.info("Detector initialized successfully.") except NameError: logger.error("Detector class not found (likely due to import errors). Vision disabled.") self.detector = None except Exception as e: logger.error(f"Unexpected error initializing Detector: {e}", exc_info=True) self.detector = None # REMOVED self._init_state() call as method doesn't exist / state init is implicit # --- Context Management Methods --- async def __aenter__(self): await self.initialize() return self async def __aexit__(self, exc_type, exc_val, exc_tb): if self.close_context: await self.close() # --- Public Initialization and Closing --- async def initialize(self): """Initializes browser, context, page if not already done.""" if self.current_page and self.context and self.playwright_browser and self.playwright: logger.debug("Browser already initialized.") return self logger.info("Initializing browser instance via initialize()") # Changed level await self._init_browser() return self async def close(self): """Closes the browser and cleans up Playwright resources.""" if not self.playwright: return logger.info('Closing browser...') try: self._cdp_session = None if self.context: try: await self.context.close() except Exception as e: logger.warning(f'Failed to close context: {e}') if self.playwright_browser and not self.config.cdp_url: try: await self.playwright_browser.close() except Exception as e: logger.warning(f'Failed to close browser: {e}') if self.playwright: try: await self.playwright.stop() except Exception as e: logger.warning(f'Failed to stop Playwright: {e}') except Exception as e: logger.error(f'Error during browser cleanup: {e}', exc_info=True) finally: # Ensure attributes are cleared self.context = None; self.current_page = None; self._state = None self.playwright_browser = None; self.playwright = None; self._cdp_session = None logger.info("Browser closed.") # --- Internal Initialization Helper --- async def _init_browser(self): """Internal method to initialize Playwright components.""" if self.current_page and self.context: return # Avoid re-init if basics exist logger.debug('Running internal browser context initialization _init_browser()') try: if self.playwright is None: self.playwright = await async_playwright().start() if self.playwright_browser is None: if self.config.cdp_url: logger.info(f'Connecting to remote browser via CDP {self.config.cdp_url}') self.playwright_browser = await self.playwright.chromium.connect_over_cdp(self.config.cdp_url, timeout=5000) else: logger.info(f'Launching new browser instance (headless=False assumed)') # Note: Headless mode might need to be configurable via BrowserConfig again if needed self.playwright_browser = await self.playwright.chromium.launch( headless=False, args=[ # Common args for stability/anti-detection '--no-sandbox', '--disable-setuid-sandbox', '--disable-infobars', '--disable-blink-features=AutomationControlled', '--disable-dev-shm-usage', '--disable-gpu', '--window-size=1200,900', # Use configured size later # '--disable-web-security', # Use with caution # '--disable-site-isolation-trials', # '--disable-features=IsolateOrigins,site-per-process', ] ) if self.context is None: existing_contexts = self.playwright_browser.contexts if existing_contexts and not self.config.cdp_url: # Reuse only if we launched it? Be careful. self.context = existing_contexts[0] logger.info("Reusing existing browser context.") else: logger.info("Creating new browser context.") self.context = await self.playwright_browser.new_context( viewport=self.config.viewport_size, user_agent='Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/108.0.0.0 Safari/537.36', java_script_enabled=True, bypass_csp=True, ignore_https_errors=True, storage_state=self.config.storage_state if self.config.storage_state else None ) await self._apply_anti_detection_scripts() # Apply only to new contexts self.context.on('page', self._on_page_change) # Attach listener if self.current_page is None: if len(self.context.pages) > 0: self.current_page = self.context.pages[-1] # Default to last open page logger.info(f"Using existing page: {self.current_page.url}") else: self.current_page = await self.context.new_page() logger.info("Created new page.") # Ensure viewport is applied regardless try: await self.current_page.set_viewport_size(self.config.viewport_size) except Exception as vp_err: logger.warning(f"Failed to set viewport: {vp_err}") if not self.current_page: raise BrowserError("Failed to get or create a page.") await self.get_cdp_session() # Initialize CDP session for current page except PlaywrightError as pe: logger.error(f"Playwright Error during browser init: {pe}", exc_info=True) await self.close(); raise BrowserError(f"Playwright initialization failed: {pe}") from pe except Exception as e: logger.error(f"Unexpected error during browser init: {e}", exc_info=True) await self.close(); raise BrowserError(f"Unexpected browser initialization failed: {e}") from e # --- Method Implementations (Ensure ALL referenced methods are defined) --- async def _apply_anti_detection_scripts(self): """Apply scripts to avoid detection as automation""" if self.context is None: return # Should not happen if called from _init_browser correctly try: await self.context.add_init_script( """ Object.defineProperty(navigator, 'webdriver', { get: () => undefined }); Object.defineProperty(navigator, 'languages', { get: () => ['en-US', 'en'] }); Object.defineProperty(navigator, 'plugins', { get: () => [] }); // Empty is safer // ... other scripts from previous version ... const originalQuery = window.navigator.permissions.query; window.navigator.permissions.query = (parameters) => ( parameters.name === 'notifications' ? Promise.resolve({ state: Notification.permission }) : originalQuery(parameters) ); """ ) logger.debug("Applied anti-detection init script.") except Exception as e: logger.error(f"Failed to add anti-detection init script: {e}", exc_info=True) async def _on_page_change(self, page: Page): """Handle page creation/popup events.""" # Don't automatically switch current page, just log logger.info(f'Page event detected. New/Popup URL: {page.url}') self._cdp_session = None # Invalidate CDP session as context changed async def get_current_page(self) -> Page: """Get the current page, ensuring browser is initialized.""" if self.current_page is None or self.current_page.is_closed(): logger.warning("Current page is None or closed, re-initializing.") await self._init_browser() if self.current_page is None: raise BrowserError("Unable to get a valid page.") return self.current_page # Inside Browser class in browser.py async def get_cdp_session(self): """Get or create a CDP session for the *current* page.""" page = await self.get_current_page() session_invalid = True # Assume invalid if self._cdp_session: # More robust check: try a simple CDP command to see if session is active try: # Example: Get cookies via CDP (relatively harmless check) await self._cdp_session.send("Network.getAllCookies") # Check if session page matches current page (using internal attr is risky) if hasattr(self._cdp_session, '_client') and hasattr(self._cdp_session._client, '_page') and self._cdp_session._client._page == page: session_invalid = False # Session seems alive and for the correct page else: logger.debug("CDP session page mismatch or internals unclear, recreating.") except Exception as session_check_err: logger.debug(f"Existing CDP session check failed ({session_check_err}), recreating.") session_invalid = True if session_invalid: try: if self.context is None: await self._init_browser() logger.debug(f"Attempting to create new CDP session for page: {page.url}") self._cdp_session = await self.context.new_cdp_session(page) logger.debug(f"Created new CDP session successfully.") except Exception as e: logger.error(f"Failed to create CDP session: {e}", exc_info=True) self._cdp_session = None raise BrowserError(f"Failed to create CDP session: {e}") from e return self._cdp_session @observe(name='browser.fast_screenshot', ignore_output=True) async def fast_screenshot(self) -> str: """Returns a base64 encoded screenshot using CDP.""" cdp_session = await self.get_cdp_session() try: screenshot_data = await cdp_session.send("Page.captureScreenshot", {"format": "png", "fromSurface": False, "captureBeyondViewport": False}) return screenshot_data["data"] except Exception as e: logger.error(f"Failed to capture screenshot via CDP: {e}") # Fallback to playwright's screenshot? Or raise error? page = await self.get_current_page() try: logger.warning("CDP screenshot failed, falling back to Playwright screenshot.") buffer = await page.screenshot() return base64.b64encode(buffer).decode() except Exception as pw_e: logger.error(f"Fallback Playwright screenshot also failed: {pw_e}") raise BrowserError(f"Failed to take screenshot: {e}") from e # --- Simple Action Methods --- @observe(name='browser.navigate_to') async def navigate_to(self, url: str): page = await self.get_current_page() logger.info(f"Navigating to: {url}") try: await page.goto(url, wait_until='domcontentloaded', timeout=60000) logger.info(f"Navigation successful. Current URL: {page.url}") except PlaywrightError as e: raise BrowserError(f"Navigation failed: {e}") from e except Exception as e: raise BrowserError(f"Navigation failed unexpectedly: {e}") from e @observe(name='browser.click') async def click(self, selector: str): page = await self.get_current_page() logger.info(f"Attempting to click element: '{selector}'") try: element = page.locator(selector).first await element.wait_for(state="visible", timeout=15000) await element.scroll_into_view_if_needed(timeout=10000) await element.click(timeout=15000, delay=50) logger.info(f"Successfully clicked element: '{selector}'") except PlaywrightError as e: raise BrowserError(f"Click action failed: {e}") from e except Exception as e: raise BrowserError(f"Click action failed unexpectedly: {e}") from e @observe(name='browser.type') async def type(self, selector: str, text: str): page = await self.get_current_page() log_text = '***' if 'password' in selector.lower() else text logger.info(f"Attempting to type into element: '{selector}', Text: '{log_text}'") try: element = page.locator(selector).first await element.wait_for(state="visible", timeout=15000) await element.scroll_into_view_if_needed(timeout=10000) await element.fill(text, timeout=15000) logger.info(f"Successfully typed into element: '{selector}'") except PlaywrightError as e: raise BrowserError(f"Type action failed: {e}") from e except Exception as e: raise BrowserError(f"Type action failed unexpectedly: {e}") from e @observe(name='browser.scroll') async def scroll(self, direction: str): page = await self.get_current_page() logger.info(f"Scrolling page {direction}") try: if direction == "down": await page.evaluate("window.scrollBy(0, window.innerHeight)") elif direction == "up": await page.evaluate("window.scrollBy(0, -window.innerHeight)") elif direction == "left": await page.evaluate("window.scrollBy(-window.innerWidth, 0)") elif direction == "right": await page.evaluate("window.scrollBy(window.innerWidth, 0)") else: logger.warning(f"Unknown scroll direction: {direction}"); return await asyncio.sleep(0.3) logger.info(f"Scrolled page {direction}") except PlaywrightError as e: raise BrowserError(f"Scroll action failed: {e}") from e except Exception as e: raise BrowserError(f"Scroll action failed unexpectedly: {e}") from e async def wait(self, milliseconds: int): logger.info(f"Waiting for {milliseconds} ms") if milliseconds <= 0: return await asyncio.sleep(milliseconds / 1000.0) logger.info("Wait finished") # --- Perception & State Methods --- async def get_content(self, max_length: int = 120000) -> str: """Gets comprehensive text representation: URL, DOM, AX Tree, VLM Elements.""" page = await self.get_current_page() logger.info("Getting comprehensive page content with vision...") combined_content = "" error_messages = [] current_url = "Unknown" screenshot_b64 = None try: current_url = page.url combined_content += f"# Page URL:\n{current_url}\n\n" try: screenshot_b64 = await self.fast_screenshot() logger.debug(f"Screenshot captured (size: {len(screenshot_b64) if screenshot_b64 else 0})") except Exception as ss_err: error_messages.append(f"Screenshot Error: {ss_err}"); logger.error("Screenshot error", exc_info=False); combined_content += "# Screenshot Error\n" try: if SIMPLIFY_PAGE_SCRIPT: simplified_dom = await page.evaluate(SIMPLIFY_PAGE_SCRIPT) if simplified_dom: combined_content += f"# Simplified DOM:\n```html\n{simplified_dom}\n```\n\n"; logger.debug(f"DOM length: {len(simplified_dom)}") else: combined_content += "# Simplified DOM:\n(Empty)\n\n"; logger.warning("JS simplification empty.") else: combined_content += "# Simplified DOM:\n(JS Script Error)\n\n"; logger.error("SIMPLIFY_PAGE_SCRIPT empty.") except Exception as js_err: error_messages.append(f"JS Error: {js_err}"); logger.error("JS Simp. Error", exc_info=False); combined_content += f"# Simplified DOM Error: {js_err}\n" try: ax_tree = await page.accessibility.snapshot(interesting_only=False) # No root arg if ax_tree: try: ax_tree_str = json.dumps(ax_tree, separators=(',', ':')) # Compact ax_max_len = 2000 if len(ax_tree_str) > ax_max_len: ax_tree_str = ax_tree_str[:ax_max_len] + "...(AX Tree truncated)" combined_content += f"# Accessibility Tree (JSON, Partial):\n```json\n{ax_tree_str}\n```\n\n"; logger.debug(f"AX Tree length: {len(ax_tree_str)}") except Exception as json_err: error_messages.append(f"AX JSON Error: {json_err}"); logger.error("AX JSON Error", exc_info=False); combined_content += "# AX Tree Error (JSON)\n" else: combined_content += "# Accessibility Tree:\n(Empty)\n\n"; logger.warning("AX snapshot empty.") except Exception as ax_err: error_messages.append(f"AX Tree Error: {ax_err}"); logger.error("AX Tree Error", exc_info=False); combined_content += f"# Accessibility Tree Error: {ax_err}\n" if self.detector and screenshot_b64: logger.info("Attempting visual detection via Detector...") try: detect_sheets = 'docs.google.com/spreadsheets/d' in current_url visual_elements = await self.detector.detect_from_image(screenshot_b64, detect_sheets) if visual_elements: formatted = [f"- ID: {el.browser_agent_id}, Box: [L:{el.rect.get('left',0)}, T:{el.rect.get('top',0)}, R:{el.rect.get('right',0)}, B:{el.rect.get('bottom',0)}] (Tag: {el.tag_name})" for el in visual_elements[:20]] combined_content += f"# Visual Elements (Detected via CV, Max 20):\n{chr(10).join(formatted)}\n\n"; logger.info(f"Added {len(formatted)} visual elements.") # Use chr(10) for newline else: combined_content += "# Visual Elements:\n(None detected or VLM error)\n\n"; logger.info("No visual elements detected.") except Exception as cv_err: error_messages.append(f"CV Error: {cv_err}"); logger.error("CV Detector Error", exc_info=True); combined_content += f"# Visual Elements Error: {cv_err}\n" else: if not self.detector: logger.info("CV Detector not available.") if not screenshot_b64: logger.info("Screenshot missing.") combined_content += "# Visual Elements:\n(Not Run)\n\n" if len(combined_content) > max_length: logger.warning(f"Combined content ({len(combined_content)}) exceeds limit ({max_length}). Truncating.") reserve = len("\n\n# Content Retrieval Errors:\n- ") + sum(len(str(e)) + 4 for e in error_messages) + 50 trunc_len = max(0, max_length - reserve); combined_content = combined_content[:trunc_len].rstrip() + "\n\n... (Content truncated)" if error_messages: combined_content += "\n\n# Content Retrieval Errors:\n- " + "\n- ".join(map(str, error_messages)) logger.info(f"Finished getting content (final length: {len(combined_content)})") return combined_content except Exception as e: logger.error(f"General error in get_content: {e}", exc_info=True); return f"# Page URL:\n{current_url}\n# Error:\nFailed to get content: {e}" # --- Other Methods from Original Code --- async def get_cookies(self) -> list[dict[str, Any]]: """Get cookies from the current browser context.""" if self.context: try: return await self.context.cookies() except Exception as e: logger.error(f"Failed to get cookies: {e}"); return [] return [] async def get_storage_state(self) -> dict[str, Any]: """Get storage state (currently only cookies) from the browser.""" # Playwright's get_storage_state includes local/session storage too, # but might require more careful handling or filtering if large. # Sticking to cookies for simplicity based on original user code structure. if self.context: try: # cookies = await self.context.cookies() # Redundant if get_cookies exists # return {'cookies': cookies} # Or use the full state function if available and needed state = await self.context.storage_state() return state except Exception as e: logger.error(f"Failed to get storage state: {e}") return {} return {} async def get_tabs_info(self) -> list[TabInfo]: """Get information about all open tabs in the current context.""" tabs_info = [] if not self.context: return [] try: # Ensure pages list is accessed correctly pages = self.context.pages for i, page in enumerate(pages): if not page.is_closed(): # Check if page is open try: url = page.url title = await page.title() # Ensure TabInfo model is available tabs_info.append(TabInfo(page_id=i, url=url, title=title)) except Exception as page_err: logger.warning(f"Failed to get info for tab {i}: {page_err}") # Add placeholder if needed? tabs_info.append(TabInfo(page_id=i, url="Error", title="Error retrieving info")) except Exception as e: logger.error(f"Failed to get tabs info: {e}") return tabs_info async def switch_to_tab(self, page_id: int) -> None: """Switch focus to a specific tab by its index.""" if self.context is None: await self._init_browser() pages = self.context.pages if not 0 <= page_id < len(pages): raise BrowserError(f'Invalid page_id: {page_id}. Available pages: {len(pages)}') if pages[page_id].is_closed(): raise BrowserError(f'Page with page_id {page_id} is closed.') logger.info(f"Switching to tab (page_id): {page_id}") self.current_page = pages[page_id] try: await self.current_page.bring_to_front() # Wait briefly for potential state changes after switch await self.current_page.wait_for_load_state('domcontentloaded', timeout=5000) except Exception as e: logger.warning(f"Error during tab switch finalization for page {page_id}: {e}") # Continue anyway, page is switched internally async def create_new_tab(self, url: str | None = None) -> None: """Create a new tab, optionally navigating to a URL, and switch to it.""" if self.context is None: await self._init_browser() logger.info(f"Creating new tab. Navigate to: {url if url else 'about:blank'}") try: new_page = await self.context.new_page() self.current_page = new_page # Switch focus to the new page if url: await self.navigate_to(url) # Reuse navigate method else: await new_page.wait_for_load_state('domcontentloaded') # Wait for about:blank load logger.info(f"Switched to new tab. URL: {self.current_page.url}") except Exception as e: logger.error(f"Failed to create new tab: {e}") raise BrowserError(f"Failed to create new tab: {e}") from e async def close_current_tab(self): """Close the currently focused tab.""" if self.current_page is None: logger.warning("No current page to close."); return if len(self.context.pages) <= 1: logger.warning("Cannot close the last remaining tab."); return # Prevent closing last tab? Or allow context close? logger.info(f"Closing current tab: {self.current_page.url}") page_to_close = self.current_page # Find index to switch to after closing (e.g., previous or first) pages = self.context.pages current_index = pages.index(page_to_close) if page_to_close in pages else -1 switch_to_index = 0 if current_index != 0 else 1 # Switch to first unless closing first if switch_to_index >= len(pages): switch_to_index = 0 # Fallback try: await page_to_close.close() logger.info("Tab closed.") # Need to wait briefly for context.pages to update sometimes await asyncio.sleep(0.1) # Switch to another tab if possible if self.context and self.context.pages: new_current_page = self.context.pages[min(switch_to_index, len(self.context.pages)-1)] self.current_page = new_current_page await self.current_page.bring_to_front() logger.info(f"Switched to tab index {min(switch_to_index, len(self.context.pages)-1)} after closing.") else: self.current_page = None # No pages left logger.info("Closed the last tab.") except Exception as e: logger.error(f"Error closing tab or switching: {e}") # Attempt to recover current page if possible if self.context and self.context.pages: self.current_page = self.context.pages[0] else: self.current_page = None async def refresh_page(self): """Refresh the current page.""" page = await self.get_current_page() logger.info(f"Refreshing page: {page.url}") try: await page.reload(wait_until='domcontentloaded') logger.info("Page refreshed.") except Exception as e: logger.error(f"Failed to refresh page: {e}") raise BrowserError(f"Failed to refresh page: {e}") from e async def go_forward(self): """Navigate forward in the current page's history.""" page = await self.get_current_page() logger.info(f"Going forward in history for: {page.url}") try: await page.go_forward(wait_until='domcontentloaded', timeout=10000) # Added timeout logger.info(f"Navigated forward. New URL: {page.url}") except Exception as e: # Often fails if no forward history exists, log as warning logger.warning(f'Failed to go forward (might be end of history): {e}') # raise BrowserError(f"Failed to go forward: {e}") from e # Option: re-raise if needed # --- State Update Methods (using CV potentially) --- def get_state(self) -> Optional[BrowserState]: """Get the last updated internal browser state.""" # Returns the state cached from the last update_state call logger.debug(f"Returning cached browser state (URL: {self._state.url if self._state else 'None'})") return self._state @observe(name='browser.update_state', ignore_output=True) async def update_state(self) -> BrowserState: """Update the internal browser state by re-evaluating the page (incl. CV if enabled).""" logger.info("Updating browser state...") try: self._state = await self._update_state() logger.info("Browser state updated successfully.") if not self._state: raise BrowserError("State update returned None unexpectedly.") # Should not happen if _update_state raises return self._state except Exception as e: logger.error(f"Failed to update browser state: {e}", exc_info=True) # Decide whether to return old state or raise error # Raising error seems more appropriate if update fails raise BrowserError(f"Failed to update state: {e}") from e @observe(name='browser._update_state', ignore_output=True) async def _update_state(self) -> BrowserState: """Internal method to get comprehensive state with retry logic.""" @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=0.5, min=0.5, max=2), retry=retry_if_exception_type((Exception)), # Retry on any exception during state fetch reraise=True # Re-raise the exception after retries fail ) async def get_stable_state(): page = await self.get_current_page() # Ensures page exists url = page.url detect_sheets = 'docs.google.com/spreadsheets/d' in url screenshot_b64 = await self.fast_screenshot() # Get screenshot interactive_elements_data: Optional[InteractiveElementsData] = None # Get combined elements using CV if detector is enabled if self.detector and screenshot_b64: logger.debug("Getting interactive elements with CV...") interactive_elements_data = await self.get_interactive_elements_with_cv(screenshot_b64, detect_sheets) # Fallback to browser-only if detector disabled or screenshot failed elif INTERACTIVE_ELEMENTS_JS_CODE: # Ensure JS code loaded logger.debug("Getting interactive elements with browser JS only...") interactive_elements_data = await self.get_interactive_elements_data() else: logger.error("Cannot get interactive elements: Detector disabled/failed and JS code missing.") interactive_elements_data = InteractiveElementsData(viewport={"width":0,"height":0}, elements=[]) # Return empty state # Check if interactive_elements_data is valid before proceeding if interactive_elements_data is None or not hasattr(interactive_elements_data, 'elements'): raise BrowserError("Failed to retrieve valid interactive elements data.") # Process elements into dictionary for state interactive_elements = {element.browser_agent_id: element for element in interactive_elements_data.elements} # Generate highlighted screenshot screenshot_with_highlights = None if screenshot_b64 and 'put_highlight_elements_on_screenshot' in globals(): try: screenshot_with_highlights = put_highlight_elements_on_screenshot( list(interactive_elements.values()), # Pass list of elements screenshot_b64 ) except Exception as high_err: logger.warning(f"Failed to generate highlighted screenshot: {high_err}") # Get tab info tabs = await self.get_tabs_info() # Ensure BrowserState model is available if 'BrowserState' not in globals() or 'BrowserState' not in locals(): raise ImportError("BrowserState model is not defined or imported.") # Create and return the state object return BrowserState( url=url, tabs=tabs, screenshot_with_highlights=screenshot_with_highlights, screenshot=screenshot_b64, viewport=interactive_elements_data.viewport, # Use viewport from data interactive_elements=interactive_elements, ) # Execute the retry logic try: new_state = await get_stable_state() self._state = new_state # Cache the new state return new_state except Exception as e: logger.error(f'Failed to update state after multiple attempts: {e}', exc_info=True) # Don't return potentially stale state, let error propagate raise BrowserError(f"Failed to update state definitively: {e}") from e @observe(name='browser.get_interactive_elements') async def get_interactive_elements_data(self) -> InteractiveElementsData: """Gets interactive elements using only in-browser JavaScript.""" page = await self.get_current_page() if not INTERACTIVE_ELEMENTS_JS_CODE: logger.error("INTERACTIVE_ELEMENTS_JS_CODE is empty. Cannot get elements.") # Return default empty structure vp = await page.viewport_size() or {"width":0, "height":0} return InteractiveElementsData(viewport=vp, elements=[]) try: result = await page.evaluate(INTERACTIVE_ELEMENTS_JS_CODE) # Validate result basic structure if not isinstance(result, dict) or 'viewport' not in result or 'elements' not in result: logger.error(f"JS evaluation returned unexpected structure: {type(result)}") vp = await page.viewport_size() or {"width":0, "height":0} return InteractiveElementsData(viewport=vp, elements=[]) # Parse using Pydantic model if available if 'InteractiveElementsData' in globals() and 'InteractiveElementsData' in locals(): return InteractiveElementsData(**result) else: # Fallback if model missing (though this indicates setup error) logger.error("InteractiveElementsData model missing, returning raw dict.") return result # type: ignore except Exception as e: logger.error(f"Error evaluating INTERACTIVE_ELEMENTS_JS_CODE: {e}", exc_info=True) vp = await page.viewport_size() or {"width":0, "height":0} return InteractiveElementsData(viewport=vp, elements=[]) @observe(name='browser.get_interactive_elements_with_cv') async def get_interactive_elements_with_cv(self, screenshot_b64: Optional[str] = None, detect_sheets: bool = False) -> InteractiveElementsData: """Combines browser JS element detection with VLM detection.""" if self.detector is None: logger.warning("CV detector not available. Falling back to browser-only detection.") return await self.get_interactive_elements_data() # Ensure screenshot exists current_screenshot_b64 = screenshot_b64 or await self.fast_screenshot() if not current_screenshot_b64: logger.error("Screenshot unavailable for CV detection.") return await self.get_interactive_elements_data() # Fallback logger.debug("Getting combined browser + CV elements...") try: # Run browser JS detection and VLM detection concurrently browser_elements_data_task = asyncio.create_task(self.get_interactive_elements_data()) cv_elements_task = asyncio.create_task(self.detector.detect_from_image(current_screenshot_b64, detect_sheets)) browser_elements_data = await browser_elements_data_task cv_elements = await cv_elements_task # Ensure results are valid before combining if not browser_elements_data or not hasattr(browser_elements_data, 'elements'): logger.warning("Browser element data invalid or missing for combine step.") browser_elements = [] viewport = await self.get_current_page().viewport_size() or {"width":0,"height":0} else: browser_elements = browser_elements_data.elements viewport = browser_elements_data.viewport # Use viewport from browser data if not isinstance(cv_elements, list): logger.warning("CV elements result is not a list.") cv_elements = [] # Combine results using utility function if 'combine_and_filter_elements' in globals(): combined_elements = combine_and_filter_elements(browser_elements, cv_elements) logger.info(f"Combined browser ({len(browser_elements)}) and CV ({len(cv_elements)}) elements into {len(combined_elements)}.") else: logger.error("combine_and_filter_elements utility function not found. Returning only browser elements.") combined_elements = browser_elements # Fallback # Return combined data in the expected structure if 'InteractiveElementsData' in globals() and 'InteractiveElementsData' in locals(): return InteractiveElementsData(viewport=viewport, elements=combined_elements) else: logger.error("InteractiveElementsData model missing, returning raw combined list.") # This fallback is problematic, structure is needed downstream return {"viewport": viewport, "elements": combined_elements} # type: ignore except Exception as e: logger.error(f"Error during combined CV+Browser element detection: {e}", exc_info=True) # Fallback gracefully to browser-only if possible try: return await self.get_interactive_elements_data() except Exception: return InteractiveElementsData(viewport={"width":0,"height":0}, elements=[]) # Final fallback ================================================ FILE: super_agents/browser_use/browser/detector.py ================================================ # super_agents/browser_use/browser/detector.py import os import json import logging import base64 from typing import List, Optional, Dict, Any # LangChain Core Imports from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.runnables.base import RunnableSerializable # Pydantic for schema try: from pydantic.v1 import BaseModel except ImportError: from pydantic import BaseModel from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) # Local imports (ensure they exist) try: from .observe_helper import observe except ImportError: def observe(name, ignore_input=False, ignore_output=False): def decorator(func): return func return decorator # Setup basic logger if not configured by main app yet logging.basicConfig(level=logging.WARNING) logger = logging.getLogger(__name__) logger.warning("observe_helper not found, using dummy decorator.") try: from .models import InteractiveElement # Define the expected VLM output schema here or import from agent.schemas # Let's define it here for clarity in this step class VLMJsonOutput(BaseModel): detected_elements: List[Dict[str, Any]] = [] except ImportError: class InteractiveElement: pass class VLMJsonOutput(BaseModel): detected_elements: List = [] # Setup basic logger if not configured by main app yet logging.basicConfig(level=logging.WARNING) logger = logging.getLogger(__name__) logger.error("Failed to import InteractiveElement or define VLMJsonOutput! Detector parsing will fail.") # Import the specific ChatOpenRouter class from the updated llm.py # Adjust path if llm.py is elsewhere relative to detector.py try: from ..llm import ChatOpenRouter # Assumes llm.py is one level up except ImportError: logger.error("Failed to import ChatOpenRouter from ..llm. Ensure llm.py is in the parent directory.") # Define a dummy class to allow loading, but it won't work class ChatOpenRouter: pass logger = logging.getLogger(__name__) # --- VLM Configuration (Read by Detector's __init__ via ChatOpenRouter) --- VLM_API_MODEL = os.getenv("VLM_API_MODEL", "openai/gpt-4o") # Read desired VLM model from .env # --- VLM Prompt Template --- VLM_PROMPT_TEMPLATE = """ Analyze the provided screenshot of a webpage. Your task is to identify all significant interactive elements visible on the screen. Interactive elements include: buttons, links ( tags), text input fields (, , etc.), password fields (), text areas (