Repository: stepfun-ai/StepDeepResearch
Branch: main
Commit: 8f6798f750fb
Files: 125
Total size: 695.1 KB
Directory structure:
gitextract_gy_2y08y/
├── .env-example
├── .gitignore
├── LICENSE
├── README.md
├── agentkit/
│ └── trace/
│ ├── __init__.py
│ ├── builder.py
│ ├── context.py
│ ├── default.py
│ ├── local_tracer.py
│ ├── remote_tracer.py
│ ├── span.py
│ ├── tracer.py
│ └── types.py
├── config.yaml
├── cortex/
│ ├── __init__.py
│ ├── agents/
│ │ ├── __init__.py
│ │ ├── agent_factory.py
│ │ ├── base_agent.py
│ │ ├── base_step_agent.py
│ │ ├── checkpoint_agent/
│ │ │ ├── checkpoint_agent.py
│ │ │ ├── checkpointer.py
│ │ │ └── react_agent.py
│ │ ├── input/
│ │ │ └── input.py
│ │ ├── react_agent.py
│ │ └── types.py
│ ├── context/
│ │ ├── __init__.py
│ │ ├── base_context.py
│ │ ├── file_context.py
│ │ └── simple_context.py
│ ├── env.py
│ ├── examples/
│ │ ├── agents/
│ │ │ ├── ask_input_agent.py
│ │ │ ├── deep_reasearch_agent.py
│ │ │ ├── main_agent.py
│ │ │ ├── math_agent.py
│ │ │ ├── plan_agent.py
│ │ │ └── search_agent.py
│ │ ├── demo_agent_cli.py
│ │ ├── demo_agent_with_orchestrator.py
│ │ ├── demo_agent_with_tool.py
│ │ ├── demo_checkpoint.py
│ │ ├── demo_toolset_channel.py
│ │ └── server.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── definition.py
│ │ ├── provider.py
│ │ ├── stepfun_chat.py
│ │ ├── stepfun_provider.py
│ │ └── utils.py
│ ├── orchestrator/
│ │ ├── __init__.py
│ │ ├── local_runner.py
│ │ ├── orchestrator.py
│ │ ├── remote_runner.py
│ │ ├── runner.py
│ │ └── types.py
│ ├── runtime_config.py
│ ├── server/
│ │ ├── channel/
│ │ │ ├── channel.py
│ │ │ ├── error.py
│ │ │ ├── memory_channel.py
│ │ │ └── ws_channel.py
│ │ ├── http_server.py
│ │ └── log/
│ │ ├── log.py
│ │ └── trace.py
│ ├── tools/
│ │ ├── __init__.py
│ │ ├── agent_tool.py
│ │ ├── base.py
│ │ ├── channel.py
│ │ ├── client_tool.py
│ │ ├── function_tool.py
│ │ ├── mcp.py
│ │ ├── mcp_tool.py
│ │ ├── session_tool.py
│ │ ├── toolset.py
│ │ ├── types.py
│ │ ├── ublock_agent_tool.py
│ │ └── unblock_client_tool.py
│ ├── tui/
│ │ ├── __init__.py
│ │ └── tui.py
│ └── utils/
│ ├── __init__.py
│ ├── generator_merger.py
│ └── generator_merger_examples.py
├── cortex-ui/
│ ├── .gitignore
│ ├── .gitlab-ci.yml
│ ├── index.html
│ ├── package.json
│ ├── src/
│ │ ├── App.tsx
│ │ ├── components/
│ │ │ ├── EndpointConfig.tsx
│ │ │ ├── ErrorBoundary.tsx
│ │ │ ├── FilePanel.tsx
│ │ │ ├── FinalAnswer.tsx
│ │ │ ├── SearchResultsPanel.tsx
│ │ │ ├── ShellPanel.tsx
│ │ │ ├── TodoPanel.tsx
│ │ │ └── WebPagePanel.tsx
│ │ ├── index.css
│ │ ├── main.tsx
│ │ ├── pages/
│ │ │ ├── AgentList.tsx
│ │ │ └── ChatPage.tsx
│ │ ├── services/
│ │ │ └── api.ts
│ │ ├── types/
│ │ │ ├── citation.ts
│ │ │ └── index.ts
│ │ ├── utils/
│ │ │ └── citationParser.ts
│ │ └── vite-env.d.ts
│ ├── tsconfig.json
│ ├── tsconfig.node.json
│ └── vite.config.ts
├── demo/
│ ├── __init__.py
│ ├── dr_agent/
│ │ ├── __init__.py
│ │ └── dr_agent.py
│ ├── server.py
│ └── tools/
│ ├── __init__.py
│ ├── batch_open.py
│ ├── batch_search.py
│ ├── batch_web_surfer.py
│ ├── file.py
│ ├── open.py
│ ├── search.py
│ ├── shell.py
│ ├── text_truncator.py
│ ├── todo.py
│ └── utils.py
├── pyproject.toml
└── scripts/
├── configs/
│ ├── prompt.py
│ ├── runner_example.yaml
│ └── tasks.example.json
└── runner.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .env-example
================================================
MODEL_PROVIDER=stepfun
MODEL_BASE=https://api.stepfun.com
STEP_MODEL_API_KEY=your-model-api-key
STEP_SEARCH_API_BASE=https://api.stepfun.com
STEP_SEARCH_API_KEY=your-search-api-key
================================================
FILE: .gitignore
================================================
# common file types
*.log
*.ipynb
#*.json
# Allow json files in frontend directory
!agent_cortex/adapter/frontend/**/*.json
*.csv
*.jsonl
*.png
*.jpg
# local dev files/directories
**/_version.py
dist/
.env
.venv
__pycache__/
downloads*
logs
deprecated/
*_debug.yaml
scripts/*_debug.py
tests/*_debug.py
*.png
!assets/**
trash/
traces/
.idea/
.cursor/
*.db
eval/
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Step-DeepResearch
## News
* Feb 2, 2026: 👋 We have released **Step 3.5 Flash**, achieving **65.27** on ResearchRubrics. Try it out by setting the environment variable `MODEL_NAME=step-3.5-flash`. [Details](https://static.stepfun.com/blog/step-3.5-flash/)
* Dec 25, 2025: 👋 You can join our group chat to get updates on your beta API application status and the latest project developments.
* Dec 24, 2025: 👋 We have made our technical report available. [Read](https://arxiv.org/pdf/2512.20491)
## Introduction
### Model Summary
**Step-DeepResearch** is a cost-effective, end-to-end deep research agent model designed for autonomous information exploration and professional report generation in open-ended research scenarios.
- **Atomic Capability Integration**: By decomposing complex research tasks into trainable atomic capabilities—including planning, information seeking, reflection and cross-validation, and professional report generation—and achieving deep internalization at the model level, the system ensures closed-loop reflection and dynamic correction within a single inference pass.
- **Progressive Training Pipeline**: We establish a complete optimization path from Agentic Mid-Training to Supervised Fine-Tuning (SFT) and Reinforcement Learning (RL), reshaping the training objective from "predicting the next token" to "deciding the next atomic action." This approach effectively enhances the model's adaptive capabilities and generalization performance in complex environments.
- **Strong Performance Across Model Scales**: With only 32B parameters, Step-DeepResearch achieves 61.4% on Scale AI ResearchRubrics, matching OpenAI Deep Research and Gemini Deep Research. In expert human evaluations on ADR-Bench, its Elo score significantly outperforms larger models including DeepSeek-v3.2 and GLM-4.6, and rivals top-tier closed-source models.
- **New model (Step 3.5 Flash)**: **Step 3.5 Flash** reaches **65.27** on ResearchRubrics, delivering research quality that competes with OpenAI and Gemini Deep Research while maintaining significantly higher inference efficiency. [Details](https://static.stepfun.com/blog/step-3.5-flash/)
- **Superior Cost-Effectiveness**: With extremely low deployment and inference costs while maintaining expert-level research capabilities, Step-DeepResearch stands as the most cost-effective deep research agent solution currently available in the industry.
- **Access**: Available via StepFun Open Platform API, free for the first month.
(left) Cost-Efficiency on ResearchRubrics: Step-DeepResearch achieves near-top performance (61.42) while significantly reducing inference costs (RMB), positioned at the high-efficiency frontier.
(right) Expert Evaluation on ADR-Bench: Step-DeepResearch consistently leads in Elo ratings across all dimensions, rivaling top-tier closed-source models.
Performance on ResearchRubrics: Step 3.5 Flash achieves 65.27 on the benchmark.
### System Architecture
Step-DeepResearch adopts a single-agent architecture based on the ReAct paradigm, enabling autonomous deep research through a dynamic cycle of reasoning, action, and reflection.
Building upon our internal proprietary toolset, we provide a streamlined toolset with local implementation in the framework, including `batch_web_surfer` (batch web search and browsing), `file` (file reading, writing, and editing), `todo` (task state management), and `shell` (interactive command execution), to support the complete research workflow. For a comprehensive introduction to our full toolset, please refer to Section 5.2 of the technical report.
Step-DeepResearch System Architecture: The agent operates within a ReAct loop, utilizing a specialized toolset (e.g., batch_web_surfer, todo, shell) for planning, execution, and reflection to generate comprehensive research reports.
## Case
We provide a sample case below, including the query and the model-generated report.
### Query
```text
我是个 AI 产品经理,正在考虑给产品增加个性化记忆功能。想系统了解一下现在 Agent 里的 Memory 模块都在往哪些方向发展,包括短期和长期记忆是怎么做的、各种技术路线的差异和取舍。你帮我写个调研报告,从工程落地和未来演进的角度分析一下 哪些方向更值得投入
```
### Model Output (Report)
Click to expand the report
# AI Agent记忆模块深度调研报告:技术演进与产品化路径
## 一、引言与核心问题界定
### 1.1 记忆能力为何成为Agent时代的分水岭
在Agentic AI加速渗透的2025年,智能体已从"能对话、会调用工具"的基础形态,向"具备长期记忆、可自主进化"的高阶阶段迈进[大模型进阶之路:AI Agent记忆能力构建技术详解(值得收藏)](https://blog.csdn.net/xxue345678/article/details/150983939)。正如OpenAI CEO山姆·奥尔特曼在2025年12月的Big Technology访谈中所言:"这是整个系统里我个人最期待的部分之一。AI的下一个重大飞跃并非是更敏锐的推理能力,而是更为根本的记忆"[超级Agent重要拼图?奥尔特曼点名“AI记忆” 存储环节迎来新叙事](https://finance.sina.com.cn/stock/t/2025-12-22/doc-inhcsmfc4526733.shtml)。他进一步指出,"人类本身是有局限的:即使你拥有世界上最好的私人助理,他们也不可能记住你说过的每一句话,不可能读过你的每一封邮件,不可能看过你写的每一份文件"——而这正是AI能够做到的[超级Agent重要拼图?奥尔特曼点名“AI记忆” 存储环节迎来新叙事](http://m.cls.cn/detail/2236511)。
当前,大多数人以为通过更大的上下文窗口或巧妙的提示词工程,AI就拥有了"记忆",但真相是,大多数AI Agent仍是无状态的,无法从过去的交互中学习,也无法随时间适应用户需求[探寻AI Agent 中隐秘的角落:记忆(Memory) - 定义、价值与实践](https://developer.volcengine.com/articles/7540134113190412324)。要从一次性工具迈向真正智能的伙伴,我们需要为AI赋予真正的记忆,而非仅仅依赖更大的提示或更强的检索。
### 1.2 记忆的本质定义
在AI Agent中,记忆是指跨时间、任务和多次用户交互,保留并调用相关信息的能力[探寻AI Agent 中隐秘的角落:记忆(Memory) - 定义、价值与实践](https://developer.volcengine.com/articles/7540134113190412324)。它让AI能记住过去发生的事,并利用这些信息优化未来的行为。记忆不是简单地存储聊天记录,也不是把更多数据塞进提示框——它是一种持久的内部状态,随着每次交互不断进化,哪怕间隔数周或数月,依然能为AI提供连续的上下文[探寻AI Agent 中隐秘的角落:记忆(Memory) - 定义、价值与实践](https://developer.volcengine.com/articles/7540134113190412324)。
记忆的三大支柱包括:**状态**(了解当前情境,掌握正在发生的事情)、**持久性**(跨会话保留知识,确保信息不因对话结束而丢失)、**选择性**(判断哪些信息值得记住,哪些可以忽略)[探寻AI Agent 中隐秘的角落:记忆(Memory) - 定义、价值与实践](https://developer.volcengine.com/articles/7540134113190412324)。这三者共同赋予AI一种前所未有的能力——连续性。
### 1.3 报告研究范围与目标
本报告旨在系统梳理Agent记忆模块的技术演进方向,深入分析短期与长期记忆的实现机制、各类技术路线的差异与取舍,并从业务价值、技术成熟度和工程可行性三个维度,为产品团队提供明确的技术选型建议和投入优先级排序。
---
## 二、Agent记忆模块的分类体系与核心架构
### 2.1 从认知科学到工程实践的记忆分层
人类记忆遵循从感觉记忆到短期记忆再到长期记忆的一般性进程[A survey on large language model based autonomous agents](https://link.springer.com/content/pdf/10.1007/s11704-024-40231-1.pdf)。当设计Agent记忆结构时,研究者从中汲取灵感:
**短期记忆(Working Memory)**:指AI正在进行的对话、脑中即时活跃的上下文。它容量有限(通常为数千tokens),但访问速度极快[从理论到落地:分层记忆架构在AI Agent中的应用实践](https://blog.csdn.net/whitehat_zhou/article/details/150269603)。例如,在MemGPT架构中,短期工作上下文由系统指令、工作上下文和FIFO队列组成[9.4k Star!MemGPT:伯克利大学最新开源、将LLM作为操作系统、无限上下文记忆、服务化部署自定义Agent](http://www.wehelpwin.com/m_article/5363)。
**中期记忆(Episodic Memory)**:指最近读完的书籍核心内容或近期发生的重大事件。它们比短期记忆更持久,但不如长期记忆根深蒂固[从理论到落地:分层记忆架构在AI Agent中的应用实践](https://blog.csdn.net/whitehat_zhou/article/details/150269603)。这类记忆通常通过向量数据库或结构化存储实现语义检索。
**长期记忆(Long-term Memory)**:指用户的个人经历、学到的技能、世界观以及那些已掌握的知识。它容量近乎无限,但检索可能需要更长时间[从理论到落地:分层记忆架构在AI Agent中的应用实践](https://blog.csdn.net/whitehat_zhou/article/details/150269603)。长期记忆是Agent个性化和持续学习的基础。
### 2.2 三种记忆形式的技术实现
根据2025年最新综述研究,Agent记忆可从形式(Forms)、功能(Functions)和动态(Dynamics)三个正交维度进行重构[2025年Memory最全综述!AI Agent记忆统一分类体系](https://zhuanlan.zhihu.com/p/1985435669187825983):
**Token级记忆(Token-level Memory)**:信息被处理在模型的上下文窗口内。这是最直接的记忆方式,包括滑动窗口注意力(Sliding Window Attention)和分块处理(Chunking)等技术[EPISODIC MEMORIES GENERATION AND EVALUATION BENCHMARK FOR LARGE LANGUAGE MODELS](https://arxiv.org/pdf/2501.13121v1.pdf)。优点是实现简单、无额外存储开销;缺点是在长序列上存在"上下文腐烂"问题[AI Agent 性能优化:核心策略与实战技巧(超详细)从零基础入门到精通!](https://m.blog.csdn.net/xiaoganbuaiuk/article/details/154012804)。
**参数级记忆(Parameter-level Memory)**:通过微调将知识直接编码到模型参数中。这种方式使信息成为模型"先天具有的知识",可在任意上下文中激活。优点是检索速度快、无需额外存储;缺点是训练成本高昂、难以增量更新[EPISODIC MEMORIES GENERATION AND EVALUATION BENCHMARK FOR LARGE LANGUAGE MODELS](https://arxiv.org/pdf/2501.13121v1.pdf)。
**潜在级记忆(Latent-level Memory)**:通过检索增强生成(RAG)等外部存储机制实现。信息以向量或结构化形式存储在外存,需要时通过语义检索召回[EPISODIC MEMORIES GENERATION AND EVALUATION BENCHMARK FOR LARGE LANGUAGE MODELS](https://arxiv.org/pdf/2501.13121v1.pdf)。这是当前商业化产品中最广泛采用的方式。
### 2.3 记忆模块在Agent架构中的位置
典型的AI Agent包括LLM用于推理和生成答案、策略或规划模块(如ReAct或AutoGPT风格)、工具或API访问权限以及检索器用于获取文档或历史数据[探寻AI Agent 中隐秘的角落:记忆(Memory) - 定义、价值与实践](https://developer.volcengine.com/articles/7540134113190412324)。然而,这些组件都无法记住昨天发生了什么——它们没有内部状态,也没有随时间进化的理解力。
加入记忆后,AI Agent的架构会发生质的改变:记忆层成为串联其他模块的"数据中枢"[大模型进阶之路:AI Agent记忆能力构建技术详解(值得收藏)](https://blog.csdn.net/xxue345678/article/details/150983939)。Agent架构的四大核心模块(感知、决策、记忆、行动)逐渐清晰,其中记忆模块(Brain-Memory & Knowledge)存储Agent的知识、用户偏好和历史交互记录,是决策与行动的"数据基础"[大模型进阶之路:AI Agent记忆能力构建技术详解(值得收藏)](https://blog.csdn.net/xxue345678/article/details/150983939)。
---
## 三、短期记忆的技术实现与优化策略
### 3.1 上下文窗口的根本限制
大语言模型在单次推理中能够接收和处理的Token数量上限被称为上下文窗口(Context Window),它直接决定了模型的"短期记忆容量"[LLM面试50问精解:从基础到进阶,掌握大语言模型核心知识,助你面试一战成名!大模型面试](https://m.blog.csdn.net/qkh1234567/article/details/155600143)。当前主流模型的上下文长度从4K token(早期GPT系列)发展到128K token(GPT-4o)甚至更长。然而,上下文窗口的扩大存在明显的"性能代价":
由于Transformer的自注意力机制需要计算序列中任意两个位置之间的相关性,其计算复杂度为O(n²),内存复杂度同样为O(n²)[Hi,我是 OceanBase PowerMem,了解一下?](http://cdn.modb.pro/db/1994227212866904064)。这意味着当上下文窗口从4K增加到128K(32倍)时,理论上的计算量和内存需求的增长不是32倍,而是1024倍[Hi,我是 OceanBase PowerMem,了解一下?](http://cdn.modb.pro/db/1994227212866904064)。
更重要的是,研究发现有效上下文长度远低于最大支持长度[LLM的长期记忆系统-成为你的个性化聊天陪伴](https://zhuanlan.zhihu.com/p/1917652340346946388)。已有研究表明,LLM的有效上下文长度在达到约2048 tokens后就开始显著下降,模型精度会随着上下文长度的增加而显著下降[LLM的长期记忆系统-成为你的个性化聊天陪伴](https://zhuanlan.zhihu.com/p/1917652340346946388)。
### 3.2 KV Cache优化技术
KV Cache(Key-Value Cache)的有效利用对提高LLM的运行效率至关重要[AI Agent优化技术深度解析:从Prompt到架构的全面指南(珍藏版)](https://m.blog.csdn.net/2401_85375186/article/details/155227353)。当前KV Cache优化可分为三个层次:
**Token级策略**:包括键值缓存选择、预算分配、合并、量化和低秩分解[【文献阅读】A Survey on Large Language Model Acceleration based on KV Cache Management](https://m.blog.csdn.net/Toky_min/article/details/146019523)。代表性技术包括:
- **SnapKV**:无需微调即可有效最小化KV缓存大小。在处理16K令牌输入时,实现一致的解码速度,与基线相比生成速度提高3.6倍,内存效率提高8.2倍[SnapKV: LLM在生成内容之前就知道您在寻找什么](https://blog.csdn.net/qq_36931982/article/details/139118015)
- **LM-Infinite**和**StreamingLLM**:保留初始token和近期token以实现无限长度上下文
- **H₂O**:基于注意力分数选择重要token进行缓存
**模型级优化**:通过架构创新和注意力机制增强键值重用[【文献阅读】A Survey on Large Language Model Acceleration based on KV Cache Management](https://m.blog.csdn.net/Toky_min/article/details/146019523)。例如Layer-Condensed KV Cache通过减少KV Cache层数而非保留所有层来降低内存消耗[Layer-Condensed KV Cache for Efficient Inference of Large Language Models](https://arxiv.org/pdf/2405.10637v2.pdf)。
**系统级方法**:解决内存管理、调度和硬件感知设计等问题[【文献阅读】A Survey on Large Language Model Acceleration based on KV Cache Management](https://m.blog.csdn.net/Toky_min/article/details/146019523)。NVIDIA Dynamo框架就采用了LLM-aware路由器来跟踪GPU集群中KV Cache的位置,当请求到达时计算新请求与缓存KV块之间的重叠度,将流量导向能最大化缓存复用的GPU[NVIDIA Dynamo Addresses Multi-Node LLM Inference Challenges](https://www.infoq.com/news/2025/12/nvidia-dynamo-kubernetes/?utm_campaign=infoq_content&utm_source=infoq&utm_medium=feed&utm_term=global)。
### 3.3 短期记忆的工程实现方案
**LangChain的记忆实现**提供了多种短期记忆方案\cite{web_bd759108}:
- **ConversationBufferMemory**:存储所有历史消息
- **ConversationBufferWindowMemory**:仅保留最近N轮对话(如k=3表示保留最近3轮)
- **ConversationSummaryMemory**:对对话历史生成摘要以节省Token
这种分级设计体现了"选择性"原则——判断哪些信息值得记住,哪些可以忽略[探寻AI Agent 中隐秘的角落:记忆(Memory) - 定义、价值与实践](https://developer.volcengine.com/articles/7540134113190412324)。
---
## 四、长期记忆的技术路线全景
### 4.1 RAG与Memory的区别与联系
RAG(检索增强生成)是OpenAI、Meta等公司提出的一种框架,用来增强语言模型的知识能力[让大模型“记住”更多:RAG与长期记忆](https://m.blog.csdn.net/2401_82452722/article/details/148744627)。然而,需要明确区分RAG与真正的Agent Memory:
RAG主要用于知识问答场景,在这些场景下知识点之间相对独立,一个用户问题往往只需涉及一两个知识点[LLM的长期记忆系统-成为你的个性化聊天陪伴](https://zhuanlan.zhihu.com/p/1917652340346946388)。企业级RAG系统通常会通过人工干预优化chunk划分,因此能够更好地支持知识类检索。
但在"个人记忆"这一场景下存在独特挑战:首先很难依赖人工完成chunk划分;其次用户的提问往往需要LLM进行推理、联想才能找到真正相关的记忆[LLM的长期记忆系统-成为你的个性化聊天陪伴](https://zhuanlan.zhihu.com/p/1917652340346946388)。当前RAG技术将文本按chunk进行切分,对每个chunk独立生成特征向量,在检索阶段query也是独立与每个chunk进行相似度计算——这种做法不仅容易造成信息"割裂",也难以保证检索到的是上下文完整的语义单元[LLM的长期记忆系统-成为你的个性化聊天陪伴](https://zhuanlan.zhihu.com/p/1917652340346946388)。
### 4.2 主流长期记忆框架对比
#### 4.2.1 MemGPT:操作系统灵感的虚拟内存管理
MemGPT是由UC Berkeley研究团队开发的记忆增强LLM框架,其核心创新是借鉴操作系统内存管理的概念[Meet MemGPT: The Future of Memory-Augmented AI - Medium](https://medium.com/@harshavasana1/meet-memgpt-the-future-of-memory-augmented-ai-9b547bbd3879):
**主上下文(Main Context)**:类比RAM的短期内存,用于即时处理。包含系统指令、工作上下文和FIFO队列[9.4k Star!MemGPT:伯克利大学最新开源、将LLM作为操作系统、无限上下文记忆、服务化部署自定义Agent](http://www.wehelpwin.com/m_article/5363)。
**外部上下文(External Context)**:类比硬盘的长期存储,用于存放不太常用的信息。MemGPT引入两个外部记忆模块——Recall Memory和Archival Memory,分别对应短期交互记录与长期知识存储[LLM的长期记忆系统-成为你的个性化聊天陪伴](https://zhuanlan.zhihu.com/p/1917652340346946388)。
MemGPT通过函数在主上下文和外部上下文之间移动数据,LLM可以通过在其输出中生成特殊关键字参数来请求立即后续推理,以将函数调用链接在一起;函数链允许MemGPT执行多步骤检索来回答用户查询[9.4k Star!MemGPT:伯克利大学最新开源、将LLM作为操作系统、无限上下文记忆、服务化部署自定义Agent](http://www.wehelpwin.com/m_article/5363)。这种架构不仅允许处理更大的数据集和更长的对话,而且还提高了模型在扩展交互中保持一致性能力[9.4k Star!MemGPT:伯克利大学最新开源、将LLM作为操作系统、无限上下文记忆、服务化部署自定义Agent](http://www.wehelpwin.com/m_article/5363)。
MemGPT的核心优势在于它证明了LLM可以被教会管理自己的内存[9.4k Star!MemGPT:伯克利大学最新开源、将LLM作为操作系统、无限上下文记忆、服务化部署自定义Agent](http://www.wehelpwin.com/m_article/5363)。其后商业化的Letta公司已获得1000万美元种子轮融资,在7000万美元估值下由Felicis Ventures领投[Letta, one of UC Berkeley's most anticipated AI startups, has just ...](https://techcrunch.com/2024/09/23/letta-one-of-uc-berkeleys-most-anticipated-ai-startups-has-just-come-out-of-stealth/)。
#### 4.2.2 Mem0:轻量级实用主义路线
Mem0被官方定义为"The Memory Layer for Personalized AI"(个性化AI的记忆层)[Mem0 获2400万美元融资,开发AI内存层技术-搜索](https://cn.bing.com/search?q=Mem0%20%E8%8E%B72400%E4%B8%87%E7%BE%8E%E5%85%83%E8%9E%8D%E8%B5%84%EF%BC%8C%E5%BC%80%E5%8F%91AI%E5%86%85%E5%AD%98%E5%B1%82%E6%8A%80%E6%9C%AF&rdr=1&rdrig=7UV8W6XI2C2HFZFXE23T7R5222E2IX6F&first=1)。其核心设计理念包括:
- **记忆是可搜索和可管理的**:通过自然语言索引+向量化混合检索
- **以用为先**:轻量级实现适合实际部署
- **结构化存储**:使用slot存储结构化长期信息(如角色设定、兴趣偏好)
Mem0在LOCOmo Benchmark中达到了业界开源Memory解决方案的SOTA水平[Hi,我是 OceanBase PowerMem,了解一下?](http://cdn.modb.pro/db/1994227212866904064)。其设计哲学是"足够好而非完美"——通过可恢复的压缩策略(如仅保留URL而非完整内容)在不永久丢失信息的情况下减少上下文开销[AI Agent 性能优化:核心策略与实战技巧(超详细)从零基础入门到精通!](https://m.blog.csdn.net/xiaoganbuaiuk/article/details/154012804)。
#### 4.2.3 OpenMemory:分层记忆分解架构
OpenMemory是一个自托管、模块化的AI记忆引擎,采用Hierarchical Memory Decomposition(HMD)架构[CaviraOSS/OpenMemory: Add long-term memory to any AI ...](https://github.com/CaviraOSS/OpenMemory):
- **单一规范节点每记忆**(无数据复制)
- **多扇区嵌入**(情节、语义、程序性、情感、反思性)
- **单一点位链接**(稀疏、生物启发式图)
- **复合相似性检索**(扇区融合+激活传播)
性能对比数据显示[CaviraOSS/OpenMemory: Add long-term memory to any AI ...](https://github.com/CaviraOSS/OpenMemory):
| 指标 | OpenMemory(自托管) | Zep(云服务) | Supermemory(SaaS) | Mem0 | 向量DB(平均) |
|------|---------------------|---------------|--------------------|-------|----------------|
| 成本 | 自托管 | 按API调用付费 | 按API调用付费 | 按API调用付费 | 按存储+查询付费 |
#### 4.2.4 A-MEM:卡片盒笔记法驱动的记忆系统
A-MEM是一种新型面向LLM Agent的智能记忆系统,遵循卡片盒笔记法(Zettelkasten)的基本原理[A-MEM智能记忆系统,让你的大模型拥有“学习”能力(收藏版)](https://blog.csdn.net/xiaoganbuaiuk/article/details/151215108):
当添加新记忆时,系统会生成包含多个结构化属性的综合笔记,包括上下文描述、关键词和标签。随后分析历史记忆以识别相关联系,并在存在有意义的相似性时建立链接——这一过程还支持记忆进化——随着新记忆的整合,它们会触发对现有历史记忆的上下文表示和属性更新[A-MEM智能记忆系统,让你的大模型拥有“学习”能力(收藏版)](https://blog.csdn.net/xiaoganbuaiuk/article/details/151215108)。
A-MEM的成本效益分析显示:通过选择性top-k检索机制,每次记忆操作约需1200个tokens,与基线方法(LoCoMo和MemGPT为16900个tokens)相比,Token使用量减少了85%-93%[A-MEM智能记忆系统,让你的大模型拥有“学习”能力(收藏版)](https://blog.csdn.net/xiaoganbuaiuk/article/details/151215108)。
### 4.3 多模态记忆的发展前沿
随着多模态LLM的发展,记忆系统也开始向多模态方向演进:
**M3-Agent**是字节跳动Seed团队提出的新型多模态代理框架,配备长期记忆能力。与人类类似,M3-Agent可以处理实时视觉和听觉输入来构建和更新其长期记忆[Research](https://seed.bytedance.com/en/public_papers)。
**VideoAgent**探讨如何将基础模型(大型语言模型和视觉语言模型)与统一记忆机制协调起来解决视频理解问题。其构建结构化内存来存储视频的通用时间事件描述和以对象为中心的跟踪状态,在NExTQA上平均提升6.6%,在EgoSchema上平均提升26.0%[VideoAgent: A Memory-augmented Multimodal Agent for Video Understanding(翻译)](https://m.blog.csdn.net/weixin_56764022/article/details/139862630)。
---
## 五、OpenAI ChatGPT记忆功能解析:行业标杆案例
### 5.1 发展历程与产品演进
OpenAI的记忆功能经历了三个重要阶段:
**第一阶段(2024年2月)**:小范围测试ChatGPT的记忆功能——记住用户在聊天中讨论过的事情,并避免重复信息。用户可以要求它记住特定的内容或让它自行获取详细信息。用得越多,ChatGPT的记忆力就会越好[ChatGPT要有记忆力了,OpenAI宣布小范围测试“记忆”功能!](http://m.nbd.com.cn/articles/2024-02-14/3245164.html)。
**第二阶段(2025年4月)**:向所有ChatGPT Plus和Pro用户推出记忆提升功能。这次更新后,ChatGPT的记忆功能能够参考用户过去的所有聊天以提供更个性化的回复[OpenAI:ChatGPT的记忆提升功能今天起对所有Plus和Pro用户推出](https://finance.jrj.com.cn/2025/04/11072749477951.shtml)。例如,如果用户曾经提到喜欢泰国菜,下次问"中午应该吃什么"时,ChatGPT可能会考虑到这一点[刚刚!OpenAI发布ChatGPT记忆功能,秒变私人助理!](https://m.blog.csdn.net/Leinwin/article/details/148428926)。
**第三阶段(2025年6月)**:为免费版ChatGPT推出轻量级记忆功能。该功能根据用户过往一段时间内的对话习惯、写作风格、提问方式进行个性化回答[OpenAI发布ChatGPT记忆功能](https://c.m.163.com/news/a/K16NIQ1D0511A6N9.html)。
### 5.2 记忆功能的核心机制
ChatGPT的记忆功能通过两个设置来控制[刚刚!OpenAI发布ChatGPT记忆功能,秒变私人助理!](https://m.blog.csdn.net/Leinwin/article/details/148428926):
- **"引用已保存的记忆"**:用户明确要求ChatGPT记住的细节(如名字、喜欢的颜色或饮食偏好)
- **"引用聊天历史"**:ChatGPT可以利用过去的对话信息来使未来的对话更有帮助
OpenAI称用户可以随时关闭新功能,在设置菜单中找到个性化选项即可将记忆设置为关闭。用户也可以在设置中找到"管理内存"选项,在其中查看和删除特定的记忆内容或清除所有记忆[ChatGPT要有记忆力了,OpenAI宣布小范围测试“记忆”功能!](http://m.nbd.com.cn/articles/2024-02-14/3245164.html)。
### 5.3 记忆功能的战略意义
投资人Allie K. Miller在X平台上感叹:"这相当于ChatGPT全天候在'偷听'——不管你有没有叫它记住,它都在默默收集"[OpenAI最近推出的ChatGPT更新简直像给AI打了“记忆芯片”](https://blog.csdn.net/2301_79342058/article/details/147156180)。她还表示,在平台功能越来越趋同的今天,真正拉开差距的关键就是"记忆+个性化"——AI的记忆就是平台的护城河[OpenAI最近推出的ChatGPT更新简直像给AI打了“记忆芯片”](https://blog.csdn.net/2301_79342058/article/details/147156180)。
OpenAI研究还表明,ChatGPT记住什么或不记住什么与人们如何体验ChatGPT的人格密切相关——许多Plus和Pro订阅者告诉OpenAI,更好的记忆是体验中最有价值的部分之一[超越“一刀切”:ChatGPT如何为8亿用户定制个性化体验| 图文 ...](https://www.zhihu.com/pin/1972634281496057035)。
---
## 六、评测基准与性能分析
### 6.1 主要评测基准概述
当前用于评估Agent Memory的主要数据集包括[Agent 又又失忆了!我来做一次记忆体检](https://juejin.cn/post/7564589641258139648)[Survey on Evaluation of LLM-based Agents](http://arxiv.org/pdf/2503.16416v1):
**LoCoMo Benchmark**:专门测试长上下文任务的记忆能力,在该基准中Mem0和Letta曾出现过评测分值分歧[2025 AI 记忆系统大横评:从插件到操作系统,谁在定义下一代Agent Infra?](https://m.toutiao.com/a7578151050135044643/)。
**LongMemEval**:UCLA团队开发的系统性基准,更像是一个"记忆体检表"——从信息提取、跨会话推理、知识更新到拒答未知,一共五项指标[Agent 又又失忆了!我来做一次记忆体检](https://juejin.cn/post/7564589641258139648)。
**Reflection-Bench**:上海人工智能实验室开发的认知心理学导向评测平台,围绕七个认知维度设计了354个任务——预测能力与决策能力、感知能力与记忆能力、反事实思维、信念更新、元反思能力等[AI大模型代理评测全攻略:从入门到精通,一篇就够了!](https://m.blog.csdn.net/m0_56255097/article/details/154832865)。
这些评测指标超越了传统的自然语言处理性能指标——聚焦事实信息的存储与利用,通过准确性指标(基于历史信息生成响应的正确性)和召回率@5指标(前5条检索结果中相关信息的占比)来衡量[上下文工程最新综述A Survey of Context Engineering for ...](https://zhuanlan.zhihu.com/p/1956121336968647834)。
### 6.2 性能数据对比
根据MemOS横向评测数据,在实际任务中的表现差异显著[2025 AI 记忆系统大横评:从插件到操作系统,谁在定义下一代Agent Infra?](https://m.toutiao.com/a7578151050135044643/):
| 技术路线 | 上下文扩展能力 | 检索准确性 | 响应延迟 | 开发复杂度 | 成本效益 |
|---------|--------------|-----------|---------|----------|---------|
| 基础RAG | 中等 | 较低 | 较低 | 低 | 高 |
| MemGPT | 高 | 中等 | 较高 | 高 | 中等 |
| Mem0 | 高 | 较高 | 中等 | 中等 | 高 |
| 向量数据库+自定义 | 可定制 | 高 | 中等 | 高 | 中等 |
---
## 七、工程实践案例与落地挑战
### 7.1 成功案例分析
来自硅谷一线AI创业者的数据显示:只有5%的AI Agent成功部署到生产环境[为什么只有5%的AI Agent落地成功?](https://wwww.huxiu.com/comment/4792610.html)。这5%的成功案例都有一个共同点——都采用了"human-in-the-loop"的设计:
- 将AI定位为辅助工具而非自主决策者
- 构建反馈循环让系统能从人类修正中学习
- 让人类可以轻松验证和否决AI的输出
这些成功案例在记忆设计上的共同特点是:精细调整模型的需求其实非常少见——一个设计完善的RAG系统通常就能满足需求。但大多数RAG系统的设计都太过初级——将所有内容编入索引导致检索信息过量反而迷惑模型;编入索引的内容过少导致模型缺乏有效信号;不加区分地混合结构化与非结构化数据会破坏嵌入向量的语义[为什么只有5%的AI Agent落地成功?](https://wwww.huxiu.com/comment/4792610.html)。
### 7.2 记忆功能的产品化架构
亚马逊云科技发布的Amazon Bedrock AgentCore为开发者打通了AI Agents从概念验证到生产部署的关键环节[跨越“演示”到“生产”鸿沟,亚马逊云科技开启AI Agents新纪元](http://app.myzaker.com/news/article.php?pk=6878d44a8e9f091282757f81)。其核心能力包括:
- **会话管理**:处理多用户并发场景下的状态隔离
- **身份权限控制**:确保AI访问敏感数据时的身份与权限可控
- **记忆系统**:支持分级存储和版本控制
- **可观测性机制**:监控Agent行为并进行调试
该平台的一个关键洞察是:大多数创始人以为自己在打造AI产品,实际上构建的是上下文选择系统[为什么只有5%的AI Agent落地成功?](https://wwww.huxiu.com/comment/4792610.html)。真正的工程工作应该得到应有的重视——精细调整模型的需求其实非常少见。
### 7.3 记忆层级的设计原则
产品化设计中应该考虑多层次的记忆架构[为什么只有5%的AI Agent落地成功?](https://wwww.huxiu.com/comment/4792610.html):
**用户级**:个人偏好设置(如图表类型、风格、写作语气)
**团队级**:高频查询、仪表盘、标准操作手册(runbooks)
**企业级**:知识库、政策文档、历史决策记录
这种分层设计既能满足个性化需求又能保护敏感信息不被泄露给其他用户。
---
## 八、技术路线对比与产品化建议
### 8.1 不同技术路线的核心差异矩阵
基于技术成熟度、落地难度、成本效益和业务价值四个维度进行综合评估:
| 技术路线 | 技术成熟度 | 落地难度 | 成本效益 | 业务价值 |
|---------|----------|---------|---------|---------|
| Token级短期记忆 | ★★★★★ | ★★☆☆☆ | ★★★★★ | ★★★☆☆ |
| 基础RAG方案 | ★★★★☆ | ★★★☆☆ | ★★★★☆ | ★★★★☆ |
| MemGPT虚拟内存 | ★★★☆☆ | ★★★★★ | ★★☆☆☆ | ★★★★★ |
| Mem0轻量级方案 | ★★★★☆ | ★★★☆☆ | ★★★★★ | ★★★★☆ |
| A-MEM卡片盒方案 | ★★★☆☆ | ★★★★☆ | ★★★★☆ | ★★★★★ |
### 8.2 各类技术路线优劣势详解
#### 基础RAG方案(推荐指数:⭐⭐⭐⭐⭐)
**优势**:
- 技术成熟度高,生态系统完善
- 成本可控,可使用开源组件构建
- 易于理解和调试
- 可快速上线验证业务价值
**劣势**:
- 检索准确性受chunk划分影响大
- 难以支持复杂推理关联
- 对非结构化个人记忆支持有限
**适用场景**:知识问答系统、文档检索助手、企业知识库查询
#### Mem0轻量级方案(推荐指数:⭐⭐⭐⭐⭐)
**优势**:
- 实现简单,有现成SDK
- 性能达到开源方案SOTA水平
- 资源占用低适合中小规模部署
- 社区活跃支持良好
**劣势**:
- 功能相对基础,缺乏深度定制能力
- 对复杂多模态场景支持有限
**适用场景**:个性化聊天助手、客服机器人升级、中小型企业AI应用
#### MemGPT虚拟内存方案(推荐指数:⭐⭐⭐★☆)
**优势**:
- 理论架构先进,支持无限扩展
- 内存利用率高
- 可学习管理自身内存
**劣势**:
- 实现复杂度高
- 运维难度大
- 对开发团队要求高
- 生产环境稳定性待验证
**适用场景**:研发型团队探索前沿技术、学术研究、复杂任务型Agent
#### 定制化向量数据库方案(推荐指数:⭐⭐⭐★☆)
**优势**:
- 可高度定制化
- 支持复杂查询逻辑
- 数据完全可控(隐私合规)
**劣势**:
- 开发周期长
- 运维成本高
- 需要专业的数据工程团队
**适用场景**:大型企业私有化部署、高隐私要求行业(医疗、金融)
### 8.3 投入优先级建议
基于您的产品定位和团队能力,建议采取以下投入策略:
#### 第一优先级(立即投入)
**基础RAG+短期记忆组合方案**
理由:
1. 技术风险最低——有成熟的开源框架和丰富的社区支持
2. 成本可控——可使用开源向量数据库(如FAISS、Milvus)
3. 快速验证价值——几周内即可上线原型
4. 底座稳固——为后续升级奠定基础
实施方案建议:
```
短期记忆层 → LangChain/LLamaIndex内置memory模块
长期记忆层 → FAISS/Pinecone向量数据库 + 自定义Chunk策略
检索增强 → 多模态Embedding + Top-k筛选 + 人工反馈闭环
```
#### 第二优先级(3个月内)
**引入Mem0或类似轻量级方案**
理由:
1. 在第一优先级方案验证价值后逐步升级
2. 成本效益比高——开源方案可快速迭代
3. 技术风险可控——有成熟产品参考
#### 第三优先级(6-12个月)
**探索差异化方案**
根据业务场景可选择:
- **如果是知识密集型应用**:深入定制向量数据库方案
- **如果是任务型Agent**:探索MemGPT虚拟内存方案
- **如果是多模态应用**:关注M3-Agent等前沿框架演进
---
## 九、未来演进方向与趋势判断
### 9.1 短期趋势(2025-2026)
根据奥尔特曼的说法:"这是2026年要考虑的事"[超级Agent重要拼图?奥尔特曼点名“AI记忆” 存储环节迎来新叙事](https://finance.sina.com.cn/stock/t/2025-12-22/doc-inhcsmfc4526733.shtml)——真正成熟的长期记忆系统将在明年成为焦点。预计的发展方向包括:
**KV Cache优化技术成熟化**
目前KV Cache优化仍在快速演进中。SnapKV已在单个A100-80GB GPU上实现处理多达38万个上下文令牌的能力[SnapKV: LLM在生成内容之前就知道您在寻找什么](https://blog.csdn.net/qq_36931982/article/details/139118015)。随着技术成熟,这一能力将成为标准配置而非差异化优势。
**多模态记忆标准化**
M3-Agent等多模态记忆框架正在快速发展[Research](https://seed.bytedance.com/en/public_papers)。预计到2026年将出现标准化的多模态记忆接口和评估基准。
### 9.2 中期趋势(2026-2027)
**操作系统级别集成**
Agent的记忆系统将从独立组件演进为操作系统级别的原生能力——类似MemGPT的理念但更加成熟稳定。届时可能出现专门针对AI应用优化的操作系统发行版或云服务层。
**个性化与隐私平衡**
随着监管加强和技术进步,个性化记忆将需要更严格的隐私保护机制。联邦学习、差分隐私等技术将与记忆系统深度整合。
### 9.3 长期展望(2027+)
真正的类人记忆能力将成为区分优秀AI产品与普通工具的关键标志。正如行业共识所言:"在AI都拥有相似模型和工具的未来,记忆将成为决定胜负的关键"[探寻AI Agent 中隐秘的角落:记忆(Memory) - 定义、价值与实践](https://developer.volcengine.com/articles/7540134113190412324)。
---
## 十、结论与行动建议
### 10.1 核心结论
通过对当前Agent Memory模块的技术全景调研,可以得出以下核心结论:
第一,记忆模块已从"锦上添花"转变为"不可或缺"——它是Agent从工具到伙伴的关键分水岭[大模型进阶之路:AI Agent记忆能力构建技术详解(值得收藏)](https://blog.csdn.net/xxue345678/article/details/150983939)[超级Agent重要拼图?奥尔特曼点名“AI记忆” 存储环节迎来新叙事](https://finance.sina.com.cn/stock/t/2025-12-22/doc-inhcsmfc4526733.shtml)。
第二,在技术成熟度和工程可行性方面,基础RAG方案配合短期记忆是最务实的起点;Mem0等轻量级方案提供了良好的平衡点;而MemGPT等前沿方案更适合有充足技术储备的研发型团队。
第三,从成本效益角度分析,A-MEM等优化方案可将Token消耗降低85%-93%[A-MEM智能记忆系统,让你的大模型拥有“学习”能力(收藏版)](https://blog.csdn.net/xiaoganbuaiuk/article/details/151215108)——这说明优化空间巨大但仍需谨慎选择合适的技术路线。
第四,在OpenAI的带领下,个性化记忆已成为行业共识方向[ChatGPT要有记忆力了,OpenAI宣布小范围测试“记忆”功能!](http://m.nbd.com.cn/articles/2024-02-14/3245164.html)[超级Agent重要拼图?奥尔特曼点名“AI记忆” 存储环节迎来新叙事](http://m.cls.cn/detail/2236511)——不跟进将面临竞争劣势风险。
### 10.2 行动建议清单
针对您的产品规划,建议按以下步骤推进:
✅ **Week 1-2**: 明确产品对记忆的具体需求——是知识检索还是个性化对话?需要短期还是长期?对准确性和延迟的要求是多少?
✅ **Week 3**: 基于需求选择技术路线——优先评估LangChain/LlamaIndex+FAISS的基础组合方案
✅ **Week 4**: 构建最小可行原型——用两周时间验证核心功能是否满足业务需求
✅ **Week 5**: 设置评估指标——参考LoCoMo或LongMemEval建立内部评测体系
✅ **Month 2**: 根据原型反馈决定是否引入Mem0或其他轻量级方案进行升级
✅ **Month 3+**: 持续优化并关注行业演进——特别是OpenAI的记忆功能迭代和其他竞品动态
最终建议采取"敏捷验证、快速迭代"策略:不要追求一步到位的理想方案,在验证业务价值的基础上逐步升级技术栈。正如一位资深AI产品经理所言:"大多数创始人以为自己在打造AI产品,但实际上他们构建的是上下文选择系统"[为什么只有5%的AI Agent落地成功?](https://wwww.huxiu.com/comment/4792610.html)——聚焦用户价值而非技术炫技才是成功的关键。
## Quick Start
You can get beta access to the model API through StepFun Open Platform.
### Requirements
- Python >= 3.10
- Node.js >= 18 (for frontend)
- npm or yarn
### 1. Environment Setup
**Install dependencies (backend & frontend)**
```bash
# Using uv (recommended)
uv sync
source .venv/bin/activate
# Or using pip
pip install -e .
```
```bash
cd cortex-ui
# Using npm
npm install
# Or using yarn
yarn install
```
**Configure environment variables**
Get your StepFun API key(s) from [StepFun Open Platform](https://platform.stepfun.com/interface-key), the StepFun API key is for both model and search.
Create a `.env` file in the repo root:
```bash
# Model provider
MODEL_PROVIDER=stepfun
# Model name for Deep Research Agent (optional: step-dr-1, step-3.5-flash; default: step-dr-1)
MODEL_NAME=step-dr-1
# Model API base URL (StepFun)
MODEL_BASE=https://api.stepfun.com
# StepFun model API key
STEP_MODEL_API_KEY=your-model-api-key
# Search API base URL (StepFun search service)
STEP_SEARCH_API_BASE=https://api.stepfun.com
# StepFun search API key
STEP_SEARCH_API_KEY=your-search-api-key
```
**Recommended System Prompt**
We recommend using the system prompt in [prompt.py](scripts/configs/prompt.py) to ensure optimal performance.
### 2. Run with Demo UI
**Start backend service**
```bash
python -m demo.server [OPTIONS]
```
Options:
- `--port PORT` Server port (default: `8001`)
The service runs on `http://localhost:8001` by default.
**Main endpoints:**
- `GET /agents` - Get list of all available Agents
- `WebSocket /ws` - Real-time communication endpoint
**Start frontend service**
```bash
cd cortex-ui
# Development mode
npm run dev
# Or
yarn dev
```
The frontend runs on `http://localhost:3000` by default and automatically proxies API requests to the backend service.
### 3. Run with Offline Runner
Use `python -m scripts.runner` to run DeepResearchAgent on ad-hoc tasks without the UI.
- **Direct args**: pass a single prompt or a tasks file.
```bash
python -m scripts.runner \
--task "列出最近的 AI 安全新闻" \
--output-dir scripts/results
```
- **Config file**: provide defaults via YAML .
```bash
python -m scripts.runner --config scripts/configs/runner_example.yaml
```
Inputs:
- `--task` / `--task-id`: one-off prompt and optional id.
- `--tasks-file`: JSON task list (relative to project root if not absolute), e.g. `scripts/configs/tasks.example.json`.
- Optional flags: `--output-dir` (default `scripts/results`), `--mode` (`multi` default), `--no-stream`, `--request-timeout`, `--context-upper-limit`, `--context-lower-limit`, `--overwrite`.
Outputs:
- One JSON trace per task in `output_dir`, containing the final answer, full message/tool events, metadata, and status.
## Our paper
We have made our technical report available. [Read](https://arxiv.org/pdf/2512.20491)
## Contact Us
You can join our group chat to get updates on your beta API application status and the latest project developments.
## License
The code in the repository is licensed under [Apache 2.0](LICENSE) License.
## Citation
```
@misc{hu2025stepdeepresearchtechnicalreport,
title={Step-DeepResearch Technical Report},
author={Chen Hu and Haikuo Du and Heng Wang and Lin Lin and Mingrui Chen and Peng Liu and Ruihang Miao and Tianchi Yue and Wang You and Wei Ji and Wei Yuan and Wenjin Deng and Xiaojian Yuan and Xiaoyun Zhang and Xiangyu Liu and Xikai Liu and Yanming Xu and Yicheng Cao and Yifei Zhang and Yongyao Wang and Yubo Shu and Yurong Zhang and Yuxiang Zhang and Zheng Gong and Zhichao Chang and Binyan Li and Dan Ma and Furong Jia and Hongyuan Wang and Jiayu Liu and Jing Bai and Junlan Liu and Manjiao Liu and Na Wang and Qiuping Wu and Qinxin Du and Shiwei Li and Wen Sun and Yifeng Gong and Yonglin Chen and Yuling Zhao and Yuxuan Lin and Ziqi Ren and Zixuan Wang and Aihu Zhang and Brian Li and Buyun Ma and Kang An and Li Xie and Mingliang Li and Pan Li and Shidong Yang and Xi Chen and Xiaojia Liu and Yuchu Luo and Yuan Song and YuanHao Ding and Yuanwei Liang and Zexi Li and Zhaoning Zhang and Zixin Zhang and Binxing Jiao and Daxin Jiang and Jiansheng Chen and Jing Li and Xiangyu Zhang and Yibo Zhu},
year={2025},
eprint={2512.20491},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2512.20491},
}
```
## Star History
[](https://www.star-history.com/#stepfun-ai/StepDeepResearch&type=date&legend=top-left)
================================================
FILE: agentkit/trace/__init__.py
================================================
from .builder import EventBuilder, FunctionSpanBuilder, HTTPSpanBuilder, SpanBuilder
from .context import (
SpanContext,
create_span,
get_current_context,
get_default_app_name,
get_default_tracer,
record_event,
set_default_app_name,
set_default_tracer,
start_trace,
trace_function,
)
from .local_tracer import LocalStorageTracer
from .remote_tracer import HybridTracer, RemoteTracer
from .span import (
DeltaEventPayload,
Error,
Event,
EventType,
FunctionSpanPayload,
HTTPSpanPayload,
LLMSpanPayload,
OtherEventPayload,
OtherSpanPayload,
Span,
SpanType,
ToolSpanPayload,
)
from .tracer import Tracer
__all__ = [
# Core types
"Span",
"Event",
"SpanType",
"EventType",
"LLMSpanPayload",
"ToolSpanPayload",
"FunctionSpanPayload",
"HTTPSpanPayload",
"OtherSpanPayload",
"DeltaEventPayload",
"OtherEventPayload",
"Error",
"Tracer",
"LocalStorageTracer",
"RemoteTracer",
"HybridTracer",
# Context management
"SpanContext",
"get_current_context",
"get_default_tracer",
"get_default_app_name",
"set_default_tracer",
"set_default_app_name",
"start_trace",
"create_span",
"record_event",
"trace_function",
# Builders
"SpanBuilder",
"FunctionSpanBuilder",
"HTTPSpanBuilder",
"EventBuilder",
]
================================================
FILE: agentkit/trace/builder.py
================================================
"""
Builder pattern for constructing Span and Event objects.
"""
from datetime import datetime
from typing import Any, Optional
from ulid import ULID
from .default import get_default
from .span import (
Event,
EventType,
FunctionSpanPayload,
HTTPSpanPayload,
OtherEventPayload,
Span,
SpanType,
)
class SpanBuilder:
"""
Span builder providing a fluent API to construct complex Spans.
Usage:
span = (SpanBuilder()
.with_name("my_operation")
.with_app_name("my_app")
.with_type(SpanType.FUNCTION)
.with_tag("user", "alice")
.with_parent(parent_span)
.build())
"""
def __init__(self, trace_id: Optional[str] = None, app_name: Optional[str] = None):
self._trace_id = trace_id or str(ULID())
self._app_name = app_name or get_default("app_name")
self._id = str(ULID())
self._name = ""
self._start_time = datetime.now()
self._end_time = None
self._tags: dict[str, str] = {}
self._payload: Optional[Any] = None
self._parent_id: Optional[str] = None
def with_id(self, span_id: str) -> "SpanBuilder":
"""Set span ID."""
self._id = span_id
return self
def with_name(self, name: str) -> "SpanBuilder":
"""Set span name."""
self._name = name
return self
def with_trace_id(self, trace_id: str) -> "SpanBuilder":
"""Set trace ID."""
self._trace_id = trace_id
return self
def with_start_time(self, start_time: datetime) -> "SpanBuilder":
"""Set start time."""
self._start_time = start_time
return self
def with_end_time(self, end_time: datetime) -> "SpanBuilder":
"""Set end time."""
self._end_time = end_time
return self
def with_tag(self, key: str, value: str) -> "SpanBuilder":
"""Add a tag."""
self._tags[key] = value
return self
def with_tags(self, tags: dict[str, str]) -> "SpanBuilder":
"""Add multiple tags."""
self._tags.update(tags)
return self
def with_payload(self, payload: Any) -> "SpanBuilder":
"""Set payload."""
self._payload = payload
return self
def with_parent(self, parent: Span) -> "SpanBuilder":
"""Set parent span."""
self._parent_id = parent.id
# Automatically inherit trace_id from parent span
if parent:
self._trace_id = parent.trace_id
return self
def with_parent_id(self, parent_id: str) -> "SpanBuilder":
"""Set parent span ID."""
self._parent_id = parent_id
return self
def with_app_name(self, app_name: str) -> "SpanBuilder":
"""Set application name."""
self._app_name = app_name
return self
def build(self) -> Span:
"""Build Span object."""
span = Span(
id=self._id,
name=self._name,
trace_id=self._trace_id,
app_name=self._app_name,
start_time=self._start_time,
end_time=self._end_time,
tags=self._tags,
payload=self._payload,
parent_id=self._parent_id,
)
return span
class FunctionSpanBuilder(SpanBuilder):
"""
Specialized builder for Function Spans.
Usage:
span = (FunctionSpanBuilder()
.with_name("calculate")
.with_function_name("calculate_total")
.with_arguments({"x": 1, "y": 2})
.with_return_value(3)
.build())
"""
def __init__(self, trace_id: Optional[str] = None, app_name: Optional[str] = None):
super().__init__(trace_id, app_name)
self._function_name = ""
self._arguments: dict[str, Any] = {}
self._ret: Any = None
self._error = None
def with_function_name(self, name: str) -> "FunctionSpanBuilder":
"""Set function name."""
self._function_name = name
if not self._name:
self._name = name
return self
def with_arguments(self, arguments: dict[str, Any]) -> "FunctionSpanBuilder":
"""Set function arguments."""
self._arguments = arguments
return self
def with_return_value(self, ret: Any) -> "FunctionSpanBuilder":
"""Set return value."""
self._ret = ret
return self
def with_error(self, code: int, message: str) -> "FunctionSpanBuilder":
"""Set error information."""
self._error = {"code": code, "message": message}
return self
def build(self) -> Span:
"""Build Function Span."""
payload = FunctionSpanPayload(
type=SpanType.FUNCTION,
name=self._function_name,
arguments=self._arguments,
ret=self._ret,
error=self._error,
)
self._payload = payload
return super().build()
class HTTPSpanBuilder(SpanBuilder):
"""
Specialized builder for HTTP Spans.
Usage:
span = (HTTPSpanBuilder()
.with_url("https://api.example.com/users")
.with_method("GET")
.with_header("Authorization", "Bearer token")
.with_response('{"users": []}')
.build())
"""
def __init__(self, trace_id: Optional[str] = None, app_name: Optional[str] = None):
super().__init__(trace_id, app_name)
self._url = ""
self._method = "GET"
self._headers: dict[str, list[str]] = {}
self._body: Optional[str | bytes] = None
self._response: Optional[str | bytes] = None
self._error = None
def with_url(self, url: str) -> "HTTPSpanBuilder":
"""Set URL."""
self._url = url
if not self._name:
self._name = f"{self._method} {url}"
return self
def with_method(self, method: str) -> "HTTPSpanBuilder":
"""Set HTTP method."""
self._method = method
if self._url:
self._name = f"{method} {self._url}"
return self
def with_header(self, key: str, value: str | list[str]) -> "HTTPSpanBuilder":
"""Add HTTP header."""
if isinstance(value, str):
value = [value]
self._headers[key] = value
return self
def with_headers(self, headers: dict[str, list[str]]) -> "HTTPSpanBuilder":
"""Add multiple HTTP headers."""
self._headers.update(headers)
return self
def with_body(self, body: str | bytes) -> "HTTPSpanBuilder":
"""Set request body."""
self._body = body
return self
def with_response(self, response: str | bytes) -> "HTTPSpanBuilder":
"""Set response body."""
self._response = response
return self
def with_error(self, code: int, message: str) -> "HTTPSpanBuilder":
"""Set error information."""
self._error = {"code": code, "message": message}
return self
def build(self) -> Span:
"""Build HTTP Span."""
payload = HTTPSpanPayload(
type=SpanType.HTTP,
url=self._url,
method=self._method,
headers=self._headers,
body=self._body,
response=self._response,
error=self._error,
)
self._payload = payload
return super().build()
class EventBuilder:
"""
Event builder.
Usage:
event = (EventBuilder()
.with_name("user_input")
.with_data({"text": "hello"})
.with_parent(span)
.build())
"""
def __init__(self, trace_id: Optional[str] = None, app_name: Optional[str] = None):
self._trace_id = trace_id or str(ULID())
self._app_name = app_name or get_default("app_name")
self._id = str(ULID())
self._name = ""
self._timestamp = datetime.now()
self._tags: dict[str, str] = {}
self._data: Any = None
self._parent_id: Optional[str] = None
def with_id(self, event_id: str) -> "EventBuilder":
"""Set event ID."""
self._id = event_id
return self
def with_name(self, name: str) -> "EventBuilder":
"""Set event name."""
self._name = name
return self
def with_trace_id(self, trace_id: str) -> "EventBuilder":
"""Set trace ID."""
self._trace_id = trace_id
return self
def with_timestamp(self, timestamp: datetime) -> "EventBuilder":
"""Set timestamp."""
self._timestamp = timestamp
return self
def with_tag(self, key: str, value: str) -> "EventBuilder":
"""Add a tag."""
self._tags[key] = value
return self
def with_tags(self, tags: dict[str, str]) -> "EventBuilder":
"""Add multiple tags."""
self._tags.update(tags)
return self
def with_data(self, data: Any) -> "EventBuilder":
"""Set data."""
self._data = data
return self
def with_parent(self, parent: Span) -> "EventBuilder":
"""Set parent span."""
self._parent_id = parent.id
# Automatically inherit trace_id from parent span
if parent:
self._trace_id = parent.trace_id
return self
def with_parent_id(self, parent_id: str) -> "EventBuilder":
"""Set parent span ID."""
self._parent_id = parent_id
return self
def build(self) -> Event:
"""Build Event object."""
payload = OtherEventPayload(type=EventType.OTHER, data=self._data)
event = Event(
id=self._id,
name=self._name,
trace_id=self._trace_id,
timestamp=self._timestamp,
tags=self._tags,
payload=payload,
parent_id=self._parent_id,
app_name=self._app_name,
)
return event
================================================
FILE: agentkit/trace/context.py
================================================
"""
Context management and utility functions for Span and Event.
"""
from contextlib import contextmanager
from contextvars import ContextVar
from datetime import datetime
from functools import wraps
try: # Python 3.11+ has HTTPMethod in stdlib
from http import HTTPMethod
except ImportError: # pragma: no cover - fallback for Python 3.10
from enum import Enum
class HTTPMethod(str, Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
PATCH = "PATCH"
HEAD = "HEAD"
OPTIONS = "OPTIONS"
CONNECT = "CONNECT"
TRACE = "TRACE"
from typing import Any, Callable, Literal, Optional
from ulid import ULID
from .default import get_default, get_default_settings, set_default
from .span import (
Event,
EventType,
FunctionSpanPayload,
HTTPSpanPayload,
LLMSpanPayload,
OtherEventPayload,
OtherSpanPayload,
Span,
SpanType,
ToolSpanPayload,
)
from .tracer import Tracer
from .types import Error
# Use contextvars to manage the current context
_current_context: ContextVar[Optional["SpanContext"]] = ContextVar(
"current_context", default=None
)
class SpanContext:
"""
Manages the current active span context, supporting nested parent-child relationships.
Supports cross-service span reconstruction: trace_id and parent_id can be specified at creation time.
"""
def __init__(
self,
app_name: Optional[str] = None,
tags: Optional[dict[str, str]] = None,
trace_id: Optional[str] = None,
parent_id: Optional[str] = None,
tracer: Optional[Tracer] = None,
):
if tracer is None:
self.tracer = get_default("tracer")
else:
self.tracer = tracer
self.app_name = app_name or get_default("app_name")
self.tags = tags
self._span_stack: list[Span] = []
self._current_trace_id: Optional[str] = trace_id
self._root_parent_id: Optional[str] = parent_id # Root parent node for cross-service reconstruction
def get_current_span(self) -> Optional[Span]:
"""Get the current active span."""
return self._span_stack[-1] if self._span_stack else None
def get_current_trace_id(self) -> str:
"""Get or create the current trace_id."""
if self._current_trace_id is None:
self._current_trace_id = str(ULID())
return self._current_trace_id
def set_trace_id(self, trace_id: str):
"""Set the current trace_id."""
self._current_trace_id = trace_id
def merge_tags(self, tags: dict[str, str]) -> dict[str, str]:
"""Merge tags."""
merged_tags = self.tags or {}
if tags:
merged_tags.update(tags)
return merged_tags
@contextmanager
def span(
self,
name: str,
tags: Optional[dict[str, str]] = None,
payload: Optional[
HTTPSpanPayload
| LLMSpanPayload
| ToolSpanPayload
| FunctionSpanPayload
| OtherSpanPayload
] = None,
):
"""
Context manager for creating a span.
Usage:
with ctx.span("my_operation", SpanType.FUNCTION):
# your code
pass
"""
parent_span = self.get_current_span()
trace_id = self.get_current_trace_id()
# Determine parent_id: prioritize span in current stack, otherwise use root parent_id (cross-service scenario)
parent_id = None
if parent_span:
parent_id = parent_span.id
elif self._root_parent_id:
parent_id = self._root_parent_id
span = Span(
name=name,
trace_id=trace_id,
app_name=self.app_name,
tags=self.merge_tags(tags),
payload=payload,
parent_id=parent_id,
)
self._span_stack.append(span)
token = _current_context.set(self)
try:
yield span
except Exception as e:
# Record error
span.tags["error"] = str(e)
raise
finally:
# Set end time
span.end_time = datetime.now()
self._span_stack.pop()
# Record span
if self.tracer:
self.tracer.record_span(span)
_current_context.reset(token)
def record_event(
self,
name: str,
data: Any,
tags: Optional[dict[str, str]] = None,
):
"""
Record an event to the current span.
Usage:
ctx.record_event("user_input", {"text": "hello"})
"""
parent_span = self.get_current_span()
trace_id = self.get_current_trace_id()
payload = OtherEventPayload(type=EventType.OTHER, data=data)
event = Event(
name=name,
trace_id=trace_id,
tags=self.merge_tags(tags),
payload=payload,
parent_id=parent_span.id if parent_span else None,
app_name=self.app_name,
)
if self.tracer:
self.tracer.record_event(event)
return event
@contextmanager
def function_span(
self,
name: str,
arguments: dict[str, Any],
tags: Optional[dict[str, str]] = None,
):
"""
Create a span for a function call.
Usage:
with ctx.function_span("calculate", {"x": 1, "y": 2}) as span:
result = calculate(1, 2)
span.update_payload_data(return_value=result)
"""
payload = FunctionSpanPayload(
type=SpanType.FUNCTION,
name=name,
arguments=arguments,
return_value=None,
)
with self.span(name, tags, payload) as span:
try:
yield span
except Exception as e:
if isinstance(span.payload, FunctionSpanPayload):
span.payload.error = Error(code=-1, message=str(e))
raise
@contextmanager
def llm_span(
self,
name: str = "llm_call",
request: Any = None,
tags: Optional[dict[str, str]] = None,
):
"""
Create a span for an LLM call.
Usage:
with ctx.llm_span("openai_call", request=messages) as span:
response = client.chat.completions.create(...)
span.update_payload_data(response=response)
"""
payload = LLMSpanPayload(
type=SpanType.LLM,
request=request,
)
with self.span(name, tags, payload) as span:
try:
yield span
except Exception as e:
if isinstance(span.payload, LLMSpanPayload):
span.payload.error = Error(code=-1, message=str(e))
raise
@contextmanager
def tool_span(
self,
name: str = "tool_call",
request: Any = None,
tags: Optional[dict[str, str]] = None,
):
"""
Create a span for a tool call.
Usage:
with ctx.tool_span("search_tool", request=tool_call) as span:
result = search(query)
span.update_payload_data(response=result)
"""
payload = ToolSpanPayload(type=SpanType.TOOL, request=request)
with self.span(name, tags, payload) as span:
try:
yield span
except Exception as e:
if isinstance(span.payload, ToolSpanPayload):
span.payload.error = Error(code=-1, message=str(e))
raise
@contextmanager
def http_span(
self,
url: str,
method: Literal[
HTTPMethod.GET,
HTTPMethod.POST,
HTTPMethod.PUT,
HTTPMethod.DELETE,
HTTPMethod.PATCH,
HTTPMethod.HEAD,
HTTPMethod.OPTIONS,
HTTPMethod.CONNECT,
HTTPMethod.TRACE,
],
name: Optional[str] = None,
headers: Optional[dict[str, list[str]]] = None,
body: Optional[str | bytes] = None,
tags: Optional[dict[str, str]] = None,
):
"""
Create a span for an HTTP request.
Usage:
with ctx.http_span("https://api.example.com", "POST", body=data) as span:
response = requests.post(url, data=data)
span.update_payload_data(response=response.text)
"""
span_name = name or f"{method} {url}"
payload = HTTPSpanPayload(
type=SpanType.HTTP,
url=url,
method=method,
headers=headers or {},
body=body,
)
with self.span(span_name, tags, payload) as span:
try:
yield span
except Exception as e:
if isinstance(span.payload, HTTPSpanPayload):
span.payload.error = Error(code=-1, message=str(e))
raise
def get_current_context() -> SpanContext:
"""Get the current SpanContext, creating a default one if it doesn't exist."""
context = _current_context.get()
if context is None:
context = SpanContext()
_current_context.set(context)
return context
def set_default_tracer(tracer: Tracer):
"""Set the global default Tracer."""
set_default(tracer=tracer)
def set_default_app_name(app_name: str):
"""Set the global default app_name."""
set_default(app_name=app_name)
def get_default_tracer() -> Tracer:
"""Get the global default Tracer."""
return get_default_settings().tracer
def get_default_app_name() -> str:
"""Get the global default app_name."""
return get_default_settings().app_name
# Convenient global functions
def trace_function(
name: Optional[str] = None,
tags: Optional[dict[str, str]] = None,
context: Optional[SpanContext] = None,
):
"""
Decorator: automatically trace function calls.
Usage:
@trace_function(name="my_function", tags={"category": "business"})
def my_function(x, y):
return x + y
"""
def decorator(func: Callable) -> Callable:
func_name = name or func.__name__
@wraps(func)
def wrapper(*args, **kwargs):
ctx = context or get_current_context()
# Build arguments dictionary
arguments = {"args": args, "kwargs": kwargs}
with ctx.function_span(func_name, arguments, tags) as span:
result = func(*args, **kwargs)
# Record return value
if isinstance(span.payload, FunctionSpanPayload):
span.payload.return_value = result
return result
@wraps(func)
async def async_wrapper(*args, **kwargs):
ctx = context or get_current_context()
# Build arguments dictionary
arguments = {"args": args, "kwargs": kwargs}
with ctx.function_span(func_name, arguments, tags) as span:
result = await func(*args, **kwargs)
# Record return value
if isinstance(span.payload, FunctionSpanPayload):
span.payload.return_value = result
return result
# Return corresponding wrapper based on function type
import asyncio
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return wrapper
return decorator
def start_trace(trace_id: Optional[str] = None, context: Optional[SpanContext] = None):
"""
Start a new trace.
Usage:
trace_id = start_trace()
"""
ctx = context or get_current_context()
if trace_id:
ctx.set_trace_id(trace_id)
else:
trace_id = ctx.get_current_trace_id()
return trace_id
def record_event(
name: str,
data: Any,
tags: Optional[dict[str, str]] = None,
context: Optional[SpanContext] = None,
):
"""
Record an event.
Usage:
record_event("user_action", {"action": "click", "button": "submit"})
"""
ctx = context or get_current_context()
return ctx.record_event(name, data, tags=tags)
@contextmanager
def create_span(
name: str,
tags: Optional[dict[str, str]] = None,
context: Optional[SpanContext] = None,
):
"""
Convenient function to create a span.
Usage:
with create_span("my_operation"):
# your code
pass
"""
ctx = context or get_current_context()
with ctx.span(name, tags) as span:
yield span
================================================
FILE: agentkit/trace/default.py
================================================
from dataclasses import dataclass
from .local_tracer import LocalStorageTracer
from .tracer import Tracer
@dataclass
class DefaultSettings:
app_name: str = "default"
tracer: Tracer = LocalStorageTracer(storage_dir="./traces")
_settings = DefaultSettings()
def set_default(**kwargs):
for key, value in kwargs.items():
if hasattr(_settings, key):
setattr(_settings, key, value)
else:
raise ValueError(f"Unknown setting: {key}")
def get_default_settings() -> DefaultSettings:
return _settings
def get_default(key: str):
"""Get a single default value."""
return getattr(_settings, key)
================================================
FILE: agentkit/trace/local_tracer.py
================================================
import json
from datetime import datetime
from pathlib import Path
from typing import Optional
from .span import Event, Span
from .tracer import Tracer
class LocalStorageTracer(Tracer):
def __init__(self, storage_dir: str = "./traces"):
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(parents=True, exist_ok=True)
self.spans_dir = self.storage_dir / "spans"
self.events_dir = self.storage_dir / "events"
self.spans_dir.mkdir(exist_ok=True)
self.events_dir.mkdir(exist_ok=True)
def _get_trace_spans_file(self, trace_id: str) -> Path:
return self.spans_dir / f"{trace_id}.jsonl"
def _get_trace_events_file(self, trace_id: str) -> Path:
return self.events_dir / f"{trace_id}.jsonl"
def record_span(self, span: Span) -> None:
spans_file = self._get_trace_spans_file(span.trace_id)
with open(spans_file, "a", encoding="utf-8") as f:
span_data = span.model_dump_json(exclude_none=True, ensure_ascii=False)
f.write(span_data + "\n")
def record_event(self, event: Event) -> None:
events_file = self._get_trace_events_file(event.trace_id)
with open(events_file, "a", encoding="utf-8") as f:
event_data = event.model_dump_json(exclude_none=True, ensure_ascii=False)
f.write(event_data + "\n")
def get_spans(self, trace_id: str) -> list[Span]:
spans_file = self._get_trace_spans_file(trace_id)
if not spans_file.exists():
return []
spans = []
with open(spans_file, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
if line.strip():
try:
span_data = json.loads(line)
spans.append(Span(**span_data))
except Exception as e:
# Log error but continue processing other lines
print(
f"Warning: Failed to parse span at line {line_num} in {spans_file}: {e}"
)
continue
return spans
def get_events(self, trace_id: str) -> list[Event]:
events_file = self._get_trace_events_file(trace_id)
if not events_file.exists():
return []
events = []
with open(events_file, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
if line.strip():
try:
event_data = json.loads(line)
events.append(Event(**event_data))
except Exception as e:
# Log error but continue processing other lines
print(
f"Warning: Failed to parse event at line {line_num} in {events_file}: {e}"
)
continue
return events
def get_trace(self, trace_id: str) -> Optional[dict]:
spans = self.get_spans(trace_id)
events = self.get_events(trace_id)
if not spans and not events:
return None
return {
"trace_id": trace_id,
"spans": [span.model_dump(mode="json") for span in spans],
"events": [event.model_dump(mode="json") for event in events],
"span_count": len(spans),
"event_count": len(events),
}
def get_trace_raw(self, trace_id: str) -> Optional[dict]:
"""
Get raw trace data (without Pydantic model validation).
Used for frontend display to avoid serialization/deserialization issues.
"""
spans_file = self._get_trace_spans_file(trace_id)
events_file = self._get_trace_events_file(trace_id)
if not spans_file.exists() and not events_file.exists():
return None
spans = []
events = []
# Read spans (raw JSON)
if spans_file.exists():
with open(spans_file, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
try:
span_data = json.loads(line)
spans.append(span_data)
except json.JSONDecodeError as e:
# Skip invalid line
print(f"Warning: Failed to parse span line: {e}")
continue
# Read events (raw JSON)
if events_file.exists():
with open(events_file, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
try:
event_data = json.loads(line)
events.append(event_data)
except json.JSONDecodeError as e:
# Skip invalid line
print(f"Warning: Failed to parse event line: {e}")
continue
return {
"trace_id": trace_id,
"spans": spans,
"events": events,
"span_count": len(spans),
"event_count": len(events),
}
def list_traces(self, limit: int = 100, offset: int = 0) -> list[dict]:
trace_files = {}
for spans_file in self.spans_dir.glob("*.jsonl"):
trace_id = spans_file.stem
trace_files[trace_id] = {
"trace_id": trace_id,
"spans_file": spans_file,
"mtime": spans_file.stat().st_mtime,
}
for events_file in self.events_dir.glob("*.jsonl"):
trace_id = events_file.stem
if trace_id not in trace_files:
trace_files[trace_id] = {
"trace_id": trace_id,
"events_file": events_file,
"mtime": events_file.stat().st_mtime,
}
else:
trace_files[trace_id]["events_file"] = events_file
trace_files[trace_id]["mtime"] = max(
trace_files[trace_id]["mtime"], events_file.stat().st_mtime
)
sorted_traces = sorted(
trace_files.values(), key=lambda x: x["mtime"], reverse=True
)
traces = []
for trace_info in sorted_traces[offset : offset + limit]:
trace_id = trace_info["trace_id"]
spans = self.get_spans(trace_id)
events = self.get_events(trace_id)
traces.append(
{
"trace_id": trace_id,
"span_count": len(spans),
"event_count": len(events),
"last_modified": datetime.fromtimestamp(trace_info["mtime"]),
}
)
return traces
================================================
FILE: agentkit/trace/remote_tracer.py
================================================
import logging
from datetime import datetime
from typing import Optional
import httpx
from pydantic import BaseModel
from .span import DataType, Event, Span
from .tracer import Tracer
logger = logging.getLogger(__name__)
class RemoteEvent(BaseModel):
id: str
data_type: DataType
timestamp: datetime
app_name: str
data: Span | Event
class RemoteTracer(Tracer):
"""
A tracer that sends trace data to a remote API.
According to API docs: /trace/agent/event
"""
def __init__(
self,
base_url: str = "",
timeout: float = 10.0,
enable_span: bool = True,
enable_event: bool = True,
):
"""
Initialize RemoteTracer.
Args:
base_url: API base URL.
timeout: Request timeout in seconds.
enable_span: Whether to enable span sending (default True).
enable_event: Whether to enable event sending (default True).
"""
self.base_url = base_url.rstrip("/")
self.event_endpoint = f"{self.base_url}/trace/agent/event"
self.timeout = timeout
self.enable_span = enable_span
self.enable_event = enable_event
self.client = httpx.Client(timeout=timeout)
def __del__(self):
"""Clean up resources."""
try:
self.client.close()
except Exception:
pass
def _send_to_api(self, data: str) -> bool:
"""
Send data to remote API.
Args:
data: Data to send.
Returns:
Whether the send was successful.
"""
try:
response = self.client.post(
self.event_endpoint,
data=data,
headers={"Content-Type": "application/json"},
)
if response.status_code == 200:
result = response.json()
if result.get("code") == 0:
logger.debug(f"Successfully sent trace data: {data}")
return True
else:
logger.error(
f"API returned error code {result.get('code')}: {result.get('msg')}"
)
return False
else:
logger.error(f"HTTP error {response.status_code}: {response.text}")
return False
except httpx.TimeoutException:
logger.error(f"Request timeout when sending trace data: {data}")
return False
except Exception as e:
logger.error(f"Failed to send trace data: {e}")
return False
def record_span(self, span: Span) -> None:
"""Record a span to the remote service."""
if not self.enable_span:
return
remote_event = RemoteEvent(
id=span.id,
data_type=DataType.SPAN,
timestamp=span.start_time,
app_name=span.app_name,
data=span,
)
self._send_to_api(remote_event.model_dump_json(exclude_none=True))
def record_event(self, event: Event) -> None:
"""Record an event to the remote service."""
if not self.enable_event:
return
remote_event = RemoteEvent(
id=event.id,
data_type=DataType.EVENT,
timestamp=event.timestamp,
app_name=event.app_name,
data=event,
)
self._send_to_api(remote_event.model_dump_json(exclude_none=True))
def get_spans(self, trace_id: str) -> list[Span]:
"""
RemoteTracer does not support read operations.
Note: The current remote API only provides write interface, query is not supported.
"""
logger.warning("RemoteTracer does not support reading spans")
return []
def get_events(self, trace_id: str) -> list[Event]:
"""
RemoteTracer does not support read operations.
Note: The current remote API only provides write interface, query is not supported.
"""
logger.warning("RemoteTracer does not support reading events")
return []
def get_trace(self, trace_id: str) -> Optional[dict]:
"""
RemoteTracer does not support read operations.
Note: The current remote API only provides write interface, query is not supported.
"""
logger.warning("RemoteTracer does not support reading traces")
return None
def list_traces(self, limit: int = 100, offset: int = 0) -> list[dict]:
"""
RemoteTracer does not support read operations.
Note: The current remote API only provides write interface, query is not supported.
"""
logger.warning("RemoteTracer does not support listing traces")
return []
class HybridTracer(Tracer):
"""
Hybrid tracer: sends data to both remote service and local storage.
Usage:
from agentkit.trace.local_tracer import LocalStorageTracer
from agentkit.trace.remote_tracer import RemoteTracer, HybridTracer
local_tracer = LocalStorageTracer("./traces")
remote_tracer = RemoteTracer("xxx")
hybrid_tracer = HybridTracer(local_tracer, remote_tracer)
"""
def __init__(self, local_tracer: Tracer, remote_tracer: RemoteTracer):
"""
Initialize hybrid tracer.
Args:
local_tracer: Local tracer (for reading and local storage).
remote_tracer: Remote tracer (for remote reporting).
"""
self.local_tracer = local_tracer
self.remote_tracer = remote_tracer
def record_span(self, span: Span) -> None:
"""Record to both local and remote."""
self.local_tracer.record_span(span)
self.remote_tracer.record_span(span)
def record_event(self, event: Event) -> None:
"""Record to both local and remote."""
self.local_tracer.record_event(event)
self.remote_tracer.record_event(event)
def get_spans(self, trace_id: str) -> list[Span]:
"""Read spans from local."""
return self.local_tracer.get_spans(trace_id)
def get_events(self, trace_id: str) -> list[Event]:
"""Read events from local."""
return self.local_tracer.get_events(trace_id)
def get_trace(self, trace_id: str) -> Optional[dict]:
"""Read trace from local."""
return self.local_tracer.get_trace(trace_id)
def get_trace_raw(self, trace_id: str) -> Optional[dict]:
"""Read raw trace from local."""
return self.local_tracer.get_trace_raw(trace_id)
def list_traces(self, limit: int = 100, offset: int = 0) -> list[dict]:
"""List traces from local."""
return self.local_tracer.list_traces(limit, offset)
================================================
FILE: agentkit/trace/span.py
================================================
from datetime import datetime
from enum import Enum
try: # Python 3.11+ has HTTPMethod in stdlib
from http import HTTPMethod
except ImportError: # pragma: no cover - fallback for Python 3.10
class HTTPMethod(str, Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
PATCH = "PATCH"
HEAD = "HEAD"
OPTIONS = "OPTIONS"
CONNECT = "CONNECT"
TRACE = "TRACE"
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field
from ulid import ULID
from .types import Error
class DataType(str, Enum):
SPAN = "span"
EVENT = "event"
class SpanType(str, Enum):
LLM = "llm_span"
TOOL = "tool_span"
HTTP = "http_span"
FUNCTION = "function_span"
OTHER = "other_span"
class LLMSpanPayload(BaseModel):
type: Literal[SpanType.LLM]
request: Any = None
response: Any = None
error: Optional[Error] = None
class ToolSpanPayload(BaseModel):
type: Literal[SpanType.TOOL]
request: Any = None
response: Any = None
error: Optional[Error] = None
class FunctionSpanPayload(BaseModel):
type: Literal[SpanType.FUNCTION]
name: str = ""
arguments: dict[str, Any] = Field(default_factory=dict)
return_value: Any = None
error: Optional[Error] = None
class HTTPSpanPayload(BaseModel):
type: Literal[SpanType.HTTP]
url: str
method: Literal[
HTTPMethod.GET,
HTTPMethod.POST,
HTTPMethod.PUT,
HTTPMethod.DELETE,
HTTPMethod.PATCH,
HTTPMethod.HEAD,
HTTPMethod.OPTIONS,
HTTPMethod.CONNECT,
HTTPMethod.TRACE,
]
headers: dict[str, list[str]] = Field(default_factory=dict)
body: Optional[str | bytes] = None
response: Optional[str | bytes] = None
error: Optional[Error] = None
class OtherSpanPayload(BaseModel):
type: Literal[SpanType.OTHER]
data: Any
class Span(BaseModel):
# use ulid
id: str = Field(default_factory=lambda: str(ULID()))
name: str = Field(default="")
data_type: Literal[DataType.SPAN] = DataType.SPAN
start_time: datetime = Field(default_factory=datetime.now)
end_time: Optional[datetime] = Field(default=None)
tags: dict[str, str] = Field(default_factory=dict)
payload: (
LLMSpanPayload
| ToolSpanPayload
| FunctionSpanPayload
| OtherSpanPayload
| HTTPSpanPayload
| None
) = Field(default=None)
parent_id: Optional[str] = Field(default=None) # Parent span ID
trace_id: str
app_name: str
def update_payload(
self,
payload: (
LLMSpanPayload
| ToolSpanPayload
| FunctionSpanPayload
| HTTPSpanPayload
| OtherSpanPayload
| None
),
) -> "Span":
"""Update span payload."""
self.payload = payload
return self
def update_payload_data(self, **kwargs) -> "Span":
"""
Update specific fields in payload.
Usage:
span.update_payload_data(ret=result, error=None)
"""
if self.payload is not None and hasattr(self.payload, "model_copy"):
# Use pydantic's model_copy to update fields
self.payload = self.payload.model_copy(update=kwargs)
return self
def add_tag(self, key: str, value: str) -> "Span":
"""Add or update a tag."""
self.tags[key] = value
return self
def add_tags(self, tags: dict[str, str]) -> "Span":
"""Add or update multiple tags."""
self.tags.update(tags)
return self
class EventType(str, Enum):
DELTA = "delta_event"
OTHER = "other_event"
class DeltaEventPayload(BaseModel):
type: Literal[EventType.DELTA]
delta: Any
class OtherEventPayload(BaseModel):
type: Literal[EventType.OTHER]
data: Any
class Event(BaseModel):
id: str = Field(default_factory=lambda: str(ULID()))
name: str = Field(default="")
data_type: Literal[DataType.EVENT] = DataType.EVENT
timestamp: datetime = Field(default_factory=datetime.now)
tags: dict[str, str] = Field(default_factory=dict)
payload: DeltaEventPayload | OtherEventPayload
parent_id: Optional[str] = Field(default=None) # Parent span ID
trace_id: str
app_name: str
================================================
FILE: agentkit/trace/tracer.py
================================================
from abc import ABC, abstractmethod
from typing import Optional
from .span import Event, Span
class Tracer(ABC):
@abstractmethod
def record_span(self, span: Span) -> None:
pass
@abstractmethod
def record_event(self, event: Event) -> None:
pass
@abstractmethod
def get_spans(self, trace_id: str) -> list[Span]:
pass
@abstractmethod
def get_events(self, trace_id: str) -> list[Event]:
pass
@abstractmethod
def get_trace(self, trace_id: str) -> Optional[dict]:
pass
@abstractmethod
def list_traces(self, limit: int = 100, offset: int = 0) -> list[dict]:
pass
def get_trace_raw(self, trace_id: str) -> Optional[dict]:
"""
Get raw trace data (optional implementation).
Default behavior is to call get_trace().
"""
return self.get_trace(trace_id)
================================================
FILE: agentkit/trace/types.py
================================================
from pydantic import BaseModel
class Error(BaseModel):
code: int
message: str
# class ModelParams(BaseModel):
# name: str
# response_format: Optional[Any] = None
# toolcall_parser_version: Optional[str] = None
# parallel_tool_calls: bool = True
# infer_kwargs: dict[str, Any] = {}
#
#
# class Function(BaseModel):
# arguments: str | None = None
# """
# The arguments to call the function with, as generated by the model in JSON
# format. Note that the model does not always generate valid JSON, and may
# hallucinate parameters not defined by your function schema. Validate the
# arguments in your code before calling your function.
# """
#
# name: str | None = None
# """The name of the function to call."""
#
#
# class ChatToolCall(BaseModel):
# index: int | None = None
# """The index of the tool call."""
#
# id: str | None = None
# """The ID of the tool call."""
#
# function: Function
# """The function that the model called."""
#
# type: str | None = "function"
# """The type of the tool. Currently, only `function` is supported."""
#
#
# class ChatMessage(BaseModel):
# id: Optional[str] = None
# role: Optional[str] = None
# content: str | list[dict] | None = None
# tool_call_id: Optional[str] = None
# tool_calls: Optional[list[ChatToolCall]] = None
#
# # for train use
# extra_info: Optional[dict] = None
#
#
# class LLMRequest(BaseModel):
# messages: list[ChatMessage]
# model: ModelParams
================================================
FILE: config.yaml
================================================
# Global runtime defaults for StepDeepResearch.
#
# Used as the default context budget for the force-final-answer workflow in BaseStepAgent.
# `scripts.runner` can override these via its own `--config` YAML or CLI flags.
context_upper_limit: 80000
context_lower_limit: 60000
================================================
FILE: cortex/__init__.py
================================================
from .env import load_env as _load_env
_load_env()
================================================
FILE: cortex/agents/__init__.py
================================================
"""Agent components module."""
from cortex.agents.base_agent import BaseAgent
from cortex.agents.base_step_agent import BaseStepAgent
from cortex.agents.react_agent import ReActAgent
from cortex.agents.types import (
AgentConfig,
AgentMessageType,
AgentResponse,
AgentRunningStatus,
)
__all__ = [
"BaseAgent",
"BaseStepAgent",
"AgentConfig",
"AgentResponse",
"AgentRunningStatus",
"AgentMessageType",
"ReActAgent",
]
================================================
FILE: cortex/agents/agent_factory.py
================================================
"""
AgentFactory is the factory class for Agent, responsible for creating and managing Agents
"""
from typing import Awaitable, Callable
from cortex.agents.base_agent import BaseAgent
from cortex.agents.types import AgentConfig
class AgentFactory:
"""
AgentFactory is the factory class for Agent, responsible for creating and managing Agents
"""
agent_make_func: dict[str, Callable[[AgentConfig, str], Awaitable[BaseAgent]]] = {}
default_agent_configs: dict[str, AgentConfig] = {}
def list_agents(self) -> list[AgentConfig]:
"""
Return all registered Agent configurations
"""
return list(self.default_agent_configs.values())
def get_default_agent_config(self, name: str) -> AgentConfig:
"""
Get Agent configuration
"""
config = self.default_agent_configs.get(name)
if config is None:
raise ValueError(
f"AgentConfig not provided, and no default configuration set for '{name}' in factory"
)
return config
def register_agent(
self,
name: str,
make_agent_func: Callable[[AgentConfig, str], Awaitable[BaseAgent]],
default_config: AgentConfig | None = None,
) -> None:
"""
Register Agent
"""
self.agent_make_func[name] = make_agent_func
if default_config is not None:
self.default_agent_configs[name] = default_config
async def make_agent(
self, name: str, context_id: str, agent_config: AgentConfig | None
) -> BaseAgent:
"""
Create Agent
"""
if name not in self.agent_make_func:
raise ValueError(f"Agent {name} not found")
config = agent_config or self.default_agent_configs.get(name)
if config is None:
raise ValueError(
f"AgentConfig not provided, and no default configuration set for '{name}' in factory"
)
make_agent_func = self.agent_make_func[name]
return await make_agent_func(config, context_id)
================================================
FILE: cortex/agents/base_agent.py
================================================
"""Base Agent class, provides run() interface as the base class for all Agents."""
import asyncio
import copy
import logging
from abc import abstractmethod
from typing import Any, AsyncGenerator
from cortex.model.definition import ChatMessage, ChatToolCall, ContentBlockType
from cortex.agents.input.input import InputChannel
from cortex.agents.types import AgentConfig, AgentResponse
from cortex.model import ModelAPI
from cortex.model.provider import ModelProvider
from cortex.model.stepfun_provider import StepFunModelProvider
from cortex.tools.toolset import ToolSet
logger = logging.getLogger(__name__)
class BaseAgent:
"""Base Agent class, provides run() interface."""
model: Any
name: str | None = None
description: str | None = None
system_prompt: str | None = None
max_steps: int | None = 5
_input_channel: InputChannel[ChatMessage] | None = None
provider: ModelAPI
def update_from_config(self):
"""Update agent properties from config."""
# Iterate config attributes and update self if attribute exists
for key, _ in self.config.model_dump().items():
setattr(self, key, copy.deepcopy(getattr(self.config, key)))
self.name = self.name or self.__class__.__name__
self.description = (
self.description
or f"Call the {self.name} sub-agent to handle specific tasks"
)
def __init__(
self,
provider: ModelProvider | None = None,
config: AgentConfig | None = None,
toolset: ToolSet | None = None,
):
self.config = config
self.update_from_config()
self._toolset = toolset
if provider is None:
provider = StepFunModelProvider(model_params=self.model)
self.provider = ModelAPI(provider)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Async context manager exit."""
async def run(
self,
messages: list[ChatMessage] | InputChannel[ChatMessage],
additional_kwargs: dict | None = None,
) -> AsyncGenerator[AgentResponse, None]:
"""Run agent, returns AgentResponse generator."""
async for response in self._run(messages, additional_kwargs):
if response.agent_name is None:
response.agent_name = self.name
yield response
@abstractmethod
async def _run(
self,
messages: list[ChatMessage] | InputChannel[ChatMessage],
additional_kwargs: dict | None = None,
) -> AsyncGenerator[AgentResponse, None]:
"""
Run agent, subclasses must implement this method.
Args:
messages: Input message list or input channel
additional_kwargs: Additional parameters
Yields:
AgentResponse: Agent response object
"""
raise NotImplementedError("Subclasses must implement this method")
def model_api(self) -> ModelAPI:
return self.provider
def toolset(self) -> ToolSet | None:
"""
Get toolset.
Returns:
ToolSet | None: Toolset
"""
return self._toolset
def as_tool(self, timeout: float | None = None) -> dict[str, Any]:
"""
Convert Agent to parameters needed for Tool creation.
Returns a dictionary containing parameters for creating AgentTool:
- name: Tool name (uses agent's name)
- description: Tool description (uses agent's description)
- agent_name: Agent name (used to specify which agent to call, as metadata)
- timeout: Timeout (optional, if provided)
Note: When creating AgentTool, channel parameter is also required, which is not included in this method.
Args:
timeout: Timeout in seconds, if None it won't be included in the returned dictionary
Returns:
dict: Dictionary containing parameters for creating AgentTool:
{
"name": str, # Tool name
"description": str, # Tool description
"agent_name": str, # Agent name (metadata)
"timeout": float, # Timeout (optional)
}
Example:
>>> agent = MathAgent(config)
>>> tool_params = agent.as_tool(timeout=60.0)
>>> # Channel is also required when creating AgentTool
>>> tool = AgentTool(**tool_params, channel=channel)
>>> toolset.register(tool)
"""
agent_name = self.name or self.__class__.__name__
tool_params: dict[str, Any] = {
"name": agent_name,
"description": self.description
or f"Call {agent_name} Agent to handle specific tasks",
"agent_name": agent_name, # As metadata, caller knows which agent this tool corresponds to
}
if timeout is not None:
tool_params["timeout"] = timeout
logger.debug(
"BaseAgent.as_tool returns tool params: name=%s, agent_name=%s, has_timeout=%s",
tool_params["name"],
tool_params["agent_name"],
"timeout" in tool_params,
)
return tool_params
@staticmethod
def has_tool_call(message: ChatMessage) -> bool:
"""
Check if message contains tool calls.
Args:
message: ChatMessage object
Returns:
bool: True if message contains tool calls, False otherwise
"""
if not message:
return False
tool_calls = getattr(message, "tool_calls", None)
return (
tool_calls is not None and len(tool_calls) > 0
if isinstance(tool_calls, (list, tuple))
else tool_calls is not None
)
async def _execute_single_tool(self, tool_call: ChatToolCall) -> ChatMessage | None:
"""
Execute a single tool call.
Args:
tool_call: Tool call object containing function.name, function.arguments, id, etc.
Returns:
ChatMessage: Tool call result message with role "tool"
"""
tool_name = tool_call.function.name
tool_args = tool_call.function.arguments
tool_call_id = tool_call.id
try:
# Execute tool call
result = await self._toolset.call(
tool_name=tool_name, parameters=tool_args, tool_call_id=tool_call_id
)
logger.info(f"@{self.name} Tool {tool_name} result: {result}")
if result is None:
return None
return ChatMessage(
role="tool",
content=result,
tool_call_id=tool_call_id,
)
except Exception as e:
error_msg = f"Error calling tool {tool_name}: {str(e)}"
logger.error(f"@{self.name} {error_msg}")
tool_result_content = [
{
"type": ContentBlockType.TEXT.value,
ContentBlockType.TEXT.value: error_msg,
}
]
return ChatMessage(
role="tool",
content=tool_result_content,
tool_call_id=tool_call_id,
)
async def run_tool_call(self, message: ChatMessage) -> list[ChatMessage]:
"""
Extract tool calls from message and execute them, returning list of tool call result messages.
Args:
message: ChatMessage object
Returns:
list[ChatMessage]: List of tool call result messages, each with role "tool"
"""
if not message:
return []
if not self._toolset:
logger.warning(f"@{self.name} run_tool_call: toolset not initialized")
return []
tool_calls = getattr(message, "tool_calls", None)
if not tool_calls:
return []
# Ensure tool_calls is a list
if isinstance(tool_calls, (list, tuple)):
toolcalls_list = list(tool_calls)
else:
toolcalls_list = [tool_calls]
# Execute all tool calls sequentially
tool_result_messages = []
for tool_call in toolcalls_list:
result_message = await self._execute_single_tool(tool_call)
if result_message is not None:
tool_result_messages.append(result_message)
return tool_result_messages
async def run_tool_call_concurrency(
self, message: ChatMessage
) -> list[ChatMessage]:
"""
Extract concurrent tool calls from message and execute them (when there are multiple tool calls).
Args:
message: ChatMessage object
Returns:
list[ChatMessage]: List of concurrent tool call result messages, only returned when tool call count > 1
"""
if not message:
return []
if not self._toolset:
logger.warning(f"@{self.name} run_tool_call_concurrency: toolset not initialized")
return []
tool_calls = getattr(message, "tool_calls", None)
if not tool_calls:
return []
# Ensure tool_calls is a list
if isinstance(tool_calls, (list, tuple)):
toolcalls_list = list(tool_calls)
else:
toolcalls_list = [tool_calls]
# Only return when there are multiple tool calls (concurrent scenario)
if len(toolcalls_list) <= 1:
return []
# Execute all tool calls concurrently
tool_result_messages = await asyncio.gather(
*[self._execute_single_tool(tc) for tc in toolcalls_list]
)
# Filter out None results
return [msg for msg in tool_result_messages if msg is not None]
================================================
FILE: cortex/agents/base_step_agent.py
================================================
import copy
import json
import logging
import re
from abc import abstractmethod
from typing import Any, AsyncGenerator, Callable
from agentkit.trace import create_span
from cortex.agents.base_agent import BaseAgent
from cortex.agents.input.input import InputChannel
from cortex.agents.types import (
AgentConfig,
AgentMessageType,
AgentResponse,
AgentRunningStatus,
)
from cortex.context import BaseContext
from cortex.model.definition import ChatMessage
from cortex.model.provider import ModelProvider
from cortex.tools.toolset import ToolSet
from cortex.runtime_config import get_context_limit_overrides
try:
import tiktoken
except Exception: # noqa: BLE001
tiktoken = None
logger = logging.getLogger(__name__)
DEFAULT_FORCE_FINAL_ANSWER_UPPER_LIMIT = 100_000
DEFAULT_FORCE_FINAL_ANSWER_THRESHOLD = DEFAULT_FORCE_FINAL_ANSWER_UPPER_LIMIT
DEFAULT_FORCE_FINAL_ANSWER_LOWER_LIMIT_RATIO = 0.9
DEFAULT_FORCE_FINAL_ANSWER_LOWER_LIMIT = max(
int(DEFAULT_FORCE_FINAL_ANSWER_UPPER_LIMIT * DEFAULT_FORCE_FINAL_ANSWER_LOWER_LIMIT_RATIO),
1,
)
DEFAULT_FORCE_FINAL_ANSWER_PROMPT = (
"你现在已经达到了你所能处理的最大上下文长度。你应该停止进行工具调用,"
"并基于以上所有信息重新思考,然后按照以下格式提供你认为最可能的答案:"
"你的最终思考\n你的答案"
)
_AVG_CHARS_PER_TOKEN = 3
def _get_encoding(model_name: str | None):
if not tiktoken:
return None
if not model_name:
try:
return tiktoken.get_encoding("cl100k_base")
except Exception: # noqa: BLE001
return None
try:
return tiktoken.encoding_for_model(model_name)
except Exception: # noqa: BLE001
try:
return tiktoken.get_encoding("cl100k_base")
except Exception: # noqa: BLE001
return None
def _estimate_token_length(messages: list[ChatMessage], model_name: str | None) -> int:
"""Token estimator with tiktoken fallback."""
encoding = _get_encoding(model_name)
total_tokens = 0
for message in messages:
try:
payload = message.model_dump(exclude_none=True)
except Exception:
payload = {
"role": getattr(message, "role", None),
"content": getattr(message, "content", None),
}
serialized = json.dumps(payload, ensure_ascii=False)
if encoding:
try:
total_tokens += len(encoding.encode(serialized))
continue
except Exception: # noqa: BLE001
encoding = None
total_tokens += len(serialized) // _AVG_CHARS_PER_TOKEN
return total_tokens
def _compress_batch_search_result(content: str) -> str:
"""Strip verbose content and mark compressed."""
compressed = re.sub(r".*?\s*", "", content, flags=re.S)
compressed = compressed.replace(
"", "", 1
)
compressed = compressed.replace(
"", "", 1
)
return compressed
class BaseStepAgent(BaseAgent):
"""
Base class for step-based Agent
Implements run() method, executes tasks by calling step() in a loop
step() method returns a tuple containing a flag indicating whether to stop
"""
def __init__(
self,
context: BaseContext,
config: AgentConfig,
provider: ModelProvider | None = None,
toolset: ToolSet | None = None,
):
super().__init__(config=config, toolset=toolset, provider=provider)
self.current_round = 0
self.context = context
extra_cfg = config.extra_config if config and config.extra_config else {}
self._force_final_answer_enabled = extra_cfg.get("force_final_answer", False)
threshold_override = extra_cfg.get("final_answer_context_threshold")
upper_override = extra_cfg.get("final_answer_context_upper_limit")
lower_override = extra_cfg.get("final_answer_context_lower_limit")
upper_from_extra = upper_override is not None
lower_from_extra = lower_override is not None
if not upper_from_extra or not lower_from_extra:
runtime_upper, runtime_lower = get_context_limit_overrides()
if not upper_from_extra and runtime_upper is not None:
upper_override = runtime_upper
# Only apply runtime lower default when upper isn't explicitly overridden;
# otherwise keep the existing "derive lower from upper" behavior.
if (
not lower_from_extra
and not upper_from_extra
and runtime_lower is not None
):
lower_override = runtime_lower
def _normalize_limit(value: Any) -> int | None:
if isinstance(value, (int, float)):
return int(value)
return None
upper_limit = _normalize_limit(upper_override)
if upper_limit is None:
upper_limit = _normalize_limit(threshold_override)
if upper_limit is None or upper_limit <= 0:
upper_limit = DEFAULT_FORCE_FINAL_ANSWER_UPPER_LIMIT
upper_limit = max(upper_limit, 2)
lower_limit = _normalize_limit(lower_override)
if lower_limit is None or lower_limit <= 0:
derived = int(upper_limit * DEFAULT_FORCE_FINAL_ANSWER_LOWER_LIMIT_RATIO)
lower_limit = derived if derived > 0 else DEFAULT_FORCE_FINAL_ANSWER_LOWER_LIMIT
lower_limit = max(lower_limit, 1)
if lower_limit >= upper_limit:
lower_limit = max(upper_limit - 1, 1)
self._force_final_answer_upper_limit = upper_limit
self._force_final_answer_lower_limit = lower_limit
self._force_final_answer_prompt = extra_cfg.get(
"final_answer_prompt", DEFAULT_FORCE_FINAL_ANSWER_PROMPT
)
self._force_prompt_inserted = False
self._model_name = getattr(config.model, "name", None) if config and config.model else None
def _insert_final_prompt(self) -> None:
"""Activate the force-final-answer prompt for subsequent model calls.
Kept for backward compatibility; this does not mutate persisted context history.
"""
if not self._force_final_answer_enabled or self._force_prompt_inserted:
return
self._force_prompt_inserted = True
logger.info("@%s Final answer prompt activated", self.name)
def _make_force_final_answer_message(self) -> ChatMessage:
prompt = self._force_final_answer_prompt or DEFAULT_FORCE_FINAL_ANSWER_PROMPT
return ChatMessage(role="system", content=prompt)
def _ensure_final_prompt(self, messages: list[ChatMessage]) -> None:
"""Ensure the force-final-answer prompt is present in model input.
This should not mutate the persisted context history.
"""
if not self._force_final_answer_enabled or not self._force_prompt_inserted:
return
prompt_message = self._make_force_final_answer_message()
if messages:
last = messages[-1]
if (
getattr(last, "role", None) == prompt_message.role
and getattr(last, "content", None) == prompt_message.content
):
return
messages.append(prompt_message)
@staticmethod
def _copy_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
"""Deep-copy messages so we can mutate model input without touching stored history."""
copied: list[ChatMessage] = []
for message in messages:
try:
copied.append(message.model_copy(deep=True))
except Exception:
copied.append(copy.deepcopy(message))
return copied
@classmethod
def _compress_batch_search_in_block(cls, block: dict) -> bool:
"""Compress batch_search XML in the first text block found (in-place)."""
if block.get("type") == "text":
text_value = block.get("text")
if (
isinstance(text_value, str)
and " tuple[Any, bool]:
if isinstance(content, str):
if (
" list[ChatMessage]:
"""Build the message list sent to the model, without mutating stored history."""
raw_messages = self.context.get_all()
messages: list[ChatMessage] = list(raw_messages)
if not self._force_final_answer_enabled:
return messages
# If final-answer mode has been activated before, always include the prompt in model input.
self._ensure_final_prompt(messages)
token_estimate = _estimate_token_length(messages, self._model_name)
if token_estimate < self._force_final_answer_upper_limit:
return messages
# We are going to mutate messages for context management: deep-copy first.
messages = self._copy_messages(messages)
self._handle_context_overflow(messages)
self._ensure_final_prompt(messages)
self._ensure_context_within_upper_limit(messages)
return messages
def _shrink_batch_search_results(self, messages: list[ChatMessage]) -> bool:
"""Compress earliest uncompressed batch_search_result content (model-input only)."""
for message in messages:
content = getattr(message, "content", None)
new_content, changed = self._compress_batch_search_in_content(content)
if changed:
message.content = new_content
logger.info("@%s Compressed batch_search_results to save tokens", self.name)
return True
return False
@staticmethod
def _parse_tool_call_arguments(raw_arguments: Any) -> dict[str, Any]:
"""Safely parse tool call arguments JSON."""
if not isinstance(raw_arguments, str) or not raw_arguments.strip():
return {}
try:
parsed = json.loads(raw_arguments)
return parsed if isinstance(parsed, dict) else {}
except Exception: # noqa: BLE001
return {}
def _is_search_tool_call(self, tool_name: str | None, tool_args: dict[str, Any]) -> bool:
"""Identify whether a tool call is search-related."""
if not tool_name:
return False
lowered = tool_name.lower()
if "search" in lowered:
return True
if lowered == "batch_web_surfer":
action = tool_args.get("action")
if isinstance(action, str) and action.lower() == "batch_search":
return True
return False
def _drop_oldest_tool_cycle(
self,
messages: list[ChatMessage],
predicate: Callable[[str | None, dict[str, Any]], bool] | None = None,
log_context: str = "tool",
) -> bool:
"""Drop earliest tool call message plus all corresponding tool results."""
drop_indices: set[int] = set()
for idx, message in enumerate(messages):
tool_calls = getattr(message, "tool_calls", None)
if not tool_calls:
continue
matched_any = False
for tc in tool_calls:
try:
tool_name = tc.function.name
raw_arguments = tc.function.arguments
except Exception:
tool_name = None
raw_arguments = None
parsed_args = self._parse_tool_call_arguments(raw_arguments)
if predicate and not predicate(tool_name, parsed_args):
continue
matched_any = True
if not matched_any:
continue
drop_indices.add(idx)
# Find corresponding tool results for all tool calls in this message.
for tc in tool_calls:
tc_id = getattr(tc, "id", None)
if not tc_id:
continue
for j in range(idx + 1, len(messages)):
tool_msg = messages[j]
if getattr(tool_msg, "role", None) != "tool":
continue
if getattr(tool_msg, "tool_call_id", None) == tc_id:
drop_indices.add(j)
break
break
if not drop_indices:
return False
for del_idx in sorted(drop_indices, reverse=True):
messages.pop(del_idx)
logger.warning(
"@%s Dropped earliest %s tool call/results to shrink context", self.name, log_context
)
return True
def _trim_oldest_messages(self, messages: list[ChatMessage]) -> bool:
"""Drop oldest non-system messages until under threshold."""
if not messages or len(messages) <= 1:
return False
removed = False
idx = 0
while (
idx < len(messages)
and _estimate_token_length(messages, self._model_name)
> self._force_final_answer_upper_limit
):
if getattr(messages[idx], "role", None) == "system":
idx += 1
continue
messages.pop(idx)
removed = True
if removed:
logger.warning("@%s Trimmed oldest messages to satisfy context budget", self.name)
return removed
def _ensure_context_within_upper_limit(self, messages: list[ChatMessage]) -> None:
"""Ensure context is below the configured upper limit before forcing final answer."""
if not self._force_final_answer_enabled:
return
while True:
token_estimate = _estimate_token_length(messages, self._model_name)
if token_estimate <= self._force_final_answer_upper_limit:
return
if self._drop_oldest_tool_cycle(messages, log_context="any"):
continue
if self._trim_oldest_messages(messages):
continue
break
final_tokens = _estimate_token_length(messages, self._model_name)
if final_tokens > self._force_final_answer_upper_limit:
logger.warning(
"@%s Unable to trim context below upper limit (%s tokens remaining)",
self.name,
final_tokens,
)
def _handle_context_overflow(self, messages: list[ChatMessage]) -> None:
"""Enforce two-threshold hysteresis for search results before triggering final answer."""
if not self._force_final_answer_enabled:
return
token_estimate = _estimate_token_length(messages, self._model_name)
if token_estimate < self._force_final_answer_upper_limit:
return
logger.warning(
(
"@%s Context %s tokens reached upper limit %s; processing search tool payloads "
"until below lower limit %s"
),
self.name,
token_estimate,
self._force_final_answer_upper_limit,
self._force_final_answer_lower_limit,
)
while True:
token_estimate = _estimate_token_length(messages, self._model_name)
if token_estimate < self._force_final_answer_lower_limit:
break
if self._shrink_batch_search_results(messages):
continue
if self._drop_oldest_tool_cycle(messages, self._is_search_tool_call, "search"):
continue
break
final_tokens = _estimate_token_length(messages, self._model_name)
if final_tokens < self._force_final_answer_lower_limit:
return
logger.warning(
(
"@%s Exhausted search tool cleanup but context still %s tokens (>= lower limit %s); "
"forcing final answer workflow"
),
self.name,
final_tokens,
self._force_final_answer_lower_limit,
)
if not self._force_prompt_inserted:
self._force_prompt_inserted = True
logger.info("@%s Final answer prompt activated", self.name)
async def _run(
self,
messages: list[ChatMessage] | InputChannel[ChatMessage],
additional_kwargs: dict | None = None,
) -> AsyncGenerator[AgentResponse, None]:
"""
Run agent, calls step() method in a loop
Args:
messages: Input message list or input channel
additional_kwargs: Additional parameters
Yields:
AgentResponse: Agent response object
"""
if additional_kwargs is None:
additional_kwargs = {}
input_messages = []
# Handle input messages
if isinstance(messages, list):
input_messages = messages
elif isinstance(messages, InputChannel):
input_messages = await messages.get()
self.current_round = 0
# Initialize history messages as member variable
self.context.add(input_messages)
should_stop = False
# Loop calling step() until stop or reach max steps
while self.current_round < self.max_steps:
with create_span(
name=f"@{self.name} Round {self.current_round}/{self.max_steps}"
):
if should_stop and not self.config.unfinished_mode:
break
if should_stop:
if isinstance(messages, InputChannel):
input_messages = await messages.get()
self.context.add(input_messages)
should_stop = False
else:
if isinstance(messages, InputChannel):
input_messages = await messages.get_no_wait()
self.context.add(input_messages)
self.current_round += 1
logger.info(f"@{self.name} Round {self.current_round}/{self.max_steps}")
try:
# Call step() method (now an async generator)
last_response = None
model_messages = self._prepare_messages_for_model()
async for response in self._step(model_messages, additional_kwargs):
# Update history messages (using member variable)
# Only add complete messages to history (has role field and not STREAM type)
logger.info(f"@{self.name} Response: {response}")
if response is None:
continue
if response.message:
# STREAM type messages are incremental updates, should not be added to history
if response.message_type == AgentMessageType.STREAM.value:
# Skip streaming incremental messages, they will be handled in accumulated messages
pass
else:
# For non-streaming messages, ensure role field exists
message_to_add = None
if isinstance(response.message, ChatMessage):
if (
hasattr(response.message, "role")
and response.message.role
):
message_to_add = response.message
elif isinstance(response.message, dict):
if response.message.get("role"):
message_to_add = ChatMessage(**response.message)
# Only add message to history when it has a valid role
if message_to_add and message_to_add.role:
self.context.add([message_to_add])
else:
logger.warning(
f"@{self.name} Skipping message without role: {response.message_type}"
)
# Set metadata (only includes round, history_messages uses member variable)
if response.metadata is None:
response.metadata = {}
response.metadata["step_count"] = self.current_round
# Return response
yield response
last_response = response
# Check if error occurred
if response.status == AgentRunningStatus.ERROR.value:
logger.error(
f"@{self.name} Error at round {self.current_round}: {response.error_msg}"
)
should_stop = True
break
# Check last response status, decide whether to stop
if last_response:
if last_response.status == AgentRunningStatus.FINISHED.value:
should_stop = True
logger.info(
f"@{self.name} Finished at round {self.current_round}"
)
elif last_response.status == AgentRunningStatus.ERROR.value:
should_stop = True
elif last_response.status == AgentRunningStatus.RUNNING.value:
# If still running, may need to continue (e.g., has tool calls)
# Here can decide whether to stop based on actual situation
pass
except Exception as e:
err_text = str(e) or repr(e)
logger.error(
f"@{self.name} Exception at round {self.current_round}: {err_text}",
exc_info=True,
)
error_response = AgentResponse(
status=AgentRunningStatus.ERROR.value,
error_msg=err_text,
metadata={
"step_count": self.current_round,
},
)
yield error_response
should_stop = True
break
# If max steps reached and not stopped yet
if self.current_round >= self.max_steps and not should_stop:
logger.warning(
f"@{self.name} Reached max steps ({self.max_steps}) without stopping"
)
# Generate final response
final_response = AgentResponse(
status=AgentRunningStatus.STOPPED.value,
metadata={
"step_count": self.current_round,
},
)
yield final_response
@abstractmethod
async def _step(
self,
messages: list[ChatMessage],
additional_kwargs: dict | None = None,
) -> AsyncGenerator[AgentResponse, None]:
"""
Execute a single step, subclasses must implement this method.
Args:
messages: Current message history.
additional_kwargs: Additional parameters.
Yields:
AgentResponse: Response for current step, can yield multiple responses.
- Last response with status FINISHED indicates completion.
- Status ERROR indicates an error occurred.
- Status RUNNING indicates continue execution.
"""
raise NotImplementedError("Subclasses must implement this method")
================================================
FILE: cortex/agents/checkpoint_agent/checkpoint_agent.py
================================================
import logging
from abc import abstractmethod
from typing import AsyncGenerator
from cortex.model.definition import ChatMessage, ContentBlockType
from pydantic import BaseModel
from cortex.agents.base_agent import BaseAgent
from cortex.agents.checkpoint_agent.checkpointer import CheckPointer, CheckpointStorage
from cortex.agents.types import (
AgentConfig,
AgentMessageType,
AgentResponse,
AgentRunningStatus,
)
from cortex.model.provider import ModelProvider
from cortex.tools.toolset import ToolSet
logger = logging.getLogger(__name__)
class PendingToolCall(BaseModel):
"""Pending tool call"""
request: ChatMessage | None
results: dict[str, ChatMessage] | None
class CheckpointState(BaseModel):
"""Checkpoint state definition"""
messages: list[ChatMessage] | None
pending_tool_calls: list[PendingToolCall] | None
tool_call_results: list[ChatMessage] | None
config: AgentConfig | None
current_step: int
max_steps: int | None
finished: bool
error: str | None
class CheckpointAgent(BaseAgent):
"""Checkpoint-based single-step execution Agent (without LangGraph)"""
def __init__(
self,
config: AgentConfig,
storage: CheckpointStorage,
provider: ModelProvider | None = None,
toolset: ToolSet | None = None,
thread_id: str | None = None,
):
super().__init__(config=config, toolset=toolset, provider=provider)
self.thread_id = thread_id or f"{self.name}_main"
self.storage = storage
self._state: CheckpointState | None = None
def _init_state(
self,
) -> CheckpointState:
"""Initialize state"""
return CheckpointState(
messages=[],
pending_tool_calls=[],
tool_call_results=[],
config=self.config,
current_step=0,
max_steps=getattr(self.config, "max_steps", 10),
finished=False,
error=None,
)
async def _process_tool_call_results(
self, state: CheckpointState
) -> AsyncGenerator[AgentResponse, None]:
"""Process tool call results"""
if not state.tool_call_results or len(state.tool_call_results) == 0:
return
logger.debug(
"@%s Processing %d tool call results", self.name, len(state.tool_call_results)
)
async for response in self._tool_call_handler(
state.messages, state.tool_call_results
):
yield response
if response.message:
state.messages.append(response.message)
# Clear processed results
state.tool_call_results = []
async def _execute_step(
self, state: CheckpointState, additional_kwargs: dict | None = None
) -> AsyncGenerator[AgentResponse, None]:
"""Execute single step"""
step_responses = []
try:
async for response in self._step(state.messages, additional_kwargs):
step_responses.append(response)
yield response
# Handle message history update
if response.message:
# STREAM type messages are incremental updates, should not be added to history
if response.message_type == AgentMessageType.STREAM.value:
pass
else:
# For non-streaming messages, ensure role field exists
message_to_add = None
if isinstance(response.message, ChatMessage):
if (
response.message.tool_calls
and len(response.message.tool_calls) > 0
):
# Execute tool calls
results = await self.run_tool_call(response.message)
# Convert result list to dict with tool_call_id as key
state.pending_tool_calls.append(
PendingToolCall(
request=response.message,
results={
result.tool_call_id: result
for result in results
},
)
)
message_to_add = response.message
elif response.message.tool_call_id:
# Tool call result
state.tool_call_results.append(response.message)
else:
message_to_add = response.message
elif isinstance(response.message, dict):
if response.message.get("role"):
message_to_add = ChatMessage(**response.message)
# Only add message to history when it has a valid role
if message_to_add and message_to_add.role:
state.messages.append(message_to_add)
else:
logger.warning(
"@%s Skipping message without role: %s",
self.name,
response.message_type,
)
state.current_step += 1
# Check if finished
if step_responses:
last_response = step_responses[-1]
if last_response.status == AgentRunningStatus.FINISHED.value:
state.finished = True
elif last_response.status == AgentRunningStatus.ERROR.value:
state.error = last_response.error_msg
state.finished = True
except Exception as e:
logger.error("@%s Error during step execution: %s", self.name, str(e))
state.error = str(e)
state.finished = True
raise
async def _update_client_tool_results(self, state: CheckpointState) -> bool:
"""
Update client tool call results.
Returns True if there are still pending tool calls.
"""
if not state.pending_tool_calls or len(state.pending_tool_calls) == 0:
return False
new_pending_tool_calls = []
for pending_item in state.pending_tool_calls:
request = pending_item.request
results = pending_item.results
# Check if each tool_call has a corresponding result
all_matched = True
for tool_call in request.tool_calls:
tool_call_id = tool_call.id
# Check if result already exists
if results.get(tool_call_id):
continue
result_content = self.toolset().get_client_tool_call_result(
tool_call_id
)
if result_content is None:
all_matched = False
continue
results[tool_call_id] = ChatMessage(
role="tool",
content=[
{
"type": ContentBlockType.TEXT.value,
ContentBlockType.TEXT.value: str(result_content),
}
],
tool_call_id=tool_call_id,
)
if all_matched:
# All tool_calls have corresponding results, move results to tool_call_results
state.tool_call_results.extend(results.values())
else:
# Still have incomplete tool calls, update results and keep in pending
pending_item.results = results
new_pending_tool_calls.append(pending_item)
state.pending_tool_calls = new_pending_tool_calls
return len(new_pending_tool_calls) > 0
def _should_continue(self, state: CheckpointState) -> bool:
"""Determine whether execution should continue"""
if len(state.pending_tool_calls) > 0:
return False # Has pending tool calls, need to wait
if len(state.tool_call_results) > 0:
return True
if state.finished:
return False # Already finished
if state.error:
return False # Has error
if state.current_step >= state.max_steps:
return False # Reached max steps
return True
async def _run(
self,
messages: list[ChatMessage],
additional_kwargs: dict | None = None,
) -> AsyncGenerator[AgentResponse, None]:
"""
Checkpoint-based run method (without LangGraph)
Args:
messages: Input message list or input channel
additional_kwargs: Additional parameters
Yields:
AgentResponse: Agent response object
"""
# Initialize or load state
async with CheckPointer[CheckpointState](
self.thread_id, self.storage, self._init_state(), CheckpointState
) as state:
state.messages.extend(messages)
# Main execution loop
while True:
# 1. Update client tool call results and check if there are pending tool calls
has_pending = await self._update_client_tool_results(state)
if has_pending:
logger.info("@%s Has pending tool calls, waiting for client response", self.name)
return
# 2. Process completed tool call results
if len(state.tool_call_results) > 0:
async for response in self._process_tool_call_results(state):
yield response
# 3. Check whether execution should continue
if not self._should_continue(state):
logger.info("@%s Execution completed", self.name)
break
# 4. Execute next step
logger.info("@%s Executing step %d", self.name, state.current_step + 1)
async for response in self._execute_step(state, additional_kwargs):
yield response
@abstractmethod
async def _step(
self,
messages: list[ChatMessage],
additional_kwargs: dict | None = None,
) -> AsyncGenerator[AgentResponse, None]:
"""
Execute a single step, subclasses must implement this method.
Args:
messages: Current message history.
additional_kwargs: Additional parameters.
Yields:
AgentResponse: Response for current step, can yield multiple responses.
- Last response with status FINISHED indicates completion.
- Status ERROR indicates an error occurred.
- Status RUNNING indicates continue execution.
"""
raise NotImplementedError("Subclasses must implement this method")
@abstractmethod
async def _tool_call_handler(
self,
messages: list[ChatMessage],
tool_calls: list[ChatMessage],
) -> AsyncGenerator[AgentResponse, None]:
"""
Handle tool calls, subclasses must implement this method.
Args:
messages: Current message history.
tool_calls: List of tool call messages.
Yields:
AgentResponse: Response for current step, can yield multiple responses.
- Last response with status FINISHED indicates completion.
- Status ERROR indicates an error occurred.
- Status RUNNING indicates continue execution.
"""
raise NotImplementedError("Subclasses must implement this method")
================================================
FILE: cortex/agents/checkpoint_agent/checkpointer.py
================================================
"""Checkpoint storage implementation."""
import json
import logging
import os
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
from typing import Generic, TypeVar
from pydantic import BaseModel
T = TypeVar("T", bound=BaseModel)
class CheckpointStorage:
def save_state(self, checkpoint_id: str, state: dict[str, Any]):
"""Save state."""
raise NotImplementedError("Subclass must implement this method")
def load_state(self, checkpoint_id: str) -> dict[str, Any] | None:
"""Load state from storage."""
raise NotImplementedError("Subclass must implement this method")
def delete_state(self, checkpoint_id: str):
"""Delete the specified checkpoint."""
raise NotImplementedError("Subclass must implement this method")
async def asave_state(self, checkpoint_id: str, state: dict[str, Any]):
"""Save state asynchronously."""
raise NotImplementedError("Subclass must implement this method")
async def aload_state(self, checkpoint_id: str) -> dict[str, Any] | None:
"""Load state from storage asynchronously."""
raise NotImplementedError("Subclass must implement this method")
async def adelete_state(self, checkpoint_id: str):
"""Delete the specified checkpoint asynchronously."""
raise NotImplementedError("Subclass must implement this method")
class CheckPointer(Generic[T]):
"""Checkpoint manager base class."""
def __init__(
self,
checkpoint_id: str,
storage: CheckpointStorage,
init_state: T | None,
state_type: type[T],
):
"""
Initialize CheckPointer.
Args:
checkpoint_id: Checkpoint ID for identifying and loading specific checkpoints
storage: Storage backend
init_state: Initial state
state_type: State type, must be a class inheriting from BaseModel
"""
self.checkpoint_id = checkpoint_id
self.storage = storage
self.init_state = init_state
self.state_type = state_type
self._state: T | None = None
def __enter__(self) -> T:
state_dict = self.storage.load_state(self.checkpoint_id)
if state_dict is None:
self._state = self.init_state
else:
# Convert dict to pydantic model
self._state = self.state_type.model_validate(state_dict)
return self._state
def __exit__(self, exc_type, exc_value, traceback):
# Auto-save state on exit
if self._state is not None and self.checkpoint_id is not None:
try:
# Use pydantic's model_dump method to convert to dict
state_dict = self._state.model_dump()
self.storage.save_state(self.checkpoint_id, state_dict)
logger.debug("Auto-saved state on exit: checkpoint_id=%s", self.checkpoint_id)
except Exception as e:
logger.error("Failed to save state on exit: %s", str(e))
# todo handle cancel error
return False
async def __aenter__(self):
return self.__enter__()
async def __aexit__(self, exc_type, exc_value, traceback):
return self.__exit__(exc_type, exc_value, traceback)
class MemoryCheckPointer(CheckpointStorage):
"""Memory-based Checkpoint storage."""
def __init__(self):
self._storage: dict[str, dict[str, Any]] = {}
def save_state(self, checkpoint_id: str, state: dict[str, Any]):
"""Save state to memory."""
self._storage[checkpoint_id] = state
logger.debug("State saved to memory: checkpoint_id=%s", checkpoint_id)
def load_state(self, checkpoint_id: str) -> dict[str, Any] | None:
"""Load state from memory."""
state = self._storage.get(checkpoint_id)
if state:
logger.debug("State loaded from memory: checkpoint_id=%s", checkpoint_id)
return state
def delete_state(self, checkpoint_id: str):
"""Delete the specified checkpoint."""
if checkpoint_id in self._storage:
del self._storage[checkpoint_id]
logger.debug("Deleted checkpoint from memory: checkpoint_id=%s", checkpoint_id)
async def asave_state(self, checkpoint_id: str, state: dict[str, Any]):
"""Save state to memory."""
self.save_state(checkpoint_id, state)
async def aload_state(self, checkpoint_id: str) -> dict[str, Any] | None:
"""Load state from memory."""
return self.load_state(checkpoint_id)
async def adelete_state(self, checkpoint_id: str):
"""Delete the specified checkpoint."""
self.delete_state(checkpoint_id)
class FileCheckPointer(CheckpointStorage):
"""File-based Checkpoint storage."""
def __init__(self, checkpoint_dir: str):
"""
Initialize file storage.
Args:
checkpoint_dir: Checkpoint storage directory
"""
self.checkpoint_dir = Path(checkpoint_dir)
# Ensure directory exists
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
def _get_checkpoint_file(self, checkpoint_id: str) -> Path:
"""Get checkpoint file path."""
return self.checkpoint_dir / f"{checkpoint_id}.json"
def save_state(self, checkpoint_id: str, state: dict[str, Any]):
"""Save state to file."""
checkpoint_file = self._get_checkpoint_file(checkpoint_id)
try:
# Write to file
with open(checkpoint_file, "w", encoding="utf-8") as f:
json.dump(state, f, indent=2, ensure_ascii=False)
logger.debug("State saved to file: %s", checkpoint_file)
except Exception as e:
logger.error("Failed to save state: %s", str(e))
raise
def load_state(self, checkpoint_id: str) -> dict[str, Any] | None:
"""Load state from file."""
checkpoint_file = self._get_checkpoint_file(checkpoint_id)
if not checkpoint_file.exists():
logger.debug("Checkpoint file not found: %s", checkpoint_file)
return None
try:
with open(checkpoint_file, "r", encoding="utf-8") as f:
state_dict = json.load(f)
logger.debug("State loaded from file: %s", checkpoint_file)
return state_dict
except (OSError, json.JSONDecodeError, ValueError) as e:
logger.error("Failed to load state: %s", str(e))
return None
def delete_state(self, checkpoint_id: str):
"""Delete checkpoint file."""
checkpoint_file = self._get_checkpoint_file(checkpoint_id)
if checkpoint_file.exists():
os.remove(checkpoint_file)
logger.debug("Deleted checkpoint file: %s", checkpoint_file)
async def asave_state(self, checkpoint_id: str, state: dict[str, Any]):
"""Save state to file."""
self.save_state(checkpoint_id, state)
async def aload_state(self, checkpoint_id: str) -> dict[str, Any] | None:
"""Load state from file."""
return self.load_state(checkpoint_id)
async def adelete_state(self, checkpoint_id: str):
"""Delete checkpoint file."""
self.delete_state(checkpoint_id)
class SqliteCheckPointer(CheckpointStorage):
"""SQLite-based Checkpoint storage."""
def __init__(self, db_path: str):
"""
Initialize SQLite storage.
Args:
db_path: SQLite database file path
"""
import sqlite3
self.db_path = db_path
self.conn: sqlite3.Connection | None = None
# Ensure database directory exists
db_dir = Path(db_path).parent
db_dir.mkdir(parents=True, exist_ok=True)
# Initialize database table
self._init_db()
def _init_db(self):
"""Initialize database table."""
import sqlite3
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS checkpoints (
checkpoint_id TEXT PRIMARY KEY,
state TEXT NOT NULL,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
conn.commit()
conn.close()
def __enter__(self):
import sqlite3
self.conn = sqlite3.connect(self.db_path)
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.conn:
self.conn.close()
self.conn = None
return False
async def __aenter__(self):
return self.__enter__()
async def __aexit__(self, exc_type, exc_value, traceback):
return self.__exit__(exc_type, exc_value, traceback)
def save_state(self, checkpoint_id: str, state: dict[str, Any]):
"""Save state to SQLite."""
import sqlite3
conn = sqlite3.connect(self.db_path)
try:
# Convert to JSON string
state_json = json.dumps(state, ensure_ascii=False)
# Insert or update
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO checkpoints (checkpoint_id, state, updated_at)
VALUES (?, ?, CURRENT_TIMESTAMP)
""",
(checkpoint_id, state_json),
)
conn.commit()
logger.debug("State saved to SQLite: checkpoint_id=%s", checkpoint_id)
except Exception as e:
logger.error("Failed to save state: %s", str(e))
raise
finally:
conn.close()
def load_state(self, checkpoint_id: str) -> dict[str, Any] | None:
"""Load state from SQLite."""
import sqlite3
conn = sqlite3.connect(self.db_path)
try:
cursor = conn.cursor()
cursor.execute(
"SELECT state FROM checkpoints WHERE checkpoint_id = ?",
(checkpoint_id,),
)
row = cursor.fetchone()
if not row:
logger.debug("Checkpoint not found: checkpoint_id=%s", checkpoint_id)
return None
state_json = row[0]
state_dict = json.loads(state_json)
logger.debug("State loaded from SQLite: checkpoint_id=%s", checkpoint_id)
return state_dict
except (json.JSONDecodeError, ValueError) as e:
logger.error("Failed to load state: %s", str(e))
return None
finally:
conn.close()
def delete_state(self, checkpoint_id: str):
"""Delete checkpoint."""
import sqlite3
conn = sqlite3.connect(self.db_path)
try:
cursor = conn.cursor()
cursor.execute(
"DELETE FROM checkpoints WHERE checkpoint_id = ?", (checkpoint_id,)
)
conn.commit()
logger.debug("Deleted checkpoint: checkpoint_id=%s", checkpoint_id)
finally:
conn.close()
async def asave_state(self, checkpoint_id: str, state: dict[str, Any]):
"""Save state to SQLite."""
self.save_state(checkpoint_id, state)
async def aload_state(self, checkpoint_id: str) -> dict[str, Any] | None:
"""Load state from SQLite."""
return self.load_state(checkpoint_id)
async def adelete_state(self, checkpoint_id: str):
"""Delete checkpoint."""
self.delete_state(checkpoint_id)
================================================
FILE: cortex/agents/checkpoint_agent/react_agent.py
================================================
"""ReActAgent - An Agent specifically designed to execute tasks in ReAct (Reasoning + Acting) mode, capable of calling tools to complete tasks."""
import logging
from typing import AsyncGenerator
from uuid import uuid4
from cortex.model.definition import ChatMessage
from cortex.agents.checkpoint_agent.checkpoint_agent import CheckpointAgent
from cortex.agents.checkpoint_agent.checkpointer import CheckpointStorage
from cortex.agents.react_agent import process_messages
from cortex.agents.types import (
AgentConfig,
AgentMessageType,
AgentResponse,
AgentRunningStatus,
)
from cortex.model.provider import ModelProvider
from cortex.tools.toolset import ToolSet
logger = logging.getLogger(__name__)
class CheckpointReActAgent(CheckpointAgent):
"""ReActAgent - An Agent specifically designed to execute tasks in ReAct (Reasoning + Acting) mode.
Features:
- Can call tools to complete tasks
- Can call tools multiple times to complete complex tasks
- Can provide detailed task execution process description
- Can provide detailed task execution results
"""
def __init__(
self,
storage: CheckpointStorage,
context_id: str | None = None,
provider: ModelProvider | None = None,
config: AgentConfig | None = None,
toolset: ToolSet | None = None,
):
# If no toolset is provided, create a default math toolset
if toolset is None:
# Note: Cannot call async functions directly here, need to initialize externally
raise ValueError(
"ReActAgent requires a toolset, please use init_react_tools() to create one"
)
if not context_id:
context_id = uuid4().hex
super().__init__(
storage=storage,
thread_id=context_id,
provider=provider,
config=config,
toolset=toolset,
)
async def _step(
self,
messages: list[ChatMessage],
additional_kwargs: dict | None = None,
) -> AsyncGenerator[AgentResponse, None]:
"""
Execute a single step, can yield multiple responses.
Args:
messages: Current message history
additional_kwargs: Additional parameters
Yields:
AgentResponse: Response for the current step
"""
try:
async for response_message in process_messages(
self.system_prompt,
messages,
self.toolset(),
self.model_api(),
getattr(self.model, "infer_kwargs", {}).get("stream", False),
trace_messages=list(messages) if messages else [],
):
yield response_message
except Exception as e:
err_text = str(e) or repr(e)
logger.error("@%s execution error: %s", self.name, err_text, exc_info=True)
error_response = AgentResponse(
message=None,
status=AgentRunningStatus.ERROR.value,
error_msg=err_text,
message_type=AgentMessageType.FINAL.value,
)
yield error_response
async def _tool_call_handler(
self, messages: list[ChatMessage], tool_results: list[ChatMessage]
) -> AsyncGenerator[AgentResponse, None]:
"""
Handle tool calls.
"""
for result in tool_results:
yield AgentResponse(
message=result,
status=AgentRunningStatus.RUNNING.value,
message_type=AgentMessageType.ACCUMULATED.value,
)
================================================
FILE: cortex/agents/input/input.py
================================================
import asyncio
import logging
from typing import Generic, TypeVar
from pydantic import BaseModel
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseModel)
class InputChannel(Generic[T]):
queue: asyncio.Queue[T]
def __init__(self, queue: asyncio.Queue[T]) -> None:
self.queue = queue
async def get(self) -> list[T]:
"""Blocks until at least one message is available, then returns all available messages."""
logger.debug("InputChannel waiting for first data")
first = await self.queue.get()
data_list: list[T] = [first]
count = 1
while True:
try:
data = self.queue.get_nowait()
data_list.append(data)
count += 1
except asyncio.QueueEmpty:
break
logger.debug(f"InputChannel returning {count} data(s)")
return data_list
async def get_no_wait(self) -> list[T]:
data_list: list[T] = []
count = 0
while True:
try:
data = self.queue.get_nowait()
data_list.append(data)
count += 1
except asyncio.QueueEmpty:
break
logger.debug(f"InputChannel returning {count} data(s)")
return data_list
================================================
FILE: cortex/agents/react_agent.py
================================================
"""ReActAgent - Specialized Agent for executing ReAct pattern, capable of calling tools to complete tasks."""
import logging
from typing import AsyncGenerator
from uuid import uuid4
from cortex.agents.base_agent import BaseAgent
from cortex.agents.base_step_agent import BaseStepAgent
from cortex.agents.types import (
AgentConfig,
AgentMessageType,
AgentResponse,
AgentRunningStatus,
)
from cortex.context import BaseContext
from cortex.context.simple_context import SimpleContext
from cortex.model import ChatMessage, MessageType, ModelAPI
from cortex.model.provider import ModelProvider
from cortex.tools.toolset import ToolSet
logger = logging.getLogger(__name__)
def _check_if_finished(response_message: ChatMessage | None) -> bool:
"""Check if task is finished based on model output."""
if not response_message:
return False
# If there are tool calls, continue execution
if BaseAgent.has_tool_call(response_message):
return False
return True
async def process_messages(
system_prompt: str | None,
messages: list[ChatMessage],
toolset: ToolSet,
model_api: ModelAPI,
use_stream: bool,
trace_messages: list[ChatMessage] | None = None,
) -> AsyncGenerator[AgentResponse, None]:
if not system_prompt:
system_prompt = """You are a professional task execution assistant capable of calling tools to complete tasks.
Your task is:
1. Understand the task proposed by the user
2. Use the provided tools to complete the task
3. Call tools multiple times to complete complex tasks
4. Provide detailed explanations of task execution process
5. Provide detailed task execution results
Please ensure:
- For complex tasks, use tools step by step
- Provide clear explanations of task execution steps
- Verify the correctness of task execution results
- When the task is complete, clearly state "Task completed"
"""
# Prepare message list (only insert system if none present, to avoid duplicate when runner/caller already added one)
infer_messages = messages.copy()
if system_prompt and not any(msg.role == "system" for msg in infer_messages):
infer_messages.insert(0, ChatMessage(role="system", content=system_prompt))
# Get all tool schemas
tool_schemas = toolset.get_all_schemas()
# Convert tool schemas to model-usable format
tools_for_model = []
for tool_name, schema in tool_schemas.items():
tools_for_model.append(
{
"type": "function",
"function": {
"name": tool_name,
"description": schema.description,
"parameters": schema.parameters,
},
}
)
# Call model
trace_request = None
if trace_messages is not None:
trace_infer_messages = trace_messages.copy()
if system_prompt and not any(msg.role == "system" for msg in trace_infer_messages):
trace_infer_messages.insert(
0, ChatMessage(role="system", content=system_prompt)
)
trace_request = {
"messages": trace_infer_messages,
"sent_messages": infer_messages,
"tools": tools_for_model if tools_for_model else None,
}
if use_stream:
# Streaming output mode
delta_count = 0
response_message = None
async for model_msg in model_api.chat_completion_stream(
messages=infer_messages,
tools=tools_for_model if tools_for_model else None,
trace_request=trace_request,
):
delta_count += 1
event = model_msg.message
response_message = (
event # Save last message (model yields accumulated_message at end)
)
if model_msg.message_type == MessageType.DELTA:
# This is a delta event, yield directly
delta_response = AgentResponse(
message=event,
status=AgentRunningStatus.RUNNING.value,
message_type=AgentMessageType.STREAM.value,
)
yield delta_response
else:
# Non-streaming output mode
model_msg = await model_api.chat_completion(
messages=infer_messages,
tools=tools_for_model if tools_for_model else None,
trace_request=trace_request,
)
response_message = model_msg.message
is_finished = _check_if_finished(response_message)
response_status = (
AgentRunningStatus.FINISHED.value
if is_finished
else AgentRunningStatus.RUNNING.value
)
message_type = (
AgentMessageType.FINAL.value
if is_finished
else AgentMessageType.ACCUMULATED.value
)
model_response = AgentResponse(
message=response_message,
status=response_status,
message_type=message_type,
)
yield model_response
class ReActAgent(BaseStepAgent):
"""ReActAgent - Specialized Agent for executing ReAct pattern, capable of calling tools to complete tasks.
Features:
- Can call tools to complete tasks
- Can call tools multiple times to complete complex tasks
- Can provide detailed explanations of task execution process
- Can provide detailed task execution results
"""
def __init__(
self,
context: BaseContext | None = None,
provider: ModelProvider | None = None,
config: AgentConfig | None = None,
toolset: ToolSet | None = None,
):
if context is None:
context = SimpleContext(uuid4().hex)
# If toolset is not provided, create default math toolset
if toolset is None:
# Note: Cannot call async functions directly here, need to initialize externally
raise ValueError(
"ReActAgent requires a toolset, please use init_react_tools() to create one"
)
super().__init__(
context=context, provider=provider, config=config, toolset=toolset
)
async def _step(
self,
messages: list[ChatMessage],
additional_kwargs: dict | None = None,
) -> AsyncGenerator[AgentResponse, None]:
"""
Execute single step operation, can yield multiple responses.
Args:
messages: Current message history
additional_kwargs: Additional parameters
Yields:
AgentResponse: Response for current step
"""
trace_messages: list[ChatMessage] | None = None
try:
trace_messages = list(self.context.get_all())
except Exception: # noqa: BLE001
trace_messages = None
if trace_messages is not None and self._force_final_answer_enabled and self._force_prompt_inserted:
prompt_message = self._make_force_final_answer_message()
if not trace_messages or not (
getattr(trace_messages[-1], "role", None) == prompt_message.role
and getattr(trace_messages[-1], "content", None) == prompt_message.content
):
trace_messages.append(prompt_message)
async for response_message in process_messages(
self.system_prompt,
messages,
self.toolset(),
self.model_api(),
getattr(self.model, "infer_kwargs", {}).get("stream", False),
trace_messages=trace_messages,
):
try:
yield response_message
# Only check and execute tool calls for non-STREAM types (i.e., ACCUMULATED or FINAL)
# In streaming output, STREAM type delta messages don't contain complete tool_call
if response_message.message_type == AgentMessageType.STREAM.value:
continue
# Check for tool calls and execute
tool_result_messages = []
message = response_message.message
if self.has_tool_call(message):
tool_result_messages = await self.run_tool_call(message)
if tool_result_messages:
logger.info(
"@%s Detected %s tool call results",
self.name,
len(tool_result_messages),
)
# Yield tool result responses
for tool_result_msg in tool_result_messages:
tool_response = AgentResponse(
message=tool_result_msg,
status=AgentRunningStatus.RUNNING.value,
message_type=AgentMessageType.FINAL.value,
)
yield tool_response
except Exception as e:
err_text = str(e) or repr(e)
logger.error("@%s Execution error: %s", self.name, err_text, exc_info=True)
error_response = AgentResponse(
message=None,
status=AgentRunningStatus.ERROR.value,
error_msg=err_text,
message_type=AgentMessageType.FINAL.value,
)
yield error_response
================================================
FILE: cortex/agents/types.py
================================================
from enum import Enum
from cortex.model.definition import ChatMessage, ModelParams
from cortex.model.utils import merge_delta_message
from pydantic import BaseModel, Field
from cortex.tools.base import ToolSchema
class RunnerType(str, Enum):
"""Agent runner type enum."""
LOCAL = "local"
REMOTE = "remote"
class AgentRunningStatus(str, Enum):
"""Agent running status enum."""
FINISHED = "finished"
STOPPED = "stopped"
ERROR = "error"
RUNNING = "running"
class AgentMessageType(str, Enum):
"""Agent message type enum."""
STREAM = "stream" # Streaming output
ACCUMULATED = "accumulated" # Accumulated output
FINAL = "final" # Final output
class AgentResponseType(str, Enum):
"""Agent response type enum."""
RESPONSE = "response"
TOOL_CALL = "tool_call"
TOOL_RESULT = "tool_result"
class AgentConfig(BaseModel):
"""Declarative configuration for Agent."""
model: ModelParams
name: str = Field(default="")
agent_type: str | None = None
system_prompt: str | None = None
description: str | None = None
tools: list[ToolSchema | str] = Field(default_factory=list)
max_steps: int = 10
extra_config: dict | None = None
runner_type: RunnerType = RunnerType.LOCAL
endpoint: str | None = None
unfinished_mode: bool = False
use_share_context: bool = False
class AgentResponse(BaseModel):
"""Agent response model."""
agent_name: str | None = None
message: ChatMessage | None = None
message_type: AgentMessageType = (
AgentMessageType.FINAL
) # delta: streaming output, accumulated: accumulated output, final: final output
status: AgentRunningStatus = AgentRunningStatus.RUNNING
error_msg: str | None = None
metadata: dict[str, object] | None = None
def get_type(self) -> AgentResponseType:
"""Get response type."""
if self.message is None:
return AgentResponseType.RESPONSE
if self.message.tool_call_id is not None:
return AgentResponseType.TOOL_CALL
if self.message.tool_calls is not None:
return AgentResponseType.TOOL_CALL
return AgentResponseType.RESPONSE
def __add__(self, other: "AgentResponse") -> "AgentResponse":
if not isinstance(other, AgentResponse):
return NotImplemented
# Merge delta_message dictionaries
merged_delta_dict = merge_delta_message(
self.message.model_dump() if self.message else None,
other.message.model_dump() if other.message else None,
)
merged_delta = ChatMessage(**merged_delta_dict)
# Create new field dictionary, default to self's fields
new_fields = self.model_dump()
new_fields["message"] = merged_delta
# Iterate other fields, override with other's value if present
for field, _ in self.model_dump().items():
if field == "message":
continue # Already processed
other_value = getattr(other, field)
if other_value:
new_fields[field] = other_value
return AgentResponse(**new_fields)
================================================
FILE: cortex/context/__init__.py
================================================
"""Context module for managing conversation and session context."""
from cortex.context.base_context import BaseContext
from cortex.context.file_context import FileContext
from cortex.context.simple_context import SimpleContext
__all__ = ["BaseContext", "make_simple_context", "make_file_context"]
def make_simple_context(session_id: str) -> BaseContext:
return SimpleContext(session_id)
def make_file_context(path: str, session_id: str) -> BaseContext:
return FileContext(session_id, path)
================================================
FILE: cortex/context/base_context.py
================================================
"""Base context management class."""
from abc import ABC, abstractmethod
from typing import List
from cortex.model.definition import ChatMessage
class BaseContext(ABC):
"""Base context management class, providing basic interface for session message management."""
def __init__(self, session_id: str):
"""
Initialize base context.
Args:
session_id: Session ID
"""
self.session_id = session_id
@abstractmethod
def add(self, messages: list[ChatMessage]) -> None:
"""
Add chat messages to context.
Args:
messages: List of chat messages to add
"""
...
@abstractmethod
def get_all(self) -> List[ChatMessage]:
"""
Get all chat messages.
Returns:
List[ChatMessage]: List of all chat messages
"""
...
================================================
FILE: cortex/context/file_context.py
================================================
"""File-based context management class."""
import asyncio
import json
import os
from pathlib import Path
from typing import List
from cortex.model.definition import ChatMessage
from cortex.context.base_context import BaseContext
class FileContext(BaseContext):
"""File-based context management class, each session_id corresponds to a file."""
def __init__(
self,
session_id: str,
storage_dir: str = "contexts",
batch_size: int = 5,
delay_seconds: float = 2.0,
):
"""
Initialize file context.
Args:
session_id: Session ID
storage_dir: Storage directory, defaults to "contexts"
batch_size: Batch write size, write immediately when this count is reached
delay_seconds: Delay write time (seconds)
"""
super().__init__(session_id)
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(exist_ok=True)
self.file_path = self.storage_dir / f"{session_id}.jsonl"
self._messages: List[ChatMessage] = []
self._pending_messages: List[ChatMessage] = []
self._batch_size = batch_size
self._delay_seconds = delay_seconds
self._write_task = None
self._load_messages()
def __del__(self):
"""Automatically call flush when object is garbage collected"""
try:
self.flush()
except Exception:
# Ignore exceptions in destructor to avoid affecting garbage collection
pass
def _load_messages(self) -> None:
"""Load messages from file"""
if self.file_path.exists():
try:
with open(self.file_path, "r", encoding="utf-8") as f:
self._messages = []
for line in f:
line = line.strip()
if line:
msg_data = json.loads(line)
self._messages.append(ChatMessage(**msg_data))
except (json.JSONDecodeError, KeyError, TypeError):
# If file is corrupted or format is incorrect, start fresh
self._messages = []
def _save_messages(self) -> None:
"""Save messages to file"""
# Rewrite entire file
all_messages = self._messages + self._pending_messages
with open(self.file_path, "w", encoding="utf-8") as f:
for msg in all_messages:
json_line = json.dumps(msg.model_dump(), ensure_ascii=False)
f.write(json_line + "\n")
# Move pending messages to main message list
self._messages.extend(self._pending_messages)
self._pending_messages = []
async def _delayed_write(self) -> None:
"""Delayed write task"""
await asyncio.sleep(self._delay_seconds)
if self._pending_messages:
self._save_messages()
self._write_task = None
def _schedule_write(self) -> None:
"""Schedule write task"""
# If batch size is reached, write immediately
if len(self._pending_messages) >= self._batch_size:
if self._write_task:
self._write_task.cancel()
self._write_task = None
self._save_messages()
else:
# If no write task in progress, create a delayed write task
if not self._write_task:
self._write_task = asyncio.create_task(self._delayed_write())
def add(self, messages: list[ChatMessage]) -> None:
"""Add chat messages to context
Args:
messages: List of chat messages to add
"""
self._pending_messages.extend(messages)
self._schedule_write()
def get_all(self) -> List[ChatMessage]:
"""Get all chat messages
Returns:
List[ChatMessage]: List of all chat messages
"""
return (self._messages + self._pending_messages).copy()
def clear(self) -> None:
"""Clear context messages"""
if self._write_task:
self._write_task.cancel()
self._write_task = None
self._messages = []
self._pending_messages = []
if self.file_path.exists():
os.remove(self.file_path)
def flush(self) -> None:
"""Force write all pending messages"""
if self._write_task:
self._write_task.cancel()
self._write_task = None
if self._pending_messages:
self._save_messages()
================================================
FILE: cortex/context/simple_context.py
================================================
from typing import List
from cortex.model.definition import ChatMessage
from cortex.context import BaseContext
simple_contexts: dict[str, list[ChatMessage]] = {}
class SimpleContext(BaseContext):
"""Simple context management class for managing session messages."""
def __init__(self, session_id: str):
super().__init__(session_id)
def add(self, msg: list[ChatMessage]) -> None:
"""Add chat message list to context."""
if self.session_id in simple_contexts:
simple_contexts[self.session_id].extend(msg)
else:
simple_contexts[self.session_id] = msg
def get_all(self) -> List[ChatMessage]:
"""Get all chat messages."""
return simple_contexts.get(self.session_id, [])
================================================
FILE: cortex/env.py
================================================
from __future__ import annotations
import os
from pathlib import Path
def _repo_root() -> Path:
return Path(__file__).resolve().parent.parent
def _parse_env_line(line: str) -> tuple[str, str] | None:
stripped = line.strip()
if not stripped or stripped.startswith("#"):
return None
if stripped.startswith("export "):
stripped = stripped[7:].lstrip()
if "=" not in stripped:
return None
key, value = stripped.split("=", 1)
key = key.strip()
if not key:
return None
value = value.strip()
if len(value) >= 2 and value[0] == value[-1] and value[0] in {"'", '"'}:
value = value[1:-1]
return key, value
def _load_env_fallback(path: Path) -> None:
try:
content = path.read_text(encoding="utf-8")
except Exception:
return
for line in content.splitlines():
parsed = _parse_env_line(line)
if not parsed:
continue
key, value = parsed
os.environ.setdefault(key, value)
def load_env() -> None:
env_path = _repo_root() / ".env"
if not env_path.exists():
return
try:
from dotenv import load_dotenv # type: ignore
except Exception:
_load_env_fallback(env_path)
return
load_dotenv(env_path, override=False)
================================================
FILE: cortex/examples/agents/ask_input_agent.py
================================================
"""AskInputAgent - Fixed-flow Agent for asking user input and repeating what the user says"""
import json
import logging
import uuid
from typing import AsyncGenerator
from uuid import uuid4
from cortex.model.definition import ChatMessage, ChatToolCall, ContentBlockType, Function
from cortex.agents.base_agent import BaseAgent
from cortex.agents.input.input import InputChannel
from cortex.agents.types import (
AgentConfig,
AgentMessageType,
AgentResponse,
AgentRunningStatus,
)
from cortex.context import BaseContext, make_simple_context
from cortex.model import ModelParams
from cortex.tools.client_tool import ClientTool
from cortex.tools.toolset import ToolSet
from cortex.tools.types import ToolType
logger = logging.getLogger(__name__)
async def init_ask_input_tools() -> ToolSet:
"""Initialize ask_input toolset"""
toolset = ToolSet()
# Register ask_input tool (ClientTool)
ask_input_tool = ClientTool(
name="ask_input",
description="Ask user for input. Used for scenarios requiring user interaction such as obtaining user feedback, confirmation, modification suggestions, etc. Parameters: prompt (required) - prompt message to display to the user; context (optional) - context information to help users understand the current situation.",
tool_type=ToolType.ASK_INPUT,
channel=toolset.channel,
timeout=300.0, # User input may take a long time
client_params={
"properties": {
"prompt": {
"type": "string",
"description": "Prompt message to display to the user, explaining what the user needs to do (confirm, modify, provide information, etc.)",
},
"context": {
"type": "string",
"description": "Context information to help users understand the current situation, such as current plan content, items that need confirmation, etc.",
},
},
"required": ["prompt"],
},
)
toolset.register(ask_input_tool)
logger.info("Registered ask_input tool")
return toolset
class AskInputAgent(BaseAgent):
"""AskInputAgent - Fixed-flow Agent for asking user input and repeating what the user says
Fixed flow:
1. Send a function call with tool_name as ask_input
2. After receiving ask_input result, send a message repeating what the user said
"""
def __init__(
self, context: BaseContext, config: AgentConfig, toolset: ToolSet | None = None
):
super().__init__(config=config, toolset=toolset)
self.context = context
async def _run(
self,
messages: list[ChatMessage] | InputChannel[ChatMessage],
additional_kwargs: dict | None = None,
) -> AsyncGenerator[AgentResponse, None]:
"""
Run agent, execute fixed flow
Args:
messages: Input message list or input channel
additional_kwargs: Additional parameters
Yields:
AgentResponse: Agent response object
"""
try:
# Step 1: Create a ChatMessage containing tool_calls, call ask_input tool
tool_call_id = f"call_{uuid.uuid4().hex[:8]}"
# Create tool call arguments
tool_args = json.dumps({"prompt": "Please enter some content"})
# Create ChatToolCall
tool_call = ChatToolCall(
id=tool_call_id,
function=Function(
name="ask_input",
arguments=tool_args,
),
)
# Create ChatMessage containing tool_calls
tool_call_message = ChatMessage(
role="assistant",
tool_calls=[tool_call],
)
# Yield tool call response
tool_call_response = AgentResponse(
message=tool_call_message,
status=AgentRunningStatus.RUNNING.value,
message_type=AgentMessageType.FINAL.value,
)
yield tool_call_response
# Execute tool call
tool_result_messages = await self.run_tool_call(tool_call_message)
if not tool_result_messages:
error_response = AgentResponse(
message=None,
status=AgentRunningStatus.ERROR.value,
error_msg="ask_input tool call failed, no result returned",
message_type=AgentMessageType.FINAL.value,
)
yield error_response
return
# Yield tool result response
tool_result_message = tool_result_messages[0]
tool_result_response = AgentResponse(
message=tool_result_message,
status=AgentRunningStatus.RUNNING.value,
message_type=AgentMessageType.FINAL.value,
)
yield tool_result_response
# Step 2: Get user input from tool result and repeat what the user said
user_input = None
if tool_result_message.content:
# Extract user input
if isinstance(tool_result_message.content, list):
for block in tool_result_message.content:
if isinstance(block, dict):
text_content = block.get(
ContentBlockType.TEXT.value
) or block.get("text")
if text_content:
user_input = str(text_content)
break
elif isinstance(block, str):
user_input = block
break
elif isinstance(tool_result_message.content, str):
user_input = tool_result_message.content
if user_input is None:
user_input = "Failed to get user input"
# Create message repeating what the user said
repeat_message = ChatMessage(
role="assistant",
content=[
{
"type": ContentBlockType.TEXT.value,
ContentBlockType.TEXT.value: f"You said: {user_input}",
}
],
)
# Yield final response
final_response = AgentResponse(
message=repeat_message,
status=AgentRunningStatus.FINISHED.value,
message_type=AgentMessageType.FINAL.value,
)
yield final_response
except Exception as e:
logger.error(f"@{self.name} execution error: {e}", exc_info=True)
error_response = AgentResponse(
message=None,
status=AgentRunningStatus.ERROR.value,
error_msg=str(e),
message_type=AgentMessageType.FINAL.value,
)
yield error_response
async def make_ask_input_agent(
config: AgentConfig, context_id: str | None = None
) -> BaseAgent:
"""Create AskInputAgent"""
toolset = await init_ask_input_tools()
context = make_simple_context(context_id)
if context_id is None:
context_id = uuid4().hex
return AskInputAgent(context=context, config=config, toolset=toolset)
def get_ask_input_agent_config() -> AgentConfig:
"""Get AskInputAgent configuration"""
return AgentConfig(
name="AskInputAgent",
description="Fixed-flow Agent for asking user input and repeating what the user says. Does not require model calls.",
system_prompt=None, # No system prompt needed
model=ModelParams(
name="gpt-4o-mini", # Although model calls are not needed, configuration requires it
infer_kwargs={"max_tokens": 100, "temperature": 0.7, "stream": False},
),
)
================================================
FILE: cortex/examples/agents/deep_reasearch_agent.py
================================================
"""DeepResearchAgent - Deep research Agent"""
from cortex.agents.base_agent import BaseAgent
from cortex.agents.react_agent import ReActAgent
from cortex.agents.types import AgentConfig
from cortex.model import ModelParams
from cortex.tools.agent_tool import AgentTool
from cortex.tools.toolset import ToolSet
async def init_deep_research_tools() -> ToolSet:
"""Initialize deep research tools"""
# Register web_search tool
toolset = ToolSet()
await toolset.register_from_mcp_server(
mcp_server="http://xxx/mcp",
tool_names=["web_search"],
)
# plan agent tool
plan_agent_tool = AgentTool(
name="PlanAgent",
description="PlanAgent - Planning Agent responsible for creating plans",
channel=toolset.channel,
timeout=3000.0,
)
toolset.register(plan_agent_tool)
return toolset
def get_deep_research_agent_config() -> AgentConfig:
"""Get deep research Agent configuration"""
return AgentConfig(
name="DeepResearchAgent",
description="DeepResearchAgent - Deep research Agent responsible for deep research tasks",
system_prompt="""You are a powerful "Deep Search Agent", an intelligent system with capabilities for initial thinking, reflection, and adaptation. When facing a complex problem, you don't simply execute searches, but rather think deeply, plan, and flexibly utilize tools (such as web search, webpage access, code execution) to find answers like a smart assistant.
The key lies in "depth" and "intelligence", which means the Agent needs to have the following characteristics:
1. **Tool Coordination**: Able to carefully analyze problems, reference plans already made in history, and strategically use tools to collect, process, and present information.
2. **Data Extraction and Visualization**: Able to extract data relevant to the problem from massive amounts of information, and use the visualization (visualize_data) tool to transform data into intuitive charts.
3. **Reflection and Adaptation**: This is the most important capability! During the search process, if problems are encountered (such as unsatisfactory search results, insufficient information, or uncertainty), the Agent won't give up easily, but will proactively reflect and adjust strategies. For example:
- Change search keywords.
- View more search results.
- Determine whether current information is sufficient to answer the question; if not, continue searching for missing information.
- Evaluate the reliability of information sources.
- Use different tools or information sources for cross-validation.
- Determine whether there is critical data that needs to be displayed through charts; if so, use visualization tools to present it.
- Check whether the collected information fully meets all requirements of the original question.
You must always follow these rules to complete tasks:
1. Always provide tool calls, otherwise it will fail
2. Always use correct tool parameters. Don't use variable names in action parameters, use specific values instead
3. Only call tools when needed: if information is not needed, don't call search tools, try to solve the problem yourself
4. Never repeat a tool call with exactly the same parameters that has already been used
5. Only use the visualize_data tool for visualization, while the execute_python_code tool can only be used for complex data calculations or file processing
6. If information involves key numerical values, data presentation, data comparison, process chains, multiple stages, entity relationships, timelines, etc., you must use the visualize_data tool to generate charts
7. For data that has already been visualized, don't call the visualization tool repeatedly
8. Do not use the execute_python_code tool to output large amounts of text, do not use the execute_python_code tool to output reports
9. Never express gratitude for any tool call results (such as search results)
10. Multi-language support: You support responding in Chinese, English, Japanese, Korean, Traditional Chinese, Spanish, and Portuguese, automatically identifying the "user's input language" and matching the output.
11. Search additional requirements: When the problem you need to search is related to travel/public opinion, you need to generate queries in the language corresponding to the travel/public opinion location for searching, and also generate an identical query in the user's language for searching.
Start now! If you complete the task correctly, you will receive a $1,000,000 reward.
""",
model=ModelParams(
name="gpt-5.1",
infer_kwargs={"max_tokens": 2000, "temperature": 0.7, "stream": False},
),
max_steps=10,
)
async def make_deep_research_agent(session_id: str, config: AgentConfig) -> BaseAgent:
"""Create deep research Agent"""
toolset = await init_deep_research_tools()
return ReActAgent(config=config, toolset=toolset)
================================================
FILE: cortex/examples/agents/main_agent.py
================================================
"""MainAgent - Main coordination Agent."""
import logging
from uuid import uuid4
from cortex.agents.base_agent import BaseAgent
from cortex.agents.react_agent import ReActAgent
from cortex.agents.types import AgentConfig
from cortex.context import make_simple_context
from cortex.examples.agents.math_agent import get_math_agent_config, make_math_agent
from cortex.examples.agents.search_agent import (
get_search_agent_config,
make_search_agent,
)
from cortex.model import ModelParams
from cortex.tools.agent_tool import AgentTool
from cortex.tools.toolset import ToolSet
logger = logging.getLogger(__name__)
async def init_main_tools() -> ToolSet:
"""Initialize main coordination tools."""
toolset = ToolSet()
search_agent = await make_search_agent(config=get_search_agent_config())
math_agent = await make_math_agent(config=get_math_agent_config())
search_tool_params = search_agent.as_tool()
math_tool_params = math_agent.as_tool()
toolset.register(
AgentTool(
name=search_tool_params["name"],
description=search_tool_params["description"],
timeout=search_tool_params["timeout"]
if "timeout" in search_tool_params
else 300,
channel=toolset.channel,
)
)
toolset.register(
AgentTool(
name=math_tool_params["name"],
description=math_tool_params["description"],
timeout=math_tool_params["timeout"]
if "timeout" in math_tool_params
else 300,
channel=toolset.channel,
)
)
return toolset
async def make_main_agent(
config: AgentConfig, context_id: str | None = None
) -> BaseAgent:
"""Create MainAgent."""
toolset = await init_main_tools()
if context_id is None:
context_id = uuid4().hex
context = make_simple_context(context_id)
return ReActAgent(context=context, config=config, toolset=toolset)
def get_main_agent_config() -> AgentConfig:
"""Get MainAgent configuration."""
return AgentConfig(
name="MainAgent",
description="Main coordination Agent responsible for coordinating and calling other specialized Agents to complete tasks. Can select appropriate Agents based on task requirements (e.g., MathAgent for mathematical calculations, SearchAgent for information search) and coordinate multiple Agents to complete complex tasks.",
system_prompt="You are a main coordination Agent responsible for coordinating and calling other specialized Agents to complete tasks.",
model=ModelParams(
name="gpt-4o-mini",
infer_kwargs={"max_tokens": 2000, "temperature": 0.7, "stream": False},
),
)
================================================
FILE: cortex/examples/agents/math_agent.py
================================================
"""MathAgent - An Agent that can solve mathematical problems."""
import logging
import math
from uuid import uuid4
from cortex.agents.base_agent import BaseAgent
from cortex.agents.react_agent import ReActAgent
from cortex.agents.types import AgentConfig
from cortex.context import make_simple_context
from cortex.model import ModelParams
from cortex.tools.function_tool import FunctionTool
from cortex.tools.toolset import ToolSet
logger = logging.getLogger(__name__)
# Basic math operation tool functions
def add(a: float, b: float) -> float:
"""Add two numbers.
Args:
a: First number
b: Second number
Returns:
Sum of the two numbers
"""
return a + b
def subtract(a: float, b: float) -> float:
"""Subtract two numbers.
Args:
a: Minuend
b: Subtrahend
Returns:
Difference of the two numbers
"""
return a - b
def multiply(a: float, b: float) -> float:
"""Multiply two numbers.
Args:
a: First number
b: Second number
Returns:
Product of the two numbers
"""
return a * b
def divide(a: float, b: float) -> float:
"""Divide two numbers.
Args:
a: Dividend
b: Divisor (cannot be 0)
Returns:
Quotient of the two numbers
Raises:
ValueError: When divisor is 0
"""
if b == 0:
raise ValueError("Divisor cannot be 0")
return a / b
def power(base: float, exponent: float) -> float:
"""Calculate power operation.
Args:
base: Base number
exponent: Exponent
Returns:
Base raised to the power of exponent
"""
return base**exponent
def sqrt(number: float) -> float:
"""Calculate square root.
Args:
number: Number to calculate square root (must be >= 0)
Returns:
Square root of the number
Raises:
ValueError: When number is less than 0
"""
if number < 0:
raise ValueError("Cannot calculate square root of negative number")
return math.sqrt(number)
def calculate_expression(expression: str) -> float:
"""Calculate a mathematical expression (using eval, for simple expressions only).
Args:
expression: Mathematical expression string, e.g., "2 + 3 * 4"
Returns:
Calculation result
Warning:
This function uses eval, only for simple mathematical expressions, do not use for untrusted input
"""
# Only allow numbers, operators, and parentheses
allowed_chars = set("0123456789+-*/.() ")
if not all(c in allowed_chars for c in expression):
raise ValueError("Expression contains disallowed characters")
try:
return float(eval(expression))
except Exception as e:
raise ValueError(f"Expression calculation error: {str(e)}")
async def init_math_tools() -> ToolSet:
"""Initialize math toolset."""
toolset = ToolSet()
# Register basic math operation tools
math_functions = [
("add", add, "Addition: add two numbers"),
("subtract", subtract, "Subtraction: subtract two numbers"),
("multiply", multiply, "Multiplication: multiply two numbers"),
("divide", divide, "Division: divide two numbers"),
("power", power, "Power: calculate a number raised to a power"),
("sqrt", sqrt, "Square root: calculate the square root of a number"),
(
"calculate_expression",
calculate_expression,
"Calculate expression: evaluate a simple mathematical expression string",
),
]
for tool_name, tool_func, tool_desc in math_functions:
function_tool = FunctionTool(
name=tool_name,
func=tool_func,
description=tool_desc,
)
toolset.register(function_tool)
logger.info(f"Registered math tool: {tool_name}")
return toolset
async def make_math_agent(
config: AgentConfig, context_id: str | None = None
) -> BaseAgent:
"""Create MathAgent."""
toolset = await init_math_tools()
if context_id is None:
context_id = uuid4().hex
context = make_simple_context(context_id)
return ReActAgent(context=context, config=config, toolset=toolset)
def get_math_agent_config() -> AgentConfig:
"""Get MathAgent configuration."""
return AgentConfig(
name="MathAgent",
description="An Agent specialized for mathematical calculations. Supports basic math operations (addition, subtraction, multiplication, division, power, square root, etc.), can handle complex mathematical expressions, supports multi-step calculations, and provides detailed calculation process explanations. Suitable for arithmetic, algebra, geometry calculations, and mathematical expression solving.",
system_prompt="You are a professional mathematical calculation assistant.",
model=ModelParams(
name="gpt-4o-mini",
infer_kwargs={"max_tokens": 2000, "temperature": 0.7, "stream": False},
),
)
================================================
FILE: cortex/examples/agents/plan_agent.py
================================================
"""PlanAgent - Creates plans based on task objectives, and finalizes plans after user confirmation or modification"""
import logging
from cortex.agents.base_agent import BaseAgent
from cortex.agents.react_agent import ReActAgent
from cortex.agents.types import AgentConfig
from cortex.context import make_simple_context
from cortex.model import ModelParams
from cortex.tools.client_tool import ClientTool
from cortex.tools.toolset import ToolSet
from cortex.tools.types import ToolType
logger = logging.getLogger(__name__)
async def init_plan_tools() -> ToolSet:
"""Initialize plan toolset"""
toolset = ToolSet()
# Register web_search tool
await toolset.register_from_mcp_server(
mcp_server="http://xxx/mcp",
tool_names=["web_search"],
)
# Register ask_input tool (ClientTool)
ask_input_tool = ClientTool(
name="ask_input",
description="Ask users for input, confirmation, or modification suggestions. Used for scenarios requiring user interaction such as obtaining user feedback, confirming plans, modifying suggestions, etc. Parameters: prompt (required) - prompt message to display to the user; context (optional) - context information to help users understand the current situation.",
tool_type=ToolType.ASK_INPUT,
channel=toolset.channel,
timeout=300.0, # User input may take a long time
client_params={
"properties": {
"prompt": {
"type": "string",
"description": "Prompt message to display to the user, explaining what the user needs to do (confirm, modify, provide information, etc.)",
},
"context": {
"type": "string",
"description": "Context information to help users understand the current situation, such as current plan content, items that need confirmation, etc.",
},
},
"required": ["prompt"],
},
)
toolset.register(ask_input_tool)
logger.info("Registered ask_input tool")
return toolset
async def make_plan_agent(context_id: str, config: AgentConfig) -> BaseAgent:
"""Create PlanAgent"""
toolset = await init_plan_tools()
context = make_simple_context(context_id)
return ReActAgent(context=context, config=config, toolset=toolset)
def get_plan_agent_config() -> AgentConfig:
"""Get PlanAgent configuration"""
return AgentConfig(
name="PlanAgent",
description="An Agent specialized in creating plans based on task objectives. Capable of analyzing task requirements, searching for relevant information, creating detailed plans, and finalizing plans after user confirmation or modification. Suitable for project planning, task decomposition, action plan creation, and similar scenarios.",
system_prompt="""You are a professional planning assistant. Your responsibility is to create detailed, feasible plans based on the task objectives provided by the user.
Workflow:
1. **Understand Task Objectives**: Carefully analyze the task objectives provided by the user, understand the core requirements, expected results, and constraints of the task.
2. **Information Collection** (if needed):
- If the task involves the need for latest information or professional knowledge, use the web_search tool to search for relevant information
- Collect background knowledge, best practices, case references, etc. related to the task
3. **User Confirmation and Modification**:
- Break down the task into clear steps
- Use the ask_input tool to show the preliminary plan to the user
- Clearly explain the plan content in the prompt and ask the user if modifications are needed
- Provide detailed plan content in context for the user to review
- Adjust the plan based on user feedback (confirmation, modification suggestions, etc.)
4. **Create Final Plan**:
- Create the final plan based on user confirmation or modification feedback
- Ensure the plan is complete, clear, and executable
- Provide a summary of the plan and execution suggestions
Important Principles:
- Be sure to use the ask_input tool to show the preliminary plan to the user and adjust the plan based on user feedback
- Be sure to output the final plan; the final plan does not need modification, just output it directly
- Plans should be specific and executable, avoiding being too abstract
- Consider practical feasibility and resource constraints""",
model=ModelParams(
name="gpt-5.1",
infer_kwargs={"max_tokens": 2000, "temperature": 0.7, "stream": False},
),
)
================================================
FILE: cortex/examples/agents/search_agent.py
================================================
"""SearchAgent - An Agent that can integrate search information"""
import logging
from uuid import uuid4
from cortex.agents.base_agent import BaseAgent
from cortex.agents.react_agent import ReActAgent
from cortex.agents.types import AgentConfig
from cortex.context import make_simple_context
from cortex.model import ModelParams
from cortex.tools.toolset import ToolSet
logger = logging.getLogger(__name__)
async def init_search_tools() -> ToolSet:
"""Initialize search tools"""
toolset = ToolSet()
await toolset.register_from_mcp_server(
mcp_server="http://xxx/mcp",
tool_names=["web_search"],
)
return toolset
async def make_search_agent(
config: AgentConfig, context_id: str | None = None
) -> BaseAgent:
"""Create SearchAgent"""
toolset = await init_search_tools()
if context_id is None:
context_id = uuid4().hex
context = make_simple_context(context_id)
return ReActAgent(context=context, config=config, toolset=toolset)
def get_search_agent_config() -> AgentConfig:
"""Get SearchAgent configuration"""
return AgentConfig(
name="SearchAgent",
description="An Agent specialized in searching and integrating information. Uses search tools to find latest information, integrates multiple search results, provides comprehensive and accurate answers with citations. Suitable for finding latest information, integrating information from multiple sources, answering questions requiring real-time data, and information retrieval and organization.",
system_prompt="You are a professional information search and integration assistant.",
model=ModelParams(
name="gpt-4o-mini",
infer_kwargs={"max_tokens": 2000, "temperature": 0.7, "stream": False},
),
)
================================================
FILE: cortex/examples/demo_agent_cli.py
================================================
"""Demo Agent CLI - TUI interface built using TUI module to display AgentEvent."""
import asyncio
import logging
from pathlib import Path
from agentkit.trace import LocalStorageTracer
from cortex.agents.agent_factory import AgentFactory
from cortex.examples.agents.ask_input_agent import (
get_ask_input_agent_config,
make_ask_input_agent,
)
from cortex.examples.agents.deep_reasearch_agent import (
get_deep_research_agent_config,
make_deep_research_agent,
)
from cortex.examples.agents.main_agent import get_main_agent_config, make_main_agent
from cortex.examples.agents.math_agent import get_math_agent_config, make_math_agent
from cortex.examples.agents.plan_agent import get_plan_agent_config, make_plan_agent
from cortex.examples.agents.search_agent import (
get_search_agent_config,
make_search_agent,
)
from cortex.orchestrator.orchestrator import Orchestrator
from cortex.tui import AgentTUIApp
logger = logging.getLogger(__name__)
async def main():
"""
Main function.
Example:
python -m cortex.examples.demo_agent_cli
"""
# Create logs directory
logs_dir = Path("./logs")
logs_dir.mkdir(exist_ok=True)
# Configure logger to output to file
log_file = logs_dir / "demo_agent_cli.log"
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(
logging.Formatter(
"%(asctime)s %(levelname)s %(filename)s:%(lineno)d %(message)s"
)
)
# Configure root logger to output only to file, not console
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)
root_logger.handlers.clear() # Clear default console handler
root_logger.addHandler(file_handler)
# Initialize Orchestrator
agent_factory = AgentFactory()
agent_factory.register_agent(
name="DeepResearchAgent",
make_agent_func=make_deep_research_agent,
default_config=get_deep_research_agent_config(),
)
agent_factory.register_agent(
name="PlanAgent",
make_agent_func=make_plan_agent,
default_config=get_plan_agent_config(),
)
agent_factory.register_agent(
name="AskInputAgent",
make_agent_func=make_ask_input_agent,
default_config=get_ask_input_agent_config(),
)
agent_factory.register_agent(
name="MainAgent",
make_agent_func=make_main_agent,
default_config=get_main_agent_config(),
)
agent_factory.register_agent(
name="SearchAgent",
make_agent_func=make_search_agent,
default_config=get_search_agent_config(),
)
agent_factory.register_agent(
name="MathAgent",
make_agent_func=make_math_agent,
default_config=get_math_agent_config(),
)
orchestrator = Orchestrator(agent_factory)
tracer = LocalStorageTracer(storage_dir="./traces")
app = AgentTUIApp(orchestrator=orchestrator, workdir="./logs", tracer=tracer)
await app.run_async()
if __name__ == "__main__":
asyncio.run(main())
================================================
FILE: cortex/examples/demo_agent_with_orchestrator.py
================================================
"""Demo Agent with Orchestrator - Demonstrates how to use the Orchestrator pattern to coordinate multiple Agents"""
import argparse
import asyncio
import logging
from agentkit.trace import LocalStorageTracer, SpanContext
from cortex.model.definition import ChatMessage
from cortex.agents.agent_factory import AgentFactory
from cortex.examples.agents.main_agent import get_main_agent_config, make_main_agent
from cortex.examples.agents.math_agent import get_math_agent_config, make_math_agent
from cortex.examples.agents.search_agent import (
get_search_agent_config,
make_search_agent,
)
from cortex.orchestrator import AgentEvent
from cortex.orchestrator.orchestrator import Orchestrator
from cortex.orchestrator.types import AgentEventType, AgentRequest
logger = logging.getLogger(__name__)
async def main(
request: str | None = None,
output_file: str | None = None,
):
"""
Main function demonstrating how to use Orchestrator
Args:
request: User request content. If None, uses default test message
output_file: Output file path. If None, outputs to stdout
"""
agent_factory = AgentFactory()
agent_factory.register_agent(
name="MainAgent",
make_agent_func=make_main_agent,
default_config=get_main_agent_config(),
)
agent_factory.register_agent(
name="SearchAgent",
make_agent_func=make_search_agent,
default_config=get_search_agent_config(),
)
agent_factory.register_agent(
name="MathAgent",
make_agent_func=make_math_agent,
default_config=get_math_agent_config(),
)
orchestrator = Orchestrator(agent_factory)
messages = [ChatMessage(role="user", content=request)]
async for event in orchestrator.run(
agent_name="MainAgent",
event=AgentEvent(
type=AgentEventType.REQUEST,
request=AgentRequest(
agent_name="MainAgent",
messages=messages,
),
),
agent_config=None,
):
logger.info("------ event: %s", event.model_dump_json())
if output_file:
with open(output_file, "a", encoding="utf-8") as f:
f.write(event.model_dump_json(ensure_ascii=False) + "\n")
else:
print(event.model_dump_json(ensure_ascii=False))
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser(
description="Demo Agent with Orchestrator - Demonstrates how to use the Orchestrator pattern to coordinate multiple Agents",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Example usage:
# Use default request, output to stdout
python cortex/examples/demo_agent_with_orchestrator.py
# Specify request content
python cortex/examples/demo_agent_with_orchestrator.py --request "Please calculate the result of 123 + 456 + 789"
# Output to file
python cortex/examples/demo_agent_with_orchestrator.py --request "Please calculate the result of 123 + 456 + 789" --output result.txt
# Use grep to filter responses (stdout with prefix)
python cortex/examples/demo_agent_with_orchestrator.py | grep "\\[RESPONSE\\]"
Demo examples (can be uncommented in code or specified via --request parameter):
1. Math calculation - Simple addition: "Please calculate the result of 123 + 456 + 789"
2. Math calculation - Complex operation: "Please calculate the result of (123 + 456) * 789 / 100"
3. Math calculation - Multi-step: "Please calculate 15 + 27 + 39 + 41 step by step, first calculate the sum of the first two numbers, then add the remaining numbers"
4. Information search: "Please search for the latest information about China's central bank gold reserves in 2025"
5. Mixed task: "Please first search for China's GDP growth rate in 2024, then calculate what the GDP would be if it grows by 5% in 2025"
""",
)
parser.add_argument(
"--request",
"-r",
type=str,
default=None,
help="User request content (if not specified, uses default test message)",
)
parser.add_argument(
"--output",
"-o",
type=str,
default=None,
help="Output file path (if not specified, outputs to stdout with prefix for grep)",
)
parser.add_argument(
"--log-level",
type=str,
default="DEBUG",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Log level (default: INFO)",
)
args = parser.parse_args()
# Set log level and corresponding source file and line number
log_level = getattr(logging, args.log_level.upper())
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s %(filename)s:%(lineno)d %(message)s",
encoding="utf-8", # Add UTF-8 encoding support for proper character display
)
print("-----------------------------------------------------")
tracer = LocalStorageTracer(storage_dir="./traces")
ctx = SpanContext(tracer=tracer, app_name="demo_agent_with_orchestrator")
with ctx.span(name="demo_agent_with_orchestrator") as span:
trace_id = ctx.get_current_trace_id()
logger.info("agent_cortex trace_id %s", trace_id)
asyncio.run(main(request=args.request, output_file=args.output))
================================================
FILE: cortex/examples/demo_agent_with_tool.py
================================================
"""Demo Agent with Tool implementation."""
import argparse
import asyncio
import logging
import math
import random
import string
import uuid
from agentkit.trace import LocalStorageTracer, SpanContext
from cortex.model.definition import ChatMessage
from cortex.agents.react_agent import ReActAgent
from cortex.agents.types import AgentConfig
from cortex.context.simple_context import SimpleContext
from cortex.model import ModelParams
from cortex.tools.function_tool import FunctionTool
from cortex.tools.toolset import ToolSet
logger = logging.getLogger(__name__)
def add_numbers(a: int, b: int) -> int:
"""Add two numbers.
Args:
a: First number
b: Second number
Returns:
Sum of the two numbers
"""
return a + b
async def multiply_numbers(a: int, b: int) -> int:
"""Multiply two numbers.
Args:
a: First number
b: Second number
Returns:
Product of the two numbers
"""
await asyncio.sleep(5)
return a * b
def get_random_string(length: int = 10) -> str:
"""Generate a random string.
Args:
length: String length, defaults to 10
Returns:
Randomly generated string
"""
return "".join(random.choices(string.ascii_letters + string.digits, k=length))
def calculate_area(radius: float) -> float:
"""Calculate the area of a circle.
Args:
radius: Circle radius
Returns:
Area of the circle
"""
return math.pi * radius * radius
async def init_tools():
"""Initialize toolset."""
toolset = ToolSet()
available_functions = [
("add_numbers", add_numbers, "Addition tool"),
("multiply_numbers", multiply_numbers, "Multiplication tool"),
("get_random_string", get_random_string, "Random string generator tool"),
("calculate_area", calculate_area, "Circle area calculator tool"),
]
# Register function tools to ToolSet
for tool_name, tool_func, tool_desc in available_functions:
function_tool = FunctionTool(
name=tool_name,
func=tool_func,
description=tool_desc,
)
toolset.register(function_tool)
logger.info("Registered tool: %s", tool_name)
await toolset.register_from_mcp_server(
mcp_server="http://xxx/mcp",
tool_names=["web_search"],
)
return toolset
async def main(user_input: str):
"""Main function demonstrating how to use DemoAgentWithTool."""
# agent config
agent_config = AgentConfig(
name="demo_agent:ReActAgent",
description="Agent demonstrating how to use ToolSet and anymodel.",
system_prompt="You are a helpful assistant that can use the provided tools to help users.",
model=ModelParams(
name="gpt-4o-mini",
infer_kwargs={"max_tokens": 2000, "temperature": 0.7, "stream": False},
),
max_steps=10,
)
# toolset
toolset = await init_tools()
# agent instance
context = SimpleContext(session_id=str(uuid.uuid4()))
agent = ReActAgent(context=context, config=agent_config, toolset=toolset)
logger.info("=== %s ===\n", agent.name)
logger.info("Registered tools: %s\n", agent.toolset().list_tools())
logger.info("Tool schema: %s\n", agent.toolset().get_all_schemas())
# messages
messages = [
ChatMessage(role="user", content=user_input),
]
# run agent
async for response in agent.run(messages):
logger.info("------ response: %s", response.model_dump_json())
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Demo Agent with Tool")
parser.add_argument(
"--input",
"-i",
type=str,
default="China central bank gold reserves 2025",
help="User input query (optional, defaults to 'China central bank gold reserves 2025')",
)
args = parser.parse_args()
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s %(levelname)s %(filename)s:%(lineno)d %(message)s",
encoding="utf-8", # Add UTF-8 encoding support
)
tracer = LocalStorageTracer(storage_dir="./traces")
ctx = SpanContext(tracer=tracer, app_name="demo_agent_with_tool")
with ctx.span(name="demo_agent_with_tool") as span:
asyncio.run(main(args.input))
================================================
FILE: cortex/examples/demo_checkpoint.py
================================================
import asyncio
import logging
from uuid import uuid4
from cortex.agents.agent_factory import AgentFactory
from cortex.agents.base_agent import BaseAgent
from cortex.agents.checkpoint_agent.checkpointer import (
CheckpointStorage,
SqliteCheckPointer,
)
from cortex.agents.checkpoint_agent.react_agent import CheckpointReActAgent
from cortex.agents.types import (
AgentConfig,
)
from cortex.examples.agents.math_agent import init_math_tools
from cortex.examples.agents.search_agent import init_search_tools
from cortex.model import ModelParams
from cortex.orchestrator.orchestrator import Orchestrator
from cortex.server.http_server import HttpServer
from cortex.tools.toolset import ToolSet
from cortex.tools.ublock_agent_tool import UnblockAgentTool
logger = logging.getLogger(__name__)
def make_storage() -> CheckpointStorage:
return SqliteCheckPointer(db_path="cp.db")
def get_search_agent_config() -> AgentConfig:
return AgentConfig(
name="SearchAgent",
description="Agent specialized in searching and aggregating information. Uses search tools to find the latest information, aggregates multiple search results, provides comprehensive and accurate answers with source citations. Suitable for finding latest information, aggregating information from multiple sources, answering questions requiring real-time data, and information retrieval organization.",
system_prompt="You are a professional information search and aggregation assistant.",
model=ModelParams(
name="gpt-4o-mini",
infer_kwargs={"max_tokens": 10000, "temperature": 0.7, "stream": False},
),
)
def get_math_agent_config() -> AgentConfig:
return AgentConfig(
name="MathAgent",
description="Agent specialized in mathematical calculations. Supports basic math operations (addition, subtraction, multiplication, division, exponentiation, square root, etc.), can handle complex mathematical expressions, supports multi-step calculations, and provides detailed calculation process explanations. Suitable for arithmetic operations, algebraic calculations, geometric calculations, and mathematical expression solving.",
system_prompt="You are a professional mathematical calculation assistant.",
model=ModelParams(
name="gpt-4o-mini",
infer_kwargs={"max_tokens": 10000, "temperature": 0.7, "stream": False},
),
)
async def make_search_agent(
config: AgentConfig, context_id: str | None = None
) -> BaseAgent:
"""Create SearchAgent"""
toolset = await init_search_tools()
if context_id is None:
context_id = uuid4().hex
return CheckpointReActAgent(
storage=make_storage(),
context_id=context_id,
config=config,
toolset=toolset,
)
async def make_math_agent(
config: AgentConfig, context_id: str | None = None
) -> BaseAgent:
toolset = await init_math_tools()
if context_id is None:
context_id = uuid4().hex
return CheckpointReActAgent(
storage=make_storage(),
context_id=context_id,
config=config,
toolset=toolset,
)
async def make_main_agent(
config: AgentConfig, context_id: str | None = None
) -> BaseAgent:
"""Create MainAgent"""
toolset = ToolSet()
# toolset.register(
# UnblockClientTool(
# name="ask_for_user",
# description="Ask user for input. Used to get user feedback, confirmation, modification suggestions, etc. Parameters: prompt (required) - message shown to user; context (optional) - context information to help user understand the situation.",
# channel=toolset.channel,
# tool_type=ToolType.ASK_INPUT,
# timeout=300.0,
# client_params={
# "properties": {
# "prompt": {
# "type": "string",
# "description": "Message shown to user, explaining what user needs to do (confirm, modify, provide information, etc.)",
# },
# "context": {
# "type": "string",
# "description": "Context information to help user understand the current situation, e.g., current plan content, items to confirm, etc.",
# },
# },
# "required": ["prompt"],
# },
# )
# )
search_agent = await make_search_agent(config=get_search_agent_config())
math_agent = await make_math_agent(config=get_math_agent_config())
search_tool_params = search_agent.as_tool()
math_tool_params = math_agent.as_tool()
toolset.register(
UnblockAgentTool(
name=search_tool_params["name"],
description=search_tool_params["description"],
timeout=search_tool_params["timeout"]
if "timeout" in search_tool_params
else 300,
channel=toolset.channel,
)
)
toolset.register(
UnblockAgentTool(
name=math_tool_params["name"],
description=math_tool_params["description"],
timeout=math_tool_params["timeout"]
if "timeout" in math_tool_params
else 300,
channel=toolset.channel,
)
)
if context_id is None:
context_id = uuid4().hex
return CheckpointReActAgent(
storage=make_storage(),
context_id=context_id,
config=config,
toolset=toolset,
)
async def main():
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(filename)s:%(lineno)d %(message)s",
encoding="utf-8", # Add UTF-8 encoding support
)
agent_factory = AgentFactory()
agent_factory.register_agent(
name="MainAgent",
make_agent_func=make_main_agent,
default_config=AgentConfig(
name="MainAgent",
description="Main coordination Agent responsible for coordinating and calling other specialized Agents to complete tasks. Can select appropriate Agents based on task requirements (e.g., MathAgent for mathematical calculations, SearchAgent for information search) and coordinate multiple Agents to complete complex tasks.",
system_prompt="You are a main coordination Agent responsible for coordinating and calling other specialized Agents to complete tasks. You can select appropriate Agents based on task requirements (e.g., MathAgent for mathematical calculations, SearchAgent for information search) and coordinate multiple Agents to complete complex tasks.",
model=ModelParams(
name="gpt-4o-mini",
infer_kwargs={"max_tokens": 10000, "temperature": 0.7, "stream": False},
),
unfinished_mode=False,
),
)
agent_factory.register_agent(
name="SearchAgent",
make_agent_func=make_search_agent,
default_config=get_search_agent_config(),
)
agent_factory.register_agent(
name="MathAgent",
make_agent_func=make_math_agent,
default_config=get_math_agent_config(),
)
# with open("graph.mermaid", "w", encoding="utf-8") as f:
# g = GraphReActAgent(
# make_checkpointer=lambda: AsyncSqliteSaver.from_conn_string(
# "checkpoints.db"
# ),
# config=get_search_agent_config(),
# toolset=ToolSet(),
# )
# f.write(g.draw())
orch = Orchestrator(agent_factory)
http_server = HttpServer(orch)
await http_server.start()
if __name__ == "__main__":
asyncio.run(main())
================================================
FILE: cortex/examples/demo_toolset_channel.py
================================================
"""Channel usage example: Demonstrates complete async communication flow."""
import asyncio
import logging
import time
from typing import Any, Dict
from cortex.tools.base import ToolSchema
from cortex.tools.channel import Channel
from cortex.tools.types import ToolParameters
logger = logging.getLogger(__name__)
# Mock external processing system (server, message queue, etc.)
class MockExternalServer:
"""Mock external server for processing requests."""
def __init__(self):
"""Initialize mock server."""
self.received_requests: Dict[str, Dict[str, Any]] = {}
async def process_request(self, tool_name: str, request_id: str, data: Any) -> Any:
"""
Process request (mock async processing).
Args:
tool_name: Tool name
request_id: Request ID
data: Request data
Returns:
Any: Processing result
"""
# Simulate network delay
await asyncio.sleep(0.5)
# Store request
self.received_requests[request_id] = {
"tool_name": tool_name,
"data": data,
"timestamp": time.time(),
}
# Mock processing logic
if tool_name == "calculator":
if "operation" in data and "operands" in data:
operation = data["operation"]
operands = data["operands"]
if operation == "add":
result = sum(operands)
elif operation == "multiply":
result = 1
for x in operands:
result *= x
else:
result = f"Unknown operation: {operation}"
return {"result": result, "operation": operation}
elif tool_name == "data_processor":
if "data" in data:
processed = f"Processed: {data['data']}"
return {"output": processed, "length": len(data["data"])}
# Default return
return {"status": "success", "data": data}
async def demo_basic_usage():
"""Demo 1: Basic usage flow."""
logger.info("=" * 60)
logger.info("Demo 1: Basic usage flow - Send and wait for response")
logger.info("=" * 60)
# Create mock server
server = MockExternalServer()
# Define send callback function
async def send_to_server(
tool_name: str, tool_schema: ToolSchema, tool_parameters: ToolParameters
):
"""Callback function to send data to server."""
# Extract request_id from tool_parameters.kwargs
request_id = tool_parameters.kwargs.pop("_request_id", None)
if request_id is None:
request_id = f"req_{hash(tool_name)}_{hash(tool_parameters.parameters)}"
# Build request data
request_data = {
"parameters": tool_parameters.parameters,
**tool_parameters.kwargs,
}
logger.info(
f" [Send] tool={tool_name}, request_id={request_id}, data={request_data}"
)
result = await server.process_request(tool_name, request_id, request_data)
# Simulate setting response via Channel
# In real scenarios, this might be done asynchronously via network, message queue, etc.
channel.set_response(request_id, result)
# Create Channel and register send callback
channel = Channel(on_send=send_to_server)
# Send request 1: Calculator tool
logger.info("\nSend request 1: Calculator - Addition")
request_id1, response1 = await channel.send_request(
tool_name="calculator",
data=ToolParameters(
parameters="", kwargs={"operation": "add", "operands": [10, 20, 30]}
),
tool_schema=ToolSchema(name="calculator", description="Calculator tool"),
timeout=5.0,
)
logger.info(f" Request ID: {request_id1}")
logger.info(f" Response: {response1}")
# Send request 2: Data processor
logger.info("\nSend request 2: Data processor")
request_id2, response2 = await channel.send_request(
tool_name="data_processor",
data=ToolParameters(parameters="", kwargs={"data": "Hello, World!"}),
tool_schema=ToolSchema(name="data_processor", description="Data processor tool"),
timeout=5.0,
)
logger.info(f" Request ID: {request_id2}")
logger.info(f" Response: {response2}")
logger.info(f"\nNumber of requests received by server: {len(server.received_requests)}")
async def demo_custom_request_id():
"""Demo 2: Using custom request_id."""
logger.info("\n" + "=" * 60)
logger.info("Demo 2: Using custom request_id")
logger.info("=" * 60)
server = MockExternalServer()
channel = Channel()
# Define send callback
async def send_handler(
tool_name: str, tool_schema: ToolSchema, tool_parameters: ToolParameters
):
"""Send handler."""
# Extract request_id from tool_parameters.kwargs
request_id = tool_parameters.kwargs.pop("_request_id", None)
if request_id is None:
request_id = f"req_{hash(tool_name)}_{hash(tool_parameters.parameters)}"
# Build request data
request_data = {
"parameters": tool_parameters.parameters,
**tool_parameters.kwargs,
}
logger.info(f" [Send] request_id={request_id}, data={request_data}")
result = await server.process_request(tool_name, request_id, request_data)
channel.set_response(request_id, result)
# Send request with custom request_id
custom_id = "req_custom_001"
logger.info(f"\nUsing custom request_id: {custom_id}")
request_id, response = await channel.send_request(
tool_name="calculator",
data=ToolParameters(
parameters="", kwargs={"operation": "multiply", "operands": [2, 3, 4]}
),
tool_schema=ToolSchema(name="calculator", description="Calculator tool"),
request_id=custom_id,
on_send=send_handler,
timeout=5.0,
)
logger.info(f" Returned request_id: {request_id}")
logger.info(f" Response: {response}")
async def demo_manual_response():
"""Demo 3: Manually set response."""
logger.info("\n" + "=" * 60)
logger.info("Demo 3: Manually set response (separate send and response)")
logger.info("=" * 60)
channel = Channel()
saved_request_id = {"id": None} # Used to save request_id
# Define send callback (only send, don't set response immediately)
async def send_only(
tool_name: str, tool_schema: ToolSchema, tool_parameters: ToolParameters
):
"""Only send data, don't set response."""
# Extract request_id from tool_parameters.kwargs
request_id = tool_parameters.kwargs.pop("_request_id", None)
if request_id is None:
request_id = f"req_{hash(tool_name)}_{hash(tool_parameters.parameters)}"
saved_request_id["id"] = request_id # Save request_id for later use
# Build request data
request_data = {
"parameters": tool_parameters.parameters,
**tool_parameters.kwargs,
}
logger.info(f" [Send only] request_id={request_id}, data={request_data}")
# In real scenarios, this might be sent to message queue, WebSocket, etc.
# Response is set asynchronously by another thread/process/service
# Start a background task to simulate async response
async def delayed_response(delay: float = 1.0):
"""Delayed response setting."""
await asyncio.sleep(delay)
req_id = saved_request_id["id"]
if req_id:
result = {
"status": "completed",
"request_id": req_id,
"data": "processed_data",
}
logger.info(f" [Set response] request_id={req_id}, response={result}")
channel.set_response(req_id, result)
# Create send request task
send_task = asyncio.create_task(
channel.send_request(
tool_name="async_processor",
data=ToolParameters(parameters="", kwargs={"task": "process_data"}),
tool_schema=ToolSchema(
name="async_processor", description="Async processor tool"
),
on_send=send_only,
timeout=5.0,
)
)
# Start delayed response task (auto-processed after request is sent)
response_task = asyncio.create_task(delayed_response(delay=0.8))
# Wait for request to complete
request_id, response = await send_task
logger.info(f" Request ID: {request_id}")
logger.info(f" Response: {response}")
await response_task
async def demo_error_handling():
"""Demo 4: Error handling."""
logger.info("\n" + "=" * 60)
logger.info("Demo 4: Error handling")
logger.info("=" * 60)
channel = Channel()
async def send_with_error(
tool_name: str, tool_schema: ToolSchema, tool_parameters: ToolParameters
):
"""Send and set error response."""
# Extract request_id from tool_parameters.kwargs
request_id = tool_parameters.kwargs.pop("_request_id", None)
if request_id is None:
request_id = f"req_{hash(tool_name)}_{hash(tool_parameters.parameters)}"
logger.info(f" [Send] request_id={request_id}")
await asyncio.sleep(0.3)
# Simulate processing failure, set error response
channel.set_response(request_id, None, error="Processing failed: Invalid input")
logger.info("\nSending request (will return error)...")
try:
request_id, response = await channel.send_request(
tool_name="failing_tool",
data=ToolParameters(parameters="", kwargs={"invalid": "data"}),
tool_schema=ToolSchema(name="failing_tool", description="Failing tool"),
on_send=send_with_error,
timeout=5.0,
)
logger.info(f" Response: {response}")
except Exception as e:
logger.info(f" Caught error: {type(e).__name__}: {e}")
# Demo timeout
logger.info("\nSending request (will timeout)...")
async def slow_sender(
tool_name: str, tool_schema: ToolSchema, tool_parameters: ToolParameters
):
"""Slow sender, won't set response."""
# Extract request_id from tool_parameters.kwargs
request_id = tool_parameters.kwargs.pop("_request_id", None)
if request_id is None:
request_id = f"req_{hash(tool_name)}_{hash(tool_parameters.parameters)}"
logger.info(f" [Send] request_id={request_id} (but won't set response)")
await asyncio.sleep(2.0) # Exceeds timeout
try:
request_id, response = await channel.send_request(
tool_name="slow_tool",
data=ToolParameters(parameters="", kwargs={"slow": "data"}),
tool_schema=ToolSchema(name="slow_tool", description="Slow tool"),
on_send=slow_sender,
timeout=1.0, # 1 second timeout
)
except TimeoutError as e:
logger.info(f" Caught timeout error: {e}")
async def demo_concurrent_requests():
"""Demo 5: Concurrent requests."""
logger.info("\n" + "=" * 60)
logger.info("Demo 5: Concurrent request handling")
logger.info("=" * 60)
server = MockExternalServer()
channel = Channel()
async def concurrent_sender(
tool_name: str, tool_schema: ToolSchema, tool_parameters: ToolParameters
):
"""Concurrent send handler."""
# Extract request_id from tool_parameters.kwargs
request_id = tool_parameters.kwargs.pop("_request_id", None)
if request_id is None:
request_id = f"req_{hash(tool_name)}_{hash(tool_parameters.parameters)}"
# Build request data
request_data = {
"parameters": tool_parameters.parameters,
**tool_parameters.kwargs,
}
result = await server.process_request(tool_name, request_id, request_data)
channel.set_response(request_id, result)
# Create multiple concurrent requests
logger.info("\nSending 5 concurrent requests...")
tasks = []
for i in range(5):
task = channel.send_request(
tool_name="calculator",
data=ToolParameters(
parameters="",
kwargs={"operation": "add", "operands": [i, i + 1, i + 2]},
),
tool_schema=ToolSchema(name="calculator", description="Calculator tool"),
on_send=concurrent_sender,
timeout=5.0,
)
tasks.append(task)
# Wait for all requests to complete
results = await asyncio.gather(*tasks)
logger.info(f"\nCompleted {len(results)} requests:")
for i, (req_id, response) in enumerate(results, 1):
logger.info(
f" Request {i}: request_id={req_id}, result={response.get('result', 'N/A')}"
)
async def main():
"""Main function: Run all demos."""
logger.info("\n" + "=" * 60)
logger.info("Channel Complete Usage Flow Demo")
logger.info("=" * 60)
await demo_basic_usage()
await demo_custom_request_id()
await demo_manual_response()
await demo_error_handling()
await demo_concurrent_requests()
logger.info("\n" + "=" * 60)
logger.info("All demos completed!")
logger.info("=" * 60)
if __name__ == "__main__":
asyncio.run(main())
================================================
FILE: cortex/examples/server.py
================================================
import asyncio
import logging
from agentkit.trace import LocalStorageTracer
from cortex.agents.agent_factory import AgentFactory
from cortex.agents.types import AgentConfig
from cortex.examples.agents.main_agent import make_main_agent
from cortex.examples.agents.math_agent import get_math_agent_config, make_math_agent
from cortex.examples.agents.search_agent import (
get_search_agent_config,
make_search_agent,
)
from cortex.model import ModelParams
from cortex.orchestrator.orchestrator import Orchestrator
from cortex.server.http_server import HttpServer
logger = logging.getLogger(__name__)
async def main():
agent_factory = AgentFactory()
agent_factory.register_agent(
name="MainAgent",
make_agent_func=make_main_agent,
default_config=AgentConfig(
name="MainAgent",
description="Main coordination Agent responsible for coordinating and calling other specialized Agents to complete tasks. Can select appropriate Agents based on task requirements (e.g., MathAgent for mathematical calculations, SearchAgent for information search) and coordinate multiple Agents to complete complex tasks.",
system_prompt="You are a main coordination Agent responsible for coordinating and calling other specialized Agents to complete tasks.",
model=ModelParams(
name="gpt-4o-mini",
infer_kwargs={"max_tokens": 2000, "temperature": 0.7, "stream": False},
),
),
)
agent_factory.register_agent(
name="SearchAgent",
make_agent_func=make_search_agent,
default_config=get_search_agent_config(),
)
agent_factory.register_agent(
name="MathAgent",
make_agent_func=make_math_agent,
default_config=get_math_agent_config(),
)
orch = Orchestrator(agent_factory)
tracer = LocalStorageTracer(storage_dir="./traces")
http_server = HttpServer(orch, tracer=tracer)
await http_server.start()
if __name__ == "__main__":
asyncio.run(main())
================================================
FILE: cortex/model/__init__.py
================================================
"""Model system for Agent components."""
from enum import Enum
from typing import AsyncGenerator
from agentkit.trace import get_current_context
from .definition import ChatMessage, ModelParams
from cortex.model.utils import merge_delta_message
from pydantic import BaseModel
from cortex.model.provider import ModelProvider
__all__ = [
"ModelAPI",
"ModelMessage",
"MessageType",
"ModelParams",
]
class MessageType(str, Enum):
"""Message type enumeration"""
DELTA = "delta" # Streaming output
ACCUMULATED = "accumulated" # Accumulated output
class ModelMessage(BaseModel):
"""Model message"""
message: ChatMessage
message_type: MessageType
class ModelAPI:
"""Model API"""
def __init__(self, provider: ModelProvider):
self.provider = provider
async def chat_completion(
self,
messages: list[ChatMessage],
tools: list | None = None,
log_file: str | None = None,
trace_request: dict | None = None,
) -> ModelMessage:
"""Call model API, return accumulated message"""
ctx = get_current_context()
with ctx.llm_span(name="ModelAPI.chat_completion") as span:
span.update_payload_data(
request=trace_request
or {
# "model_params": self.params(),
"messages": messages,
"tools": tools,
},
tools=tools,
)
response = await self.provider.chat_completion(
messages=messages, tools=tools, log_file=log_file
)
span.update_payload_data(
response=ModelMessage(
message=response, message_type=MessageType.ACCUMULATED
),
)
return ModelMessage(message=response, message_type=MessageType.ACCUMULATED)
async def chat_completion_stream(
self,
messages: list[ChatMessage],
tools: list | None = None,
log_file: str | None = None,
trace_request: dict | None = None,
) -> AsyncGenerator[ModelMessage, None]:
"""Call model API, return streaming messages"""
accumulated_message = None
async for event in self.provider.chat_completion_stream(
messages=messages, tools=tools, log_file=log_file
):
merged_dict = merge_delta_message(
accumulated_message.model_dump() if accumulated_message else None,
event.model_dump() if event else None,
)
accumulated_message = ChatMessage(**merged_dict)
yield ModelMessage(message=event, message_type=MessageType.DELTA)
# Yield the accumulated complete message
if accumulated_message:
yield ModelMessage(
message=accumulated_message, message_type=MessageType.ACCUMULATED
)
ctx = get_current_context()
with ctx.llm_span(name="ModelAPI.chat_completion_stream") as span:
span.update_payload_data(
request=trace_request
or {
# "model_params": self.params(),
"messages": messages,
"tools": tools,
},
)
span.update_payload_data(
response=ModelMessage(
message=accumulated_message, message_type=MessageType.ACCUMULATED
),
)
================================================
FILE: cortex/model/definition.py
================================================
from enum import Enum
from typing import Any, Dict, List, Optional
from typing_extensions import TypedDict
from pydantic import BaseModel, Field
class MessageRole(Enum):
USER = "user"
HUMAN = "human"
ASSISTANT = "assistant"
SYSTEM = "system"
TOOL_DESC = "tool-description"
TOOL_RESPONSE = "tool-response"
TOOL = "tool"
class ContentBlockType(Enum):
TEXT = "text"
THINK = "thinking"
REDACTED_THINK = "redacted_thinking"
TOOLCALL = "toolcall"
TOOLUSE = "tool_use"
TOOLRESULT = "tool_result"
IMAGE = "image"
IMAGE_URL = "image_url"
VIDEO = "video"
VIDEO_URL = "video_url"
AUDIO = "input_audio"
AUDIO_URL = "audio_url"
DOCURL = "doc_url"
class ModelParams(BaseModel):
name: str
response_format: BaseModel | None = None
toolcall_parser_version: str | None = None
parallel_tool_calls: bool = True
infer_kwargs: dict = {}
user_role: str = MessageRole.USER.value
tool_role: str = MessageRole.TOOL.value
explicit_model_family_value: str | None = None
explicit_api_base: str | None = None
explicit_api_key: str | None = None
class Function(BaseModel):
arguments: str | None = None
"""
The arguments to call the function with, as generated by the model in JSON
format. Note that the model does not always generate valid JSON, and may
hallucinate parameters not defined by your function schema. Validate the
arguments in your code before calling your function.
"""
name: str | None = None
"""The name of the function to call."""
class ChatToolCall(BaseModel):
index: int | None = None
"""The index of the tool call."""
id: str | None = None
"""The ID of the tool call."""
function: Function
"""The function that the model called."""
type: str | None = "function"
"""The type of the tool. Currently, only `function` is supported."""
class ExtraInfo(TypedDict, total=False):
"""Extra information for ChatMessage."""
cache_msg_id: str
"""The cache message ID.
Try best to leverage the cache functionality of the backend model service.
But it's not guaranteed to be successful.
Currently only Anthropic model of models-proxy backend supports this functionality.
"""
usage: Dict[str, Any]
"""The usage information of the model. """
finish_reason: str
"""The finish reason of the model.
For OpenAI model, it's called "finish_reason".
For Anthropic model, it's called "stop_reason".
Here we use the unified name "finish_reason" for all cases.
"""
class ChatMessage(BaseModel):
id: Optional[str] = None
role: str | None = None
content: str | List[Dict] | None = None
tool_call_id: Optional[str] = None
tool_calls: Optional[List[ChatToolCall]] = Field(default_factory=list)
# for train use
extra_info: Optional[ExtraInfo] = None
@classmethod
def from_dict(cls, data: dict) -> "ChatMessage":
return cls(**data)
def to_dict(self) -> dict:
return self.model_dump()
================================================
FILE: cortex/model/provider.py
================================================
from abc import ABC, abstractmethod
from typing import AsyncGenerator
from .definition import ChatMessage
class ModelProvider(ABC):
@abstractmethod
async def chat_completion_stream(
self,
messages: list[ChatMessage],
tools: list | None = None,
log_file: str | None = None,
) -> AsyncGenerator[ChatMessage, None]:
pass
@abstractmethod
async def chat_completion(
self,
messages: list[ChatMessage],
tools: list | None = None,
log_file: str | None = None,
) -> ChatMessage:
pass
================================================
FILE: cortex/model/stepfun_chat.py
================================================
"""StepFun Chat API Client with reasoning support."""
import json
from typing import Any, AsyncGenerator, Iterable
import httpx
from pydantic import BaseModel
# StepFun API supported infer_kwargs parameters
SUPPORTED_INFER_KWARGS = frozenset([
"temperature", # 0.0-2.0, default 0.5
"top_p", # default 0.9
"max_tokens",
"n", # default 1
"stop",
"frequency_penalty", # 0.0-1.0, default 0
"response_format",
"reasoning_format", # StepFun specific: "general" or "deepseek-style"
])
# ============ Structured Response Types ============
class Function(BaseModel):
"""Function call in tool_calls."""
name: str | None = None
arguments: str | None = None
class ToolCall(BaseModel):
"""Tool call object."""
id: str | None = None
index: int | None = None
type: str = "function"
function: Function | None = None
class Message(BaseModel):
"""Chat completion message with reasoning support."""
role: str | None = None
content: str | None = None
reasoning: str | None = None # StepFun specific
tool_calls: list[ToolCall] | None = None
class Delta(BaseModel):
"""Streaming delta with reasoning support."""
role: str | None = None
content: str | None = None
reasoning: str | None = None # StepFun specific
tool_calls: list[ToolCall] | None = None
class Choice(BaseModel):
"""Chat completion choice."""
index: int = 0
message: Message | None = None
finish_reason: str | None = None
class StreamChoice(BaseModel):
"""Streaming choice with delta."""
index: int = 0
delta: Delta | None = None
finish_reason: str | None = None
class Usage(BaseModel):
"""Token usage information."""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
class ChatCompletion(BaseModel):
"""Chat completion response."""
id: str | None = None
object: str = "chat.completion"
created: int | None = None
model: str | None = None
choices: list[Choice] = []
usage: Usage | None = None
class ChatCompletionChunk(BaseModel):
"""Streaming chat completion chunk."""
id: str | None = None
object: str = "chat.completion.chunk"
created: int | None = None
model: str | None = None
choices: list[StreamChoice] = []
usage: Usage | None = None
# ============ Client ============
class StepFunClient:
"""StepFun API client using httpx.
Returns raw OpenAI format data, with an additional reasoning field in message.
"""
DEFAULT_API_BASE = "https://api.stepfun.com"
CHAT_COMPLETIONS_ENDPOINT = "/v1/chat/completions"
def __init__(
self,
api_key: str,
api_base: str | None = None,
timeout: float = 120.0,
):
self.api_key = api_key
self.api_base = (api_base or self.DEFAULT_API_BASE).rstrip("/")
self.timeout = timeout
def _build_headers(self) -> dict[str, str]:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def _build_request_body(
self,
model: str,
messages: Iterable[dict[str, Any]],
tools: Iterable[dict[str, Any]] | None = None,
infer_kwargs: dict[str, Any] | None = None,
stream: bool = False,
) -> dict[str, Any]:
"""Build request body for chat completion."""
body: dict[str, Any] = {
"model": model,
"messages": list(messages),
"stream": stream,
}
if tools:
body["tools"] = list(tools)
# Add supported parameters from infer_kwargs
if infer_kwargs:
for key in SUPPORTED_INFER_KWARGS:
if key in infer_kwargs:
body[key] = infer_kwargs[key]
return body
async def chat_completion(
self,
model: str,
messages: Iterable[dict[str, Any]],
tools: Iterable[dict[str, Any]] | None = None,
infer_kwargs: dict[str, Any] | None = None,
) -> ChatCompletion:
"""Make a non-streaming chat completion request.
Returns:
ChatCompletion structured object, choices[0].message contains reasoning field
"""
url = f"{self.api_base}{self.CHAT_COMPLETIONS_ENDPOINT}"
body = self._build_request_body(
model=model,
messages=messages,
tools=tools,
infer_kwargs=infer_kwargs,
stream=False,
)
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
url,
headers=self._build_headers(),
json=body,
)
response.raise_for_status()
data = response.json()
return ChatCompletion.model_validate(data)
async def chat_completion_stream(
self,
model: str,
messages: Iterable[dict[str, Any]],
tools: Iterable[dict[str, Any]] | None = None,
infer_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[ChatCompletionChunk, None]:
"""Make a streaming chat completion request.
Yields:
ChatCompletionChunk structured object, delta contains reasoning field
"""
url = f"{self.api_base}{self.CHAT_COMPLETIONS_ENDPOINT}"
body = self._build_request_body(
model=model,
messages=messages,
tools=tools,
infer_kwargs=infer_kwargs,
stream=True,
)
async with httpx.AsyncClient(timeout=self.timeout) as client:
async with client.stream(
"POST",
url,
headers=self._build_headers(),
json=body,
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line or not line.startswith("data: "):
continue
data = line[6:] # Remove "data: " prefix
if data == "[DONE]":
break
try:
chunk = json.loads(data)
yield ChatCompletionChunk.model_validate(chunk)
except json.JSONDecodeError:
continue
================================================
FILE: cortex/model/stepfun_provider.py
================================================
"""StepFun Model Provider implementation."""
import asyncio
import logging
import re
from typing import Any, AsyncGenerator, Callable, Awaitable
import httpx
from .definition import ChatMessage
from cortex.model.definition import ChatToolCall, ContentBlockType, ExtraInfo, MessageRole, ModelParams
from cortex.model.provider import ModelProvider
from .stepfun_chat import (
ChatCompletion,
ChatCompletionChunk,
Message,
Delta,
StepFunClient,
)
# Regex pattern for matching ... tags
THINK_TAG_PATTERN = re.compile(r"(.*?)", re.DOTALL)
class StepFunModelProvider(ModelProvider):
"""Model provider for StepFun API with reasoning support."""
# Constants for matching tags
_THINK_OPEN_TAG = ""
_THINK_CLOSE_TAG = ""
def __init__(self, model_params: ModelParams):
self.model_params = model_params
# Used to track tag state in streaming scenarios
self._stream_in_think_tag = False
# Used to buffer potentially incomplete tags
self._stream_tag_buffer = ""
self._logger = logging.getLogger(__name__)
async def _call_with_retry(
self,
description: str,
func: Callable[[], Awaitable[Any]],
max_attempts: int = 5,
) -> Any:
"""Retry wrapper with linear backoff (2s,4s,6s,8s)."""
last_exc: Exception | None = None
for attempt in range(1, max_attempts + 1):
try:
return await func()
except Exception as exc: # noqa: BLE001
last_exc = exc
detail = ""
if isinstance(exc, httpx.HTTPStatusError):
try:
detail = f" body={exc.response.text}"
except Exception:
detail = ""
if attempt >= max_attempts:
break
delay = 2 * attempt
self._logger.warning(
"StepFun %s failed (attempt %s/%s): %s%s; retrying in %ss",
description,
attempt,
max_attempts,
exc,
detail,
delay,
)
await asyncio.sleep(delay)
assert last_exc is not None
raise last_exc
async def _stream_with_retry(
self,
description: str,
stream_factory: Callable[[], AsyncGenerator[Any, None]],
max_attempts: int = 5,
) -> AsyncGenerator[Any, None]:
"""Retry wrapper for streaming calls."""
last_exc: Exception | None = None
for attempt in range(1, max_attempts + 1):
try:
async for chunk in stream_factory():
yield chunk
return
except Exception as exc: # noqa: BLE001
last_exc = exc
if attempt >= max_attempts:
break
delay = 2 * attempt
self._logger.warning(
"StepFun %s stream failed (attempt %s/%s): %s; retrying in %ss",
description,
attempt,
max_attempts,
exc,
delay,
)
await asyncio.sleep(delay)
assert last_exc is not None
raise last_exc
def _create_client(self) -> StepFunClient:
"""Create a StepFunClient from model params."""
api_key = self.model_params.explicit_api_key
if not api_key:
raise ValueError("StepFun API key is required")
transport_timeout = None
if isinstance(self.model_params.infer_kwargs, dict):
transport_timeout = (
self.model_params.infer_kwargs.get("request_timeout")
or self.model_params.infer_kwargs.get("timeout")
)
return StepFunClient(
api_key=api_key,
api_base=self.model_params.explicit_api_base,
timeout=float(transport_timeout) if transport_timeout else 120.0,
)
def _chat_messages_to_openai(
self,
messages: list[ChatMessage] | list[dict],
) -> list[dict[str, Any]]:
"""Convert ChatMessage list to OpenAI format messages.
References openai_model.py's chat_message_to_openai_messages implementation,
handles complex ChatMessage formats (including content blocks, tool results, etc.).
"""
openai_messages = []
for message in messages:
if isinstance(message, dict):
message = ChatMessage(**message)
if isinstance(message.content, str) or message.content is None:
new_message = message.model_dump()
elif isinstance(message.content, list):
new_content = []
for block in message.content:
if block["type"] == ContentBlockType.TEXT.value:
# Merge consecutive text blocks
if (
new_content
and new_content[-1]["type"] == ContentBlockType.TEXT.value
):
new_content[-1]["text"] += block.get("text", block.get(block["type"], ""))
else:
new_content.append(block)
elif block["type"] == ContentBlockType.THINK.value:
# Convert thinking content to text wrapped in tags
think_content = block.get(block["type"], "")
if (
new_content
and new_content[-1]["type"] == ContentBlockType.TEXT.value
):
new_content[-1]["text"] += f"{think_content}"
else:
new_content.append({
"type": ContentBlockType.TEXT.value,
"text": f"{think_content}",
})
elif block["type"] == ContentBlockType.REDACTED_THINK.value:
redacted_content = block.get("data", "")
if (
new_content
and new_content[-1]["type"] == ContentBlockType.TEXT.value
):
new_content[-1]["text"] += f"{redacted_content}"
else:
new_content.append({
"type": ContentBlockType.TEXT.value,
"text": f"{redacted_content}",
})
elif block["type"] == ContentBlockType.TOOLRESULT.value:
# Tool result requires special handling
tool_block_content = block.get("content", [])
if isinstance(tool_block_content, str):
tool_block_content = [{
"type": ContentBlockType.TEXT.value,
"text": tool_block_content,
}]
# First text as tool role message
openai_messages.append({
"role": message.role,
"content": tool_block_content[0].get("text", "") if tool_block_content else "",
"tool_call_id": block.get("tool_use_id", message.tool_call_id),
})
# Remaining content as user message
if len(tool_block_content) > 1:
extra_message = {
"role": MessageRole.USER.value,
"content": tool_block_content[1:],
}
openai_messages.append(extra_message)
else:
# Other block types (e.g., image_url, etc.) keep as is
new_content.append(block)
new_message = message.model_dump()
new_message["content"] = new_content
# Handle tool_calls
if message.tool_calls:
normalized_tool_calls = []
for tc in message.tool_calls:
tc_dict = tc.model_dump()
# Ensure function and arguments are present for API validation
func = tc_dict.get("function") or {}
if func.get("arguments") in (None, ""):
func["arguments"] = "{}"
if not func.get("name"):
func["name"] = tc_dict.get("id") or "unknown"
tc_dict["function"] = func
normalized_tool_calls.append(tc_dict)
new_message["tool_calls"] = normalized_tool_calls
openai_messages.append(new_message)
elif new_message.get("content"):
openai_messages.append(new_message)
return openai_messages
def _extract_think_from_content(self, content: str | None) -> tuple[str | None, str | None]:
"""Extract content from ... tags in content.
Args:
content: Original content string
Returns:
tuple[reasoning, remaining_content]:
- reasoning: Content inside tags, or None if not present
- remaining_content: Content after removing tags
"""
if not content or not isinstance(content, str):
return None, content
# Find all ... tags
matches = THINK_TAG_PATTERN.findall(content)
if not matches:
return None, content
# Merge all think content
reasoning = "".join(matches)
# Remove all ... tags
remaining_content = THINK_TAG_PATTERN.sub("", content).strip()
return reasoning if reasoning else None, remaining_content if remaining_content else None
def _message_to_chat_message(self, message: Message) -> ChatMessage:
"""Convert StepFun Message (with reasoning) to ChatMessage.
StepFun message format:
- role: str
- content: str | None (may contain ... tags)
- reasoning: str | None (StepFun specific, may be empty)
- tool_calls: list[ToolCall] | None
reasoning may be in the standalone reasoning field or in tags in content.
"""
reasoning = message.reasoning
content = message.content
# If reasoning field is empty, try to extract tag content from content
if not reasoning and isinstance(content, str):
extracted_reasoning, remaining_content = self._extract_think_from_content(content)
if extracted_reasoning:
reasoning = extracted_reasoning
content = remaining_content
# Process reasoning and content separately, build content blocks
new_content: list[dict] = []
# Handle reasoning field, convert to thinking content block
if reasoning:
new_content.append({
"type": ContentBlockType.THINK.value,
ContentBlockType.THINK.value: reasoning,
})
# Handle content field
if isinstance(content, str) and content:
new_content.append({
"type": ContentBlockType.TEXT.value,
"text": content,
})
elif isinstance(content, list):
new_content.extend(content)
# Determine final content: use list if there are blocks, otherwise keep original value
final_content: str | list | None
if new_content:
final_content = new_content
else:
final_content = content # Could be None or empty string
# Handle tool_calls
tool_calls = None
if message.tool_calls:
tool_calls = [
ChatToolCall(**tc.model_dump(exclude_none=True))
for tc in message.tool_calls
]
return ChatMessage(
role=message.role,
content=final_content,
tool_calls=tool_calls,
)
def _process_stream_content_for_think(self, content: str | None) -> tuple[str | None, str | None]:
"""Process content in streaming scenarios to identify tags.
Uses character-level state machine to handle tags spanning multiple chunks:
- Tags may be split across chunks (e.g., '')
- Uses _stream_tag_buffer to cache potentially incomplete tag fragments
- Uses _stream_in_think_tag to track whether inside a tag
Args:
content: Current chunk's content
Returns:
tuple[reasoning, text_content]:
- reasoning: If inside tag, returns current reasoning content
- text_content: If not inside tag, returns plain text content
"""
if not content:
return None, None
# Merge cached content with new content
buffer = self._stream_tag_buffer + content
self._stream_tag_buffer = ""
reasoning_parts = []
text_parts = []
i = 0
while i < len(buffer):
if not self._stream_in_think_tag:
# Not inside tag, look for opening tag
# Check if this could be the start of a tag
if buffer[i] == '<':
# Check if remaining content is sufficient to determine
remaining = buffer[i:]
if remaining.startswith(self._THINK_OPEN_TAG):
# Complete tag
self._stream_in_think_tag = True
i += len(self._THINK_OPEN_TAG)
elif self._THINK_OPEN_TAG.startswith(remaining):
# Possibly incomplete tag (e.g., ' tag, treat as plain text
text_parts.append(buffer[i])
i += 1
else:
text_parts.append(buffer[i])
i += 1
else:
# Inside tag, look for closing tag
if buffer[i] == '<':
# Check if this could be a tag
remaining = buffer[i:]
if remaining.startswith(self._THINK_CLOSE_TAG):
# Complete tag
self._stream_in_think_tag = False
i += len(self._THINK_CLOSE_TAG)
elif self._THINK_CLOSE_TAG.startswith(remaining):
# Possibly incomplete tag
self._stream_tag_buffer = remaining
break
else:
# Not a tag, treat as reasoning content
reasoning_parts.append(buffer[i])
i += 1
else:
reasoning_parts.append(buffer[i])
i += 1
reasoning = "".join(reasoning_parts) if reasoning_parts else None
text_content = "".join(text_parts) if text_parts else None
return reasoning, text_content
def _delta_to_chat_message(
self,
delta: Delta,
chunk_id: str | None = None,
) -> ChatMessage:
"""Convert StepFun Delta (with reasoning) to ChatMessage.
StepFun stream delta format:
- role: str | None
- content: str | None (may contain ... tag fragments)
- reasoning: str | None (StepFun specific, may be empty)
- tool_calls: list[ToolCall] | None
reasoning may be in the standalone reasoning field or in tags in content.
In streaming scenarios, tags may span multiple chunks.
"""
reasoning = delta.reasoning
content = delta.content
# If reasoning field is empty, try to extract tag content from content
if not reasoning and isinstance(content, str):
extracted_reasoning, remaining_content = self._process_stream_content_for_think(content)
reasoning = extracted_reasoning
content = remaining_content
# Process reasoning and content separately
new_content: list[dict] = []
if reasoning:
new_content.append({
"type": ContentBlockType.THINK.value,
ContentBlockType.THINK.value: reasoning,
})
if isinstance(content, str) and content:
new_content.append({
"type": ContentBlockType.TEXT.value,
"text": content,
})
# Determine final content
final_content: str | list | None
if new_content:
final_content = new_content
else:
final_content = None # Return None when streaming has no content
# Handle tool_calls
tool_calls = None
if delta.tool_calls:
tool_calls = [
ChatToolCall(**tc.model_dump(exclude_none=True))
for tc in delta.tool_calls
]
return ChatMessage(
id=chunk_id,
role=delta.role,
content=final_content,
tool_calls=tool_calls,
)
async def chat_completion(
self,
messages: list[ChatMessage],
tools: list | None = None,
log_file: str | None = None,
) -> ChatMessage:
"""Make a non-streaming chat completion request."""
client = self._create_client()
openai_messages = self._chat_messages_to_openai(messages)
infer_kwargs = dict(self.model_params.infer_kwargs or {})
infer_kwargs.pop("request_timeout", None)
infer_kwargs.pop("timeout", None)
response: ChatCompletion = await self._call_with_retry(
description="chat_completion",
func=lambda: client.chat_completion(
model=self.model_params.name,
messages=openai_messages,
tools=tools,
infer_kwargs=infer_kwargs,
),
)
# Extract and convert message
if not response.choices:
return ChatMessage()
choice = response.choices[0]
if not choice.message:
return ChatMessage()
result = self._message_to_chat_message(choice.message)
# Add usage and finish_reason
result.extra_info = ExtraInfo()
if response.usage:
result.extra_info["usage"] = response.usage.model_dump()
if choice.finish_reason:
result.extra_info["finish_reason"] = choice.finish_reason
return result
async def chat_completion_stream(
self,
messages: list[ChatMessage],
tools: list | None = None,
log_file: str | None = None,
) -> AsyncGenerator[ChatMessage, None]:
"""Make a streaming chat completion request."""
client = self._create_client()
openai_messages = self._chat_messages_to_openai(messages)
infer_kwargs = dict(self.model_params.infer_kwargs or {})
infer_kwargs.pop("request_timeout", None)
infer_kwargs.pop("timeout", None)
async def stream_factory() -> AsyncGenerator[ChatCompletionChunk, None]:
# Reset streaming state each attempt
self._stream_in_think_tag = False
self._stream_tag_buffer = ""
async for chunk in client.chat_completion_stream(
model=self.model_params.name,
messages=openai_messages,
tools=tools,
infer_kwargs=infer_kwargs,
):
yield chunk
async for chunk in self._stream_with_retry(
description="chat_completion_stream",
stream_factory=stream_factory,
):
chunk: ChatCompletionChunk
if not chunk.choices:
continue
choice = chunk.choices[0]
if not choice.delta:
continue
result = self._delta_to_chat_message(choice.delta, chunk.id)
# Skip this chunk if no valid content
if result.content is None and result.role is None and result.tool_calls is None:
continue
# Add usage and finish_reason (if present)
if chunk.usage or choice.finish_reason:
result.extra_info = ExtraInfo()
if chunk.usage:
result.extra_info["usage"] = chunk.usage.model_dump()
if choice.finish_reason:
result.extra_info["finish_reason"] = choice.finish_reason
yield result
================================================
FILE: cortex/model/utils.py
================================================
import copy
def merge_delta_message(d1: dict | None, d2: dict | None) -> dict:
"""
Merge two message dictionaries.
Args:
d1: Accumulated message dictionary, can be dict or None
d2: Delta message dictionary, can be dict or None
Returns:
Merged dictionary
Note:
role and id fields will retain d1's values (if present), so during streaming
the role from the first chunk can be correctly preserved.
"""
if d1 is None:
return d2
if d2 is None:
return d1
result = copy.deepcopy(d1) # Create a copy of d1
for key, value in d2.items():
if key in result:
if key == "index":
pass
elif key == "role" or key == "id":
# Retain d1's value for role and id (if present), otherwise use d2's value
if not result[key] and value:
result[key] = value
elif key == "tool_calls":
if value and not isinstance(value, list):
value = [value]
toolcall = {}
if result[key]:
for v in result[key]:
toolcall[v["index"]] = v
if value:
for v in value:
toolcall_id = v["index"]
if toolcall_id in toolcall:
toolcall[toolcall_id] = merge_delta_message(
toolcall[toolcall_id], v
)
else:
toolcall[toolcall_id] = v
result[key] = []
for k, v in toolcall.items():
result[key].append(v)
elif isinstance(result[key], dict) and isinstance(value, dict):
# Recursively merge nested dictionaries
result[key] = merge_delta_message(result[key], value)
elif isinstance(result[key], str) and isinstance(value, str) and value:
# type/id are fixed enum-like values, don't concatenate (avoids "functionfunction...")
if key in ("type", "id"):
result[key] = value
else:
result[key] += value
elif isinstance(result[key], list) and isinstance(value, list):
result[key] += value
elif isinstance(result[key], (int, float)) and isinstance(
value, (int, float)
):
result[key] += value
elif value:
# Type mismatch, override with d2's value
result[key] = value
else:
result[key] = value
# Merge data like [{'type': 'text', 'text': '...'}], combining consecutive items with the same type into one
for key, value in result.items():
if isinstance(value, list):
new_value = []
last_type = None
last_index = None
for i in range(len(value)):
if "type" in value[i]:
current_type = value[i]["type"]
current_index = value[i].get("index", None)
if current_type == last_type and current_index == last_index:
# Merge all non-type fields
for k, v in value[i].items():
if k != "type":
if k in new_value[-1]:
new_value[-1][k] += v
else:
new_value[-1][k] = v
else:
new_value.append(value[i])
last_type = current_type
last_index = current_index
result[key] = new_value
return result
================================================
FILE: cortex/orchestrator/__init__.py
================================================
"""Orchestrator module - Provides generator merging and agent coordination functionality."""
from .local_runner import LocalRunner
from .orchestrator import Orchestrator
from .remote_runner import RemoteRunner
from .runner import Runner
from .types import AgentEvent, AgentRequest, ClientToolCall, ClientToolCallType
__all__ = [
"Orchestrator",
"AgentEvent",
"AgentRequest",
"ClientToolCall",
"ClientToolCallType",
"Runner",
"LocalRunner",
"RemoteRunner",
]
================================================
FILE: cortex/orchestrator/local_runner.py
================================================
"""Local Runner implementation"""
import asyncio
from typing import AsyncGenerator, Optional
from uuid import uuid4
from cortex.model.definition import ChatMessage, Function
from cortex.agents.agent_factory import AgentFactory
from cortex.agents.base_agent import BaseAgent
from cortex.agents.input.input import InputChannel
from cortex.agents.types import (
AgentConfig,
AgentMessageType,
AgentResponse,
AgentRunningStatus,
)
from cortex.orchestrator.runner import Runner
from cortex.orchestrator.types import (
AgentEvent,
AgentEventType,
ClientToolCall,
ClientToolCallType,
)
from cortex.tools.base import ToolSchema
from cortex.tools.types import ToolParameters, ToolType
from cortex.utils.generator_merger import GeneratorMerger
class LocalRunner(Runner):
"""Local Agent runner."""
_continue_input: bool = False
def __init__(
self,
agent_factory: AgentFactory,
task_id: str,
parent_task_id: str | None = None,
root_task_id: str | None = None,
tool_call_id: str | None = None,
):
"""Initialize LocalRunner.
Args:
task_id: Task ID
parent_task_id: Parent task ID
root_task_id: Root task ID
continue_input: Whether to continue input
tool_call_id: Tool call ID
"""
super().__init__(task_id, parent_task_id, root_task_id)
self._agent_factory = agent_factory
self._tool_call_id = tool_call_id
self._task_id = task_id
self._merger = GeneratorMerger(
on_generator_complete=self._on_generator_complete
)
self._agent_factory = agent_factory
self._agent: BaseAgent | None = None
self._agent_name: str | None = None
self._result: AgentEvent | None = None
async def init(
self,
agent_name: str,
context_id: str | None = None,
config: AgentConfig | None = None,
) -> None:
"""Initialize Runner.
Args:
agent_name: Agent name
context_id: Context ID
config: AgentConfig
"""
if config is not None and not config.use_share_context:
context_id = uuid4().hex
self._agent_name = agent_name
self._agent = await self._agent_factory.make_agent(
agent_name, context_id, config
)
if self._agent.config.unfinished_mode:
self._continue_input = True
self._message_queue: asyncio.Queue[ChatMessage] = asyncio.Queue()
self._messages: InputChannel[ChatMessage] = InputChannel(
self._message_queue
)
else:
self._messages: list[ChatMessage] = []
self._agent.toolset().set_on_send(self._on_client_tool_send)
async def send(self, event: AgentEvent) -> None:
"""Send AgentEvent.
Args:
event: AgentEvent to send
"""
if event.type == AgentEventType.REQUEST:
# Add messages from request to queue
if event.request and event.request.messages:
for message in event.request.messages:
if self._continue_input:
await self._message_queue.put(message)
else:
self._messages.append(message)
if event.type == AgentEventType.CLIENT_TOOL_RESULT:
result = event.client_tool_result
if result is not None:
self._agent.toolset().set_response(
result.message.tool_call_id,
result.message.content,
result.error_msg,
)
async def run(self) -> AsyncGenerator[AgentEvent, None]:
"""Run and return AgentEvent generator.
Yields:
AgentEvent: Agent event
"""
if self._task_id is None:
raise ValueError("task_id not set, please call send() first to send REQUEST event")
async def agent_generator() -> AsyncGenerator[AgentEvent, None]:
if self._agent is None:
raise ValueError("Agent not initialized, please call init() first")
async for response in self._agent.run(self._messages):
self._on_agent_finished(response)
yield AgentEvent(
agent_name=self._agent_name,
task_id=self._task_id,
parent_task_id=self._parent_task_id,
root_task_id=self._root_task_id,
type=AgentEventType.RESPONSE,
response=response,
)
self._messages = []
self._merger.add_async_generator(
agent_generator, generator_id=f"local_runner_{self._task_id}"
)
async for event in self._merger.merge():
yield event
def _on_agent_finished(self, response: AgentResponse) -> None:
"""Handle Agent finished event.
Args:
response: AgentResponse
"""
if response.status != AgentRunningStatus.FINISHED:
return
if response.message_type == AgentMessageType.STREAM:
return
if self._tool_call_id is None:
self._result = AgentEvent(
agent_name=self._agent_name,
task_id=self._task_id,
parent_task_id=self._parent_task_id,
root_task_id=self._root_task_id,
type=AgentEventType.RESPONSE,
response=response,
)
return
response.message.tool_call_id = self._tool_call_id
self._result = AgentEvent(
agent_name=self._agent_name,
task_id=self._task_id,
parent_task_id=self._parent_task_id,
root_task_id=self._root_task_id,
type=AgentEventType.CLIENT_TOOL_RESULT,
client_tool_result=response,
)
async def _on_generator_complete(
self, generator_id: str, generator_type: str, error: Optional[Exception]
) -> None:
"""Handle generator complete event.
Args:
generator_id: Generator ID
generator_type: Generator Type
error: Error
"""
# Generator complete event handling (no special handling needed currently)
async def _on_client_tool_send(
self, tool_name: str, tool_schema: ToolSchema, tool_parameters: ToolParameters
) -> None:
"""Handle client tool send event.
Args:
tool_name: Tool name
tool_schema: Tool Schema
tool_parameters: Tool Parameters (contains tool_call_id in kwargs)
"""
# Extract tool_call_id from tool_parameters.kwargs
tool_call_id = tool_parameters.kwargs.get("tool_call_id")
if tool_call_id is None:
tool_call_id = (
f"tool_call_{hash(tool_name)}_{hash(tool_parameters.parameters)}"
)
async def tool_generator() -> AsyncGenerator[AgentResponse, None]:
tool_type = ClientToolCallType.TOOL
if tool_schema.tool_type == ToolType.AGENT:
tool_type = ClientToolCallType.AGENT
elif tool_schema.tool_type == ToolType.ASK_INPUT:
tool_type = ClientToolCallType.ASK_INPUT
# tool params event
yield AgentEvent(
task_id=self._task_id,
parent_task_id=self._parent_task_id,
root_task_id=self._root_task_id,
agent_name=self._agent_name,
type=AgentEventType.CLIENT_TOOL_CALL,
client_tool_call=ClientToolCall(
tool_call_id=tool_call_id,
function=Function(
arguments=tool_parameters.parameters,
name=tool_schema.name,
),
type=tool_type,
extra=tool_parameters.kwargs,
),
)
self._merger.add_async_generator(tool_generator, generator_id=tool_call_id)
def get_result(self) -> AgentEvent:
"""Get AgentEvent.
Returns:
AgentEvent: AgentEvent
"""
return self._result
================================================
FILE: cortex/orchestrator/orchestrator.py
================================================
"""Orchestrator for coordinating execution of multiple Agents."""
import logging
from enum import Enum
from typing import AsyncGenerator, Optional
from uuid import uuid4
from cortex.agents.agent_factory import AgentFactory
from cortex.agents.types import AgentConfig, AgentResponse, AgentRunningStatus
from cortex.orchestrator.local_runner import LocalRunner
from cortex.orchestrator.runner import Runner
from cortex.orchestrator.types import (
AgentEvent,
AgentEventType,
AgentRequest,
ClientToolCallType,
)
from cortex.tools.agent_tool import AgentTool
from cortex.utils.generator_merger import GeneratorMerger
logger = logging.getLogger(__name__)
class OrchMode(str, Enum):
"""Orchestrator mode."""
MULTI = "multi"
SINGLE = "single"
class Orchestrator:
"""Orchestrator for coordinating execution of multiple Agents."""
def __init__(self, agent_factory: AgentFactory):
"""Initialize Orchestrator.
Args:
agent_factory: AgentFactory instance for creating Agents
"""
self._agent_factory = agent_factory
# Manage runner by task_id
self._runners: dict[str, Runner] = {}
# Manage generator_merger by root_task_id
self._mergers: dict[str, GeneratorMerger] = {}
# Record task_id to root_task_id mapping
self._task_to_root: dict[str, str] = {}
# Record task_id to parent_task_id mapping
self._task_to_parent: dict[str, str] = {}
# Record runner's task_id (for cleanup)
self._runner_task_ids: dict[str, str] = {} # runner_id -> task_id
def list_agents(self) -> list[AgentConfig]:
"""List all Agent configurations."""
return self._agent_factory.list_agents()
async def run(
self,
agent_name: str,
event: AgentEvent,
agent_config: AgentConfig | None = None,
mode: OrchMode = OrchMode.MULTI,
context_id: str | None = None,
) -> AsyncGenerator[AgentEvent, None]:
"""Run Agent and return response stream.
Args:
agent_name: Agent name
messages: Message list
agent_config: Agent configuration
Yields:
AgentEvent: Agent event
"""
task_id = event.task_id or f"root_{uuid4().hex}"
root_task_id = event.root_task_id or task_id
parent_task_id = event.parent_task_id or None
event.task_id = task_id
event.root_task_id = root_task_id
event.parent_task_id = parent_task_id
if context_id is None:
context_id = task_id
# Create root runner
runner = await self._create_runner(
task_id=task_id,
parent_task_id=parent_task_id,
root_task_id=root_task_id,
)
# Initialize runner
await runner.init(agent_name, context_id, agent_config)
async def on_generator_complete_with_mode(
generator_id: str, _generator_type: str, error: Optional[Exception]
) -> None:
await self._on_generator_complete(
generator_id, _generator_type, error, mode
)
# Create generator_merger to merge events
merger = GeneratorMerger(on_generator_complete=on_generator_complete_with_mode)
self._mergers[root_task_id] = merger
need_run_root_runner = True
while need_run_root_runner:
need_run_root_runner = False
await self.run_root_runner(runner, merger, root_task_id, event)
event = None
# Get events from merger and yield AgentResponse
async for event in merger.merge():
if not isinstance(event, AgentEvent):
continue
# Handle client tool call
if (
mode == OrchMode.MULTI
and event.type == AgentEventType.CLIENT_TOOL_CALL
):
# Check if it's AGENT type, only AGENT type needs special handling
client_tool_call = event.client_tool_call
if (
client_tool_call
and client_tool_call.type == ClientToolCallType.AGENT
):
await self._handle_client_tool_call(
context_id, event, root_task_id
)
need_run_root_runner = True
else:
# Non-AGENT type CLIENT_TOOL_CALL (e.g., ASK_INPUT, TOOL) yield directly
yield event
else:
yield event
self._cleanup_runner(task_id)
async def run_root_runner(
self,
runner: Runner,
merger: GeneratorMerger,
root_task_id: str,
event: AgentEvent,
):
# Add runner's run() to merger
async def runner_generator() -> AsyncGenerator[AgentEvent, None]:
async for event in runner.run():
yield event
merger.add_async_generator(runner_generator, generator_id=root_task_id)
if event is None:
return
if event.type == AgentEventType.REQUEST and len(event.request.messages) > 0:
# Send REQUEST event
await runner.send(event)
elif event.type == AgentEventType.CLIENT_TOOL_RESULT:
await runner.send(event)
async def send_event(self, event: AgentEvent) -> None:
"""Receive external input AgentEvent, find runner by task_id and send.
Args:
event: AgentEvent
"""
task_id = event.task_id
runner = self._runners.get(task_id)
if runner is None:
logger.error("Cannot find runner for task_id=%s", task_id)
raise ValueError("Cannot find runner for task_id=%s", task_id)
await runner.send(event)
async def _create_runner(
self,
task_id: str,
parent_task_id: str | None = None,
root_task_id: str | None = None,
tool_call_id: str | None = None,
) -> Runner:
"""Create Runner instance.
Args:
task_id: Task ID
parent_task_id: Parent task ID
root_task_id: Root task ID
tool_call_id: Tool call ID
Returns:
Runner instance
"""
runner = LocalRunner(
agent_factory=self._agent_factory,
task_id=task_id,
parent_task_id=parent_task_id,
root_task_id=root_task_id,
tool_call_id=tool_call_id,
)
self._runners[task_id] = runner
if root_task_id:
self._task_to_root[task_id] = root_task_id
if parent_task_id:
self._task_to_parent[task_id] = parent_task_id
return runner
async def _handle_client_tool_call(
self, context_id: str, event: AgentEvent, root_task_id: str
):
"""Handle client tool call event.
Args:
event: AgentEvent, type is CLIENT_TOOL_CALL
context_id: Context ID
root_task_id: Root task ID
Returns:
None
"""
client_tool_call = event.client_tool_call
if client_tool_call is None:
return
# Determine if it's an agent based on client_tool_call.type
is_agent = client_tool_call.type == ClientToolCallType.AGENT
if not is_agent:
# Not an agent, return directly
return
agent_name = client_tool_call.function.name
# Get agent config
try:
agent_config = self._agent_factory.get_default_agent_config(agent_name)
except ValueError:
agent_config = None
if agent_config is None:
return
child_task_id = f"child_{uuid4().hex}"
# Create child runner
child_runner = await self._create_runner(
task_id=child_task_id,
parent_task_id=event.task_id,
root_task_id=root_task_id,
tool_call_id=client_tool_call.tool_call_id,
)
# Initialize child runner
await child_runner.init(agent_name, context_id, agent_config)
# Get merger for root_task_id
merger = self._mergers.get(root_task_id)
if merger is None:
raise ValueError(f"Cannot find merger for root_task_id={root_task_id}")
# Prepare messages (convert tool call to user message)
tool_call_messages = AgentTool.parse_messages(
client_tool_call.function.arguments
)
if tool_call_messages is None:
tool_call_messages = []
# Add child runner's run() to merger
async def child_runner_generator() -> AsyncGenerator[AgentEvent, None]:
# Send REQUEST event to child runner
await child_runner.send(
AgentEvent(
task_id=child_task_id,
parent_task_id=event.task_id,
root_task_id=root_task_id,
type=AgentEventType.REQUEST,
request=AgentRequest(
agent_name=agent_name,
config=agent_config,
messages=tool_call_messages,
),
)
)
# Run child runner and yield events
async for child_event in child_runner.run():
yield child_event
merger.add_async_generator(child_runner_generator, generator_id=child_task_id)
self._runner_task_ids[child_task_id] = child_task_id
async def _on_generator_complete(
self,
generator_id: str,
_generator_type: str,
error: Optional[Exception],
mode: OrchMode,
) -> None:
"""Handle generator complete event (callback).
Args:
generator_id: Generator ID (usually task_id)
_generator_type: Generator Type (unused, kept for interface compatibility)
error: Error message (if any)
"""
# Find corresponding runner
task_id = generator_id
runner = self._runners.get(task_id)
if runner is None:
return
if mode == OrchMode.SINGLE:
self._cleanup_runner(task_id)
return
# Get parent_task_id
parent_task_id = self._task_to_parent.get(task_id)
if parent_task_id is None:
# No parent, means this is root runner, no need to send result
return
# Has parent, need to send result to parent runner
parent_runner = self._runners.get(parent_task_id)
if parent_runner is None:
# Parent runner doesn't exist, cleanup current runner
self._cleanup_runner(task_id)
return
# Create result response
if error:
result_response = AgentResponse(
status=AgentRunningStatus.ERROR,
error_msg=str(error),
)
else:
# Completed normally
runner_result = runner.get_result()
if runner_result is None:
result_response = AgentResponse(
status=AgentRunningStatus.ERROR,
error_msg=f"runner {task_id} did not return result",
)
else:
result_response = runner_result.client_tool_result
# Send CLIENT_TOOL_RESULT event to parent runner
await parent_runner.send(
AgentEvent(
task_id=parent_task_id,
parent_task_id=parent_runner.get_parent_task_id(),
root_task_id=parent_runner.get_root_task_id(),
type=AgentEventType.CLIENT_TOOL_RESULT,
client_tool_result=result_response,
)
)
# Cleanup completed runner
self._cleanup_runner(task_id)
def _cleanup_runner(self, task_id: str) -> None:
"""Cleanup completed runner.
Args:
task_id: Task ID
"""
# Remove from runners
if task_id in self._runners:
del self._runners[task_id]
# Remove from mappings
if task_id in self._task_to_root:
del self._task_to_root[task_id]
if task_id in self._task_to_parent:
del self._task_to_parent[task_id]
if task_id in self._runner_task_ids:
del self._runner_task_ids[task_id]
# If it's root_task_id, cleanup merger
if task_id in self._mergers:
del self._mergers[task_id]
================================================
FILE: cortex/orchestrator/remote_runner.py
================================================
from typing import AsyncGenerator
from cortex.agents.types import AgentConfig, AgentRunningStatus
from cortex.orchestrator.runner import Runner
from cortex.orchestrator.types import AgentEvent, AgentEventType
from cortex.server.channel.channel import Channel
class RemoteRunner(Runner):
"""Remote Agent runner that communicates with remote services through Channel."""
def __init__(self, channel: Channel):
"""
Initialize remote runner.
Args:
channel: Channel instance for communicating with remote services
"""
self.channel = channel
self._task_id: str | None = None
self._running: bool = False
async def init(
self,
agent_name: str,
context_id: str | None = None,
config: AgentConfig | None = None,
) -> None:
"""
Initialize Runner.
Args:
agent_name: Agent name
context_id: Context ID
config: AgentConfig
"""
# Remote runner doesn't need local configuration, configuration is on the remote server side
# If needed, you can save config or perform other initialization here
async def send(self, event: AgentEvent) -> None:
"""
Send AgentEvent to remote service.
Args:
event: The AgentEvent to send
"""
if event.type == AgentEventType.REQUEST:
# Save task_id
if self._task_id is None:
self._task_id = event.task_id
# Send event to remote service
await self.channel.send(event.model_dump())
async def run(self) -> AsyncGenerator[AgentEvent, None]:
"""
Run and return an AgentEvent generator, receiving events from remote service.
Yields:
AgentEvent: Agent events
"""
if self._task_id is None:
raise ValueError("task_id is not set, please call send() to send REQUEST event first")
self._running = True
try:
while self._running:
# Receive events from remote service
data = await self.channel.receive()
event = AgentEvent.model_validate(data)
# Only yield events related to current task_id
if event.task_id == self._task_id:
yield event
# Stop receiving if error or completion signal is received
if event.type == AgentEventType.ERROR:
break
elif event.type == AgentEventType.RESPONSE and event.response:
if event.response.status in (
AgentRunningStatus.FINISHED,
AgentRunningStatus.STOPPED,
AgentRunningStatus.ERROR,
):
break
finally:
self._running = False
================================================
FILE: cortex/orchestrator/runner.py
================================================
"""Runner interface definition."""
from abc import ABC, abstractmethod
from typing import AsyncGenerator
from cortex.agents.types import AgentConfig
from cortex.orchestrator.types import AgentEvent
class Runner(ABC):
"""Runner interface for handling AgentEvent sending and running."""
def __init__(
self,
task_id: str,
parent_task_id: str | None = None,
root_task_id: str | None = None,
):
"""
Initialize Runner.
Args:
task_id: Task ID
parent_task_id: Parent task ID
root_task_id: Root task ID
"""
self._task_id: str = task_id
self._parent_task_id: str | None = parent_task_id
self._root_task_id: str | None = root_task_id
def get_parent_task_id(self) -> str | None:
"""
Get parent task ID.
Returns:
str | None: Parent task ID
"""
return self._parent_task_id
def get_root_task_id(self) -> str | None:
"""
Get root task ID.
Returns:
str | None: Root task ID
"""
return self._root_task_id
@abstractmethod
async def init(
self,
agent_name: str,
context_id: str | None = None,
config: AgentConfig | None = None,
) -> None:
"""
Initialize Runner.
Args:
agent_name: Agent name
context_id: Context ID
config: AgentConfig
"""
raise NotImplementedError
@abstractmethod
async def send(self, event: AgentEvent) -> None:
"""
Send AgentEvent.
Args:
event: The AgentEvent to send
"""
raise NotImplementedError
@abstractmethod
async def run(self) -> AsyncGenerator[AgentEvent, None]:
"""
Run and return an AgentEvent generator.
Yields:
AgentEvent: Agent events
"""
raise NotImplementedError
@abstractmethod
def get_result(self) -> AgentEvent:
"""
Get AgentEvent.
Returns:
AgentEvent: AgentEvent
"""
raise NotImplementedError
================================================
FILE: cortex/orchestrator/types.py
================================================
from enum import Enum
from uuid import uuid4
from cortex.model.definition import ChatMessage, Function
from pydantic import BaseModel, Field
from cortex.agents.types import AgentConfig, AgentResponse
class AgentEventType(str, Enum):
"""Agent event type enum."""
REQUEST = "request"
RESPONSE = "response"
ERROR = "error"
SIGNAL = "signal"
CLIENT_TOOL_CALL = "client_tool_call"
CLIENT_TOOL_RESULT = "client_tool_result"
class AgentRequest(BaseModel):
"""Agent request model."""
agent_name: str
config: AgentConfig | None = None
messages: list[ChatMessage] | None = None
class ClientToolCallType(str, Enum):
"""Client tool call type enum."""
AGENT = "agent"
TOOL = "tool"
ASK_INPUT = "ask_input"
class ClientToolCall(BaseModel):
"""Client tool call model."""
tool_call_id: str
function: Function
type: ClientToolCallType
extra: dict | None = None
class AgentEvent(BaseModel):
"""Agent event model."""
event_id: str = Field(default_factory=lambda: f"{uuid4().hex}")
task_id: str | None = None
parent_task_id: str | None = None
root_task_id: str | None = None
type: AgentEventType
metadata: dict | None = None
agent_name: str | None = None
# Input
request: AgentRequest | None = None
# Output
response: AgentResponse | None = None
error: str | None = None
# client tool call
client_tool_call: ClientToolCall | None = None
client_tool_result: AgentResponse | None = None
================================================
FILE: cortex/runtime_config.py
================================================
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
try:
import yaml
except Exception: # noqa: BLE001
yaml = None
_RUNTIME_CONFIG_CACHE: dict[str, Any] | None = None
def _repo_root() -> Path:
# `cortex/` lives under the project root in this repo.
return Path(__file__).resolve().parent.parent
def _pick_first_present(cfg: dict[str, Any], keys: tuple[str, ...]) -> Any:
for key in keys:
if key in cfg:
return cfg.get(key)
return None
def _as_int(value: Any) -> int | None:
if value is None:
return None
if isinstance(value, bool):
return None
if isinstance(value, (int, float)):
return int(value)
if isinstance(value, str):
text = value.strip()
if not text:
return None
try:
return int(text)
except Exception: # noqa: BLE001
return None
return None
def load_runtime_config() -> dict[str, Any]:
"""Load runtime config from YAML, cached for process lifetime.
Resolution order:
1) `$STEP_DEEPRESEARCH_CONFIG` (if set)
2) `/config.yaml` (if exists)
"""
global _RUNTIME_CONFIG_CACHE # noqa: PLW0603
if _RUNTIME_CONFIG_CACHE is not None:
return _RUNTIME_CONFIG_CACHE
if yaml is None:
_RUNTIME_CONFIG_CACHE = {}
return _RUNTIME_CONFIG_CACHE
env_path = (os.getenv("STEP_DEEPRESEARCH_CONFIG") or "").strip()
candidates: list[Path] = []
if env_path:
candidates.append(Path(env_path))
candidates.append(_repo_root() / "config.yaml")
for path in candidates:
try:
if not path.exists() or not path.is_file():
continue
with path.open("r", encoding="utf-8") as f:
loaded = yaml.safe_load(f) or {}
if isinstance(loaded, dict):
_RUNTIME_CONFIG_CACHE = loaded
return _RUNTIME_CONFIG_CACHE
except Exception: # noqa: BLE001
continue
_RUNTIME_CONFIG_CACHE = {}
return _RUNTIME_CONFIG_CACHE
def get_context_limit_overrides() -> tuple[int | None, int | None]:
"""Return (upper, lower) context limit overrides from runtime config."""
cfg = load_runtime_config()
upper = _pick_first_present(
cfg, ("context_upper_limit", "final_answer_context_upper_limit")
)
lower = _pick_first_present(
cfg, ("context_lower_limit", "final_answer_context_lower_limit")
)
return _as_int(upper), _as_int(lower)
================================================
FILE: cortex/server/channel/channel.py
================================================
from abc import ABC, abstractmethod
class Channel(ABC):
def __init__(self, context_id: str) -> None:
self.context_id: str = context_id
@abstractmethod
async def send(self, event: dict[str, object]):
pass
@abstractmethod
async def receive(self) -> dict[str, object]:
pass
@abstractmethod
async def heartbeat(self):
pass
@abstractmethod
async def close(self):
pass
================================================
FILE: cortex/server/channel/error.py
================================================
class ChannelClosedError(Exception):
pass
================================================
FILE: cortex/server/channel/memory_channel.py
================================================
import asyncio
from typing import Dict
from loguru import logger
from cortex.server.channel.channel import Channel
from cortex.server.channel.error import ChannelClosedError
class MemoryChannel(Channel):
def __init__(self, context_id: str) -> None:
super().__init__(context_id)
self.send_queue = asyncio.Queue()
self.receive_queue = asyncio.Queue()
self.is_closed = False
logger.debug(f"MemoryChannel created with context_id: {context_id}")
async def send(self, event: Dict[str, object]):
if self.is_closed:
raise ChannelClosedError("Channel is closed")
await self.send_queue.put(event)
logger.debug(f"MemoryChannel {self.context_id}: sent event")
async def receive(self) -> Dict[str, object]:
if self.is_closed:
raise ChannelClosedError("Channel is closed")
data = await self.receive_queue.get()
logger.debug(f"MemoryChannel {self.context_id}: received event")
return data
async def heartbeat(self):
pass
async def close(self):
logger.debug(f"MemoryChannel {self.context_id}: closing")
self.is_closed = True
================================================
FILE: cortex/server/channel/ws_channel.py
================================================
"""WebSocket channel implementation for agent communication."""
import asyncio
import logging
import time
from typing import cast
from fastapi import WebSocket, WebSocketDisconnect
from cortex.server.channel.channel import Channel
from cortex.server.channel.error import ChannelClosedError
logger = logging.getLogger(__name__)
class WebSocketChannel(Channel):
_closed: bool = False
last_heartbeat_time: float = time.time()
def __init__(self, ws: WebSocket) -> None:
self.ws: WebSocket = ws
async def send(self, event: dict[str, object]) -> None:
if self._closed:
raise ChannelClosedError
await self.ws.send_json(event)
async def receive(self) -> dict[str, object]:
if self._closed:
raise ChannelClosedError
while True:
try:
data = cast(dict[str, object], await self.ws.receive_json())
self.last_heartbeat_time = time.time()
if data.get("type") == "ping":
await self.ws.send_json({"type": "pong"})
continue
if data.get("type") == "pong":
continue
return data
except WebSocketDisconnect:
self._closed = True
raise ChannelClosedError
async def heartbeat(self) -> None:
await asyncio.sleep(10)
while not self._closed:
await self.ws.send_json({"type": "ping"})
await asyncio.sleep(10)
# TODO Check if last_heartbeat_time has timed out
async def close(self) -> None:
logger.debug("WebSocketChannel closing")
if not self._closed:
self._closed = True
await self.ws.close()
================================================
FILE: cortex/server/http_server.py
================================================
import asyncio
import logging
import uuid
import uvicorn
from agentkit.trace import SpanContext, Tracer
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.websockets import WebSocket
from pydantic import ValidationError
from starlette.middleware.base import BaseHTTPMiddleware
from cortex.agents.types import AgentConfig
from cortex.orchestrator.orchestrator import Orchestrator, OrchMode
from cortex.orchestrator.types import AgentEvent, AgentEventType
from cortex.server.channel.ws_channel import WebSocketChannel
from cortex.server.log.log import setup_logging
from cortex.server.log.trace import set_trace_id
setup_logging()
logger = logging.getLogger(__name__)
def extract_and_set_trace_id(headers: dict) -> str:
"""
Extract x-step-trace from headers, auto-generate UUID if not present,
and set it to ContextVar
Args:
headers: HTTP headers dict
Returns:
trace_id string
"""
trace_id = headers.get("Step-Trace-ID")
if not trace_id:
trace_id = str(uuid.uuid4())
set_trace_id(trace_id)
return trace_id
class TraceMiddleware(BaseHTTPMiddleware):
"""Extract x-step-trace from HTTP header and inject into log context, auto-generate if not present"""
async def dispatch(self, request: Request, call_next):
# Use unified function to extract and set trace_id
extract_and_set_trace_id(request.headers)
response = await call_next(request)
return response
class HttpServer:
orch: Orchestrator
tracer: Tracer
def __init__(self, orch: Orchestrator, tracer: Tracer):
self.orch = orch
self.tracer = tracer
async def start(self, host: str = "0.0.0.0", port: int = 8001) -> None:
app = self._build_app()
logger.info("Starting HTTP server, listening on 0.0.0.0:8001")
# Configure uvicorn to use Python logging
server_config = uvicorn.Config(
app,
host=host,
port=port,
log_config=None, # Disable uvicorn default log config
access_log=True, # Enable access log
)
await uvicorn.Server(server_config).serve()
def _build_app(self) -> FastAPI:
app = FastAPI(title="Agent Server")
# Add CORS middleware, support cross-origin requests from any domain
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all domains
allow_credentials=True,
allow_methods=["*"], # Allow all HTTP methods
allow_headers=["*"], # Allow all request headers
)
# Add Trace middleware to extract trace_id from x-step-trace header and inject into log
app.add_middleware(TraceMiddleware)
@app.get("/agents")
async def list_agents() -> list[AgentConfig]:
return self.orch.list_agents()
@app.websocket("/multi/ws/{agent_name}/{context_id}")
async def multi_call(websocket: WebSocket, agent_name: str, context_id: str):
await websocket_handler(websocket, agent_name, context_id, OrchMode.MULTI)
@app.websocket("/single/ws/{agent_name}/{context_id}")
async def single_call(websocket: WebSocket, agent_name: str, context_id: str):
await websocket_handler(websocket, agent_name, context_id, OrchMode.SINGLE)
@app.post("/multi/sse/{agent_name}/{context_id}")
async def multi_call_sse(agent_name: str, context_id: str, request: AgentEvent):
return StreamingResponse(
sse_handler(
agent_name, context_id, OrchMode.MULTI, request
), # Async generator (core)
media_type="text/event-stream", # SSE response type
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
# Cross-origin support (needed if frontend and backend have different domains)
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
},
)
@app.post("/single/sse/{agent_name}/{context_id}")
async def single_call_sse(
agent_name: str, context_id: str, request: AgentEvent
):
return StreamingResponse(
sse_handler(
agent_name, context_id, OrchMode.SINGLE, request
), # Async generator (core)
media_type="text/event-stream", # SSE response type
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
# Cross-origin support (needed if frontend and backend have different domains)
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
},
)
async def websocket_handler(
websocket: WebSocket, agent_name: str, context_id: str, mode: OrchMode
):
# Use unified function to extract and set trace_id from headers
extract_and_set_trace_id(websocket.headers)
await websocket.accept()
channel = WebSocketChannel(websocket)
task_id = str(uuid.uuid4())
request = AgentEvent.model_validate(await channel.receive())
if request.task_id is None:
request.task_id = task_id
if request.type != AgentEventType.REQUEST:
raise ValueError("first event type must be request")
async def send_to_agent_loop():
while True:
try:
data = await channel.receive()
event = AgentEvent.model_validate(data)
await self.orch.send_event(event)
except ValidationError as e:
logger.error(f"Pydantic validation error: {e}")
continue
async def send_to_client_loop():
ctx = SpanContext(tracer=self.tracer, app_name=agent_name)
with ctx.span(name=f"orchestrator_{agent_name}_{context_id}"):
async for event in self.orch.run(
agent_name, request, request.request.config, mode, context_id
):
logger.warning(f'event: {event.model_dump()}')
await channel.send(event.model_dump())
agent_task = asyncio.create_task(
send_to_client_loop(), name="send_to_client_loop"
)
recv_task = asyncio.create_task(
send_to_agent_loop(), name="send_to_agent_loop"
)
done, pending = await asyncio.wait(
[recv_task, agent_task], return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
logger.info(f"Cancelling pending task: {task.get_name()}")
task.cancel()
for task in done:
try:
task.result()
logger.info(f"Completed task: {task.get_name()}")
except Exception as e:
logger.error(f"Task error in {task.get_name()}: {e}", exc_info=True)
await channel.close()
logger.info("WebSocket connection closed")
async def sse_handler(
agent_name: str, context_id: str, mode: OrchMode, request: AgentEvent
):
ctx = SpanContext(tracer=self.tracer, app_name=agent_name)
with ctx.span(name=f"http_sse_handler_{agent_name}_{context_id}"):
config = None
if request.request is not None:
config = request.request.config
async for event in self.orch.run(
agent_name, request, config, mode, context_id
):
yield f"data: {event.model_dump_json()}\n\n"
@app.get("/health")
async def health_check() -> dict[str, str]:
return {"status": "ok"}
return app
================================================
FILE: cortex/server/log/log.py
================================================
# context.py
import logging
import os
from datetime import datetime
from pathlib import Path
from pythonjsonlogger import jsonlogger
from cortex.server.log.trace import TraceIdFilter
def setup_logging(log_dir: str = "./logs", log_level: int = logging.WARNING):
"""
Configure logging system to output to both console and file.
Args:
log_dir: Log file directory, defaults to ./logs
log_level: Log level, defaults to INFO
"""
# Use rename_fields parameter to alias field names
formatter = jsonlogger.JsonFormatter(
"%(asctime)s %(levelname)s %(trace_id)s %(name)s %(message)s",
rename_fields={
"asctime": "time", # Timestamp
"levelname": "level", # Log level
"trace_id": "traceid", # Trace ID (camelCase)
"name": "name", # Logger name
"message": "msg", # Log message
},
)
# Set StreamHandler encoding to UTF-8 for proper Unicode handling
import sys
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
stream_handler.addFilter(TraceIdFilter())
handlers = [stream_handler]
# Add file Handler
log_path = Path(log_dir)
log_path.mkdir(parents=True, exist_ok=True)
log_file = log_path / f"server_{datetime.now().strftime('%Y%m%d')}.log"
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(formatter)
file_handler.addFilter(TraceIdFilter())
handlers.append(file_handler)
# Ensure basicConfig uses UTF-8 encoding
logging.basicConfig(
level=log_level,
handlers=handlers,
encoding="utf-8", # Add UTF-8 encoding support
)
================================================
FILE: cortex/server/log/trace.py
================================================
import logging
from contextvars import ContextVar
trace_id_var: ContextVar[str | None] = ContextVar("trace_id", default=None)
def set_trace_id(trace_id: str):
trace_id_var.set(trace_id)
def get_trace_id() -> str | None:
return trace_id_var.get()
class TraceIdFilter(logging.Filter):
"""Automatically inject trace_id from contextvars into LogRecord."""
def filter(self, record: logging.LogRecord) -> bool:
record.trace_id = get_trace_id() or "-"
return True
================================================
FILE: cortex/tools/__init__.py
================================================
"""Tool system for Agent components."""
from .agent_tool import AgentTool
from .base import Tool, ToolSchema
from .channel import Channel
from .client_tool import ClientTool
from .function_tool import FunctionTool
from .mcp_tool import MCPTool
from .session_tool import SessionTool
from .toolset import ToolSet
from .types import ExecutionType, ToolConfig, ToolType
__all__ = [
"Tool",
"ToolSchema",
"ExecutionType",
"ToolType",
"ToolConfig",
"MCPTool",
"FunctionTool",
"SessionTool",
"ClientTool",
"AgentTool",
"ToolSet",
"Channel",
]
================================================
FILE: cortex/tools/agent_tool.py
================================================
"""Agent Tool implementation - calls agents through Orchestrator."""
import json
import logging
from typing import Any, Optional
from cortex.model.definition import ChatMessage
from .base import Tool, ToolSchema
from .channel import Channel
from .types import ToolParameters, ToolType
logger = logging.getLogger(__name__)
class AgentTool(Tool):
"""Agent tool implementation for calling other Agents, communicates with Orchestrator through Channel."""
share_context: bool = False
def __init__(
self,
name: str,
description: str = "",
channel: Optional[Channel] = None,
timeout: Optional[float] = None,
share_context: bool = False,
**kwargs,
):
"""
Initialize Agent tool.
Args:
name: Tool name
description: Tool description
channel: Channel instance (for async communication)
timeout: Default timeout in seconds
share_context: Whether to share context
**kwargs: Additional parameters
"""
super().__init__(
name=name, description=description, tool_type=ToolType.AGENT, **kwargs
)
if channel is None:
raise ValueError("AgentTool requires a Channel instance")
self.channel = channel
self.timeout = timeout or 30.0
logger.debug(
"AgentTool initialized: name=%s, description=%s, timeout=%s",
name,
description,
self.timeout,
)
def _define_schema(self) -> ToolSchema:
"""Define Agent tool schema."""
schema = ToolSchema(
name=self.name,
description=self.description,
parameters={
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "Specific instruction content to send to the Agent (optional, ignored if messages is provided)",
},
"messages": {
"type": "array",
"description": "List of messages to send to the Agent (optional, ChatMessage format with role and content fields)",
"items": {
"type": "object",
"properties": {
"role": {
"type": "string",
"description": "Message role, e.g., 'user', 'assistant', 'system'",
"enum": ["user", "assistant", "system", "tool"],
},
"content": {
"type": "string",
"description": "Message content (string format)",
},
},
"required": ["role", "content"],
},
},
"timeout": {
"type": "number",
"description": "Timeout in seconds (optional)",
},
},
"required": [],
},
return_type="agent_response",
tool_type=self.tool_type,
)
logger.debug(
"AgentTool define schema: name=%s, parameters=%s", self.name, schema.parameters
)
return schema
async def _call(self, parameters: str, **kwargs) -> Any:
"""
Call the Agent tool.
Send request to Orchestrator through Channel, which creates a new Agent and executes the task.
Args:
parameters: Tool parameters (JSON string format), including:
- content: Specific instruction content (optional, ignored if messages is provided)
- messages: List of messages (optional, ChatMessage format with role and content fields)
- timeout: Timeout in seconds (optional)
**kwargs: Additional parameters (unused)
Returns:
Any: Agent execution result
Raises:
ValueError: Invalid parameters format
TimeoutError: Request timeout
"""
tool_call_id = kwargs.get("tool_call_id")
agent_name = self.name
# Process messages: prefer messages, if not provided use content to create ChatMessage
messages = AgentTool.parse_messages(parameters)
# Build request data in format expected by orchestrator
request_data = {
"agent_name": agent_name,
}
if messages:
request_data["messages"] = messages
request_data.update(kwargs)
logger.debug(
"AgentTool._call: tool=%s, agent=%s, messages=%d, tool_call_id=%s",
self.name,
agent_name,
len(messages) if messages else 0,
tool_call_id,
)
# Send request through Channel and wait for response
try:
tool_parameters = ToolParameters(parameters=parameters, kwargs=request_data)
_, response = await self.channel.send_request(
request_id=tool_call_id,
tool_name=self.name,
data=tool_parameters,
tool_schema=self.get_schema(),
timeout=self.timeout,
)
logger.debug(
"AgentTool._call completed: tool=%s, agent=%s, response_type=%s",
self.name,
agent_name,
type(response).__name__,
)
return response
except Exception as e:
logger.error(
"AgentTool._call failed: tool=%s, agent=%s, error=%s",
self.name,
agent_name,
e,
exc_info=True,
)
raise
@staticmethod
def parse_messages(parameters: str) -> list[ChatMessage]:
"""
Parse and prepare message list.
Args:
parameters: Tool parameters (JSON string format)
Returns:
list[ChatMessage]: Processed message list
"""
# Parse parameters JSON string
try:
params_dict = json.loads(parameters) if parameters else {}
except json.JSONDecodeError as e:
raise ValueError(f"Invalid parameters JSON format: {e}") from e
# Prioritize using messages
messages = params_dict.get("messages", [])
if messages:
processed = []
for msg in messages:
if isinstance(msg, dict):
processed.append(ChatMessage(**msg))
elif isinstance(msg, ChatMessage):
processed.append(msg)
else:
processed.append(msg)
return processed
# If no messages, try using content
content = params_dict.get("content")
if content:
return [ChatMessage(role="user", content=content)]
return []
================================================
FILE: cortex/tools/base.py
================================================
"""Base Tool class."""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from .types import ToolType
@dataclass
class ToolSchema:
"""Tool schema definition."""
name: str
description: str
parameters: Dict[str, Any] = field(default_factory=dict)
return_type: Optional[str] = None
tool_type: Optional[ToolType] = None
class Tool(ABC):
"""Base class for tools."""
def __init__(
self,
name: str,
description: str = "",
tool_type: Optional[ToolType] = None,
**kwargs, # noqa: ARG002
):
"""
Initialize the tool.
Args:
name: Tool name
description: Tool description
tool_type: Tool type
**kwargs: Additional parameters (passed to subclasses)
"""
self.name = name
self.description = description
self.tool_type = tool_type
self._schema: Optional[ToolSchema] = None
def get_schema(self) -> ToolSchema:
"""
Get tool schema.
Returns:
ToolSchema: Tool schema object
"""
if self._schema is None:
self._schema = self._define_schema()
return self._schema
@abstractmethod
def _define_schema(self) -> ToolSchema:
"""
Define tool schema (must be implemented by subclasses).
Returns:
ToolSchema: Tool schema object
"""
pass # pragma: no cover
async def call(self, parameters: str, **kwargs) -> Any:
"""
Call the tool (async).
Args:
parameters: Tool parameters (string format)
**kwargs: Additional parameters
Returns:
Any: Tool execution result
"""
return await self._call(parameters, **kwargs)
@abstractmethod
async def _call(self, parameters: str, **kwargs) -> Any:
"""
Tool call implementation (must be implemented by subclasses).
Args:
parameters: Tool parameters (string format)
**kwargs: Additional parameters
Returns:
Any: Tool execution result
"""
================================================
FILE: cortex/tools/channel.py
================================================
"""Channel for async tool execution communication."""
import asyncio
from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple
from .base import ToolSchema
from .types import ToolParameters
class MessageType(Enum):
"""Message type."""
REQUEST = "request"
RESPONSE = "response"
ERROR = "error"
@dataclass
class ChannelMessage:
"""Channel message."""
message_type: MessageType
tool_name: str
request_id: str
data: Any
error: Optional[str] = None
class Channel:
"""Channel for async tool execution communication."""
def __init__(
self,
on_send: Optional[
Callable[[str, ToolSchema, ToolParameters], Awaitable[None]]
] = None,
):
"""
Initialize Channel.
Args:
on_send: Callback function for sending requests,
receives (tool_name, tool_schema, data) and sends data asynchronously
"""
self._pending_requests: Dict[str, asyncio.Future] = {}
self._request_counter = 0
self._on_send = on_send
def set_on_send(
self,
on_send: Optional[Callable[[str, ToolSchema, ToolParameters], Awaitable[None]]],
):
"""
Set the send callback function.
Args:
on_send: Callback function for sending requests,
receives (tool_name, tool_schema, tool_parameters) and sends data asynchronously
"""
self._on_send = on_send
def create_request_id(self) -> str:
"""
Create a request ID.
Returns:
str: Request ID
"""
self._request_counter += 1
return f"req_{self._request_counter}"
async def send_request(
self,
tool_name: str,
data: ToolParameters,
tool_schema: ToolSchema,
request_id: Optional[str] = None,
timeout: Optional[float] = None,
on_send: Optional[
Callable[[str, ToolSchema, ToolParameters], Awaitable[None]]
] = None,
) -> Tuple[str, Any]:
"""
Send request and wait for response.
Args:
tool_name: Tool name
data: Request data (ToolParameters object)
tool_schema: Tool schema
request_id: Request ID (auto-generated if not provided)
timeout: Timeout in seconds
on_send: Callback function for sending requests,
receives (tool_name, tool_schema, tool_parameters) and sends data asynchronously.
If provided, overrides the on_send set during initialization.
Returns:
Tuple[str, Any]: (request_id, response data)
Raises:
TimeoutError: Request timeout
Exception: Tool execution error
"""
# Auto-generate request_id if not provided
if request_id is None:
request_id = self.create_request_id()
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
# Send data to external system
send_handler = on_send or self._on_send
if send_handler:
# Create a copy of data to avoid modifying the original object
tool_parameters = ToolParameters(
parameters=data.parameters, kwargs={**data.kwargs}
)
# Add request_id to kwargs
tool_parameters.kwargs["_request_id"] = request_id
await send_handler(tool_name, tool_schema, tool_parameters)
# Wait for response (response is set via set_response)
if timeout:
response = await asyncio.wait_for(future, timeout=timeout)
else:
response = await future
return (request_id, response)
except asyncio.TimeoutError as exc:
self._pending_requests.pop(request_id, None)
raise TimeoutError(
f"Request {request_id} for tool {tool_name} timed out"
) from exc
except Exception:
self._pending_requests.pop(request_id, None)
raise
finally:
self._pending_requests.pop(request_id, None)
def set_response(self, request_id: str, data: Any, error: Optional[str] = None):
"""
Set response data.
Args:
request_id: Request ID
data: Response data
error: Error message (if any)
"""
if request_id not in self._pending_requests:
return
future = self._pending_requests[request_id]
# Check if future has already been set
if future.done():
return
try:
if error:
future.set_exception(Exception(error))
else:
future.set_result(data)
except Exception:
# Ignore if future has already been set or cancelled
# This may happen in concurrent scenarios
pass
================================================
FILE: cortex/tools/client_tool.py
================================================
"""Client Tool implementation - uses Channel for communication."""
import logging
from typing import Any, Optional
from .base import Tool, ToolSchema
from .channel import Channel
from .types import ToolParameters, ToolType
logger = logging.getLogger(__name__)
class ClientTool(Tool):
"""Client tool implementation, uses Channel for async communication."""
def __init__(
self,
name: str,
description: str = "",
tool_type: Optional[ToolType] = None,
channel: Optional[Channel] = None,
timeout: Optional[float] = None,
**kwargs,
):
"""
Initialize client tool.
Args:
name: Tool name
description: Tool description
channel: Channel instance (for async communication)
timeout: Default timeout in seconds
**kwargs: Additional parameters
"""
# If tool_type is not specified, default to CLIENT
# But if name is "ask_input", use ASK_INPUT
if tool_type is None:
if name == "ask_input":
tool_type = ToolType.ASK_INPUT
else:
tool_type = ToolType.CLIENT
super().__init__(
name=name, description=description, tool_type=tool_type, **kwargs
)
if channel is None:
raise ValueError("ClientTool requires a Channel instance")
self.channel = channel
self.timeout = timeout or 30.0
self._client_params = kwargs.get("client_params", {})
def _define_schema(self) -> ToolSchema:
"""Define client tool schema."""
properties = self._client_params.get("properties", {})
required = self._client_params.get("required", [])
return ToolSchema(
name=self.name,
description=self.description,
parameters={
"type": "object",
"properties": properties,
"required": required,
},
return_type="client_response",
tool_type=self.tool_type,
)
async def _call(self, parameters: str, **kwargs) -> Any:
"""
Call the client tool.
Send request through Channel and wait for response.
Args:
parameters: Tool parameters (string format)
**kwargs: Additional parameters
Returns:
Any: Execution result
"""
# Build request data, convert parameters and kwargs to ToolParameters
tool_parameters = ToolParameters(parameters=parameters, kwargs=kwargs)
tool_call_id = kwargs.get("tool_call_id")
if tool_call_id is None:
tool_call_id = f"tool_call_{hash(self.name)}_{hash(parameters)}"
logger.debug(
"ClientTool._call: tool=%s, parameters=%s, tool_call_id=%s",
self.name,
parameters,
tool_call_id,
)
# Send request through Channel and wait for response
_, response = await self.channel.send_request(
tool_name=self.name,
data=tool_parameters,
tool_schema=self.get_schema(),
request_id=kwargs.get("tool_call_id"),
timeout=kwargs.get("timeout") or self.timeout,
)
return response
================================================
FILE: cortex/tools/function_tool.py
================================================
"""Function Tool implementation."""
import inspect
import json
from typing import Any, Callable, Optional
from agents.function_schema import function_schema
from .base import Tool, ToolSchema
from .types import ToolType
class FunctionTool(Tool):
"""Function tool implementation."""
def __init__(
self, name: str, func: Callable, description: Optional[str] = None, **kwargs
):
"""
Initialize function tool.
Args:
name: Tool name
func: Function to wrap
description: Tool description
**kwargs: Additional parameters
"""
self.func = func
description = description or func.__doc__ or ""
super().__init__(
name=name, description=description, tool_type=ToolType.FUNCTION, **kwargs
)
def _define_schema(self) -> ToolSchema:
"""Define function tool schema."""
# Infer parameters from function signature
# strict_json_schema=False to preserve default value behavior for optional parameters
data = function_schema(self.func, strict_json_schema=False)
return ToolSchema(
name=data.name,
description=data.description,
parameters=data.params_json_schema,
return_type="any",
)
async def _call(self, parameters: str, **kwargs) -> Any:
"""
Call the function tool.
Args:
parameters: Tool parameters (string format)
**kwargs: Additional parameters
Returns:
Any: Function execution result
"""
# Convert parameters to dictionary
parameters_dict = json.loads(parameters)
# If the function is async
if inspect.iscoroutinefunction(self.func):
return await self.func(**parameters_dict)
else:
# Call sync function in async context
import asyncio
return await asyncio.to_thread(self.func, **parameters_dict)
================================================
FILE: cortex/tools/mcp.py
================================================
import logging
from contextlib import AsyncExitStack
from typing import Any, final
from mcp import ClientSession, Tool
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import CallToolResult
logger = logging.getLogger(__name__)
class MCPNotInitializedError(Exception):
"""Exception raised when MCP client is not initialized."""
def __init__(self) -> None:
super().__init__("MCP client is not initialized")
@final
class MCPClient:
def __init__(self, server_url: str) -> None:
self.server_url = server_url
self.session = None
self.exit_stack = AsyncExitStack()
async def initialize(self) -> None:
read_stream, write_stream, _ = await self.exit_stack.enter_async_context(
streamablehttp_client(self.server_url)
)
self.session = await self.exit_stack.enter_async_context(
ClientSession(read_stream, write_stream)
)
logger.info(f"Connected to server: {self.server_url}")
await self.session.initialize()
logger.info("Initialized session")
async def aclose(self) -> None:
"""Close all resources opened via the exit stack.
This must be called from the same task that created the contexts to
avoid AnyIO cancel scope errors during shutdown.
"""
try:
await self.exit_stack.aclose()
finally:
self.session = None
async def list_tools(self) -> list[Tool]:
if self.session is None:
raise MCPNotInitializedError()
result = await self.session.list_tools()
return result.tools
async def call_tool(
self, tool_name: str, arguments: dict[str, Any]
) -> CallToolResult:
if self.session is None:
raise MCPNotInitializedError()
return await self.session.call_tool(tool_name, arguments)
================================================
FILE: cortex/tools/mcp_tool.py
================================================
"""MCP (Model Context Protocol) Tool implementation."""
import json
from typing import Any
from .base import Tool, ToolSchema
from .mcp import MCPClient
from .types import ToolType
class MCPTool(Tool):
"""MCP tool implementation."""
def __init__(
self, name: str, description: str = "", mcp_server: str = None, **kwargs
):
"""
Initialize MCP tool.
Args:
name: Tool name
description: Tool description
mcp_server: MCP server
**kwargs: Additional parameters
"""
super().__init__(
name=name, description=description, tool_type=ToolType.MCP, **kwargs
)
self.mcp_server = mcp_server
self._mcp_params = kwargs.get("mcp_params", {})
# MCP tool can handle directly, no Channel needed
def _define_schema(self) -> ToolSchema:
"""Define MCP tool schema."""
return ToolSchema(
name=self.name,
description=self.description,
parameters={
"type": "object",
"properties": self._mcp_params.get("properties", {}),
"required": self._mcp_params.get("required", []),
},
return_type="mcp_response",
)
async def _call(self, parameters: str, **kwargs) -> Any:
"""
Call MCP tool.
Args:
parameters: Tool parameters (string format)
**kwargs: Additional parameters
Returns:
Any: Execution result
"""
arguments = json.loads(parameters)
mcp_client = MCPClient(self.mcp_server)
await mcp_client.initialize()
try:
result = await mcp_client.call_tool(self.name, arguments)
return result
finally:
await mcp_client.aclose()
================================================
FILE: cortex/tools/session_tool.py
================================================
"""Session Tool implementation."""
from typing import Any, Dict, Optional
from .base import Tool, ToolSchema
from .types import ToolType
class SessionTool(Tool):
"""Session tool implementation for maintaining session state."""
def __init__(
self,
name: str,
description: str = "",
session_id: Optional[str] = None,
**kwargs,
):
"""
Initialize session tool.
Args:
name: Tool name
description: Tool description
session_id: Session ID
**kwargs: Additional parameters
"""
super().__init__(
name=name, description=description, tool_type=ToolType.SESSION, **kwargs
)
self.session_id = session_id
self._session_state: Dict[str, Any] = {}
def _define_schema(self) -> ToolSchema:
"""Define session tool schema."""
return ToolSchema(
name=self.name,
description=self.description,
parameters={
"type": "object",
"properties": {
"action": {
"type": "string",
"description": "Session action (get, set, update, clear)",
"enum": ["get", "set", "update", "clear"],
},
"key": {"type": "string", "description": "Session key"},
"value": {"type": "any", "description": "Session value"},
},
"required": ["action"],
},
return_type="session_response",
)
async def _call(self, parameters: str, **kwargs) -> Any:
"""
Call session tool.
Args:
parameters: Tool parameters (string format)
**kwargs: Additional parameters
Returns:
Any: Execution result
"""
action = kwargs.get("action", "get")
if action == "get":
key = kwargs.get("key")
if key:
return self._session_state.get(key)
return self._session_state
elif action == "set":
key = kwargs.get("key")
value = kwargs.get("value")
if key:
self._session_state[key] = value
return {"status": "ok", "key": key, "value": value}
elif action == "update":
updates = kwargs.get("value", {})
self._session_state.update(updates)
return {"status": "ok", "updated": list(updates.keys())}
elif action == "clear":
self._session_state.clear()
return {"status": "ok", "message": "session cleared"}
else:
raise ValueError(f"Unknown action: {action}")
================================================
FILE: cortex/tools/toolset.py
================================================
"""ToolSet for managing and executing tools."""
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type
from agentkit.trace import get_current_context
from .agent_tool import AgentTool
from .base import Tool, ToolSchema
from .channel import Channel
from .client_tool import ClientTool
from .function_tool import FunctionTool
from .mcp import MCPClient
from .mcp_tool import MCPTool
from .session_tool import SessionTool
from .types import ToolConfig, ToolParameters, ToolType
logger = logging.getLogger(__name__)
class ToolSet:
"""Tool collection that manages tool registration, initialization, and invocation."""
_client_tool_results: dict[str, Any] = {}
def __init__(
self,
channel: Optional[Channel] = None,
on_send: Optional[
Callable[[str, ToolSchema, ToolParameters], Awaitable[None]]
] = None,
):
"""
Initialize ToolSet.
Args:
channel: Shared Channel instance (optional)
on_send: Callback function for sending requests, passed to Channel
(if channel is not provided, a new Channel will be created with this callback)
"""
self._tools: Dict[str, Tool] = {}
self._tool_factories: Dict[ToolType, Type[Tool]] = {
ToolType.MCP: MCPTool,
ToolType.FUNCTION: FunctionTool,
ToolType.SESSION: SessionTool,
ToolType.CLIENT: ClientTool,
ToolType.AGENT: AgentTool,
ToolType.ASK_INPUT: ClientTool, # ASK_INPUT uses ClientTool implementation
}
if channel:
self.channel = channel
# If on_send parameter is provided, set it (overrides existing)
if on_send:
self.channel.set_on_send(on_send)
else:
# If no channel is provided, create a new channel with on_send
self.channel = Channel(on_send=on_send)
def set_on_send(
self, on_send: Callable[[str, ToolSchema, ToolParameters], Awaitable[None]]
) -> None:
"""
Set the on_send callback function for ToolSet.
Args:
on_send: Callback function for sending requests,
receives (tool_name, tool_schema, tool_parameters) and sends data asynchronously
Returns:
None
"""
self.channel.set_on_send(on_send)
def register(self, tool: Tool, name: Optional[str] = None) -> None:
"""
Register a tool.
Args:
tool: Tool instance
name: Tool name (if not provided, uses tool.name)
"""
tool_name = name or tool.name
if tool_name in self._tools:
raise ValueError(f"Tool '{tool_name}' is already registered")
# If it's a ClientTool or AgentTool without a channel, set the shared channel
if isinstance(tool, (ClientTool, AgentTool)) and tool.channel != self.channel:
tool.channel = self.channel
self._tools[tool_name] = tool
logger.info(f"✓ Registered tool: {tool_name} ({tool.tool_type.value})")
async def register_from_mcp_server(
self, mcp_server: str, tool_names: list[str] | None = None
) -> None:
"""
Register tools from an MCP server.
Args:
mcp_server: The MCP server URL
tool_names: Optional list of tool names to register. If None, all available tools will be registered.
"""
try:
mcp_client = MCPClient(mcp_server)
await mcp_client.initialize()
mcp_tools = await mcp_client.list_tools()
if tool_names is not None:
mcp_tools = [tool for tool in mcp_tools if tool.name in tool_names]
for mcp_tool in mcp_tools:
# Extract properties and required from inputSchema
input_schema = mcp_tool.inputSchema or {}
mcp_params = {
"properties": input_schema.get("properties", {}),
"required": input_schema.get("required", []),
}
server_tool = MCPTool(
name=mcp_tool.name,
description=mcp_tool.description,
mcp_server=mcp_server,
mcp_params=mcp_params,
)
self.register(server_tool, mcp_tool.name)
finally:
await mcp_client.aclose()
def register_from_config(self, config: ToolConfig) -> Tool:
"""
Initialize and register a tool from configuration.
Args:
config: Tool configuration
Returns:
Tool: Created tool instance
"""
factory = self._tool_factories.get(config.tool_type)
if not factory:
raise ValueError(f"Unknown tool type: {config.tool_type}")
# Prepare initialization parameters
init_params = config.params or {}
if config.tool_type == ToolType.MCP:
init_params.setdefault("endpoint", init_params.get("endpoint"))
elif (
config.tool_type == ToolType.CLIENT
or config.tool_type == ToolType.ASK_INPUT
):
# ClientTool and ASK_INPUT must use channel, use ToolSet's channel if not provided
if "channel" not in init_params:
init_params["channel"] = self.channel
elif config.tool_type == ToolType.AGENT:
# AgentTool must use channel, use ToolSet's channel if not provided
if "channel" not in init_params:
init_params["channel"] = self.channel
elif config.tool_type == ToolType.FUNCTION:
if "func" not in init_params:
raise ValueError("Function tool requires 'func' parameter")
# Create tool instance
tool = factory(
name=config.name,
description=init_params.get("description", ""),
**init_params,
)
# Register tool
self.register(tool, config.name)
return tool
def get_tool(self, name: str) -> Optional[Tool]:
"""
Get a tool by name.
Args:
name: Tool name
Returns:
Tool: Tool instance, or None if not found
"""
return self._tools.get(name)
def list_tools(self) -> List[str]:
"""
List all registered tool names.
Returns:
List[str]: List of tool names
"""
return list(self._tools.keys())
async def call(self, tool_name: str, parameters: str, **kwargs) -> Any:
"""
Call a tool.
Args:
tool_name: Tool name
parameters: Tool parameters
**kwargs: Additional parameters
Returns:
Any: Tool execution result
Raises:
ValueError: Tool not found
"""
ctx = get_current_context()
with ctx.tool_span(name=f"ToolSet.call {tool_name}") as span:
span.update_payload_data(
request=kwargs,
)
tool = self.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool '{tool_name}' is not registered")
logger.info(
f"ToolSet.call {tool_name} parameters: {parameters} kwargs: {kwargs}"
)
resp = await tool.call(parameters, **kwargs)
span.update_payload_data(
response=resp,
)
return resp
def get_schema(self, tool_name: str) -> Any:
"""
Get tool schema.
Args:
tool_name: Tool name
Returns:
Any: Tool schema
Raises:
ValueError: Tool not found
"""
tool = self.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool '{tool_name}' is not registered")
return tool.get_schema()
def get_all_schemas(self) -> Dict[str, Any]:
"""
Get schemas for all tools.
Returns:
Dict[str, Any]: Mapping of tool names to schemas
"""
return {name: tool.get_schema() for name, tool in self._tools.items()}
def get_client_tool_call_result(self, tool_call_id: str) -> Any:
"""
Get the result of a ClientTool call.
Args:
tool_call_id: Tool call ID
Returns:
Any: Tool call result
"""
result, error = self._client_tool_results.get(tool_call_id, (None, None))
if error:
raise Exception(error)
return result
def set_client_tool_call_result(
self, tool_call_id: str, result: Any, error: Optional[str] = None
):
"""
Set the result of a ClientTool call.
"""
self._client_tool_results[tool_call_id] = (result, error)
def set_response(self, request_id: str, data: Any, error: Optional[str] = None):
"""
Set ClientTool response.
Args:
request_id: Request ID
data: Response data
error: Error message (if any)
Raises:
ValueError: Tool not found or not a ClientTool
"""
# tool = self.get_tool(tool_name)
# if not tool:
# raise ValueError(f"Tool '{tool_name}' is not registered")
# if not isinstance(tool, ClientTool):
# raise ValueError(
# f"Tool '{tool_name}' is not a ClientTool, cannot set response"
# )
# Set response through Channel
self.channel.set_response(request_id, data, error)
self.set_client_tool_call_result(request_id, data, error)
================================================
FILE: cortex/tools/types.py
================================================
"""Tool types and enums."""
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
class ToolParameters(BaseModel):
"""Tool parameters model."""
parameters: str = Field(description="Tool parameters", default="")
kwargs: Dict[str, Any] = Field(description="Additional parameters", default={})
class ToolType(Enum):
"""Tool type enum."""
MCP = "mcp"
FUNCTION = "function"
SESSION = "session"
CLIENT = "client"
AGENT = "agent"
ASK_INPUT = "ask_input"
class ExecutionType(Enum):
"""Execution type enum."""
SYNC = "sync"
ASYNC = "async"
@dataclass
class ToolConfig:
"""Tool configuration."""
name: str
tool_type: ToolType
params: Optional[Dict[str, Any]] = None
schema: Optional[Dict[str, Any]] = None
================================================
FILE: cortex/tools/ublock_agent_tool.py
================================================
"""Agent Tool implementation - calls agents through Orchestrator."""
import json
import logging
from typing import Any, Optional
from cortex.model.definition import ChatMessage
from .base import Tool, ToolSchema
from .channel import Channel
from .types import ToolParameters, ToolType
logger = logging.getLogger(__name__)
class UnblockAgentTool(Tool):
"""Agent tool implementation for calling other Agents through Channel communication with Orchestrator."""
share_context: bool = False
def __init__(
self,
name: str,
description: str = "",
channel: Optional[Channel] = None,
timeout: Optional[float] = None,
share_context: bool = False,
**kwargs,
):
"""
Initialize Agent tool.
Args:
name: Tool name.
description: Tool description.
channel: Channel instance (for async communication).
timeout: Default timeout in seconds.
share_context: Whether to share context.
**kwargs: Other parameters.
"""
super().__init__(
name=name, description=description, tool_type=ToolType.AGENT, **kwargs
)
if channel is None:
raise ValueError("AgentTool requires a Channel instance")
self.channel = channel
self.timeout = timeout or 30.0
logger.debug(
"AgentTool initialized: name=%s, description=%s, timeout=%s",
name,
description,
self.timeout,
)
def _define_schema(self) -> ToolSchema:
"""Define Agent tool schema."""
schema = ToolSchema(
name=self.name,
description=self.description,
parameters={
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "Specific instruction content to send to Agent (optional, ignored if messages is provided)",
},
"messages": {
"type": "array",
"description": "List of messages to send to Agent (optional, ChatMessage format with role and content fields)",
"items": {
"type": "object",
"properties": {
"role": {
"type": "string",
"description": "Message role, e.g., 'user', 'assistant', 'system'",
"enum": ["user", "assistant", "system", "tool"],
},
"content": {
"type": "string",
"description": "Message content (string format)",
},
},
"required": ["role", "content"],
},
},
"timeout": {
"type": "number",
"description": "Timeout in seconds (optional)",
},
},
"required": [],
},
return_type="agent_response",
tool_type=self.tool_type,
)
logger.debug(
"AgentTool schema defined: name=%s, parameters=%s", self.name, schema.parameters
)
return schema
async def _call(self, parameters: str, **kwargs) -> Any:
"""
Call Agent tool.
Send request to Orchestrator through Channel, which creates new Agent and executes task.
Args:
parameters: Tool parameters (JSON string format), containing:
- content: Specific instruction content (optional, ignored if messages is provided)
- messages: Message list (optional, ChatMessage format with role and content fields)
- timeout: Timeout in seconds (optional)
**kwargs: Other parameters (unused)
Returns:
Any: Agent execution result
Raises:
ValueError: Invalid parameters format
TimeoutError: Request timeout
"""
tool_call_id = kwargs.get("tool_call_id")
agent_name = self.name
# Process messages: prefer messages, use content to create ChatMessage if not available
messages = UnblockAgentTool.parse_messages(parameters)
# Build request data in orchestrator expected format
request_data = {
"agent_name": agent_name,
}
if messages:
request_data["messages"] = messages
request_data.update(kwargs)
logger.debug(
"UnblockAgentTool._call: tool=%s, agent=%s, messages=%d, tool_call_id=%s",
self.name,
agent_name,
len(messages) if messages else 0,
tool_call_id,
)
# Send request through Channel and wait for response
tool_parameters = ToolParameters(parameters=parameters, kwargs=request_data)
await self.channel._on_send(self.name, self.get_schema(), tool_parameters)
return None
@staticmethod
def parse_messages(parameters: str) -> list[ChatMessage]:
"""
Prepare message list.
Args:
parameters: Tool parameters (JSON string format)
Returns:
list[ChatMessage]: Processed message list
"""
# Parse parameters JSON string
try:
params_dict = json.loads(parameters) if parameters else {}
except json.JSONDecodeError as e:
raise ValueError(f"Invalid parameters JSON format: {e}") from e
# Prefer messages
messages = params_dict.get("messages", [])
if messages:
processed = []
for msg in messages:
if isinstance(msg, dict):
processed.append(ChatMessage(**msg))
elif isinstance(msg, ChatMessage):
processed.append(msg)
else:
processed.append(msg)
return processed
# If no messages, try to use content
content = params_dict.get("content")
if content:
return [ChatMessage(role="user", content=content)]
return []
================================================
FILE: cortex/tools/unblock_client_tool.py
================================================
"""Client Tool implementation - uses Channel for communication."""
import logging
from typing import Any, Optional
from .base import Tool, ToolSchema
from .channel import Channel
from .types import ToolParameters, ToolType
logger = logging.getLogger(__name__)
class UnblockClientTool(Tool):
"""Client tool implementation using Channel for asynchronous communication."""
def __init__(
self,
name: str,
description: str = "",
tool_type: Optional[ToolType] = None,
channel: Optional[Channel] = None,
timeout: Optional[float] = None,
**kwargs,
):
"""
Initialize Client tool.
Args:
name: Tool name
description: Tool description
channel: Channel instance (for async communication)
timeout: Default timeout in seconds
**kwargs: Additional parameters
"""
# If tool_type is not specified, default to CLIENTd, default to CLIENT
# But if name is "ask_input", use ASK_INPUT
if tool_type is None:
if name == "ask_input":
tool_type = ToolType.ASK_INPUT
else:
tool_type = ToolType.CLIENT
super().__init__(
name=name, description=description, tool_type=tool_type, **kwargs
)
self.timeout = timeout or 30.0
self._client_params = kwargs.get("client_params", {})
self.channel = channel
def _define_schema(self) -> ToolSchema:
"""Define Client tool schema."""
properties = self._client_params.get("properties", {})
required = self._client_params.get("required", [])
return ToolSchema(
name=self.name,
description=self.description,
parameters={
"type": "object",
"properties": properties,
"required": required,
},
return_type="client_response",
tool_type=self.tool_type,
)
async def _call(self, parameters: str, **kwargs) -> Any:
"""
Call the Client tool.
Send request through Channel and wait for response.
Args:
parameters: Tool parameters (string format)
**kwargs: Additional parameters
Returns:
Any: Execution result
"""
# Build request data, convert parameters and kwargs to ToolParameters
tool_parameters = ToolParameters(parameters=parameters, kwargs=kwargs)
tool_call_id = kwargs.get("tool_call_id")
if tool_call_id is None:
tool_call_id = f"tool_call_{hash(self.name)}_{hash(parameters)}"
logger.debug(
"ClientTool._call: tool=%s, parameters=%s, tool_call_id=%s",
self.name,
parameters,
tool_call_id,
)
await self.channel._on_send(self.name, self.get_schema(), tool_parameters)
return None
================================================
FILE: cortex/tui/__init__.py
================================================
"""Agent TUI module - Provides Textual-based TUI interface"""
from cortex.tui.tui import AgentTUIApp
__all__ = ["AgentTUIApp"]
================================================
FILE: cortex/tui/tui.py
================================================
"""Agent TUI - TUI interface built with Textual and Rich, displaying AgentEvent"""
import asyncio
import json
import logging
from pathlib import Path
from typing import Callable, Optional
from agentkit.trace import SpanContext, Tracer
from cortex.model.definition import ChatMessage
from rich.console import Group, RenderableType
from rich.markdown import Markdown
from rich.panel import Panel
from rich.text import Text
from textual import on, work
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Container, Horizontal, Vertical, VerticalScroll
from textual.events import Key
from textual.widgets import Input, Label, ListItem, ListView, Static
from cortex.agents.types import AgentConfig, AgentResponse, AgentRunningStatus
from cortex.orchestrator import AgentEvent
from cortex.orchestrator.orchestrator import Orchestrator
from cortex.orchestrator.types import AgentEventType, AgentRequest, ClientToolCallType
logger = logging.getLogger(__name__)
def _content_to_string(content) -> str:
"""Convert content to string
Args:
content: Can be string, list, or dict
Returns:
str: Converted string
"""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
# Handle list format content, e.g., [{"type": "text", "text": "..."}]
text_parts = []
for item in content:
if isinstance(item, dict):
# Try various possible key names to get text content
text = None
# Common key names
for key in ["text", "content", "value"]:
if key in item:
text = item[key]
break
# If not found, try to find fields with string values
if not text:
for key, value in item.items():
if key != "type" and isinstance(value, str):
text = value
break
if text:
text_parts.append(str(text))
elif isinstance(item, str):
text_parts.append(item)
return "\n".join(text_parts)
if isinstance(content, dict):
# Handle dict format
# Try various possible key names
for key in ["text", "content", "value"]:
if key in content:
return str(content[key])
# If not found, return string representation of entire dict
return str(content)
return str(content)
class EventItem(ListItem):
"""Event list item"""
def __init__(self, event: AgentEvent, index: int):
self.event = event
self.index = index
super().__init__()
def render(self) -> RenderableType:
"""Render event item"""
event = self.event
event_type = event.type.value
agent_name = event.agent_name or "N/A"
task_id = (
event.task_id[:8] + "..."
if event.task_id and len(event.task_id) > 8
else (event.task_id or "N/A")
)
# Set color based on event type
type_colors = {
"request": "blue",
"response": "green",
"error": "red",
"signal": "yellow",
"client_tool_call": "cyan",
"client_tool_result": "magenta",
}
color = type_colors.get(event_type, "white")
# Build display text
lines = [
Text(f"[{self.index}] ", style="dim"),
Text(f"[{event_type}]", style=f"bold {color}"),
Text(f" Agent: {agent_name}", style="cyan"),
Text(f" Task: {task_id}", style="dim"),
]
# Add response content
if event.response and event.response.message:
# Check if there are tool calls
tool_calls = getattr(event.response.message, "tool_calls", None)
has_tool_calls = False
if tool_calls:
if isinstance(tool_calls, (list, tuple)) and len(tool_calls) > 0:
has_tool_calls = True
elif not isinstance(tool_calls, (list, tuple)):
has_tool_calls = True
raw_content = event.response.message.content or ""
content = _content_to_string(raw_content)
# If no tool calls, show full markdown content
if not has_tool_calls and content:
# Use markdown to render full content
header = Text.assemble(
Text(f"[{self.index}] ", style="dim"),
Text(f"[{event_type}]", style=f"bold {color}"),
Text(f" Agent: {agent_name}", style="cyan"),
Text(f" Task: {task_id}", style="dim"),
)
markdown_content = Markdown(content)
return Panel(
Group(header, markdown_content),
border_style=color,
title=f"Event #{self.index}",
)
else:
# With tool calls, only show preview
if content:
preview = content[:50] + "..." if len(content) > 50 else content
lines.append(Text(f"\n {preview}", style="dim"))
# Add tool call info
if tool_calls:
if isinstance(tool_calls, (list, tuple)):
tool_names = []
for tool_call in tool_calls:
# Try to get tool name
if hasattr(tool_call, "function"):
if hasattr(tool_call.function, "name"):
tool_names.append(tool_call.function.name)
elif isinstance(tool_call.function, dict):
tool_names.append(
tool_call.function.get("name", "unknown")
)
elif isinstance(tool_call, dict):
func = tool_call.get("function", {})
tool_names.append(
func.get("name", "unknown")
if isinstance(func, dict)
else "unknown"
)
if tool_names:
tools_str = ", ".join(tool_names)
lines.append(Text(f"\n 🔧 Tools: {tools_str}", style="yellow"))
else:
# Single tool_call
if hasattr(tool_calls, "function"):
if hasattr(tool_calls.function, "name"):
lines.append(
Text(
f"\n 🔧 Tool: {tool_calls.function.name}",
style="yellow",
)
)
# Add client_tool_call info
if event.client_tool_call:
tool_name = "unknown"
tool_call_id = (
getattr(event.client_tool_call, "tool_call_id", None) or "N/A"
)
if hasattr(event.client_tool_call, "function"):
if hasattr(event.client_tool_call.function, "name"):
tool_name = event.client_tool_call.function.name
elif isinstance(event.client_tool_call.function, dict):
tool_name = event.client_tool_call.function.get("name", "unknown")
# Check if it's ask_input type
tool_type = getattr(event.client_tool_call, "type", None)
# Get function.arguments
args_dict = {}
context_content = None
prompt_content = None
if hasattr(event.client_tool_call, "function"):
func = event.client_tool_call.function
if hasattr(func, "arguments"):
args = func.arguments
if isinstance(args, str):
try:
args_dict = json.loads(args)
except json.JSONDecodeError:
args_dict = {"raw": args}
elif isinstance(args, dict):
args_dict = args
else:
args_dict = {}
# Check if context field exists
if isinstance(args_dict, dict) and "context" in args_dict:
context_content = args_dict.get("context")
# Remove context from args_dict, handle separately
args_dict = {
k: v for k, v in args_dict.items() if k != "context"
}
# Check if prompt field exists
if isinstance(args_dict, dict) and "prompt" in args_dict:
prompt_content = args_dict.get("prompt")
# Remove prompt from args_dict, handle separately
args_dict = {
k: v for k, v in args_dict.items() if k != "prompt"
}
# Set icon and style based on type
if tool_type == ClientToolCallType.ASK_INPUT:
icon = "❓"
style_color = "bold yellow"
type_label = "Ask Input"
elif tool_type == ClientToolCallType.AGENT:
icon = "🤖"
style_color = "bold cyan"
type_label = "Agent Tool"
else:
icon = "🔧"
style_color = "bold cyan"
type_label = "Client Tool"
# Show tool call basic info
lines.append(
Text(f"\n {icon} {type_label}: {tool_name}", style=style_color)
)
lines.append(Text(f" Tool Call ID: {tool_call_id}", style="dim"))
# Show all parameters (except context)
if args_dict:
lines.append(
Text("\n Parameters:", style=f"bold {style_color.split()[-1]}")
)
for key, value in args_dict.items():
value_str = (
json.dumps(value, ensure_ascii=False, indent=2)
if isinstance(value, (dict, list))
else str(value)
)
# If parameter value is too long, display on multiple lines
if len(value_str) > 200:
lines.append(Text(f" {key}:", style=style_color))
# Display long content on multiple lines
for line in value_str.split("\n"):
if line.strip():
lines.append(Text(f" {line}", style="dim"))
else:
lines.append(Text(f" {key}: {value_str}", style="dim"))
# If context exists, display with markdown
if context_content is not None:
context_str = _content_to_string(context_content)
if context_str:
lines.append(
Text("\n Context:", style=f"bold {style_color.split()[-1]}")
)
# Add context content to lines, will be rendered with Markdown later
# Add a marker here to indicate markdown rendering is needed
lines.append(("markdown_context", context_str))
# If prompt exists, display with markdown
if prompt_content is not None:
prompt_str = _content_to_string(prompt_content)
if prompt_str:
lines.append(
Text("\n Prompt:", style=f"bold {style_color.split()[-1]}")
)
# Add prompt content to lines, will be rendered with Markdown later
lines.append(("markdown_prompt", prompt_str))
# Show extra info
if (
hasattr(event.client_tool_call, "extra")
and event.client_tool_call.extra
):
lines.append(
Text("\n Extra Info:", style=f"bold {style_color.split()[-1]}")
)
for key, value in event.client_tool_call.extra.items():
value_str = (
json.dumps(value, ensure_ascii=False, indent=2)
if isinstance(value, (dict, list))
else str(value)
)
if len(value_str) > 200:
lines.append(Text(f" {key}:", style=style_color))
for line in value_str.split("\n"):
if line.strip():
lines.append(Text(f" {line}", style="dim"))
else:
lines.append(Text(f" {key}: {value_str}", style="dim"))
# Add client_tool_result info
if event.client_tool_result and event.client_tool_result.message:
result_content = None
tool_call_id = None
# Get tool_call_id
if hasattr(event.client_tool_result.message, "tool_call_id"):
tool_call_id = event.client_tool_result.message.tool_call_id
# Get user input content
raw_content = event.client_tool_result.message.content or ""
result_content = _content_to_string(raw_content)
# Show tool call result
lines.append(Text("\n 📥 Tool Call Result:", style="bold magenta"))
if tool_call_id:
lines.append(Text(f" Tool Call ID: {tool_call_id}", style="dim"))
# If content exists, display with markdown
if result_content:
lines.append(Text("\n User Input:", style="bold magenta"))
# Add user input content to lines, will be rendered with Markdown later
lines.append(("markdown_result", result_content))
# Add error info
if event.error:
lines.append(Text(f"\n Error: {event.error}", style="red"))
# Add completion signal info
if event.type == AgentEventType.SIGNAL and event.metadata:
status = event.metadata.get("status", "")
message = event.metadata.get("message", "")
if status == "completed":
lines.append(Text(f"\n ✅ {message}", style="green"))
# Check if there's markdown content (context, prompt or result)
markdown_context = None
markdown_prompt = None
markdown_result = None
text_lines = []
for item in lines:
if isinstance(item, tuple) and len(item) == 2:
if item[0] == "markdown_context":
markdown_context = item[1]
elif item[0] == "markdown_prompt":
markdown_prompt = item[1]
elif item[0] == "markdown_result":
markdown_result = item[1]
else:
text_lines.append(item)
else:
text_lines.append(item)
# If there's markdown content, combine Text and Markdown
markdown_parts = []
if markdown_context:
markdown_parts.append(Markdown(markdown_context))
if markdown_prompt:
markdown_parts.append(Markdown(markdown_prompt))
if markdown_result:
markdown_parts.append(Markdown(markdown_result))
if markdown_parts:
header = Text.assemble(*text_lines)
# If there are multiple markdown contents, combine with Group
if len(markdown_parts) == 1:
return Panel(
Group(header, markdown_parts[0]),
border_style=color,
title=f"Event #{self.index}",
)
else:
return Panel(
Group(header, *markdown_parts),
border_style=color,
title=f"Event #{self.index}",
)
else:
return Panel(
Text.assemble(*text_lines),
border_style=color,
title=f"Event #{self.index}",
)
class CommandItem(ListItem):
"""Command list item"""
def __init__(self, command: str):
self.command = command
super().__init__()
def render(self) -> RenderableType:
"""Render command item"""
return Text(f"/{self.command}", style="cyan")
class AgentItem(ListItem):
"""Agent list item"""
def __init__(self, agent_config: AgentConfig):
self.agent_config = agent_config
super().__init__()
def render(self) -> RenderableType:
"""Render Agent item"""
name = self.agent_config.name or "N/A"
description = self.agent_config.description or "No description"
agent_type = self.agent_config.agent_type or "N/A"
lines = [
Text(f"Name: {name}", style="bold cyan"),
Text(f" Type: {agent_type}", style="dim"),
Text(f" Description: {description}", style="dim"),
]
return Panel(
Text.assemble(*lines),
border_style="blue",
title=f"Agent: {name}",
)
class ProcessView(Container):
"""Process view - displays event list"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.events: list[AgentEvent] = []
self.list_view = ListView(id="event-list")
self.json_view = Static("", id="process-json-view")
def compose(self) -> ComposeResult:
with Horizontal():
with Container(id="process-list-container"):
yield Label("Agent Event Stream", id="process-title")
yield self.list_view
with VerticalScroll(id="process-json-container"):
yield Label("Event Details (JSON)", id="process-json-title")
yield self.json_view
def add_event(self, event: AgentEvent) -> None:
"""Add event to list"""
self.events.append(event)
index = len(self.events)
item = EventItem(event, index)
self.list_view.append(item)
# Auto scroll to bottom
self.list_view.scroll_end(animate=False)
@on(ListView.Selected, "#event-list")
def on_list_item_selected(self, event: ListView.Selected) -> None:
"""Handle list item selection"""
if hasattr(event, "item") and isinstance(event.item, EventItem):
event_obj = event.item.event
json_text = event_obj.model_dump_json(indent=2, ensure_ascii=False)
self.json_view.update(
Panel(json_text, title="Event JSON", border_style="blue")
)
# Notify main app to update placeholder
app = self.app
if app and hasattr(app, "_update_placeholder_by_focus"):
app._update_placeholder_by_focus()
def clear_events(self) -> None:
"""Clear event list"""
self.events.clear()
self.list_view.clear()
class AgentsListView(Container):
"""Agents list view - displays registered Agents"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.list_view = ListView(id="agents-list")
self.json_view = Static("", id="agents-json-view")
self.selection_callback: Optional[Callable[[str], None]] = None
def set_selection_callback(self, callback: Callable[[str], None]) -> None:
"""Set selection callback function"""
self.selection_callback = callback
def compose(self) -> ComposeResult:
with Horizontal():
with Container(id="agents-list-container"):
yield Label("Registered Agents", id="agents-title")
yield self.list_view
with VerticalScroll(id="agents-json-container"):
yield Label("Agent Details (JSON)", id="agents-json-title")
yield self.json_view
def update_agents(self, agents: list[AgentConfig]) -> None:
"""Update Agents list"""
self.list_view.clear()
for agent_config in agents:
item = AgentItem(agent_config)
self.list_view.append(item)
@on(ListView.Selected, "#agents-list")
def on_list_item_selected(self, event: ListView.Selected) -> None:
"""Handle list item selection"""
if hasattr(event, "item") and isinstance(event.item, AgentItem):
agent_config = event.item.agent_config
json_text = agent_config.model_dump_json(indent=2, ensure_ascii=False)
self.json_view.update(
Panel(json_text, title="Agent Config JSON", border_style="green")
)
# Call callback to notify main app
if self.selection_callback:
self.selection_callback(agent_config.name)
# Notify main app to update placeholder
app = self.app
if app and hasattr(app, "_update_placeholder_by_focus"):
app._update_placeholder_by_focus()
class TaskItem(ListItem):
"""Task list item"""
def __init__(self, task_id: str, request_data: dict):
self.task_id = task_id
self.request_data = request_data
super().__init__()
def render(self) -> RenderableType:
"""Render Task item"""
task_id_short = (
self.task_id[:16] + "..." if len(self.task_id) > 16 else self.task_id
)
messages = self.request_data.get("messages", [])
content_preview = ""
if messages and len(messages) > 0:
first_msg = messages[0] if isinstance(messages, list) else messages
if isinstance(first_msg, dict):
content = first_msg.get("content", "")
else:
content = getattr(first_msg, "content", "")
content_str = _content_to_string(content)
content_preview = (
content_str[:50] + "..." if len(content_str) > 50 else content_str
)
lines = [
Text(f"Task ID: {task_id_short}", style="bold cyan"),
Text(f" Content: {content_preview}", style="dim"),
]
return Panel(
Text.assemble(*lines),
border_style="green",
title=f"Task: {task_id_short}",
)
class TasksListView(Container):
"""Tasks list view - displays saved Tasks"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.list_view = ListView(id="tasks-list")
self.json_view = Static("", id="tasks-json-view")
self.selection_callback: Optional[Callable[[str], None]] = None
def set_selection_callback(self, callback: Callable[[str], None]) -> None:
"""Set selection callback function"""
self.selection_callback = callback
def compose(self) -> ComposeResult:
with Horizontal():
with Container(id="tasks-list-container"):
yield Label("Saved Tasks", id="tasks-title")
yield self.list_view
with VerticalScroll(id="tasks-json-container"):
yield Label("Task Details (JSON)", id="tasks-json-title")
yield self.json_view
def update_tasks(self, tasks: list[tuple[str, dict]]) -> None:
"""Update Tasks list
Args:
tasks: [(task_id, request_data), ...] list
"""
self.list_view.clear()
for task_id, request_data in tasks:
item = TaskItem(task_id, request_data)
self.list_view.append(item)
@on(ListView.Selected, "#tasks-list")
def on_list_item_selected(self, event: ListView.Selected) -> None:
"""Handle list item selection"""
if hasattr(event, "item") and isinstance(event.item, TaskItem):
task_id = event.item.task_id
request_data = event.item.request_data
json_text = json.dumps(request_data, indent=2, ensure_ascii=False)
self.json_view.update(
Panel(json_text, title="Request JSON", border_style="green")
)
# Call callback to notify main app
if self.selection_callback:
self.selection_callback(task_id)
class CommandSelector(Container):
"""Command selector - displays command list and supports up/down key selection"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.command_list_view = ListView(id="command-list")
self.matched_commands: list[str] = []
def compose(self) -> ComposeResult:
yield self.command_list_view
def update_commands(self, commands: list[str]) -> None:
"""Update command list"""
self.matched_commands = commands
self.command_list_view.clear()
for cmd in commands:
item = CommandItem(cmd)
self.command_list_view.append(item)
# If there are commands, select the first one
if commands:
self.command_list_view.index = 0
def get_selected_command(self) -> Optional[str]:
"""Get currently selected command"""
if not self.matched_commands:
return None
# Get current selected index
try:
index = self.command_list_view.index
if 0 <= index < len(self.matched_commands):
return self.matched_commands[index]
except Exception:
pass
# If no selected item, return the first one
if self.matched_commands:
return self.matched_commands[0]
return None
class AgentTUIApp(App):
"""Agent TUI Application"""
CSS = """
#process-title, #process-json-title, #agents-title, #agents-json-title {
text-style: dim;
padding: 0 1;
margin-bottom: 1;
height: 1;
text-align: left;
}
#process-view, #agents-view {
height: 1fr;
}
#process-list-container, #agents-list-container {
width: 50%;
border-right: wide $primary;
height: 1fr;
}
#process-json-container, #agents-json-container {
width: 50%;
height: 1fr;
}
#process-json-view, #agents-json-view {
padding: 1;
}
#event-list, #agents-list {
height: 1fr;
}
#input-container {
height: auto;
min-height: 3;
border-top: wide $primary;
margin-top: 1;
padding-top: 1;
}
#selected-agent-label {
height: 1;
padding: 0 1;
margin-bottom: 1;
text-style: bold;
background: $surface;
}
#input-box {
width: 1fr;
}
#command-selector {
height: auto;
max-height: 10;
border-top: wide $primary;
background: $surface;
margin-top: 1;
margin-bottom: 1;
}
#command-list {
height: auto;
max-height: 10;
}
.hidden {
display: none;
}
#main-container {
height: 1fr;
overflow: hidden;
}
Screen {
layout: vertical;
}
"""
BINDINGS = [
Binding("q", "quit", "Quit"),
Binding("ctrl+c", "quit", "Quit"),
]
def __init__(
self,
orchestrator: Orchestrator,
workdir: Optional[str | Path] = None,
tracer: Optional[Tracer] = None,
):
super().__init__()
self.process_view: Optional[ProcessView] = None
self.agents_view: Optional[AgentsListView] = None
self.tasks_view: Optional[TasksListView] = None
self.command_selector: Optional[CommandSelector] = None
self.input_box: Optional[Input] = None
self.selected_agent_label: Optional[Label] = None
self.orchestrator: Orchestrator = orchestrator
self.running_task: Optional[asyncio.Task] = None
self.commands = [
"quit",
"clear",
"help",
"ls agents",
"ls tasks",
"view process",
]
self.show_command_selector = False
self.current_view = (
"agents" # "process" or "agents" or "tasks", default to agents view
)
self.selected_agent_name: Optional[str] = None # Currently selected agent name
self.events: list[AgentEvent] = [] # Store all events in memory
self.default_placeholder = "/command or input..." # Default placeholder
self.workdir: Optional[Path] = Path(workdir) if workdir else None
if self.workdir:
self.workdir.mkdir(parents=True, exist_ok=True)
self.pending_ask_input_event: Optional[AgentEvent] = (
None # Pending ask_input event
)
self.tracer = tracer
def compose(self) -> ComposeResult:
"""Compose application interface"""
with Vertical():
with Container(id="main-container"):
process_view = ProcessView(id="process-view")
process_view.add_class("hidden")
yield process_view
agents_view = AgentsListView(id="agents-view")
yield agents_view
tasks_view = TasksListView(id="tasks-view")
tasks_view.add_class("hidden")
yield tasks_view
command_selector = CommandSelector(id="command-selector")
command_selector.add_class("hidden")
yield command_selector
with Container(id="input-container"):
yield Label("Current Agent: Not selected", id="selected-agent-label")
yield Input(placeholder=self.default_placeholder, id="input-box")
def on_mount(self) -> None:
"""Initialize on app mount"""
self.process_view = self.query_one("#process-view", ProcessView)
self.agents_view = self.query_one("#agents-view", AgentsListView)
self.tasks_view = self.query_one("#tasks-view", TasksListView)
self.command_selector = self.query_one("#command-selector", CommandSelector)
self.input_box = self.query_one("#input-box", Input)
self.selected_agent_label = self.query_one("#selected-agent-label", Label)
# Set callbacks
self.agents_view.set_selection_callback(self._on_agent_selected)
self.tasks_view.set_selection_callback(self._on_task_selected)
# Initialize agents list (default to agents view)
agents = self.orchestrator.list_agents()
self.agents_view.update_agents(agents)
# Select first agent by default
if agents:
first_agent = agents[0]
if first_agent.name:
self._on_agent_selected(first_agent.name)
# Select first item in list
self.agents_view.list_view.index = 0
# Focus input box
self._focus_input()
@on(Input.Changed, "#input-box")
def on_input_changed(self, event: Input.Changed) -> None:
"""Handle input changes"""
value = event.value
if value == "/":
# Show all commands
self.show_command_selector = True
self.command_selector.update_commands(self.commands)
self.command_selector.remove_class("hidden")
# Keep input focus, user can continue typing to filter
elif value.startswith("/"):
# Filter matching commands
query = value[1:].lower()
matched = [cmd for cmd in self.commands if cmd.startswith(query)]
if matched:
self.show_command_selector = True
self.command_selector.update_commands(matched)
self.command_selector.remove_class("hidden")
else:
# No matching commands, hide selector
self.show_command_selector = False
self.command_selector.add_class("hidden")
else:
# Hide command selector
self.show_command_selector = False
self.command_selector.add_class("hidden")
@on(ListView.Selected, "#command-list")
def on_command_selected(self, event: ListView.Selected) -> None:
"""Handle command selection"""
if hasattr(event, "item") and isinstance(event.item, CommandItem):
command = event.item.command
self._handle_command(command)
self.input_box.value = ""
self.command_selector.add_class("hidden")
self.show_command_selector = False
self._focus_input()
def _focus_input(self) -> None:
"""Focus input box and update placeholder"""
self.set_focus(self.input_box)
if self.input_box:
self.input_box.placeholder = self.default_placeholder
def _update_placeholder_by_focus(self) -> None:
"""Update placeholder based on focus state"""
if not self.input_box:
return
focused = self.screen.focused
if focused == self.input_box:
self.input_box.placeholder = self.default_placeholder
else:
self.input_box.placeholder = "ESC to return to input box"
@on(Key)
def on_key(self, event: Key) -> None:
"""Handle keyboard events"""
# Global ESC handling: if focus is not on input box, press ESC to return to input
if event.key == "escape":
focused = self.screen.focused
if focused != self.input_box:
self._focus_input()
event.prevent_default()
return
# If command selector is shown
if self.show_command_selector and not self.command_selector.has_class("hidden"):
focused = self.screen.focused
# If focus is on input box, transfer to command list on up/down keys
if focused == self.input_box:
if event.key == "up" or event.key == "down":
# Transfer focus to command list
self.set_focus(self.command_selector.command_list_view)
self._update_placeholder_by_focus()
event.prevent_default()
return
# If focus is on command list
elif focused == self.command_selector.command_list_view:
# Press enter in command list to execute selected command
if event.key == "enter":
selected = self.command_selector.get_selected_command()
if selected:
self._handle_command(selected)
self.input_box.value = ""
self.command_selector.add_class("hidden")
self.show_command_selector = False
self._focus_input()
event.prevent_default()
return
# Press Esc to return to input box
elif event.key == "escape":
self.command_selector.add_class("hidden")
self.show_command_selector = False
self._focus_input()
event.prevent_default()
return
@on(Input.Submitted, "#input-box")
def on_input_submitted(self, event: Input.Submitted) -> None:
"""Handle input submission"""
value = event.value.strip()
if not value:
return
# Check if there's a pending ask_input event
if self.pending_ask_input_event:
# Send user input as tool_call result
self._send_ask_input_result(value)
self.input_box.value = ""
self.pending_ask_input_event = None
# Restore default placeholder
if self.input_box:
self.input_box.placeholder = self.default_placeholder
return
# Handle commands
if value.startswith("/"):
# If command selector is shown, use selected command
if self.show_command_selector and not self.command_selector.has_class(
"hidden"
):
selected = self.command_selector.get_selected_command()
if selected:
command = selected
else:
command = value[1:].lower()
else:
command = value[1:].lower()
self._handle_command(command)
self.input_box.value = ""
self.command_selector.add_class("hidden")
self.show_command_selector = False
return
# Execute user instruction
self._run_agent(value)
self.input_box.value = ""
return
def _handle_command(self, command: str) -> None:
"""Handle command"""
if command == "quit" or command == "q":
self.exit()
elif command == "clear":
self.events.clear()
self.process_view.clear_events()
elif command == "help":
# Show help information
help_text = "Available commands:\n"
for cmd in self.commands:
help_text += f" /{cmd}\n"
self.notify(help_text, title="Help", timeout=5)
elif command == "ls agents" or command == "ls":
# Switch to agents view
self._switch_view("agents")
elif command == "ls tasks":
# Switch to tasks view
self._switch_view("tasks")
elif command == "view process" or command == "process":
# Switch to process view
self._switch_view("process")
else:
# Try fuzzy match
if "quit" in command or "q" in command:
self.exit()
elif "clear" in command:
self.events.clear()
self.process_view.clear_events()
elif "help" in command:
help_text = "Available commands:\n"
for cmd in self.commands:
help_text += f" /{cmd}\n"
self.notify(help_text, title="Help", timeout=5)
elif "ls" in command and "agent" in command:
self._switch_view("agents")
elif "ls" in command and "task" in command:
self._switch_view("tasks")
elif ("view" in command and "process" in command) or command == "process":
self._switch_view("process")
def _switch_view(self, view_name: str) -> None:
"""Switch view"""
self.current_view = view_name
if view_name == "agents":
# Show agents view
self.process_view.add_class("hidden")
self.tasks_view.add_class("hidden")
self.agents_view.remove_class("hidden")
# Update agents list
agents = self.orchestrator.list_agents()
self.agents_view.update_agents(agents)
elif view_name == "tasks":
# Show tasks view
self.process_view.add_class("hidden")
self.agents_view.add_class("hidden")
self.tasks_view.remove_class("hidden")
# Update tasks list
self._load_tasks()
# Auto focus to task list
self.set_focus(self.tasks_view.list_view)
self._update_placeholder_by_focus()
else:
# Show process view
self.process_view.remove_class("hidden")
self.agents_view.add_class("hidden")
self.tasks_view.add_class("hidden")
def _on_agent_selected(self, agent_name: str) -> None:
"""Handle Agent selection"""
self.selected_agent_name = agent_name
if self.selected_agent_label:
self.selected_agent_label.update(f"Current Agent: {agent_name}")
def _send_ask_input_result(self, user_input: str) -> None:
"""Send ask_input result"""
if not self.pending_ask_input_event or not self.orchestrator:
return
ask_input_event = self.pending_ask_input_event
# Get tool_call_id from ask_input_event
tool_call_id = None
if ask_input_event.client_tool_call:
tool_call_id = ask_input_event.client_tool_call.tool_call_id
# Create AgentResponse as tool_call result
# role should be "tool" instead of "user", because this is the result of a tool call
result_response = AgentResponse(
agent_name=ask_input_event.agent_name,
message=ChatMessage(
role="tool", content=user_input, tool_call_id=tool_call_id
),
status=AgentRunningStatus.FINISHED,
)
# Create CLIENT_TOOL_RESULT event
result_event = AgentEvent(
task_id=ask_input_event.task_id,
parent_task_id=ask_input_event.parent_task_id,
root_task_id=ask_input_event.root_task_id,
type=AgentEventType.CLIENT_TOOL_RESULT,
client_tool_result=result_response,
)
# Send event through orchestrator
try:
# Use send_event method to send event
asyncio.create_task(self.orchestrator.send_event(result_event))
# Save to memory and view
self.events.append(result_event)
self.process_view.add_event(result_event)
# Save to file
if ask_input_event.root_task_id and self.workdir:
self._save_event(ask_input_event.root_task_id, result_event)
except Exception as e:
logger.error("Failed to send ask_input result: %s", e)
self.notify(f"Failed to send ask_input result: {e}", title="Error", timeout=3)
def _on_task_selected(self, task_id: str) -> None:
"""Handle Task selection"""
# Switch to process view
self._switch_view("process")
# Load all events for this task
self._load_task_events(task_id)
def _load_tasks(self) -> None:
"""Load all tasks"""
if not self.workdir:
return
tasks = []
# Find all {task_id}_request.json files
for request_file in self.workdir.glob("*_request.json"):
task_id = request_file.stem.replace("_request", "")
try:
with open(request_file, "r", encoding="utf-8") as f:
request_data = json.load(f)
# Get file modification time
mtime = request_file.stat().st_mtime
tasks.append((task_id, request_data, mtime))
except Exception as e:
logger.error(f"Failed to load task {task_id}: {e}")
# Sort by file modification time (newest first)
tasks.sort(key=lambda x: x[2], reverse=True)
# Remove mtime, keep only (task_id, request_data)
tasks = [(task_id, request_data) for task_id, request_data, _ in tasks]
self.tasks_view.update_tasks(tasks)
def _load_task_events(self, task_id: str) -> None:
"""Load all events for a task from file"""
if not self.workdir:
return
jsonl_file = self.workdir / f"{task_id}.jsonl"
if not jsonl_file.exists():
self.notify(f"Event file not found for task {task_id}", title="Error", timeout=3)
return
# Clear current events
self.events.clear()
self.process_view.clear_events()
# Load events from jsonl file
try:
with open(jsonl_file, "r", encoding="utf-8") as f:
for line in f:
strip_line = line.strip()
if not strip_line:
continue
try:
event_data = json.loads(strip_line)
event = AgentEvent(**event_data)
self.events.append(event)
self.process_view.add_event(event)
except Exception as e:
logger.error(f"Failed to parse event: {e}")
except Exception as e:
logger.error(f"Failed to load task events: {e}")
self.notify(f"Failed to load task events: {e}", title="Error", timeout=3)
def _save_request(self, root_task_id: str, request: AgentRequest) -> None:
"""Save request to file"""
if not self.workdir:
return
request_file = self.workdir / f"{root_task_id}_request.json"
try:
request_data = request.model_dump()
with open(request_file, "w", encoding="utf-8") as f:
json.dump(request_data, f, indent=2, ensure_ascii=False)
except Exception as e:
logger.error(f"Failed to save request: {e}")
def _save_event(self, root_task_id: str, event: AgentEvent) -> None:
"""Save event to jsonl file"""
if not self.workdir:
return
jsonl_file = self.workdir / f"{root_task_id}.jsonl"
try:
event_data = event.model_dump()
with open(jsonl_file, "a", encoding="utf-8") as f:
f.write(json.dumps(event_data, ensure_ascii=False) + "\n")
except Exception as e:
logger.error(f"Failed to save event: {e}")
@work(exclusive=True)
async def _run_agent(self, user_input: str) -> None:
"""Run Agent"""
if not self.orchestrator:
return
# Check if agent is selected
if not self.selected_agent_name:
self.notify("Please select an Agent first (use /ls agents)", title="Error", timeout=3)
return
# Switch to process view to show execution
self._switch_view("process")
# Clear previous events
self.events.clear()
self.process_view.clear_events()
messages = [ChatMessage(role="user", content=user_input)]
request = AgentRequest(
agent_name=self.selected_agent_name,
messages=messages,
)
event = AgentEvent(
type=AgentEventType.REQUEST,
request=request,
)
# Get root_task_id (from the first event)
root_task_id = None
ctx = SpanContext(tracer=self.tracer, app_name=self.selected_agent_name)
with ctx.span(name=f"{self.selected_agent_name}_run_agent"):
try:
async for agent_event in self.orchestrator.run(
agent_name=self.selected_agent_name,
event=event,
agent_config=None,
):
# Get root_task_id (from the first event)
if root_task_id is None:
root_task_id = agent_event.root_task_id or agent_event.task_id
# Save request to file
if root_task_id and self.workdir:
self._save_request(root_task_id, request)
# Check if it's ask_input type client_tool_call
if (
agent_event.client_tool_call
and getattr(agent_event.client_tool_call, "type", None)
== ClientToolCallType.ASK_INPUT
):
# Save as pending ask_input event
self.pending_ask_input_event = agent_event
# Update placeholder to prompt user input
if self.input_box:
self.input_box.placeholder = (
"Please enter content (as ask_input response)..."
)
# Save to memory
self.events.append(agent_event)
# Add to view
self.process_view.add_event(agent_event)
# Save to file
if root_task_id and self.workdir:
self._save_event(root_task_id, agent_event)
# After execution completes, add completion event
completion_event = AgentEvent(
type=AgentEventType.SIGNAL,
agent_name=self.selected_agent_name,
metadata={"status": "completed", "message": "Execution completed"},
)
self.events.append(completion_event)
self.process_view.add_event(completion_event)
# Save completion event to file
if root_task_id and self.workdir:
self._save_event(root_task_id, completion_event)
except Exception as e:
logger.error("Error running Agent: %s", e, exc_info=True)
error_event = AgentEvent(
type=AgentEventType.ERROR,
error=str(e),
)
# Save to memory
self.events.append(error_event)
# Add to view
self.process_view.add_event(error_event)
async def action_quit(self) -> None:
"""Quit application"""
self.exit()
================================================
FILE: cortex/utils/__init__.py
================================================
from .generator_merger import GeneratorMerger
__all__ = ["GeneratorMerger"]
================================================
FILE: cortex/utils/generator_merger.py
================================================
"""Generator merger implementation using asyncio."""
import asyncio
from collections.abc import Generator
from typing import Any, AsyncGenerator, Awaitable, Callable, Optional
class GeneratorMerger:
"""A merger that combines multiple generator functions using asyncio for concurrent execution."""
def __init__(
self,
on_generator_complete: Optional[
Callable[[str, str, Optional[Exception]], Awaitable[None]]
] = None,
):
"""
Initialize the merger.
Args:
on_generator_complete: Optional callback function called when a sub-generator completes.
Parameters: (generator_id, generator_type, error)
If not provided, completion events will be yielded.
"""
self._generators: dict[str, Callable[[], Generator[Any, None, None]]] = {}
self._async_generators: dict[str, Callable[[], AsyncGenerator[Any, None]]] = {}
self._running_tasks: list[asyncio.Task] = []
self._queue: asyncio.Queue = asyncio.Queue()
self._active_count: int = 0
self._lock: asyncio.Lock = asyncio.Lock()
self._processed_generators: set[str] = set() # Processed generator IDs
self._on_generator_complete = on_generator_complete
self._generator_id_counter: int = 0 # Counter for generating unique generator IDs
def add_generator(
self,
generator_func: Callable[[], Generator[Any, None, None]],
generator_id: Optional[str] = None,
):
"""
Dynamically add a synchronous generator function.
Note: This method is synchronous and does not use locks when adding.
Lock protection is primarily used in the async merge() and delete operations.
Args:
generator_func: A function that returns a synchronous generator
generator_id: Optional generator identifier used in completion events
"""
if generator_id is None:
generator_id = f"sync_gen_{len(self._generators)}"
if generator_id in self._generators or generator_id in self._async_generators:
raise ValueError(f"Generator {generator_id} already exists")
self._generators[generator_id] = generator_func
def add_async_generator(
self,
async_generator_func: Callable[[], AsyncGenerator[Any, None]],
generator_id: Optional[str] = None,
):
"""
Dynamically add an async generator function.
Args:
async_generator_func: A function that returns an async generator
generator_id: Optional generator identifier used in completion events
"""
if generator_id is None:
generator_id = f"async_gen_{len(self._async_generators)}"
# Check if generator_id already exists
if generator_id in self._generators or generator_id in self._async_generators:
raise ValueError(f"Generator {generator_id} already exists")
self._async_generators[generator_id] = async_generator_func
def _get_next_item(self, generator: Generator[Any, None, None]) -> tuple[bool, Any]:
"""
Get the next value from a generator (runs in thread).
Args:
generator: generator object
Returns:
(has_more, item) tuple, where has_more indicates if there are more values
"""
try:
item = next(generator)
return (True, item)
except StopIteration:
return (False, None)
async def _run_generator_with_wrapper(
self,
generator_id: str,
generator_type: str,
generator_executor: Callable[[], AsyncGenerator[Any, None]],
):
"""
Generic generator execution wrapper.
Args:
generator_id: generator identifier
generator_type: generator type ("sync" or "async")
generator_executor: async function that executes the generator and yields values
"""
async with self._lock:
self._active_count += 1
error = None
try:
# Execute specific generator logic
async for item in generator_executor():
# Put generator-produced values into queue
await self._queue.put(item)
except Exception as e:
error = e
# If generator errors, also put into queue for processing
await self._queue.put(("__error__", e))
finally:
async with self._lock:
self._active_count -= 1
# Notify generator completion
await self._notify_generator_complete(generator_id, generator_type, error)
async def _run_generator(
self,
generator_id: str,
generator_func: Callable[[], Generator[Any, None, None]],
):
"""
Run a synchronous generator in the event loop.
Args:
generator_id: generator identifier
generator_func: function that returns a synchronous generator
"""
async def sync_generator_executor():
# Create generator (this step is fast, doesn't need to run in thread)
generator = generator_func()
loop = asyncio.get_event_loop()
# Get values one by one, run in thread pool to avoid blocking event loop
while True:
has_more, item = await loop.run_in_executor(
None, self._get_next_item, generator
)
if not has_more:
break
yield item
await self._run_generator_with_wrapper(
generator_id, "sync", sync_generator_executor
)
async def _run_async_generator(
self,
generator_id: str,
async_generator_func: Callable[[], AsyncGenerator[Any, None]],
):
"""
Run an async generator in the event loop.
Args:
generator_id: generator identifier
async_generator_func: function that returns an async generator
"""
async def async_generator_executor():
# Create async generator
async_generator = async_generator_func()
# Get values directly from async generator
async for item in async_generator:
yield item
await self._run_generator_with_wrapper(
generator_id, "async", async_generator_executor
)
async def _notify_generator_complete(
self, generator_id: str, generator_type: str, error: Optional[Exception]
):
"""
Notify generator completion.
Args:
generator_id: Generator identifier
generator_type: Generator type ("sync" or "async")
error: Error object if any; otherwise None
"""
# Use lock to protect dictionary modification
async with self._lock:
# Mark generator as completed
self._processed_generators.add(generator_id)
# Auto-delete completed generator
if generator_type == "sync" and generator_id in self._generators:
del self._generators[generator_id]
elif generator_type == "async" and generator_id in self._async_generators:
del self._async_generators[generator_id]
event = {
"type": "generator_complete",
"generator_id": generator_id,
"generator_type": generator_type,
"status": "error" if error else "completed",
"error": str(error) if error else None,
}
if self._on_generator_complete:
# If there's a callback, call it
if asyncio.iscoroutinefunction(self._on_generator_complete):
await self._on_generator_complete(generator_id, generator_type, error)
else:
self._on_generator_complete(generator_id, generator_type, error)
else:
# If no callback, put event into queue, will be yielded
await self._queue.put(("__event__", event))
async def merge(self) -> AsyncGenerator[Any, None]:
"""
Merge all added generators (sync and async), returning an async generator.
Supports dynamically adding new generators during iteration.
Yields:
Values produced by the various generators
"""
if not self._generators and not self._async_generators:
return
# Reset state
self._queue = asyncio.Queue()
self._active_count = 0
self._running_tasks = []
self._processed_generators = set()
# Use lock to get initial generator snapshot
async with self._lock:
sync_generators_snapshot = list(self._generators.items())
async_generators_snapshot = list(self._async_generators.items())
# Create tasks for initial sync generators
for generator_id, generator_func in sync_generators_snapshot:
task = asyncio.create_task(
self._run_generator(generator_id, generator_func),
name=f"sync_{generator_id}",
)
self._running_tasks.append(task)
# Create tasks for initial async generators
for generator_id, async_generator_func in async_generators_snapshot:
task = asyncio.create_task(
self._run_async_generator(generator_id, async_generator_func),
name=f"async_{generator_id}",
)
self._running_tasks.append(task)
# Get values from queue and yield
while True:
# Use lock to get current generator snapshot
async with self._lock:
sync_generators_snapshot = list(self._generators.items())
async_generators_snapshot = list(self._async_generators.items())
processed_generators_snapshot = self._processed_generators.copy()
# Check for newly added sync generators
for generator_id, generator_func in sync_generators_snapshot:
if generator_id not in processed_generators_snapshot:
# Check if a task is already running this generator
task_exists = any(
task.get_name() == f"sync_{generator_id}" and not task.done()
for task in self._running_tasks
)
if not task_exists:
task = asyncio.create_task(
self._run_generator(generator_id, generator_func),
name=f"sync_{generator_id}",
)
self._running_tasks.append(task)
# Check for newly added async generators
for generator_id, async_generator_func in async_generators_snapshot:
if generator_id not in processed_generators_snapshot:
# Check if a task is already running this generator
task_exists = any(
task.get_name() == f"async_{generator_id}" and not task.done()
for task in self._running_tasks
)
if not task_exists:
task = asyncio.create_task(
self._run_async_generator(
generator_id, async_generator_func
),
name=f"async_{generator_id}",
)
self._running_tasks.append(task)
# Check if all generators have been processed and no active tasks
async with self._lock:
all_generators = set(self._generators.keys()) | set(
self._async_generators.keys()
)
all_done = self._active_count == 0 and all_generators.issubset(
self._processed_generators
)
if all_done:
# Re-check for newly added generators (avoid race condition)
await asyncio.sleep(0) # Yield control to allow other tasks to run
async with self._lock:
all_generators = set(self._generators.keys()) | set(
self._async_generators.keys()
)
if self._active_count == 0 and all_generators.issubset(
self._processed_generators
):
break
# Get value from queue (with timeout to avoid infinite wait while allowing new generator checks)
try:
item = await asyncio.wait_for(self._queue.get(), timeout=0.01)
except asyncio.TimeoutError:
# Continue checking for new generators after timeout
continue
# Check for special markers
if isinstance(item, tuple):
if item[0] == "__error__":
raise item[1]
elif item[0] == "__event__":
# If it's an event, yield the event object
yield item[1]
continue
yield item
# Wait for all tasks to complete
await asyncio.gather(*self._running_tasks, return_exceptions=True)
async def __aiter__(self):
"""Allow the merger to be used as an async iterator."""
async for item in self.merge():
yield item
async def example_usage():
"""Example usage - demonstrates generator completion events and deletion functionality."""
print("=" * 60)
print("Example 1: Using callback")
print("=" * 60)
# Define callback
async def on_complete(
generator_id: str, generator_type: str, error: Optional[Exception]
):
if error:
print(
f" [CALLBACK] Generator {generator_id} ({generator_type}) completed with error: {error}"
)
else:
print(f" [CALLBACK] Generator {generator_id} ({generator_type}) completed")
# Create merger with callback
merger = GeneratorMerger(on_generator_complete=on_complete)
def generator1():
for i in range(1, 5):
yield f"sync-gen1-{i}"
async def async_generator1():
for i in range(20, 25):
await asyncio.sleep(0.01)
yield f"async-gen1-{i}"
def generator2():
for i in range(1, 5):
yield f"sync-gen2-{i}"
async def async_generator2():
for i in range(40, 45):
await asyncio.sleep(0.01)
yield f"async-gen2-{i}"
# Add generators
merger.add_generator(generator1, generator_id="gen1")
merger.add_async_generator(async_generator1, generator_id="async_gen1")
merger.add_generator(generator2, generator_id="gen2")
merger.add_async_generator(async_generator2, generator_id="async_gen2")
print("Merged generator output:")
async for item in merger:
if isinstance(item, dict) and item.get("type") == "generator_complete":
# If callback exists, events won't be yielded, so this won't execute
print(f" Event: {item}")
else:
print(f" Data: {item}")
print("\n" + "=" * 60)
print("Example 2: Dynamic addition and deletion of generators")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(example_usage())
================================================
FILE: cortex/utils/generator_merger_examples.py
================================================
"""GeneratorMerger usage examples collection."""
import asyncio
import time
from .generator_merger import GeneratorMerger
async def example_1_basic_usage():
"""Example 1: Basic usage - merging multiple simple generators."""
print("=" * 60)
print("Example 1: Basic Usage")
print("=" * 60)
merger = GeneratorMerger()
def number_generator(start: int, end: int, prefix: str):
"""Generator that produces number sequences."""
for i in range(start, end):
yield f"{prefix}-{i}"
# Add multiple generators
merger.add_generator(lambda: number_generator(1, 4, "A"))
merger.add_generator(lambda: number_generator(10, 13, "B"))
merger.add_generator(lambda: number_generator(20, 23, "C"))
print("Merged output (order may vary due to concurrent execution):")
async for item in merger:
print(f" Received: {item}")
print()
async def example_2_dynamic_addition():
"""Example 2: Dynamic generator addition - adding new generators during iteration."""
print("=" * 60)
print("Example 2: Dynamic Generator Addition")
print("=" * 60)
merger = GeneratorMerger()
def slow_generator(name: str, count: int, delay: float):
"""Generator with delay, simulating slow data source."""
for i in range(count):
time.sleep(delay) # Simulate I/O operation
yield f"{name}-{i}"
# Initially add two generators
merger.add_generator(lambda: slow_generator("Fast", 3, 0.1))
merger.add_generator(lambda: slow_generator("Medium", 3, 0.15))
print("Starting iteration, will dynamically add new generator later...")
count = 0
async for item in merger:
print(f" Received: {item}")
count += 1
# Dynamically add new generator during iteration
if count == 3:
print(" Dynamically adding new generator...")
merger.add_generator(lambda: slow_generator("Slow", 3, 0.2))
print()
async def example_3_different_data_types():
"""Example 3: Handling different data types."""
print("=" * 60)
print("Example 3: Handling Different Data Types")
print("=" * 60)
merger = GeneratorMerger()
def string_generator():
"""Generate strings."""
for s in ["hello", "world", "python"]:
yield s
def number_generator():
"""Generate numbers."""
for n in [1, 2, 3, 4, 5]:
yield n
def dict_generator():
"""Generate dictionaries."""
for i in range(3):
yield {"id": i, "name": f"item_{i}", "value": i * 10}
merger.add_generator(string_generator)
merger.add_generator(number_generator)
merger.add_generator(dict_generator)
print("Merging different data types:")
async for item in merger:
print(f" Type: {type(item).__name__}, Value: {item}")
print()
async def example_4_data_streams():
"""Example 4: Simulating multiple data stream scenarios."""
print("=" * 60)
print("Example 4: Simulating Multiple Data Stream Scenarios")
print("=" * 60)
merger = GeneratorMerger()
def log_stream(source: str):
"""Simulate log stream."""
for i in range(5):
time.sleep(0.05)
yield {"source": source, "level": "INFO", "message": f"Log entry {i}"}
def metric_stream(metric_name: str):
"""Simulate metric stream."""
for i in range(4):
time.sleep(0.08)
yield {"metric": metric_name, "value": i * 10, "timestamp": time.time()}
def event_stream(event_type: str):
"""Simulate event stream."""
for i in range(3):
time.sleep(0.06)
yield {"event": event_type, "id": i, "data": f"event_data_{i}"}
# Add multiple data streams
merger.add_generator(lambda: log_stream("server1"))
merger.add_generator(lambda: log_stream("server2"))
merger.add_generator(lambda: metric_stream("cpu_usage"))
merger.add_generator(lambda: metric_stream("memory_usage"))
merger.add_generator(lambda: event_stream("user_action"))
print("Merging multiple data streams (real-time output):")
async for item in merger:
if "source" in item:
print(f" [LOG] {item['source']}: {item['message']}")
elif "metric" in item:
print(f" [METRIC] {item['metric']}: {item['value']}")
elif "event" in item:
print(f" [EVENT] {item['event']}: {item['data']}")
print()
async def example_5_file_processing():
"""Example 5: Simulating file processing scenario - merging content from multiple files."""
print("=" * 60)
print("Example 5: Simulating File Processing Scenario")
print("=" * 60)
merger = GeneratorMerger()
def file_reader(filename: str, lines: list[str]):
"""Simulate file reader."""
for line_num, line in enumerate(lines, 1):
time.sleep(0.02) # Simulate read delay
yield {"file": filename, "line": line_num, "content": line}
# Simulate content of three files
file1_content = ["Line 1", "Line 2", "Line 3"]
file2_content = ["A", "B", "C", "D"]
file3_content = ["Data 1", "Data 2"]
merger.add_generator(lambda: file_reader("file1.txt", file1_content))
merger.add_generator(lambda: file_reader("file2.txt", file2_content))
merger.add_generator(lambda: file_reader("file3.txt", file3_content))
print("Merging and processing multiple files:")
async for item in merger:
print(f" [{item['file']}] Line {item['line']}: {item['content']}")
print()
async def example_6_batch_processing():
"""Example 6: Batch processing scenario - merging results from multiple tasks."""
print("=" * 60)
print("Example 6: Batch Processing Scenario")
print("=" * 60)
merger = GeneratorMerger()
def task_processor(task_id: int, items: list[str]):
"""Simulate task processor."""
for item in items:
time.sleep(0.03) # Simulate processing time
yield {
"task_id": task_id,
"item": item,
"status": "processed",
"result": f"result_{item}",
}
# Add multiple tasks
merger.add_generator(lambda: task_processor(1, ["item1", "item2", "item3"]))
merger.add_generator(lambda: task_processor(2, ["itemA", "itemB"]))
merger.add_generator(
lambda: task_processor(3, ["data1", "data2", "data3", "data4"])
)
print("Concurrently processing multiple tasks:")
results = []
async for item in merger:
results.append(item)
print(f" Task {item['task_id']} completed: {item['item']} -> {item['result']}")
print(f"\nTotal processed: {len(results)} items")
print()
async def example_7_error_handling():
"""Example 7: Error handling - demonstrating behavior when generator errors occur."""
print("=" * 60)
print("Example 7: Error Handling")
print("=" * 60)
merger = GeneratorMerger()
def normal_generator():
"""Normal generator."""
for i in range(3):
yield f"normal-{i}"
def error_generator():
"""Generator that will error."""
yield "error-1"
yield "error-2"
raise ValueError("Simulated error")
def another_normal_generator():
"""Another normal generator."""
for i in range(2):
yield f"another-{i}"
merger.add_generator(normal_generator)
merger.add_generator(error_generator)
merger.add_generator(another_normal_generator)
print("Processing generators with errors:")
try:
async for item in merger:
print(f" Received: {item}")
except ValueError as e:
print(f" Caught error: {e}")
print()
async def example_8_callback_usage():
"""Example 8: Using callback to monitor generator completion events."""
print("=" * 60)
print("Example 8: Using Callback to Monitor Generator Completion Events")
print("=" * 60)
# Track completed generators
completed_generators = []
# Define callback function
async def on_generator_complete(
generator_id: str, generator_type: str, error: Exception | None
):
"""Called when generator completes."""
status = "Success" if error is None else f"Failed: {error}"
completed_generators.append(
{"id": generator_id, "type": generator_type, "status": status}
)
print(
f" [CALLBACK] Generator '{generator_id}' ({generator_type}) completed: {status}"
)
# Create merger with callback
merger = GeneratorMerger(on_generator_complete=on_generator_complete)
def fast_generator():
"""Fast completing generator."""
for i in range(1, 4):
yield f"fast-{i}"
def slow_generator():
"""Slow completing generator."""
for i in range(10, 13):
time.sleep(0.05) # Simulate slow operation
yield f"slow-{i}"
async def async_generator():
"""Async generator."""
for i in range(20, 23):
await asyncio.sleep(0.03)
yield f"async-{i}"
def error_generator():
"""Generator that will error."""
yield "error-1"
yield "error-2"
raise ValueError("Test error")
# Add generators with generator_id specified
merger.add_generator(fast_generator, generator_id="fast_gen")
merger.add_generator(slow_generator, generator_id="slow_gen")
merger.add_async_generator(async_generator, generator_id="async_gen")
merger.add_generator(error_generator, generator_id="error_gen")
print("Starting generator processing (callback will be called on completion):")
print()
try:
async for item in merger:
# Note: If there's a callback, completion events won't be yielded, only callback is called
print(f" [DATA] Received data: {item}")
except ValueError as e:
print(f" [ERROR] Caught error: {e}")
print()
print("Completed generators statistics:")
for gen_info in completed_generators:
print(f" - {gen_info['id']} ({gen_info['type']}): {gen_info['status']}")
print()
async def example_9_callback_vs_event():
"""Example 9: Comparing callback vs event yield differences."""
print("=" * 60)
print("Example 9: Comparing Callback vs Event Yield Differences")
print("=" * 60)
def simple_generator(name: str):
"""Simple generator."""
for i in range(1, 3):
yield f"{name}-{i}"
print("Method 1: Using callback (events won't appear in iteration)")
print("-" * 60)
async def callback(
generator_id: str, _generator_type: str, _error: Exception | None
):
print(f" [CALLBACK] {generator_id} completed")
merger1 = GeneratorMerger(on_generator_complete=callback)
merger1.add_generator(lambda: simple_generator("A"), generator_id="gen_A")
merger1.add_generator(lambda: simple_generator("B"), generator_id="gen_B")
print(" Iteration output:")
async for item in merger1:
print(f" {item}")
print()
print("Method 2: Without callback (events will be yielded)")
print("-" * 60)
merger2 = GeneratorMerger()
merger2.add_generator(lambda: simple_generator("X"), generator_id="gen_X")
merger2.add_generator(lambda: simple_generator("Y"), generator_id="gen_Y")
print(" Iteration output (includes completion events):")
async for item in merger2:
if isinstance(item, dict) and item.get("type") == "generator_complete":
print(
f" [EVENT] {item['generator_id']} ({item['generator_type']}) "
f"Status: {item['status']}"
)
else:
print(f" [DATA] {item}")
print()
async def run_all_examples():
"""Run all examples."""
examples = [
example_1_basic_usage,
example_2_dynamic_addition,
example_3_different_data_types,
example_4_data_streams,
example_5_file_processing,
example_6_batch_processing,
example_7_error_handling,
example_8_callback_usage,
example_9_callback_vs_event,
]
for example in examples:
try:
await example()
await asyncio.sleep(0.5) # Interval between examples
except Exception as e:
print(f"Example execution error: {e}\n")
if __name__ == "__main__":
print("\n" + "=" * 60)
print("GeneratorMerger Usage Examples")
print("=" * 60 + "\n")
# Run all examples
asyncio.run(run_all_examples())
print("=" * 60)
print("All examples completed!")
print("=" * 60)
================================================
FILE: cortex-ui/.gitignore
================================================
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?
================================================
FILE: cortex-ui/.gitlab-ci.yml
================================================
variables:
APP: $CI_PROJECT_NAMESPACE
SERVER: $CI_PROJECT_NAME
VERSION: $CI_COMMIT_SHA
IMAGE_TAG_PROD: ${CI_COMMIT_TAG}-${CI_COMMIT_SHORT_SHA}
IMAGE_TAG_DEV: dev-${CI_COMMIT_BRANCH}${CI_MERGE_REQUEST_SOURCE_BRANCH_NAME}-${CI_COMMIT_SHORT_SHA}
IMAGE: catalyst-cn-shanghai.cr.volces.com/capy/${CI_PROJECT_PATH}
workflow:
rules:
- if: $CI_PIPELINE_SOURCE == "push" && $CI_COMMIT_TITLE == "edit by devops"
when: never
- if: $CI_COMMIT_BRANCH && $CI_OPEN_MERGE_REQUESTS
when: never
- if: $CI_PIPELINE_SOURCE == "merge_request_event"
- if: $CI_PIPELINE_SOURCE == "schedule"
- if: $CI_COMMIT_BRANCH
- if: $CI_COMMIT_TAG
default:
tags: [shell]
before_script:
- export PATH=$PWD/node_modules/.bin:$PATH
# - command -v moon || pnpm install
- pnpm install
- git fetch origin main
stages:
# - lint
- image
image-push-dev:
stage: image
script:
- docker build -f "Dockerfile" --platform=linux/amd64 -t ${IMAGE}:${IMAGE_TAG_DEV} ${CI_PROJECT_DIR}
- docker push ${IMAGE}:${IMAGE_TAG_DEV}
rules:
- when: manual
allow_failure: true
================================================
FILE: cortex-ui/index.html
================================================
Agent Cortex
================================================
FILE: cortex-ui/package.json
================================================
{
"name": "agent-cortex-frontend",
"version": "1.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "tsc && vite build",
"preview": "vite preview"
},
"dependencies": {
"@ant-design/icons": "^5.2.6",
"antd": "^5.12.0",
"axios": "^1.6.2",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-markdown": "^9.0.1",
"react-router-dom": "^6.20.0",
"rehype-raw": "^7.0.0",
"remark-gfm": "^4.0.0"
},
"devDependencies": {
"@types/react": "^18.2.43",
"@types/react-dom": "^18.2.17",
"@vitejs/plugin-react": "^4.2.1",
"typescript": "^5.3.3",
"vite": "^5.0.8"
}
}
================================================
FILE: cortex-ui/src/App.tsx
================================================
import React from 'react';
import { BrowserRouter as Router, Routes, Route } from 'react-router-dom';
import { ConfigProvider } from 'antd';
import zhCN from 'antd/locale/zh_CN';
import AgentList from './pages/AgentList';
import ChatPage from './pages/ChatPage';
import { ErrorBoundary } from './components/ErrorBoundary';
const App: React.FC = () => {
return (
} />
} />
);
};
export default App;
================================================
FILE: cortex-ui/src/components/EndpointConfig.tsx
================================================
import React, { useState } from 'react';
import { Input, Button, Typography, message } from 'antd';
import { EditOutlined, SaveOutlined, CloseOutlined } from '@ant-design/icons';
import { getStoredEndpoint } from '../services/api';
const { Text } = Typography;
interface EndpointConfigProps {
onSave?: (endpoint: string) => void;
}
export const EndpointConfig: React.FC = ({ onSave }) => {
const [endpoint, setEndpoint] = useState(() => {
return getStoredEndpoint() || '';
});
const [isEditing, setIsEditing] = useState(false);
const handleSave = () => {
localStorage.setItem('api_endpoint', endpoint);
setIsEditing(false);
message.success('Endpoint saved');
if (onSave) {
onSave(endpoint);
}
};
const handleCancel = () => {
setEndpoint(getStoredEndpoint() || '');
setIsEditing(false);
};
return (